diff --git a/livekit-rtc/livekit/rtc/__init__.py b/livekit-rtc/livekit/rtc/__init__.py index cc9c1b1e..35aa5241 100644 --- a/livekit-rtc/livekit/rtc/__init__.py +++ b/livekit-rtc/livekit/rtc/__init__.py @@ -39,6 +39,7 @@ ) from ._proto.video_frame_pb2 import VideoBufferType, VideoCodec, VideoRotation from .audio_frame import AudioFrame +from .audio_ring_buffer import AudioRingBuffer from .audio_source import AudioSource from .audio_stream import AudioFrameEvent, AudioStream, NoiseCancellationOptions from .audio_filter import AudioFilter @@ -137,6 +138,7 @@ "VideoRotation", "stats", "AudioFrame", + "AudioRingBuffer", "AudioSource", "AudioStream", "NoiseCancellationOptions", diff --git a/livekit-rtc/livekit/rtc/audio_ring_buffer.py b/livekit-rtc/livekit/rtc/audio_ring_buffer.py new file mode 100644 index 00000000..a2532208 --- /dev/null +++ b/livekit-rtc/livekit/rtc/audio_ring_buffer.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import threading + +from .audio_frame import AudioFrame + + +class AudioRingBuffer: + """Pre-allocated circular buffer for raw PCM audio data. + + Stores int16 PCM samples in a fixed-size bytearray. Push is zero-allocation. + """ + + def __init__(self, max_duration: float, sample_rate: int, num_channels: int) -> None: + self._sample_rate = sample_rate + self._num_channels = num_channels + self._bytes_per_second = sample_rate * num_channels * 2 # int16 + self._max_bytes = int(max_duration * self._bytes_per_second) + if self._max_bytes <= 0: + raise ValueError("max_duration must be positive") + + self._buf = bytearray(self._max_bytes) + self._write_pos = 0 + self._size = 0 + self._lock = threading.Lock() + + @property + def duration(self) -> float: + with self._lock: + return self._size / self._bytes_per_second + + @property + def max_duration(self) -> float: + return self._max_bytes / self._bytes_per_second + + def push(self, frame: AudioFrame) -> None: + data = frame.data.cast("b") + n = len(data) + if n == 0: + return + + with self._lock: + if n >= self._max_bytes: + # frame larger than buffer — keep only the tail + self._buf[:] = data[n - self._max_bytes :] + self._write_pos = 0 + self._size = self._max_bytes + return + + end = self._write_pos + n + if end <= self._max_bytes: + self._buf[self._write_pos : end] = data + else: + first = self._max_bytes - self._write_pos + self._buf[self._write_pos : self._max_bytes] = data[:first] + self._buf[: n - first] = data[first:] + + self._write_pos = end % self._max_bytes + self._size = min(self._size + n, self._max_bytes) + + def capture(self) -> bytes: + """Snapshot the buffer contents and reset. Returns raw PCM bytes.""" + with self._lock: + if self._size == 0: + return b"" + + read_pos = (self._write_pos - self._size) % self._max_bytes + if read_pos + self._size <= self._max_bytes: + data = bytes(self._buf[read_pos : read_pos + self._size]) + else: + first = self._max_bytes - read_pos + data = bytes(self._buf[read_pos:]) + bytes(self._buf[: self._size - first]) + + self._write_pos = 0 + self._size = 0 + return data + + def clear(self) -> None: + with self._lock: + self._write_pos = 0 + self._size = 0 diff --git a/livekit-rtc/livekit/rtc/audio_source.py b/livekit-rtc/livekit/rtc/audio_source.py index 63cc1a5d..c592ed77 100644 --- a/livekit-rtc/livekit/rtc/audio_source.py +++ b/livekit-rtc/livekit/rtc/audio_source.py @@ -16,12 +16,16 @@ import time import asyncio +from typing import TYPE_CHECKING from ._ffi_client import FfiHandle, FfiClient from ._proto import audio_frame_pb2 as proto_audio_frame from ._proto import ffi_pb2 as proto_ffi from .audio_frame import AudioFrame +if TYPE_CHECKING: + from .audio_ring_buffer import AudioRingBuffer + class AudioSource: """ @@ -69,6 +73,7 @@ def __init__( self._q_size = 0.0 self._join_handle: asyncio.TimerHandle | None = None self._join_fut: asyncio.Future[None] | None = None + self._preconnect_buffer: AudioRingBuffer | None = None @property def sample_rate(self) -> int: @@ -119,6 +124,9 @@ async def capture_frame(self, frame: AudioFrame) -> None: if frame.samples_per_channel == 0 or self._ffi_handle.disposed: return + if self._preconnect_buffer is not None: + self._preconnect_buffer.push(frame) + now = time.monotonic() elapsed = 0.0 if self._last_capture == 0.0 else now - self._last_capture self._q_size += frame.samples_per_channel / self.sample_rate - elapsed @@ -162,6 +170,9 @@ async def wait_for_playout(self) -> None: await asyncio.shield(self._join_fut) + def _set_preconnect_buffer(self, buf: AudioRingBuffer | None) -> None: + self._preconnect_buffer = buf + def _release_waiter(self) -> None: if self._join_fut is None: return # could be None when clear_queue is called diff --git a/livekit-rtc/livekit/rtc/participant.py b/livekit-rtc/livekit/rtc/participant.py index 150477d5..54d7163c 100644 --- a/livekit-rtc/livekit/rtc/participant.py +++ b/livekit-rtc/livekit/rtc/participant.py @@ -20,7 +20,7 @@ import os import mimetypes import aiofiles -from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast, TypeVar +from typing import TYPE_CHECKING, List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast, TypeVar from abc import abstractmethod, ABC from ._ffi_client import FfiClient, FfiHandle @@ -36,7 +36,7 @@ ParticipantTrackPermission, ) from ._utils import BroadcastQueue -from .track import LocalTrack +from .track import LocalAudioTrack, LocalTrack from .track_publication import ( LocalTrackPublication, RemoteTrackPublication, @@ -57,6 +57,9 @@ from .data_track import LocalDataTrack from ._proto import data_track_pb2 as proto_data_track +if TYPE_CHECKING: + from .room import Room + class PublishTrackError(Exception): def __init__(self, message: str) -> None: @@ -189,6 +192,7 @@ def __init__( self._room_queue = room_queue self._track_publications: dict[str, LocalTrackPublication] = {} # type: ignore self._rpc_handlers: Dict[str, RpcHandler] = {} + self._room: Room | None = None @property def track_publications(self) -> Mapping[str, LocalTrackPublication]: @@ -728,7 +732,11 @@ async def publish_data_track( return LocalDataTrack(cb.publish_data_track.track) async def publish_track( - self, track: LocalTrack, options: TrackPublishOptions = TrackPublishOptions() + self, + track: LocalTrack, + options: TrackPublishOptions = TrackPublishOptions(), + *, + preconnect_buffer_auto_send_to: str | None = None, ) -> LocalTrackPublication: """ Publish a local track to the room. @@ -736,6 +744,8 @@ async def publish_track( Args: track (LocalTrack): The track to publish. options (TrackPublishOptions, optional): Options for publishing the track. + preconnect_buffer_auto_send_to (str, optional): If set, automatically sends the + preconnect buffer when a participant with this identity becomes active. Returns: LocalTrackPublication: The publication of the published track. @@ -763,11 +773,48 @@ async def publish_track( track._info.sid = track_publication.sid self._track_publications[track_publication.sid] = track_publication + if isinstance(track, LocalAudioTrack): + track._participant = self + track._publication_sid = track_publication.sid + + if preconnect_buffer_auto_send_to: + if track.has_preconnect_buffer: + self._setup_preconnect_auto_send( + track, preconnect_buffer_auto_send_to + ) + else: + logger.warning( + "preconnect_buffer_auto_send_to set but no preconnect buffer " + "is active — call track.start_preconnect_buffer() first" + ) + queue.task_done() return track_publication finally: self._room_queue.unsubscribe(queue) + def _setup_preconnect_auto_send( + self, track: LocalAudioTrack, target_identity: str + ) -> None: + room = self._room + if room is None: + return + + async def _on_participant_active(participant: RemoteParticipant) -> None: + if participant.identity != target_identity: + return + if not track.has_preconnect_buffer: + return + room.off("participant_active", _on_participant_active) + try: + await track.send_preconnect_buffer( + destination_identity=participant.identity + ) + except Exception: + logger.exception("failed to auto-send preconnect buffer") + + room.on("participant_active", _on_participant_active) + async def unpublish_track(self, track_sid: str) -> None: """ Unpublish a track from the room. diff --git a/livekit-rtc/livekit/rtc/room.py b/livekit-rtc/livekit/rtc/room.py index 7fd778ae..f1ff3ae8 100644 --- a/livekit-rtc/livekit/rtc/room.py +++ b/livekit-rtc/livekit/rtc/room.py @@ -525,6 +525,7 @@ def on_participant_connected(participant): self._local_participant = LocalParticipant( self._room_queue, cb.connect.result.local_participant ) + self._local_participant._room = self for pt in cb.connect.result.participants: rp = self._create_remote_participant(pt.participant) diff --git a/livekit-rtc/livekit/rtc/track.py b/livekit-rtc/livekit/rtc/track.py index 8a6fe692..11b76f26 100644 --- a/livekit-rtc/livekit/rtc/track.py +++ b/livekit-rtc/livekit/rtc/track.py @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import asyncio from typing import TYPE_CHECKING, List, Union + from ._ffi_client import FfiHandle, FfiClient from ._proto import ffi_pb2 as proto_ffi from ._proto import track_pb2 as proto_track from ._proto import stats_pb2 as proto_stats if TYPE_CHECKING: + from .audio_ring_buffer import AudioRingBuffer from .audio_source import AudioSource + from .participant import LocalParticipant from .video_source import VideoSource +PRE_CONNECT_AUDIO_BUFFER_TOPIC = "lk.agent.pre-connect-audio-buffer" + class Track: def __init__(self, owned_info: proto_track.OwnedTrack): @@ -68,26 +76,80 @@ async def get_stats(self) -> List[proto_stats.RtcStats]: class LocalAudioTrack(Track): - def __init__(self, info: proto_track.OwnedTrack): + def __init__(self, info: proto_track.OwnedTrack, source: AudioSource | None = None): super().__init__(info) + self._source = source + self._preconnect_buffer: AudioRingBuffer | None = None + self._participant: LocalParticipant | None = None + self._publication_sid: str | None = None + self._send_lock = asyncio.Lock() @staticmethod - def create_audio_track(name: str, source: "AudioSource") -> "LocalAudioTrack": + def create_audio_track(name: str, source: AudioSource) -> LocalAudioTrack: req = proto_ffi.FfiRequest() req.create_audio_track.name = name req.create_audio_track.source_handle = source._ffi_handle.handle resp = FfiClient.instance.request(req) - return LocalAudioTrack(resp.create_audio_track.track) + return LocalAudioTrack(resp.create_audio_track.track, source=source) - def mute(self): + @property + def has_preconnect_buffer(self) -> bool: + return self._preconnect_buffer is not None + + def start_preconnect_buffer(self, *, max_duration: float = 10.0) -> None: + if self._source is None: + raise RuntimeError("track has no audio source") + + from .audio_ring_buffer import AudioRingBuffer + + self._preconnect_buffer = AudioRingBuffer( + max_duration=max_duration, + sample_rate=self._source.sample_rate, + num_channels=self._source.num_channels, + ) + self._source._set_preconnect_buffer(self._preconnect_buffer) + + def stop_preconnect_buffer(self) -> None: + if self._source is not None: + self._source._set_preconnect_buffer(None) + self._preconnect_buffer = None + + async def send_preconnect_buffer(self, *, destination_identity: str) -> None: + if self._participant is None: + raise RuntimeError("track is not published") + if self._preconnect_buffer is None: + raise RuntimeError("preconnect buffer is not active") + + async with self._send_lock: + data = self._preconnect_buffer.capture() + if not data: + return + + assert self._source is not None + writer = await self._participant.stream_bytes( + "preconnect-buffer", + topic=PRE_CONNECT_AUDIO_BUFFER_TOPIC, + mime_type="application/octet-stream", + destination_identities=[destination_identity], + attributes={ + "trackId": self._publication_sid or self.sid, + "sampleRate": str(self._source.sample_rate), + "channels": str(self._source.num_channels), + }, + ) + + await writer.write(data) + await writer.aclose() + + def mute(self) -> None: req = proto_ffi.FfiRequest() req.local_track_mute.track_handle = self._ffi_handle.handle req.local_track_mute.mute = True FfiClient.instance.request(req) self._info.muted = True - def unmute(self): + def unmute(self) -> None: req = proto_ffi.FfiRequest() req.local_track_mute.track_handle = self._ffi_handle.handle req.local_track_mute.mute = False