Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions lark_oapi/ws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ def __init__(self,
self._ping_interval: int = 120
self._cache: ExpiringCache = ExpiringCache(clear_interval=30)
self._lock = asyncio.Lock()
# Background task handles so ``stop()`` can cancel them on graceful
# shutdown. Without these references, ``loop.close()`` after stop
# would report "Task was destroyed but it is pending!" for the
# ping/receive loops and they'd leak across reconnect cycles.
self._ping_task: Optional[asyncio.Task] = None
self._receive_message_task: Optional[asyncio.Task] = None
self._main_task: Optional[asyncio.Task] = None
# Observer hooks for higher-level wrappers (e.g. FeishuChannel) to
# react to reconnect lifecycle. ``on_reconnecting`` fires when the
# client decides a connection was lost and starts retrying;
Expand All @@ -170,8 +177,18 @@ def start(self) -> None:
else:
raise e

loop.create_task(self._ping_loop())
loop.run_until_complete(_select())
self._ping_task = loop.create_task(self._ping_loop())
# ``_main_task`` blocks the foreground until ``stop()`` cancels it.
# We keep the reference so ``stop()`` can release ``start()``.
# Backward compat: still uses the module-level ``_select`` coroutine
# for callers/tests that monkeypatch it.
self._main_task = loop.create_task(_select())
try:
loop.run_until_complete(self._main_task)
except asyncio.CancelledError:
# Graceful shutdown via ``stop()`` — return cleanly so the
# caller's ``start()`` thread can exit.
pass

async def _ping_loop(self):
while True:
Expand Down Expand Up @@ -203,7 +220,9 @@ async def _connect(self) -> None:
self._service_id = service_id

logger.info(self._fmt_log("connected to {}", conn_url))
loop.create_task(self._receive_message_loop())
# Save handle so ``stop()`` can cancel this background task and
# avoid "Task was destroyed but it is pending!" on loop close.
self._receive_message_task = loop.create_task(self._receive_message_loop())
except InvalidHandshake as e:
_parse_ws_conn_exception(e)
finally:
Expand Down Expand Up @@ -391,6 +410,45 @@ async def _disconnect(self):
self._service_id = ""
self._lock.release()

async def stop(self) -> None:
"""Gracefully stop the WebSocket client.

Disables auto-reconnect, closes the WebSocket connection, cancels the
internal background tasks (ping loop, receive loop), and releases the
blocking ``start()`` call so its thread can exit cleanly.

Designed for cross-thread use — when ``start()`` is running on a
dedicated worker thread (the common pattern, since it blocks), call
``stop()`` from another thread via::

future = asyncio.run_coroutine_threadsafe(client.stop(), client_loop)
future.result(timeout=5)

where ``client_loop`` is the event loop ``start()`` is running on
(typically the loop created in the worker thread).

After ``stop()`` returns, ``start()`` will exit and the worker thread
can be joined. Subsequent calls to ``stop()`` are safe (idempotent).
"""
# Prevent the receive loop's exception handler from auto-reconnecting
# after we close the WebSocket below.
self._auto_reconnect = False

# Cancel background tasks first so they don't observe the disconnect
# as a network failure and try to reconnect.
for task in (self._ping_task, self._receive_message_task):
if task is not None and not task.done():
task.cancel()

# Close the underlying WebSocket. Safe to call multiple times —
# ``_disconnect`` early-returns when ``self._conn is None``.
await self._disconnect()

# Release the foreground ``start()`` blocker. Cancelling this task
# makes ``loop.run_until_complete(self._main_task)`` return.
if self._main_task is not None and not self._main_task.done():
self._main_task.cancel()

async def _write_message(self, data: bytes):
async with self._lock:
if self._conn is None:
Expand Down
167 changes: 167 additions & 0 deletions lark_oapi/ws/tests/test_stop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""Tests for ``Client.stop()`` public shutdown method.

These tests use a mock WebSocket connection and verify that ``stop()``:
1. Sets ``_auto_reconnect=False`` so the receive loop's exception handler
doesn't re-establish the connection after we close it.
2. Cancels the ping loop, receive loop, and main ``_select`` tasks so
``loop.close()`` doesn't report "Task was destroyed but it is pending!".
3. Closes the underlying WebSocket via ``_disconnect()``.
4. Releases the blocking ``start()`` call by cancelling the main task.
5. Is idempotent — safe to call when already stopped.
"""
import asyncio
import threading
from types import SimpleNamespace

import pytest

from lark_oapi.ws import client as ws_client


class _FakeConn:
"""Minimal WebSocketClientProtocol stub."""

def __init__(self):
self.closed = False
self._closed_event = asyncio.Event()

async def close(self):
self.closed = True
self._closed_event.set()

async def recv(self):
# Block forever until close() is called; mimics live WS behavior
# where recv() raises ConnectionClosed after the conn shuts down.
await self._closed_event.wait()
raise ws_client.ConnectionClosedException("closed")


@pytest.mark.asyncio
async def test_stop_disables_auto_reconnect():
"""``stop()`` must flip ``_auto_reconnect`` off so the receive loop's
``except`` handler doesn't reconnect after we close the WebSocket."""
client = ws_client.Client("app_id", "app_secret", auto_reconnect=True)
assert client._auto_reconnect is True

await client.stop()

assert client._auto_reconnect is False


@pytest.mark.asyncio
async def test_stop_closes_websocket(monkeypatch):
"""``stop()`` must call ``_disconnect()`` to close the underlying WS."""
client = ws_client.Client("app_id", "app_secret")
fake_conn = _FakeConn()
client._conn = fake_conn

await client.stop()

assert fake_conn.closed is True
assert client._conn is None


@pytest.mark.asyncio
async def test_stop_cancels_background_tasks():
"""``stop()`` must cancel ping_task, receive_message_task, and main_task
so ``loop.close()`` doesn't warn about pending tasks."""
client = ws_client.Client("app_id", "app_secret")

async def _forever():
while True:
await asyncio.sleep(3600)

# Plant fake tasks pretending start() already populated them.
client._ping_task = asyncio.create_task(_forever())
client._receive_message_task = asyncio.create_task(_forever())
client._main_task = asyncio.create_task(_forever())

await client.stop()
# Give cancellation a chance to propagate.
await asyncio.sleep(0)

assert client._ping_task.cancelled() or client._ping_task.done()
assert client._receive_message_task.cancelled() or client._receive_message_task.done()
assert client._main_task.cancelled() or client._main_task.done()


@pytest.mark.asyncio
async def test_stop_is_idempotent():
"""Calling ``stop()`` twice should not raise — useful for cleanup paths
that may run multiple times (signal handlers, finally blocks, etc)."""
client = ws_client.Client("app_id", "app_secret")
# No conn, no tasks — simulate "never started" state.
await client.stop()
await client.stop() # second call must not raise

assert client._auto_reconnect is False


@pytest.mark.asyncio
async def test_stop_handles_already_done_tasks():
"""If background tasks already completed (e.g. WS dropped on its own),
``stop()`` should skip cancelling them without error."""
client = ws_client.Client("app_id", "app_secret")

async def _quick():
return

client._ping_task = asyncio.create_task(_quick())
client._main_task = asyncio.create_task(_quick())
# Let them finish.
await asyncio.sleep(0)
await asyncio.sleep(0)

assert client._ping_task.done()
assert client._main_task.done()

# Should not raise even though tasks are already done.
await client.stop()


def test_stop_from_another_thread():
"""Real-world use case: ``start()`` runs on a worker thread (it blocks
forever), so ``stop()`` must be called via ``run_coroutine_threadsafe``
from another thread. Verify the cross-thread pattern works.
"""
# Loop A: where the SDK "runs" (where stop() will be scheduled to)
sdk_loop = asyncio.new_event_loop()
ready = threading.Event()

def _run_sdk_loop():
asyncio.set_event_loop(sdk_loop)
ready.set()
sdk_loop.run_forever()

sdk_thread = threading.Thread(target=_run_sdk_loop, daemon=True)
sdk_thread.start()
ready.wait(timeout=2)

try:
# Build client + plant a fake "blocking main task" on the sdk loop
# to simulate start() being parked on _select().
client = ws_client.Client("app_id", "app_secret")

async def _build_main_task():
async def _block():
while True:
await asyncio.sleep(3600)
client._main_task = asyncio.create_task(_block())
return client._main_task

main_task_future = asyncio.run_coroutine_threadsafe(_build_main_task(), sdk_loop)
main_task_future.result(timeout=2)

# Now call stop() from the test thread, scheduled onto sdk_loop.
stop_future = asyncio.run_coroutine_threadsafe(client.stop(), sdk_loop)
stop_future.result(timeout=2)

# The blocking main task should be cancelled.
async def _check():
return client._main_task.cancelled() or client._main_task.done()

check_future = asyncio.run_coroutine_threadsafe(_check(), sdk_loop)
assert check_future.result(timeout=2) is True
finally:
sdk_loop.call_soon_threadsafe(sdk_loop.stop)
sdk_thread.join(timeout=2)