diff --git a/lark_oapi/ws/client.py b/lark_oapi/ws/client.py index a7545d05..a7d2c63a 100644 --- a/lark_oapi/ws/client.py +++ b/lark_oapi/ws/client.py @@ -147,6 +147,13 @@ def __init__(self, self._ping_interval: int = 120 self._cache: ExpiringCache = ExpiringCache(clear_interval=30) self._lock = asyncio.Lock() + # Background task handles so ``stop()`` can cancel them on graceful + # shutdown. Without these references, ``loop.close()`` after stop + # would report "Task was destroyed but it is pending!" for the + # ping/receive loops and they'd leak across reconnect cycles. + self._ping_task: Optional[asyncio.Task] = None + self._receive_message_task: Optional[asyncio.Task] = None + self._main_task: Optional[asyncio.Task] = None # Observer hooks for higher-level wrappers (e.g. FeishuChannel) to # react to reconnect lifecycle. ``on_reconnecting`` fires when the # client decides a connection was lost and starts retrying; @@ -170,8 +177,18 @@ def start(self) -> None: else: raise e - loop.create_task(self._ping_loop()) - loop.run_until_complete(_select()) + self._ping_task = loop.create_task(self._ping_loop()) + # ``_main_task`` blocks the foreground until ``stop()`` cancels it. + # We keep the reference so ``stop()`` can release ``start()``. + # Backward compat: still uses the module-level ``_select`` coroutine + # for callers/tests that monkeypatch it. + self._main_task = loop.create_task(_select()) + try: + loop.run_until_complete(self._main_task) + except asyncio.CancelledError: + # Graceful shutdown via ``stop()`` — return cleanly so the + # caller's ``start()`` thread can exit. + pass async def _ping_loop(self): while True: @@ -203,7 +220,9 @@ async def _connect(self) -> None: self._service_id = service_id logger.info(self._fmt_log("connected to {}", conn_url)) - loop.create_task(self._receive_message_loop()) + # Save handle so ``stop()`` can cancel this background task and + # avoid "Task was destroyed but it is pending!" on loop close. + self._receive_message_task = loop.create_task(self._receive_message_loop()) except InvalidHandshake as e: _parse_ws_conn_exception(e) finally: @@ -391,6 +410,45 @@ async def _disconnect(self): self._service_id = "" self._lock.release() + async def stop(self) -> None: + """Gracefully stop the WebSocket client. + + Disables auto-reconnect, closes the WebSocket connection, cancels the + internal background tasks (ping loop, receive loop), and releases the + blocking ``start()`` call so its thread can exit cleanly. + + Designed for cross-thread use — when ``start()`` is running on a + dedicated worker thread (the common pattern, since it blocks), call + ``stop()`` from another thread via:: + + future = asyncio.run_coroutine_threadsafe(client.stop(), client_loop) + future.result(timeout=5) + + where ``client_loop`` is the event loop ``start()`` is running on + (typically the loop created in the worker thread). + + After ``stop()`` returns, ``start()`` will exit and the worker thread + can be joined. Subsequent calls to ``stop()`` are safe (idempotent). + """ + # Prevent the receive loop's exception handler from auto-reconnecting + # after we close the WebSocket below. + self._auto_reconnect = False + + # Cancel background tasks first so they don't observe the disconnect + # as a network failure and try to reconnect. + for task in (self._ping_task, self._receive_message_task): + if task is not None and not task.done(): + task.cancel() + + # Close the underlying WebSocket. Safe to call multiple times — + # ``_disconnect`` early-returns when ``self._conn is None``. + await self._disconnect() + + # Release the foreground ``start()`` blocker. Cancelling this task + # makes ``loop.run_until_complete(self._main_task)`` return. + if self._main_task is not None and not self._main_task.done(): + self._main_task.cancel() + async def _write_message(self, data: bytes): async with self._lock: if self._conn is None: diff --git a/lark_oapi/ws/tests/test_stop.py b/lark_oapi/ws/tests/test_stop.py new file mode 100644 index 00000000..1c36d043 --- /dev/null +++ b/lark_oapi/ws/tests/test_stop.py @@ -0,0 +1,167 @@ +"""Tests for ``Client.stop()`` public shutdown method. + +These tests use a mock WebSocket connection and verify that ``stop()``: + 1. Sets ``_auto_reconnect=False`` so the receive loop's exception handler + doesn't re-establish the connection after we close it. + 2. Cancels the ping loop, receive loop, and main ``_select`` tasks so + ``loop.close()`` doesn't report "Task was destroyed but it is pending!". + 3. Closes the underlying WebSocket via ``_disconnect()``. + 4. Releases the blocking ``start()`` call by cancelling the main task. + 5. Is idempotent — safe to call when already stopped. +""" +import asyncio +import threading +from types import SimpleNamespace + +import pytest + +from lark_oapi.ws import client as ws_client + + +class _FakeConn: + """Minimal WebSocketClientProtocol stub.""" + + def __init__(self): + self.closed = False + self._closed_event = asyncio.Event() + + async def close(self): + self.closed = True + self._closed_event.set() + + async def recv(self): + # Block forever until close() is called; mimics live WS behavior + # where recv() raises ConnectionClosed after the conn shuts down. + await self._closed_event.wait() + raise ws_client.ConnectionClosedException("closed") + + +@pytest.mark.asyncio +async def test_stop_disables_auto_reconnect(): + """``stop()`` must flip ``_auto_reconnect`` off so the receive loop's + ``except`` handler doesn't reconnect after we close the WebSocket.""" + client = ws_client.Client("app_id", "app_secret", auto_reconnect=True) + assert client._auto_reconnect is True + + await client.stop() + + assert client._auto_reconnect is False + + +@pytest.mark.asyncio +async def test_stop_closes_websocket(monkeypatch): + """``stop()`` must call ``_disconnect()`` to close the underlying WS.""" + client = ws_client.Client("app_id", "app_secret") + fake_conn = _FakeConn() + client._conn = fake_conn + + await client.stop() + + assert fake_conn.closed is True + assert client._conn is None + + +@pytest.mark.asyncio +async def test_stop_cancels_background_tasks(): + """``stop()`` must cancel ping_task, receive_message_task, and main_task + so ``loop.close()`` doesn't warn about pending tasks.""" + client = ws_client.Client("app_id", "app_secret") + + async def _forever(): + while True: + await asyncio.sleep(3600) + + # Plant fake tasks pretending start() already populated them. + client._ping_task = asyncio.create_task(_forever()) + client._receive_message_task = asyncio.create_task(_forever()) + client._main_task = asyncio.create_task(_forever()) + + await client.stop() + # Give cancellation a chance to propagate. + await asyncio.sleep(0) + + assert client._ping_task.cancelled() or client._ping_task.done() + assert client._receive_message_task.cancelled() or client._receive_message_task.done() + assert client._main_task.cancelled() or client._main_task.done() + + +@pytest.mark.asyncio +async def test_stop_is_idempotent(): + """Calling ``stop()`` twice should not raise — useful for cleanup paths + that may run multiple times (signal handlers, finally blocks, etc).""" + client = ws_client.Client("app_id", "app_secret") + # No conn, no tasks — simulate "never started" state. + await client.stop() + await client.stop() # second call must not raise + + assert client._auto_reconnect is False + + +@pytest.mark.asyncio +async def test_stop_handles_already_done_tasks(): + """If background tasks already completed (e.g. WS dropped on its own), + ``stop()`` should skip cancelling them without error.""" + client = ws_client.Client("app_id", "app_secret") + + async def _quick(): + return + + client._ping_task = asyncio.create_task(_quick()) + client._main_task = asyncio.create_task(_quick()) + # Let them finish. + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert client._ping_task.done() + assert client._main_task.done() + + # Should not raise even though tasks are already done. + await client.stop() + + +def test_stop_from_another_thread(): + """Real-world use case: ``start()`` runs on a worker thread (it blocks + forever), so ``stop()`` must be called via ``run_coroutine_threadsafe`` + from another thread. Verify the cross-thread pattern works. + """ + # Loop A: where the SDK "runs" (where stop() will be scheduled to) + sdk_loop = asyncio.new_event_loop() + ready = threading.Event() + + def _run_sdk_loop(): + asyncio.set_event_loop(sdk_loop) + ready.set() + sdk_loop.run_forever() + + sdk_thread = threading.Thread(target=_run_sdk_loop, daemon=True) + sdk_thread.start() + ready.wait(timeout=2) + + try: + # Build client + plant a fake "blocking main task" on the sdk loop + # to simulate start() being parked on _select(). + client = ws_client.Client("app_id", "app_secret") + + async def _build_main_task(): + async def _block(): + while True: + await asyncio.sleep(3600) + client._main_task = asyncio.create_task(_block()) + return client._main_task + + main_task_future = asyncio.run_coroutine_threadsafe(_build_main_task(), sdk_loop) + main_task_future.result(timeout=2) + + # Now call stop() from the test thread, scheduled onto sdk_loop. + stop_future = asyncio.run_coroutine_threadsafe(client.stop(), sdk_loop) + stop_future.result(timeout=2) + + # The blocking main task should be cancelled. + async def _check(): + return client._main_task.cancelled() or client._main_task.done() + + check_future = asyncio.run_coroutine_threadsafe(_check(), sdk_loop) + assert check_future.result(timeout=2) is True + finally: + sdk_loop.call_soon_threadsafe(sdk_loop.stop) + sdk_thread.join(timeout=2)