diff --git a/examples/echo/README.md b/examples/echo/README.md new file mode 100644 index 0000000..b00a2ac --- /dev/null +++ b/examples/echo/README.md @@ -0,0 +1,40 @@ +# Echo Live Runner Demo + +This example demonstrates: + +* Runner registration +* Video input - taken from a local file +* Video output - echoed to output with blur applied +* Parameter updates - adjust the amount of blur + +Start go-livepeer: + +```sh +./livepeer -orchestrator -useLiveRunners -serviceAddr localhost:8935 -v 99 -orchSecret abcdef +``` + +Start the runner: + +```sh +uv run examples/echo/runner.py --orchestrator https://localhost:8935 --orchSecret abcdef +``` + +Run the client with a local sample input (`~/samples/bbb_720p.mp4`): + +```sh +uv run examples/echo/client.py --blur ~/samples/bbb_720p.mp4 +``` + +The resulting file is stored at echo-out.ts. To use a different file +or redirect to stdout for live playback: + +```sh +uv run client.py --blur --output - ~/samples/bbb_720p.mp4 | ffplay - +``` + +The client discovers the `livepeer-sample/echo` runner automatically. To use a +different orchestrator or discovery endpoint: + +```sh +uv run examples/echo/client.py --discovery http://localhost:8935/discovery --blur ~/samples/bbb_720p.mp4 +``` diff --git a/examples/echo/client.py b/examples/echo/client.py new file mode 100755 index 0000000..0b91738 --- /dev/null +++ b/examples/echo/client.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import sys +import time +from contextlib import nullcontext, suppress +from pathlib import Path + +import av + +from livepeer_gateway.errors import LivepeerGatewayError +from livepeer_gateway.live_runner import stop_runner_session +from livepeer_gateway.media_output import MediaOutput +from livepeer_gateway.media_publish import MediaPublish +from livepeer_gateway.http import post_json +from livepeer_gateway.selection import reserve_session + +DEFAULT_DISCOVERY = "http://localhost:8935/discovery" +ECHO_APP_ID = "livepeer-sample/echo" +DEFAULT_OUTPUT = "echo-out.ts" +BLUR_UPDATE_INTERVAL_S = 0.01 +MAX_BLUR_RADIUS = 100 + + +def _log(*args: object) -> None: + print(*args, file=sys.stderr) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run the proxied echo Live Runner demo.") + parser.add_argument("input") + parser.add_argument("--discovery", default=DEFAULT_DISCOVERY) + parser.add_argument("--output", default=DEFAULT_OUTPUT) + parser.add_argument("--radius", type=int, default=75) + parser.add_argument("--max-frames", type=int, default=0, help="Stop after this many input video frames (0 = full file).") + parser.add_argument("--blur", action="store_true", help="Sweep blur radius while publishing the sample.") + return parser.parse_args() + + +def _channel_url(echo_response: dict[str, object], name: str) -> str: + url = echo_response.get(name) + if not isinstance(url, str) or not url: + raise LivepeerGatewayError(f"echo response missing {name!r} url") + return url + + +async def _publish_video( + input_path: Path, + publish_url: str, + *, + max_frames: int = 0, + app_url: str = "", + blur: bool = False, +) -> None: + input_ = av.open(str(input_path)) + try: + if not input_.streams.video: + raise LivepeerGatewayError(f"No video stream found in input file: {input_path}") + publisher = MediaPublish(publish_url) + prev_pts_time: float | None = None + prev_wall: float | None = None + next_update_pts_time: float | None = None + blur_radius = 0 + blur_direction = 1 + + try: + for index, frame in enumerate(input_.decode(video=0), start=1): + if max_frames > 0 and index > max_frames: + break + current_pts_time = None + if frame.pts is not None and frame.time_base is not None: + current_pts_time = float(frame.pts * frame.time_base) + if next_update_pts_time is None: + next_update_pts_time = current_pts_time + + while ( + blur + and app_url + and current_pts_time is not None + and next_update_pts_time is not None + and current_pts_time >= next_update_pts_time + ): + await post_json(f"{app_url.rstrip('/')}/update", {"mode": "blur", "radius": blur_radius}) + if blur_radius == MAX_BLUR_RADIUS: + blur_direction = -1 + elif blur_radius == 0: + blur_direction = 1 + blur_radius += blur_direction + next_update_pts_time += BLUR_UPDATE_INTERVAL_S + + if ( + prev_pts_time is not None + and prev_wall is not None + and current_pts_time is not None + ): + delta_s = current_pts_time - prev_pts_time + elapsed_s = time.monotonic() - prev_wall + sleep_s = max(0.0, delta_s - elapsed_s) + if sleep_s > 0: + await asyncio.sleep(sleep_s) + + if current_pts_time is not None: + prev_pts_time = current_pts_time + prev_wall = time.monotonic() + + await publisher.write_frame(frame) + finally: + await publisher.close() + finally: + input_.close() + + +async def main() -> None: + args = _parse_args() + input_path = Path(args.input).expanduser() + output_stdout = args.output.strip().lower() in {"-", "stdout"} + output_path = None if output_stdout else Path(args.output).expanduser() + if not input_path.exists(): + raise SystemExit(f"input file does not exist: {input_path}") + + session = None + + try: + session = await reserve_session(discovery_url=args.discovery, app=ECHO_APP_ID) + _log("runner_url:", session.runner.url if session.runner is not None else session.runner_url) + _log("session_id:", session.session_id) + _log("app_url:", session.app_url) + + echo = await post_json(f"{session.app_url.rstrip('/')}/echo", {"radius": args.radius}) + in_url = _channel_url(echo, "in") + out_url = _channel_url(echo, "out") + _log("in:", in_url) + _log("out:", out_url) + + with nullcontext(sys.stdout.buffer) if output_stdout else output_path.open("wb") as fh: + def _write_chunk(chunk: bytes) -> None: + fh.write(chunk) + if output_stdout: + fh.flush() + + async with MediaOutput(out_url, on_bytes=_write_chunk): + await _publish_video( + input_path, + in_url, + max_frames=max(0, args.max_frames), + app_url=session.app_url, + blur=args.blur, + ) + _log("publish complete; waiting for output to drain...") + fh.flush() + except LivepeerGatewayError as exc: + raise SystemExit(f"ERROR: {exc}") from exc + finally: + if session is not None: + with suppress(Exception): + await stop_runner_session(session) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/echo/runner.py b/examples/echo/runner.py new file mode 100755 index 0000000..0d0e98a --- /dev/null +++ b/examples/echo/runner.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +from contextlib import suppress +from dataclasses import dataclass +from typing import Any + +import av +from aiohttp import web + +from livepeer_gateway.live_runner import create_trickle_channels, register_runner +from livepeer_gateway.media_decode import AudioDecodedMediaFrame, VideoDecodedMediaFrame +from livepeer_gateway.media_output import MediaOutput +from livepeer_gateway.media_publish import MediaPublish + +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8989 +MODES = frozenset({"echo", "gray", "invert", "blur"}) + +state: "EchoSession | None" = None + + +@dataclass +class ModeState: + mode: str = "echo" + radius: int = 7 + + +@dataclass +class EchoSession: + session_id: str + in_url: str + out_url: str + mode: ModeState + output: MediaOutput + publisher: MediaPublish + + def to_json(self) -> dict[str, Any]: + data = { + "session": self.session_id, + "in": self.in_url, + "out": self.out_url, + "mode": self.mode.mode, + } + if self.mode.mode == "blur": + data["radius"] = self.mode.radius + return data + + +async def _close_pipeline() -> None: + global state + if state is None: + return + current = state + state = None + with suppress(Exception): + await current.publisher.close() + with suppress(Exception): + await current.output.close() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Live Runner echo app demo.") + parser.add_argument("--orchestrator", default="http://localhost:8935") + parser.add_argument("--orchSecret", default="abcdef") + parser.add_argument("--runner-url", default=f"http://{DEFAULT_HOST}:{DEFAULT_PORT}") + return parser.parse_args() + + +def _session_id(request: web.Request) -> str: + session_id = request.headers.get("Livepeer-Session-Id", "").strip() + if not session_id: + raise web.HTTPBadRequest(text="missing Livepeer-Session-Id header") + return session_id + + +def _parse_mode(payload: dict[str, Any]) -> ModeState: + mode = str(payload.get("mode", "echo")).strip().lower() + if mode not in MODES: + raise web.HTTPBadRequest(text=f"mode must be one of {sorted(MODES)}") + radius = payload.get("radius", 7) + try: + radius_int = int(radius) + except (TypeError, ValueError) as exc: + raise web.HTTPBadRequest(text="radius must be an integer") from exc + return ModeState(mode=mode, radius=max(1, min(99, radius_int))) + + +def _odd_kernel(radius: int) -> int: + kernel = max(1, int(radius)) + if kernel % 2 == 0: + kernel += 1 + return min(kernel, 99) + + +def _transform_frame( + decoded: AudioDecodedMediaFrame | VideoDecodedMediaFrame, + mode: ModeState, +) -> av.VideoFrame | None: + if decoded.kind != "video": + return None + + frame = decoded.frame + if mode.mode == "echo": + return frame + + import cv2 + + img = frame.to_ndarray(format="bgr24") + if mode.mode == "gray": + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) + elif mode.mode == "invert": + img = 255 - img + elif mode.mode == "blur": + kernel = _odd_kernel(mode.radius) + img = cv2.GaussianBlur(img, (kernel, kernel), 0) + + out = av.VideoFrame.from_ndarray(img, format="bgr24") + out.pts = frame.pts + out.time_base = frame.time_base + return out + + +async def _handle_echo(request: web.Request) -> web.Response: + global state + session_id = _session_id(request) + + if state is not None: + if state.session_id != session_id: + raise web.HTTPConflict(text="echo runner already has an active session") + return web.json_response(state.to_json()) + + channels = await create_trickle_channels( + request, + [ + {"name": "in", "mime_type": "video/mp2t"}, + {"name": "out", "mime_type": "video/mp2t"}, + ], + ) + by_name = {channel["name"]: channel for channel in channels} + if "in" not in by_name or "out" not in by_name: + raise web.HTTPInternalServerError(text="orchestrator did not return in/out channels") + + # for production apps, handle errors + mode = _parse_mode(json.loads(await request.read())) + publisher = MediaPublish(by_name["out"]["url"]) + + async def _on_frame(decoded) -> None: + frame = _transform_frame(decoded, mode) + if frame is not None: + await publisher.write_frame(frame) + + output = MediaOutput(by_name["in"]["url"], on_frame=_on_frame) + + state = EchoSession( + session_id=session_id, + in_url=by_name["in"]["url"], + out_url=by_name["out"]["url"], + mode=mode, + output=output, + publisher=publisher, + ) + for task in output.callback_tasks(): + task.add_done_callback(lambda _task: asyncio.create_task(_close_pipeline())) + print(f"started echo session {session_id}") + return web.json_response(state.to_json()) + + +async def _handle_update(request: web.Request) -> web.Response: + session_id = _session_id(request) + if state is None: + raise web.HTTPNotFound(text="echo session not started") + if state.session_id != session_id: + raise web.HTTPConflict(text="echo runner has a different active session") + + # for production apps, handle errors + mode = _parse_mode(json.loads(await request.read())) + state.mode.mode = mode.mode + state.mode.radius = mode.radius + return web.json_response(state.to_json()) + + +async def _on_cleanup(app: web.Application) -> None: + await _close_pipeline() + + +async def _on_startup(app: web.Application) -> None: + args = _parse_args() + registration = await register_runner( + args.orchestrator, + secret=args.orchSecret, + runner_url=args.runner_url, + app="livepeer-sample/echo", + ) + print( + f"runner_id={registration.runner_id} orchestrator={registration.orchestrator_url}" + ) + + +def main() -> None: + app = web.Application() + app.router.add_post("/echo", _handle_echo) + app.router.add_post("/update", _handle_update) + app.on_startup.append(_on_startup) + app.on_cleanup.append(_on_cleanup) + web.run_app(app, host=DEFAULT_HOST, port=DEFAULT_PORT) + + +if __name__ == "__main__": + main() diff --git a/examples/get_orchestrator_info.py b/examples/get_orchestrator_info.py index a1430ad..1ba239a 100644 --- a/examples/get_orchestrator_info.py +++ b/examples/get_orchestrator_info.py @@ -10,7 +10,8 @@ get_per_capability_map, ) from livepeer_gateway import get_orch_info -from livepeer_gateway.orchestrator import LivepeerGatewayError, discover_orchestrators +from livepeer_gateway.discovery import discover_orchestrators +from livepeer_gateway.errors import LivepeerGatewayError from livepeer_gateway.token import parse_token def _parse_args() -> argparse.Namespace: diff --git a/examples/ping-pong/README.md b/examples/ping-pong/README.md new file mode 100644 index 0000000..e09ba38 --- /dev/null +++ b/examples/ping-pong/README.md @@ -0,0 +1,38 @@ +# Ping/Pong Websocket Runner Demo + +This example demonstrates runtime registration for a single-shot websocket +runner. The runner exposes a websocket endpoint: + +- `/ws` receives `{"ping": }` and responds with + `{"pong": , "delta_ms": }`. + +When the websocket closes, the single-shot workload is over and the runner +handler releases its per-connection state. + +Start go-livepeer: + +```sh +./livepeer -orchestrator -useLiveRunners -serviceAddr localhost:8935 -v 99 -orchSecret abcdef +``` + +Start the runner: + +```sh +uv run runner.py --orchestrator http://localhost:8935 --orchSecret abcdef +``` + +Run the client: + +```sh +uv run client.py +``` + +The client discovers the `livepeer-sample/ping-pong` runner, connects +to its proxied websocket URL, sends one ping every second, and prints both the +receiver-side delta and the client round-trip time. + +To send a fixed number of pings: + +```sh +uv run client.py --count 10 +``` diff --git a/examples/ping-pong/client.py b/examples/ping-pong/client.py new file mode 100644 index 0000000..a497b60 --- /dev/null +++ b/examples/ping-pong/client.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import time + +import aiohttp + +from livepeer_gateway.errors import LivepeerGatewayError +from livepeer_gateway.selection import runner_selector + +DEFAULT_DISCOVERY = "http://localhost:8935/discovery" +APP_ID = "livepeer-sample/ping-pong" + + +def _log(*args: object) -> None: + print(*args, file=sys.stderr) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run the websocket ping/pong Live Runner demo.") + parser.add_argument("--discovery", default=DEFAULT_DISCOVERY) + parser.add_argument("--count", type=int, default=10, help="Stop after this many pings (0 = until closed).") + return parser.parse_args() + +async def _select_runner(discovery_url: str) -> str: + cursor = await runner_selector(discovery_url=discovery_url, app=APP_ID) + for candidate in cursor.candidates: + return candidate.url + raise LivepeerGatewayError(f"no websocket runner discovered for app {APP_ID!r}") + + +async def _run_client(url: str, *, count: int) -> None: + async with aiohttp.ClientSession() as session: + async with session.ws_connect(url) as ws: + _log("connected:", url) + sent = 0 + while count <= 0 or sent < count: + ping = time.time() + await ws.send_json({"ping": ping}) + sent += 1 + + msg = json.loads((await ws.receive()).data) + received_at = time.time() + receiver_delta_ms = float(msg.get("delta_ms", -1)) + round_trip_ms = (received_at - ping) * 1000.0 + print( + "ping-pong receiver_delta_ms={:.2f} round_trip_ms={:.2f}".format( + receiver_delta_ms, + round_trip_ms, + ) + ) + + elapsed = time.time() - ping + await asyncio.sleep(max(0.0, 1.0 - elapsed)) + + +async def main() -> None: + args = _parse_args() + app_url = await _select_runner(args.discovery) + await _run_client(app_url + "/ws", count=max(0, args.count)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/ping-pong/runner.py b/examples/ping-pong/runner.py new file mode 100644 index 0000000..365fdf2 --- /dev/null +++ b/examples/ping-pong/runner.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import time +from contextlib import suppress + +from aiohttp import web + +from livepeer_gateway.live_runner import LiveRunnerRegistration, register_runner + +DEFAULT_HOST = "127.0.0.1" +DEFAULT_PORT = 8991 +APP_ID = "livepeer-sample/ping-pong" + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Live Runner websocket ping/pong demo.") + parser.add_argument("--orchestrator", default="http://localhost:8935") + parser.add_argument("--orchSecret", default="abcdef") + parser.add_argument("--runner-url", default=f"http://{DEFAULT_HOST}:{DEFAULT_PORT}") + return parser.parse_args() + + +def _pong_response(payload: str, *, now: float | None = None) -> dict[str, float]: + try: + data = json.loads(payload) + except json.JSONDecodeError as exc: + raise ValueError("message must be JSON") from exc + if not isinstance(data, dict): + raise ValueError("message must be a JSON object") + + ping = data.get("ping") + if isinstance(ping, bool) or not isinstance(ping, (int, float)): + raise ValueError("message must include numeric ping") + + received_at = time.time() if now is None else now + return { + "pong": float(ping), + "delta_ms": max(0.0, (received_at - float(ping)) * 1000.0), + } + + +async def _handle_ws(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + print("websocket session opened") + + try: + async for msg in ws: + if msg.type != web.WSMsgType.TEXT: + continue + try: + response = _pong_response(msg.data) + except ValueError as exc: + await ws.send_json({"error": str(exc)}) + continue + await ws.send_json(response) + finally: + print("websocket session closed") + + return ws + + +async def _on_startup(app: web.Application) -> None: + args = _parse_args() + registration = await register_runner( + args.orchestrator, + secret=args.orchSecret, + runner_url=args.runner_url, + app=APP_ID, + mode="single-shot", + ) + app["registration"] = registration + print( + f"runner_id={registration.runner_id} orchestrator={registration.orchestrator_url}" + ) + + +async def _on_cleanup(app: web.Application) -> None: + registration = app.get("registration") + if isinstance(registration, LiveRunnerRegistration): + with suppress(Exception): + await registration.close() + + +def main() -> None: + app = web.Application() + app.router.add_get("/ws", _handle_ws) + app.on_startup.append(_on_startup) + app.on_cleanup.append(_on_cleanup) + web.run_app(app, host=DEFAULT_HOST, port=DEFAULT_PORT) + + +if __name__ == "__main__": + main() diff --git a/examples/text/README.md b/examples/text/README.md new file mode 100644 index 0000000..7720e9f --- /dev/null +++ b/examples/text/README.md @@ -0,0 +1,39 @@ +# Text Stream Demo + +Single-shot text streaming on Livpeeer with static configuration. + +This demo exposes a tiny aiohttp runner with two streaming endpoints: + +- `/text` streams a story as `text/plain`, one character at a time. +- `/sse` streams a story as Server-Sent Events, one line per event. + +The app also exposes `/healthz` for go-livepeer runner health checks. + +Start go-livepeer with a static runner config: + +```sh +# assumes `livepeer` is somewhere in your PATH +livepeer -config go-livepeer.conf +``` + +Start the runner app: + +```sh +uv run runner.py +``` + +Call the endpoints through go-livepeer: + +```sh +# Plain text stream. `-N` disables curl output buffering +curl -N http://localhost:8935/apps/story-runner/app/text + +# SSE stream. Each story line is emitted as one `data:` event +curl -N http://localhost:8935/apps/story-runner/app/sse +``` + +Verify the runner registration with go-livepeer: + +``` +curl http://localhost:8935/discovery | jq +``` diff --git a/examples/text/go-livepeer.conf b/examples/text/go-livepeer.conf new file mode 100644 index 0000000..88cc50b --- /dev/null +++ b/examples/text/go-livepeer.conf @@ -0,0 +1,10 @@ +# go-livepeer config files use the same keys as CLI flags in `key value` form. +orchestrator true +useLiveRunners true +httpAddr http://localhost:8935 +serviceAddr http://localhost:8935 +orchSecret abcdef +v 99 + +# Point this at the static runner config below instead of registering from Python. +liveRunnerConfig runners.json diff --git a/examples/text/runner.py b/examples/text/runner.py new file mode 100644 index 0000000..4c7d795 --- /dev/null +++ b/examples/text/runner.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import asyncio + +from aiohttp import web + + +async def _handle_sse(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse( + status=200, + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + await response.prepare(request) + + with open("story.txt", encoding="utf-8", errors="replace") as lines: + for line in lines: + await response.write(f"data: {line.rstrip('\n')}\n\n".encode("utf-8")) + await asyncio.sleep(0.5) + + await response.write_eof() + return response + + +async def _handle_text(request: web.Request) -> web.StreamResponse: + response = web.StreamResponse( + status=200, + headers={ + "Content-Type": "text/plain; charset=utf-8", + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + }, + ) + await response.prepare(request) + + with open("story.txt", encoding="utf-8", errors="replace") as story: + while char := story.read(1): + await response.write(char.encode("utf-8")) + await asyncio.sleep(0.02) + + await response.write_eof() + return response + + +async def _handle_health(_: web.Request) -> web.Response: + return web.json_response({"ok": True}) + + +def main() -> None: + app = web.Application() + app.router.add_get("/sse", _handle_sse) + app.router.add_get("/text", _handle_text) + app.router.add_get("/healthz", _handle_health) + web.run_app(app, host="127.0.0.1", port=8990) + + +if __name__ == "__main__": + main() diff --git a/examples/text/runners.json b/examples/text/runners.json new file mode 100644 index 0000000..8466e07 --- /dev/null +++ b/examples/text/runners.json @@ -0,0 +1,13 @@ +{ + "runners": [ + { + "label": "story-runner", + "app": "livepeer/read-story", + "runner_url": "http://127.0.0.1:8990", + "health_url": "/healthz", + "routing": "label", + "capacity": 10, + "mode": "single-shot" + } + ] +} diff --git a/examples/text/story.txt b/examples/text/story.txt new file mode 100644 index 0000000..09464d9 --- /dev/null +++ b/examples/text/story.txt @@ -0,0 +1,112 @@ +"The Open Window" +Saki (1914) + +‘My Aunt will be down presently, Mr Nuttel,’ said a +self-possessed young lady of fifteen. ‘In the meantime, +you must put up with me.’ + +Framton Nuttel tried to make pleasant conversation +while waiting for the Aunt. Privately, he doubted more +than ever whether these formal visits on total strangers +would help the nerve cure which he was supposed to be +undergoing in this rural retreat. + +‘I’ll just give you letters to all the people I know +there,’ his sister had said. ‘Otherwise you’ll bury +yourself and not speak to a soul and your nerves will be +worse than ever from moping.’ + +‘Do you know many people around here?’ asked the +niece. + +‘Hardly a soul. My sister gave me letters of +introduction to some people here.’ + +‘Then you know practically nothing about my Aunt?’ +continued the self-possessed young lady. + +‘Only her name and address,’ admitted the caller. + +‘Her great tragedy happened just three years ago,’ +said the child. + +‘Her tragedy?’ asked Framton. Somehow, in this restful +spot, tragedies seemed out of place. + +‘You may wonder why we keep that window open so late +in the year,’ said the niece, indicating a large French +window that opened on a lawn. ‘Out through that window, +three years ago to a day, her husband and her two young +brothers went off for their day’s shooting. In crossing +the moor, they were engulfed in a treacherous bog. +Their bodies were never recovered.’ + +Here the child’s voice faltered. ‘Poor Aunt always +thinks that they’ll come back someday, they and the +little brown spaniel that was lost with them, and walk +in the window. That is why it is kept open every +evening till dusk. She has often told me how they went +out, her husband with his white waterproof coat over +his arm. You know, sometimes on still evenings like +this I get a creepy feeling that they will all walk in +through that window —’ + +She broke off with a little shudder. It was a relief +to Framton when the aunt bustled into the room with a +whirl of apologies for keeping him waiting. + +‘I hope you don’t mind the open window,’ she said. +‘My husband and brothers will be home directly from +shooting and they always come in this way.’ + +She rattled on cheerfully about the prospects for duck +shooting in the winter. Framton made a desperate effort +to tum the talk to a less ghastly topic, conscious that +his hostess was giving him only a fragment of her +attention, and that her eyes were constantly straying +past him to the open window. It was certainly an +unfortunate coincidence that he should have paid his +visit on this tragic anniversary. + +‘The doctors ordered me a complete rest from mental +excitement and physical exercise,’ announced Framton, +who imagined that everyone — even a complete stranger — +was interested in his illness. + +‘Oh?’ said Mrs Sappleton, vaguely. Then she suddenly +brightened into attention — but not to what Framton was +saying. + +‘Here they are at last!’ she cried. ‘In time for tea, +and muddy up to the eyes.’ + +Framton shivered slightly and turned towards the niece +with a look intended to convey sympathetic +understanding. The child was staring through the open +window with dazed horror in her eyes. Framton swung +round and looked in the same direction. + +In the deepening twilight three figures were walking +noiselessly across the lawn, a tired brown spaniel +close at their heels. They all carried guns, and one +had a white coat over his shoulders. + +Framton grabbed his stick; the hall door and the gravel +drive were dimly noted stages in his headlong retreat. + +‘Here we are, my dear,’ said the bearer of the white +mackintosh. + +‘Who was that who bolted out as we came up?’ + +‘An extraordinary man, a Mr Nuttel,’ said Mrs +Sappleton, ‘who could only talk about his illness, and +dashed off without a word of apology when you arrived. +One would think he had seen a ghost.’ + +‘I expect it was the spaniel,’ said the niece calmly. +‘He told me he had a horror of dogs. He was once hunted +into a cemetery on the banks of the Ganges by a pack of +stray dogs and had to spend the night in a newly-dug +grave with the creatures snarling and foaming above +him. Enough to make anyone lose his nerve.’ diff --git a/pyproject.toml b/pyproject.toml index afa0751..51a7ec3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,6 @@ examples = [ [tool.hatch.build.targets.wheel] packages = ["src/livepeer_gateway", "src/net"] + +[tool.ruff.lint.per-file-ignores] +"src/livepeer_gateway/lp_rpc_pb2_grpc.py" = ["F401"] diff --git a/src/livepeer_gateway/__init__.py b/src/livepeer_gateway/__init__.py index 3a82a4d..3baf9f5 100644 --- a/src/livepeer_gateway/__init__.py +++ b/src/livepeer_gateway/__init__.py @@ -1,5 +1,5 @@ from .capabilities import CapabilityId, build_capabilities -from .channel_reader import ChannelReader, JSONLReader +from .channel_reader import ChannelEventCallback, ChannelReader, JSONLReader from .channel_writer import ChannelWriter, JSONLWriter from .control import Control, ControlConfig, ControlMode from .byoc import ( @@ -15,7 +15,7 @@ wait_for_training, list_capabilities, ) -from .errors import LivepeerGatewayError, NoOrchestratorAvailableError, PaymentError +from .errors import LivepeerHTTPError, LivepeerGatewayError, NoOrchestratorAvailableError, NoRunnerAvailableError, PaymentError from .events import Events from .media_publish import ( AudioOutputConfig, @@ -32,14 +32,41 @@ DemuxedMediaPacket, VideoDecodedMediaFrame, ) -from .media_output import MediaOutput, MediaOutputStats -from .errors import OrchestratorRejection +from .media_output import ( + MediaBytesCallback, + MediaFrameCallback, + MediaOutput, + MediaOutputStats, + MediaPacketCallback, +) +from .errors import OrchestratorRejection, RunnerRejection from .lv2v import LiveVideoToVideo, StartJobRequest, start_lv2v +from .live_runner import ( + LiveRunnerCallResult, + LiveRunnerGPU, + LiveRunnerInstance, + LiveRunnerPriceInfo, + LiveRunnerRegistration, + LiveRunnerSession, + LiveRunnerSessionCallback, + LiveRunnerSessionEvent, + call_runner, + create_trickle_channels, + register_runner, + remove_trickle_channels, + stop_runner_session, +) +from .discovery import discover_orchestrators, discover_runners from .orch_info import get_orch_info -from .orchestrator import discover_orchestrators -from .remote_signer import PaymentSession +from .remote_signer import LivePaymentSession, PaymentSession from .scope import start_scope -from .selection import SelectionCursor, orchestrator_selector +from .selection import ( + RunnerSelectionCursor, + SelectionCursor, + orchestrator_selector, + runner_selector, + reserve_session, +) from .token import parse_token from .trickle_publisher import ( TricklePublishError, @@ -55,15 +82,34 @@ "Control", "ControlConfig", "ControlMode", + "ByocJobRequest", + "ByocJobResponse", + "ByocTrainingRequest", + "ByocTrainingResponse", + "ByocTrainingStatus", "ChannelWriter", "CapabilityId", "build_capabilities", "discover_orchestrators", + "discover_runners", + "get_training_status", "get_orch_info", "LiveVideoToVideo", + "LiveRunnerCallResult", + "LiveRunnerGPU", + "LiveRunnerInstance", + "LiveRunnerPriceInfo", + "LiveRunnerRegistration", + "LiveRunnerSession", + "LiveRunnerSessionCallback", + "LiveRunnerSessionEvent", + "LivePaymentSession", "LivepeerGatewayError", + "LivepeerHTTPError", "NoOrchestratorAvailableError", + "NoRunnerAvailableError", "OrchestratorRejection", + "RunnerRejection", "PaymentError", "MediaPublish", "MediaPublishConfig", @@ -74,20 +120,35 @@ "AudioOutputConfig", "MediaOutput", "MediaOutputStats", + "MediaBytesCallback", + "MediaFrameCallback", + "MediaPacketCallback", "AudioDecodedMediaFrame", "DecodedMediaFrame", "DemuxedMediaPacket", + "ChannelEventCallback", "ChannelReader", "JSONLReader", "JSONLWriter", "Events", "PaymentSession", "parse_token", + "RunnerSelectionCursor", "SelectionCursor", "orchestrator_selector", + "runner_selector", + "reserve_session", "StartJobRequest", + "call_runner", + "create_trickle_channels", + "register_runner", + "remove_trickle_channels", + "refresh_training_payment", "start_lv2v", "start_scope", + "stop_runner_session", + "submit_byoc_job", + "submit_training_job", "TricklePublishError", "TricklePublisher", "TricklePublisherStats", @@ -98,4 +159,6 @@ "TrickleSubscriber", "TrickleSubscriberStats", "VideoDecodedMediaFrame", + "wait_for_training", + "list_capabilities", ] diff --git a/src/livepeer_gateway/async_cache.py b/src/livepeer_gateway/async_cache.py new file mode 100644 index 0000000..b0acac4 --- /dev/null +++ b/src/livepeer_gateway/async_cache.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from collections import OrderedDict +from functools import wraps +from typing import Any, Awaitable, Callable, TypeVar + +_T = TypeVar("_T") + + +def async_lru_cache( + maxsize: int, +) -> Callable[[Callable[..., Awaitable[_T]]], Callable[..., Awaitable[_T]]]: + def decorator(func: Callable[..., Awaitable[_T]]) -> Callable[..., Awaitable[_T]]: + cache: OrderedDict[tuple[tuple[Any, ...], tuple[tuple[str, Any], ...]], _T] = OrderedDict() + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> _T: + key = (args, tuple(sorted(kwargs.items()))) + cached = cache.get(key) + if cached is not None: + cache.move_to_end(key) + return cached + + value = await func(*args, **kwargs) + cache[key] = value + cache.move_to_end(key) + if len(cache) > maxsize: + cache.popitem(last=False) + return value + + wrapper.cache_clear = cache.clear # type: ignore[attr-defined] + return wrapper + + return decorator diff --git a/src/livepeer_gateway/channel_reader.py b/src/livepeer_gateway/channel_reader.py index 7103076..cf528d5 100644 --- a/src/livepeer_gateway/channel_reader.py +++ b/src/livepeer_gateway/channel_reader.py @@ -1,15 +1,245 @@ from __future__ import annotations +import asyncio +import inspect import json -from typing import Any, AsyncIterator +import logging +from typing import Any, AsyncIterator, Awaitable, Callable, Optional from .errors import LivepeerGatewayError +from .segment_reader import SegmentReader from .trickle_subscriber import TrickleSubscriber +_LOG = logging.getLogger(__name__) -class ChannelReader: - def __init__(self, events_url: str) -> None: +ChannelEventCallback = Callable[[dict[str, Any]], None | Awaitable[None]] +""" +Callback invoked for each decoded channel event. + +Callbacks may be synchronous or asynchronous. Async callback results are awaited +before the next event is delivered. +""" + + +async def _maybe_await(value: object) -> None: + if inspect.isawaitable(value): + await value + + +class _ChannelReaderCallback: + def _init_callback( + self, + events_url: str, + *, + start_seq: int = -2, + max_retries: int = 5, + max_event_bytes: int = 1_048_576, + on_event: Optional[ChannelEventCallback] = None, + ) -> None: self.events_url = events_url + self.start_seq = start_seq + self.max_retries = max_retries + self.max_event_bytes = max_event_bytes + self.on_event = on_event + self._event_callback_task: Optional[asyncio.Task[None]] = None + self._callback_error: Optional[BaseException] = None + if self.on_event is not None: + self.start_callback() + + def __call__( + self, + *, + start_seq: int = -2, + max_retries: int = 5, + max_event_bytes: int = 1_048_576, + ) -> AsyncIterator[dict[str, Any]]: + raise NotImplementedError + + def start_callback( + self, + ) -> Optional[asyncio.Task[None]]: + """ + Start the configured event callback consumer. + + This is idempotent. If called without a running event loop, no task is + started and callers may retry later from async code. + + Callback consumption uses the start_seq, max_retries, and + max_event_bytes values supplied to the reader constructor. Those + constructor values do not affect explicit iterator calls via __call__. + """ + if self.on_event is None: + return None + if self._event_callback_task is not None and not self._event_callback_task.done(): + return self._event_callback_task + try: + loop = asyncio.get_running_loop() + except RuntimeError: + _LOG.warning( + "No running event loop; %s callback not started. " + "Call reader.start_callback() from async code or use async with the reader.", + type(self).__name__, + ) + return None + + task = loop.create_task( + self._run_event_callback_loop( + self.on_event, + start_seq=self.start_seq, + max_retries=self.max_retries, + max_event_bytes=self.max_event_bytes, + ), + name=f"{type(self).__name__}.on_event", + ) + self._callback_error = None + task.add_done_callback(self._record_callback_task_result) + self._event_callback_task = task + return task + + def callback_task(self) -> Optional[asyncio.Task[None]]: + """ + Return the active or completed callback task, if one has been created. + """ + return self._event_callback_task + + async def wait_callback(self, timeout: Optional[float] = None) -> object: + """ + Wait for the configured event callback consumer to finish. + + Raises the first callback error, matching close(). + """ + task = self.callback_task() + if task is None: + return None + try: + result = await asyncio.wait_for(task, timeout=timeout) + except asyncio.CancelledError: + raise + except BaseException as exc: + self._record_callback_error(exc) + raise + return result + + def _record_callback_error(self, error: BaseException) -> None: + if isinstance(error, asyncio.CancelledError): + return + if self._callback_error is None: + self._callback_error = error + + async def _run_event_callback_loop( + self, + callback: ChannelEventCallback, + *, + start_seq: int, + max_retries: int, + max_event_bytes: int, + ) -> None: + async for event in self( + start_seq=start_seq, + max_retries=max_retries, + max_event_bytes=max_event_bytes, + ): + await _maybe_await(callback(event)) + + def _record_callback_task_result(self, task: asyncio.Task[None]) -> None: + try: + exc = task.exception() + except asyncio.CancelledError: + return + if exc is None: + return + self._record_callback_error(exc) + _LOG.error( + "%s callback task failed", + type(self).__name__, + exc_info=(type(exc), exc, exc.__traceback__), + ) + + async def close( + self, + *, + wait_callback: bool = True, + timeout: Optional[float] = 10.0, + ) -> None: + """ + Stop callback consumption and surface callback errors. + + If wait_callback is true, close waits up to timeout for the callback + task to finish naturally before cancelling it. Any callback exception is + raised from close(), matching wait_callback(). + """ + task = self.callback_task() + if task is not None: + if wait_callback and (timeout is None or timeout > 0): + try: + await self.wait_callback(timeout=timeout) + except asyncio.TimeoutError: + pass + if not task.done(): + task.cancel() + (result,) = await asyncio.gather(task, return_exceptions=True) + if isinstance(result, BaseException): + self._record_callback_error(result) + if self._callback_error is not None: + raise self._callback_error + + async def __aenter__(self): + self.start_callback() + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self.close() + + +class ChannelReader(_ChannelReaderCallback): + """ + Read a trickle channel containing one JSON object per segment. + + Iterator usage is lazy and configured per call: + + async for event in ChannelReader(url)(start_seq=-2): + ... + + Callback usage is configured on the instance: + + reader = ChannelReader(url, start_seq=-2, on_event=handle_event) + reader.start_callback() + + The constructor's start_seq, max_retries, and max_event_bytes values apply + only to callback consumption. Explicit calls to reader(...) keep their own + arguments and defaults. + """ + + def __init__( + self, + events_url: str, + *, + start_seq: int = -2, + max_retries: int = 5, + max_event_bytes: int = 1_048_576, + on_event: Optional[ChannelEventCallback] = None, + ) -> None: + """ + Create a JSON channel reader. + + Args: + events_url: Trickle subscribe URL. + start_seq: Initial server sequence for callback consumption only. + max_retries: Retry count for callback consumption only. + max_event_bytes: Per-segment byte limit for callback consumption only. + on_event: Optional callback invoked for each decoded JSON object. + + If on_event is provided while an event loop is running, callback + consumption starts immediately. If no loop is running, call + start_callback() later from async code or use async with the reader. + """ + self._init_callback( + events_url, + start_seq=start_seq, + max_retries=max_retries, + max_event_bytes=max_event_bytes, + on_event=on_event, + ) def __call__( self, @@ -24,6 +254,9 @@ def __call__( Each yielded item is a decoded JSON object (dict). The underlying network subscription starts lazily on first iteration. + These arguments configure this iterator only. They do not change the + instance settings used by callback consumption. + max_event_bytes applies per segment (per JSON message), not across the entire stream. """ @@ -82,9 +315,55 @@ async def _iter() -> AsyncIterator[dict[str, Any]]: return _iter() -class JSONLReader: - def __init__(self, events_url: str) -> None: - self.events_url = events_url +class JSONLReader(_ChannelReaderCallback): + """ + Read a trickle channel containing newline-delimited JSON objects. + + Iterator usage is lazy and configured per call: + + async for event in JSONLReader(url)(start_seq=-2): + ... + + Callback usage is configured on the instance: + + reader = JSONLReader(url, start_seq=-2, on_event=handle_event) + reader.start_callback() + + The constructor's start_seq, max_retries, and max_event_bytes values apply + only to callback consumption. Explicit calls to reader(...) keep their own + arguments and defaults. + """ + + def __init__( + self, + events_url: str, + *, + start_seq: int = -2, + max_retries: int = 5, + max_event_bytes: int = 1_048_576, + on_event: Optional[ChannelEventCallback] = None, + ) -> None: + """ + Create a JSONL channel reader. + + Args: + events_url: Trickle subscribe URL. + start_seq: Initial server sequence for callback consumption only. + max_retries: Retry count for callback consumption only. + max_event_bytes: Per-segment byte limit for callback consumption only. + on_event: Optional callback invoked for each decoded JSON object. + + If on_event is provided while an event loop is running, callback + consumption starts immediately. If no loop is running, call + start_callback() later from async code or use async with the reader. + """ + self._init_callback( + events_url, + start_seq=start_seq, + max_retries=max_retries, + max_event_bytes=max_event_bytes, + on_event=on_event, + ) def __call__( self, @@ -99,6 +378,9 @@ def __call__( Events are yielded incrementally as newline-terminated lines arrive, without buffering the entire segment in memory first. max_event_bytes applies per segment, not across the entire stream. + + These arguments configure this iterator only. They do not change the + instance settings used by callback consumption. """ url = self.events_url @@ -173,4 +455,3 @@ async def _iter() -> AsyncIterator[dict[str, Any]]: ) from e return _iter() - diff --git a/src/livepeer_gateway/discovery.py b/src/livepeer_gateway/discovery.py new file mode 100644 index 0000000..43f1cad --- /dev/null +++ b/src/livepeer_gateway/discovery.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Optional, Sequence +from urllib.parse import parse_qsl, quote, urlencode, urlparse, urlunparse + +from . import lp_rpc_pb2 +from .capabilities import capabilities_to_query +from .errors import LivepeerGatewayError +from .remote_signer import RemoteSignerError +from .http import _http_origin, _parse_http_url, get_json, get_json_sync + +_LOG = logging.getLogger(__name__) + +FilterValue = str | Sequence[str] +_RUNNER_DISCOVERY_BATCH_SIZE = 5 + + +def _normalize_filter_values(value: Optional[FilterValue]) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + values = [value] + else: + values = list(value) + return [item.strip() for item in values if isinstance(item, str) and item.strip()] + + +def _append_query_values(url: str, values: Sequence[tuple[str, str]]) -> str: + if not values: + return url + + parsed = urlparse(url) + query_pairs = parse_qsl(parsed.query, keep_blank_values=True) + query_pairs.extend(values) + query = urlencode(query_pairs, doseq=True, quote_via=quote, safe="/") + return urlunparse(parsed._replace(query=query)) + + +def _append_caps(url: str, capabilities: Optional[lp_rpc_pb2.Capabilities]) -> str: + """ + Append repeated `caps` query parameters to a URL. + + Existing query params are preserved. Capability values keep `/` unescaped. + """ + if capabilities is None: + return url + return _append_query_values(url, [("caps", cap) for cap in capabilities_to_query(capabilities)]) + + +def _append_runner_filters( + url: str, + *, + app: Optional[FilterValue] = None, + gpu: Optional[FilterValue] = None, +) -> str: + values: list[tuple[str, str]] = [] + values.extend(("app", item) for item in _normalize_filter_values(app)) + values.extend(("gpu", item) for item in _normalize_filter_values(gpu)) + return _append_query_values(url, values) + + +def discover_orchestrators( + orchestrators: Optional[Sequence[str] | str] = None, + *, + signer_url: Optional[str] = None, + signer_headers: Optional[dict[str, str]] = None, + discovery_url: Optional[str] = None, + discovery_headers: Optional[dict[str, str]] = None, + capabilities: Optional[lp_rpc_pb2.Capabilities] = None, +) -> list[str]: + """ + Discover orchestrators and return a list of addresses. + + This discovery can happen via the following parameters in priority order (highest first): + - orchestrators: list or comma-delimited string + (empty/whitespace-only input falls through) + - discovery_url: use this discovery endpoint + - signer_url: use signer-provided discovery service + """ + if orchestrators is not None: + if isinstance(orchestrators, str): + orch_list = [orch.strip() for orch in orchestrators.split(",")] + else: + try: + orch_list = list(orchestrators) + except TypeError as e: + raise LivepeerGatewayError( + "discover_orchestrators requires a list of orchestrator URLs or a comma-delimited string" + ) from e + orch_list = [orch.strip() for orch in orch_list if isinstance(orch, str) and orch.strip()] + if orch_list: + return orch_list + + if discovery_url: + discovery_endpoint = _parse_http_url(discovery_url).geturl() + request_headers = discovery_headers + elif signer_url: + discovery_endpoint = f"{_http_origin(signer_url)}/discover-orchestrators" + request_headers = signer_headers + else: + _LOG.debug("discover_orchestrators failed: no discovery inputs") + raise LivepeerGatewayError("discover_orchestrators requires discovery_url or signer_url") + + if capabilities is not None: + discovery_endpoint = _append_caps(discovery_endpoint, capabilities) + + try: + _LOG.debug("discover_orchestrators running discovery: %s", discovery_endpoint) + data = get_json_sync(discovery_endpoint, headers=request_headers) + except LivepeerGatewayError as e: + _LOG.debug("discover_orchestrators discovery failed: %s", e) + raise RemoteSignerError( + discovery_endpoint, + str(e), + cause=e.__cause__ or e, + ) from None + + if not isinstance(data, list): + _LOG.debug( + "discover_orchestrators discovery response not list: type=%s", + type(data).__name__, + ) + raise RemoteSignerError( + discovery_endpoint, + f"Discovery response must be a JSON list, got {type(data).__name__}", + cause=None, + ) from None + + _LOG.debug("discover_orchestrators discovery response: %s", data) + + orch_list = [] + for item in data: + if not isinstance(item, dict): + continue + address = item.get("address") + if isinstance(address, str) and address.strip(): + orch_list.append(address.strip()) + _LOG.debug("discover_orchestrators discovered %d orchestrators", len(orch_list)) + + return orch_list + + +async def discover_runners( + *, + signer_url: Optional[str] = None, + signer_headers: Optional[dict[str, str]] = None, + discovery_url: Optional[str] = None, + discovery_headers: Optional[dict[str, str]] = None, + app: Optional[FilterValue] = None, + gpu: Optional[FilterValue] = None, +) -> list[dict[str, Any]]: + """ + Discover live runners and return discovery entries. + + Filters are composed as OR within each field and AND across fields. + For example, app=["a", "b"], gpu=["H100", "L40S"] matches + (app=a OR app=b) AND (gpu=H100 OR gpu=L40S). + """ + if discovery_url: + discovery_endpoint = _parse_http_url(discovery_url).geturl() + request_headers = discovery_headers + elif signer_url: + discovery_endpoint = f"{_http_origin(signer_url)}/discover-orchestrators" + request_headers = signer_headers + else: + _LOG.debug("discover_runners failed: no discovery inputs") + raise LivepeerGatewayError("discover_runners requires discovery_url or signer_url") + + app_filters = _normalize_filter_values(app) + gpu_filters = _normalize_filter_values(gpu) + discovery_endpoint = _append_runner_filters(discovery_endpoint, app=app_filters, gpu=gpu_filters) + + try: + _LOG.debug("discover_runners running discovery: %s", discovery_endpoint) + data = await get_json(discovery_endpoint, headers=request_headers) + except LivepeerGatewayError as e: + _LOG.debug("discover_runners discovery failed: %s", e) + raise RemoteSignerError( + discovery_endpoint, + str(e), + cause=e.__cause__ or e, + ) from None + + if not isinstance(data, list): + _LOG.debug( + "discover_runners discovery response not list: type=%s", + type(data).__name__, + ) + raise RemoteSignerError( + discovery_endpoint, + f"Discovery response must be a JSON list, got {type(data).__name__}", + cause=None, + ) from None + + entries = _filter_runner_discovery_entries(data, app_filters=app_filters, gpu_filters=gpu_filters) + _LOG.debug("discover_runners discovered %d orchestrator entries", len(entries)) + return entries + + +async def discover_orchestrator_runners( + orchestrators: Optional[Sequence[str] | str], + *, + app: Optional[FilterValue] = None, + gpu: Optional[FilterValue] = None, + batch_size: int = _RUNNER_DISCOVERY_BATCH_SIZE, +) -> list[dict[str, Any]]: + first_error: Exception | None = None + urls = orchestrator_discovery_urls(orchestrators) + for batch_start in range(0, len(urls), batch_size): + batch = urls[batch_start : batch_start + batch_size] + results = await asyncio.gather( + *(discover_runners(discovery_url=discovery_url, app=app, gpu=gpu) for discovery_url in batch), + return_exceptions=True, + ) + for discovery_url, result in zip(batch, results): + if isinstance(result, Exception): + if first_error is None: + first_error = result + _LOG.debug("discover_orchestrator_runners failed: %s (%s)", discovery_url, result) + continue + if result: + return result + + if first_error is not None: + raise first_error + return [] + + +def orchestrator_discovery_urls(orchestrators: Optional[Sequence[str] | str]) -> list[str]: + if orchestrators is None: + return [] + if isinstance(orchestrators, str): + candidates = [item.strip() for item in orchestrators.split(",")] + else: + try: + candidates = [item.strip() for item in orchestrators if isinstance(item, str)] + except TypeError as e: + raise LivepeerGatewayError( + "orchestrator_discovery_urls requires a list of orchestrator URLs or a comma-delimited string" + ) from e + + urls = [] + for candidate in candidates: + if not candidate: + continue + try: + parsed = _parse_http_url(candidate, context="orchestrator URL") + except ValueError as e: + raise LivepeerGatewayError(f"Invalid orchestrator URL: {candidate!r}") from e + base_path = parsed.path.rstrip("/") + discovery_path = f"{base_path}/discovery" if base_path else "/discovery" + urls.append(parsed._replace(path=discovery_path, query="", fragment="").geturl()) + return urls + + +def _filter_runner_discovery_entries( + data: Sequence[Any], + *, + app_filters: Sequence[str], + gpu_filters: Sequence[str], +) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + for item in data: + if not isinstance(item, dict): + continue + runners = item.get("runners") + if not isinstance(runners, list): + continue + + matched_runners = [] + for runner in runners: + if not isinstance(runner, dict): + continue + if not _valid_runner(runner): + continue + if not _runner_matches_filters(runner, app_filters=app_filters, gpu_filters=gpu_filters): + continue + matched_runners.append(runner) + + if matched_runners: + entry = dict(item) + entry["runners"] = matched_runners + entries.append(entry) + return entries + + +def _valid_runner(runner: dict[str, Any]) -> bool: + url = runner.get("url") + app = runner.get("app") + return isinstance(url, str) and bool(url.strip()) and isinstance(app, str) and bool(app.strip()) + + +def _runner_matches_filters( + runner: dict[str, Any], + *, + app_filters: Sequence[str], + gpu_filters: Sequence[str], +) -> bool: + app = runner.get("app") + app_value = app.strip() if isinstance(app, str) else "" + if app_filters and app_value not in app_filters: + return False + if gpu_filters and _runner_gpu_name(runner) not in gpu_filters: + return False + return True + + +def _runner_gpu_name(runner: dict[str, Any]) -> str: + gpu = runner.get("gpu") + if isinstance(gpu, dict): + name = gpu.get("name") + if isinstance(name, str): + return name.strip() + return "" diff --git a/src/livepeer_gateway/errors.py b/src/livepeer_gateway/errors.py index 14fb8b4..d967e8b 100644 --- a/src/livepeer_gateway/errors.py +++ b/src/livepeer_gateway/errors.py @@ -1,13 +1,22 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Optional +from dataclasses import dataclass class LivepeerGatewayError(RuntimeError): """Base error for the library.""" +class LivepeerHTTPError(LivepeerGatewayError): + """Raised when an HTTP endpoint returns a non-success status.""" + + def __init__(self, status_code: int, url: str, body: str = "", message: str | None = None) -> None: + self.status_code = int(status_code) + self.url = url + self.body = body + super().__init__(message or f"HTTP {status_code} from endpoint (url={url})") + + @dataclass class OrchestratorRejection: """Records a single orchestrator that was tried and rejected.""" @@ -15,6 +24,13 @@ class OrchestratorRejection: reason: str +@dataclass +class RunnerRejection: + """Records a single runner that was tried and rejected.""" + url: str + reason: str + + class NoOrchestratorAvailableError(LivepeerGatewayError): """Raised when no orchestrator could be selected.""" @@ -23,9 +39,26 @@ def __init__(self, message: str, rejections: list[OrchestratorRejection] | None self.rejections: list[OrchestratorRejection] = rejections or [] +class NoRunnerAvailableError(LivepeerGatewayError): + """Raised when no runner could be selected.""" + + def __init__(self, message: str, rejections: list[RunnerRejection] | None = None) -> None: + super().__init__(message) + self.rejections: list[RunnerRejection] = rejections or [] + + class SignerRefreshRequired(LivepeerGatewayError): """Raised when the remote signer returns HTTP 480 and a refresh is required.""" + def __init__( + self, + message: str, + *, + orchestrator_url: str | None = None, + ) -> None: + super().__init__(message) + self.orchestrator_url = orchestrator_url + class SkipPaymentCycle(LivepeerGatewayError): """Raised when the signer returns HTTP 482 to skip a payment cycle.""" diff --git a/src/livepeer_gateway/http.py b/src/livepeer_gateway/http.py new file mode 100644 index 0000000..0b3566d --- /dev/null +++ b/src/livepeer_gateway/http.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import asyncio +import json +import ssl +from typing import Any, Optional +from urllib.error import HTTPError, URLError +from urllib.parse import ParseResult, urlparse +from urllib.request import Request, urlopen + +import aiohttp + +from .errors import ( + LivepeerHTTPError, + LivepeerGatewayError, + SignerRefreshRequired, + SkipPaymentCycle, +) + +_REFRESH_SESSION_ORCHESTRATOR_URL_HEADER = "Livepeer-Orchestrator-URL" + + +def _truncate(s: str, max_len: int = 2000) -> str: + if len(s) <= max_len: + return s + return s[:max_len] + f"...(+{len(s) - max_len} chars)" + + +def _http_error_body(e: HTTPError) -> str: + """ + Best-effort read of an HTTPError response body for debugging. + """ + try: + b = e.read() + if not b: + return "" + if isinstance(b, bytes): + return b.decode("utf-8", errors="replace") + return str(b) + except Exception: + return "" + + +def _extract_error_message_from_body(body: str) -> str: + """ + Best-effort extraction of a useful error message from an HTTP error body. + + If the body is JSON and matches {"error": {"message": "..."}}, return that message. + Otherwise return the full body. + + Always truncates the returned value for readability. + """ + s = body.strip() + if not s: + return "" + + try: + data = json.loads(s) + except Exception: + return _truncate(body) + + if isinstance(data, dict): + err = data.get("error") + if isinstance(err, dict): + msg = err.get("message") + if isinstance(msg, str) and msg: + return _truncate(msg) + + return _truncate(body) + + +def _extract_error_message(e: HTTPError) -> str: + """ + Best-effort extraction of a useful error message from an HTTPError body. + """ + return _extract_error_message_from_body(_http_error_body(e)) + + +def _header_value(headers: dict[str, str], name: str) -> Optional[str]: + needle = name.lower() + for key, value in headers.items(): + if key.lower() == needle and isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def _json_request_parts( + url: str, + *, + method: Optional[str] = None, + payload: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, +) -> tuple[str, dict[str, str], Optional[bytes]]: + req_headers: dict[str, str] = { + "Accept": "application/json", + "User-Agent": "livepeer-python-gateway/0.1", + } + body: Optional[bytes] = None + if payload is not None: + req_headers["Content-Type"] = "application/json" + body = json.dumps(payload).encode("utf-8") + if headers: + req_headers.update(headers) + + resolved_method = method.upper() if method else ("POST" if payload is not None else "GET") + return resolved_method, req_headers, body + + +def _raise_http_json_error( + status: int, + url: str, + body: str = "", + headers: Optional[dict[str, str]] = None, +) -> None: + message = _extract_error_message_from_body(body) + body_part = f"; body={message!r}" if message else "" + if status == 480: + raise SignerRefreshRequired( + f"Signer returned HTTP 480 (refresh session required) (url={url}){body_part}", + orchestrator_url=_header_value(headers or {}, _REFRESH_SESSION_ORCHESTRATOR_URL_HEADER), + ) + if status == 482: + raise SkipPaymentCycle( + f"Signer returned HTTP 482 (skip payment cycle) (url={url}){body_part}" + ) + raise LivepeerHTTPError( + status, + url, + body, + f"HTTP {status} from endpoint (url={url}){body_part}", + ) + + +def _ensure_json_object(data: Any, *, url: str) -> dict[str, Any]: + if not isinstance(data, dict): + raise LivepeerGatewayError( + f"HTTP JSON error: expected JSON object, got {type(data).__name__} (url={url})" + ) + return data + + +def request_json_sync( + url: str, + *, + method: Optional[str] = None, + payload: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, +) -> Any: + """ + Make a JSON HTTP request and parse the JSON response. + + If method is None, defaults to POST when payload is provided, otherwise GET. + + Raises LivepeerGatewayError on HTTP/network/JSON parsing errors. + """ + resolved_method, req_headers, body = _json_request_parts( + url, + method=method, + payload=payload, + headers=headers, + ) + req = Request(url, data=body, headers=req_headers, method=resolved_method) + + # Always ignore HTTPS certificate validation (matches our gRPC behavior). + ssl_ctx = ssl._create_unverified_context() + + try: + with urlopen(req, timeout=timeout, context=ssl_ctx) as resp: + raw = resp.read().decode("utf-8") + data: Any = json.loads(raw) + except HTTPError as e: + raw_body = _http_error_body(e) + body_text = _extract_error_message_from_body(raw_body) + body_part = f"; body={body_text!r}" if body_text else "" + if e.code == 480: + raise SignerRefreshRequired( + f"Signer returned HTTP 480 (refresh session required) (url={url}){body_part}", + orchestrator_url=_header_value( + dict(e.headers.items()), + _REFRESH_SESSION_ORCHESTRATOR_URL_HEADER, + ), + ) from e + if e.code == 482: + raise SkipPaymentCycle( + f"Signer returned HTTP 482 (skip payment cycle) (url={url}){body_part}" + ) from e + raise LivepeerHTTPError( + e.code, + url, + raw_body, + f"HTTP {e.code} from endpoint (url={url}){body_part}", + ) from e + except ConnectionRefusedError as e: + raise LivepeerGatewayError( + f"HTTP JSON error: connection refused (is the server running? is the host/port correct?) (url={url})" + ) from e + except URLError as e: + raise LivepeerGatewayError( + f"HTTP JSON error: failed to reach endpoint: {getattr(e, 'reason', e)} (url={url})" + ) from e + except json.JSONDecodeError as e: + raise LivepeerGatewayError(f"HTTP JSON error: endpoint did not return valid JSON: {e} (url={url})") from e + except Exception as e: + raise LivepeerGatewayError( + f"HTTP JSON error: unexpected error: {e.__class__.__name__}: {e} (url={url})" + ) from e + + return data + + +def post_json_sync( + url: str, + payload: dict[str, Any], + *, + headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, +) -> dict[str, Any]: + """ + POST JSON to `url` and parse a JSON object response. + """ + data = request_json_sync( + url, + payload=payload, + headers=headers, + timeout=timeout, + ) + return _ensure_json_object(data, url=url) + + +def get_json_sync( + url: str, + *, + headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, +) -> Any: + """ + GET JSON from `url` and parse the response. + """ + return request_json_sync(url, headers=headers, timeout=timeout) + + +async def request_json( + url: str, + *, + method: Optional[str] = None, + payload: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, +) -> Any: + """ + Make an async JSON HTTP request and parse the JSON response. + + If method is None, defaults to POST when payload is provided, otherwise GET. + + Raises LivepeerGatewayError on HTTP/network/JSON parsing errors. + """ + resolved_method, req_headers, body = _json_request_parts( + url, + method=method, + payload=payload, + headers=headers, + ) + + try: + client_timeout = aiohttp.ClientTimeout(total=timeout) + connector = aiohttp.TCPConnector(ssl=False) + async with aiohttp.ClientSession(timeout=client_timeout, connector=connector) as session: + async with session.request(resolved_method, url, data=body, headers=req_headers) as resp: + raw = await resp.text() + if resp.status >= 400: + _raise_http_json_error(resp.status, url, raw, dict(resp.headers.items())) + data: Any = json.loads(raw) + except (SignerRefreshRequired, SkipPaymentCycle, LivepeerGatewayError): + raise + except json.JSONDecodeError as e: + raise LivepeerGatewayError(f"HTTP JSON error: endpoint did not return valid JSON: {e} (url={url})") from e + except ConnectionRefusedError as e: + raise LivepeerGatewayError( + f"HTTP JSON error: connection refused (is the server running? is the host/port correct?) (url={url})" + ) from e + except getattr(aiohttp, "ClientConnectorError", ()) as e: + os_error = getattr(e, "os_error", None) + if isinstance(os_error, ConnectionRefusedError): + raise LivepeerGatewayError( + f"HTTP JSON error: connection refused (is the server running? is the host/port correct?) (url={url})" + ) from e + raise LivepeerGatewayError( + f"HTTP JSON error: failed to reach endpoint: {getattr(e, 'message', e)} (url={url})" + ) from e + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise LivepeerGatewayError( + f"HTTP JSON error: failed to reach endpoint: {getattr(e, 'message', e)} (url={url})" + ) from e + except Exception as e: + raise LivepeerGatewayError( + f"HTTP JSON error: unexpected error: {e.__class__.__name__}: {e} (url={url})" + ) from e + + return data + + +async def post_json( + url: str, + payload: dict[str, Any], + *, + headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, +) -> dict[str, Any]: + """ + POST JSON to `url` and parse a JSON object response. + """ + data = await request_json( + url, + payload=payload, + headers=headers, + timeout=timeout, + ) + return _ensure_json_object(data, url=url) + + +async def get_json( + url: str, + *, + headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, +) -> Any: + """ + GET JSON from `url` and parse the response. + """ + return await request_json(url, headers=headers, timeout=timeout) + + +def _parse_http_url(url: str, *, context: str = "URL") -> ParseResult: + """ + Normalize a URL for HTTP(S) endpoints. + + Accepts: + - "host:port" (implicitly https://host:port) + - "http://host:port[/...]" + - "https://host:port[/...]" + """ + url = url.strip() + normalized = url if "://" in url else f"https://{url}" + parsed = urlparse(normalized) + if parsed.scheme not in ("http", "https"): + raise ValueError(f"Only http:// or https:// {context}s are supported (got {parsed.scheme!r})") + if not parsed.netloc: + raise ValueError(f"Invalid {context}: {url!r}") + return parsed + + +def _http_origin(url: str) -> str: + """ + Normalize a URL (possibly with a path) into a scheme:// origin (scheme + host:port). + + Accepts: + - "host:port" (implicitly https://host:port) + - "http://host:port[/...]" (path/query/fragment are ignored) + - "https://host:port[/...]" (path/query/fragment are ignored) + """ + parsed = _parse_http_url(url) + return f"{parsed.scheme}://{parsed.netloc}" diff --git a/src/livepeer_gateway/live_runner.py b/src/livepeer_gateway/live_runner.py new file mode 100644 index 0000000..fdd70b8 --- /dev/null +++ b/src/livepeer_gateway/live_runner.py @@ -0,0 +1,1025 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +import logging +import os +import re +import shutil +import subprocess +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Literal, Optional, Protocol, TypedDict, cast +from urllib.parse import quote, urlparse, urlunparse + +import aiohttp + +from .channel_reader import ChannelReader +from .errors import LivepeerGatewayError, LivepeerHTTPError, SignerRefreshRequired +from .http import post_json, request_json +from .remote_signer import ( + GetPaymentResponse, + LivePaymentSession, + _freeze_headers, + get_signer_info, +) + +_LOG = logging.getLogger(__name__) + +_DEFAULT_HEARTBEAT_INTERVAL_S = 5.0 +_LIVE_RUNNER_PAYER_ADDRESS_HEADER = "Livepeer-Payer-Address" +_LIVE_RUNNER_MODES = frozenset({"persistent", "single-shot"}) + +# golang format duration, eg "10s" +_DURATION_RE = re.compile(r"^\s*(?P[0-9]+(?:\.[0-9]+)?)(?Pns|us|\u00b5s|ms|s|m|h)\s*$") + + +class LiveRunnerTrickleChannelRequest(TypedDict): + name: str + mime_type: str + + +class LiveRunnerTrickleChannel(TypedDict): + name: str + channel_name: str + url: str + mime_type: str + + +class LiveRunnerSessionHeaders(Protocol): + def get(self, key: str, default: str = "") -> str: ... + + +class LiveRunnerSessionRequest(Protocol): + headers: LiveRunnerSessionHeaders + + +@dataclass(frozen=True) +class LiveRunnerSessionEvent: + session_id: str + event: Literal["reserved", "released"] + timestamp: Optional[str] + raw: dict[str, Any] + + +LiveRunnerSessionCallback = Callable[[LiveRunnerSessionEvent], None | Awaitable[None]] + + +@dataclass(frozen=True) +class LiveRunnerInstance: + """A normalized live runner discovered from an orchestrator entry.""" + + url: str + app: str + runner_id: str + mode: str + orchestrator_url: str + raw: dict[str, Any] + + +@dataclass(frozen=True) +class LiveRunnerSession: + session_id: str + app_url: str + runner_url: str + runner: Optional[LiveRunnerInstance] = None + + +@dataclass(frozen=True) +class LiveRunnerCallResult: + data: dict[str, Any] + runner_url: str + runner: Optional[LiveRunnerInstance] = None + session_id: str = "" + payment_session: Optional[LivePaymentSession] = field( + default=None, + repr=False, + compare=False, + ) + + +@dataclass(frozen=True) +class LiveRunnerGPU: + id: str = "" + name: str = "" + vram_mb: int = 0 + + def to_json(self) -> dict[str, Any]: + data: dict[str, Any] = {} + if self.id: + data["id"] = self.id + if self.name: + data["name"] = self.name + if self.vram_mb > 0: + data["vram_mb"] = self.vram_mb + return data + + +@dataclass(frozen=True) +class LiveRunnerPriceInfo: + price_per_unit: int + pixels_per_unit: int + unit: str = "USD" + + def to_json(self) -> dict[str, Any]: + return { + "price_per_unit": self.price_per_unit, + "pixels_per_unit": self.pixels_per_unit, + "unit": self.unit, + } + + +class LiveRunnerRegistration: + def __init__( + self, + *, + orchestrator_url: str, + secret: str, + runner_url: str, + app: str, + price_info: LiveRunnerPriceInfo, + runner_id: str = "", + mode: str = "persistent", + label: str = "", + version: str = "", + status: str = "ready", + capacity: int = 1, + gpu: Optional[LiveRunnerGPU] = None, + timeout: float = 5.0, + heartbeat_interval_s: Optional[float] = None, + unregister_on_close: bool = True, + on_session_reserve: Optional[LiveRunnerSessionCallback] = None, + on_session_release: Optional[LiveRunnerSessionCallback] = None, + ) -> None: + self.orchestrator_url = _normalize_http_base(orchestrator_url) + self.runner_id = runner_id + self.heartbeat_interval_s = heartbeat_interval_s or _DEFAULT_HEARTBEAT_INTERVAL_S + self.heartbeat_ttl_s: Optional[float] = None + + self._bootstrap_secret = secret + self._heartbeat_secret: Optional[str] = None + self._runner_url = runner_url + self._app = app + self._mode = _normalize_runner_mode(mode) + self._price_info = price_info + self._label = label + self._version = version + self._status = status + self._capacity = capacity + self._gpu = gpu + self._timeout = timeout + self._heartbeat_interval_override = heartbeat_interval_s + self._unregister_on_close = unregister_on_close + self._on_session_reserve = on_session_reserve + self._on_session_release = on_session_release + self._active_session_ids: list[str] = [] + self.o2r_channel: Optional[LiveRunnerTrickleChannel] = None + self._o2r_reader: Optional[ChannelReader] = None + self._closed = False + self._task: Optional[asyncio.Task[None]] = None + self._o2r_task: Optional[asyncio.Task[None]] = None + + async def start(self) -> "LiveRunnerRegistration": + await self._send_heartbeat() + self._task = asyncio.create_task(self._heartbeat_loop()) + return self + + @property + def active_session_ids(self) -> tuple[str, ...]: + # Return an immutable snapshot; internal storage stays list-backed to preserve reservation order. + return tuple(self._active_session_ids) + + async def close(self) -> None: + self._closed = True + o2r_reader = self._o2r_reader + self._o2r_reader = None + self._o2r_task = None + if o2r_reader is not None: + try: + await o2r_reader.close(wait_callback=False) + except asyncio.CancelledError: + raise + except Exception: + _LOG.exception("Live runner O2R reader failed during shutdown") + + task = self._task + self._task = None + if task is not None and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + _LOG.exception("Live runner heartbeat task failed during shutdown") + + if self._unregister_on_close and self.runner_id: + secret = self._heartbeat_secret + if not secret: + _LOG.warning("Skipping live runner unregister without heartbeat secret") + return + try: + await _post_empty( + _join_endpoint(self.orchestrator_url, f"/runners/{quote(self.runner_id, safe='')}/unregister"), + {"Authorization": secret}, + self._timeout, + ) + except Exception: + _LOG.debug("Live runner unregister failed", exc_info=True) + + async def __aenter__(self) -> "LiveRunnerRegistration": + return self + + async def __aexit__(self, exc_type: object, exc: object, tb: object) -> None: + await self.close() + + async def create_trickle_channels( + self, + session: str | LiveRunnerSessionRequest, + channels: list[LiveRunnerTrickleChannelRequest], + *, + session_token: str = "", + ) -> list[LiveRunnerTrickleChannel]: + """Create channels for a live runner app session. + + This is intended for apps running behind the orchestrator's live-runner + proxy, not end-user clients. Apps should normally pass the incoming + request so the orchestrator-provided session headers are used. + """ + return await create_trickle_channels( + session, + channels, + orchestrator_url=self.orchestrator_url, + runner_id=self.runner_id, + session_token=session_token, + timeout=self._timeout, + ) + + async def remove_trickle_channels( + self, + session: str | LiveRunnerSessionRequest, + channels: list[str], + *, + session_token: str = "", + ) -> list[str]: + """Remove channels for a live runner app session. + + This is intended for apps running behind the orchestrator's live-runner + proxy, not end-user clients. Apps should normally pass the incoming + request so the orchestrator-provided session headers are used. + """ + return await remove_trickle_channels( + session, + channels, + orchestrator_url=self.orchestrator_url, + runner_id=self.runner_id, + session_token=session_token, + timeout=self._timeout, + ) + + def _payload(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "runner_url": self._runner_url, + "app": self._app, + "mode": self._mode, + "capacity": self._capacity, + "price_info": self._price_info.to_json(), + "session_ids": list(self.active_session_ids), + } + if self.runner_id: + payload["runner_id"] = self.runner_id + if self._label: + payload["label"] = self._label + if self._version: + payload["version"] = self._version + if self._status: + payload["status"] = self._status + if self._gpu is not None: + gpu = self._gpu.to_json() + if gpu: + payload["gpu"] = gpu + return payload + + async def _heartbeat_loop(self) -> None: + while not self._closed: + await asyncio.sleep(self.heartbeat_interval_s) + if self._closed: + return + try: + await self._send_heartbeat() + except LivepeerGatewayError as exc: + _LOG.warning("Live runner heartbeat failed; retrying on next interval: %s", exc) + except Exception: + _LOG.warning("Live runner heartbeat failed; retrying on next interval", exc_info=True) + + async def _send_heartbeat(self) -> None: + is_initial_heartbeat = self._heartbeat_secret is None + auth = self._heartbeat_secret or self._bootstrap_secret + try: + data = await self._post_heartbeat(auth) + except LivepeerGatewayError as exc: + if is_initial_heartbeat or not _is_invalid_authorization_error(exc): + raise + _LOG.info("Live runner heartbeat authorization expired; resetting heartbeat auth") + self._heartbeat_secret = None + is_initial_heartbeat = True + data = await self._post_heartbeat(self._bootstrap_secret) + + runner_id = data.get("runner_id") + if not isinstance(runner_id, str) or not runner_id.strip(): + raise LivepeerGatewayError("Live runner heartbeat response missing runner_id") + self.runner_id = runner_id.strip() + + orchestrator = data.get("orchestrator") + if isinstance(orchestrator, str) and orchestrator.strip(): + self.orchestrator_url = _normalize_http_base(orchestrator) + + if self._heartbeat_interval_override is None: + self.heartbeat_interval_s = _parse_go_duration_s( + data.get("heartbeat_interval"), + default=_DEFAULT_HEARTBEAT_INTERVAL_S, + ) + self.heartbeat_ttl_s = _parse_go_duration_s(data.get("heartbeat_ttl"), default=None) + + heartbeat_secret = data.get("heartbeat_secret") + if isinstance(heartbeat_secret, str) and heartbeat_secret.strip(): + self._heartbeat_secret = heartbeat_secret.strip() + elif is_initial_heartbeat: + raise LivepeerGatewayError("Live runner heartbeat response missing heartbeat_secret") + + if is_initial_heartbeat: + self._start_o2r(data.get("o2r")) + + async def _post_heartbeat(self, auth: str) -> dict[str, Any]: + return await post_json( + _join_endpoint(self.orchestrator_url, "/runners/heartbeat"), + self._payload(), + headers={"Authorization": auth}, + timeout=self._timeout, + ) + + def _start_o2r(self, value: object) -> None: + if self._closed or self._o2r_reader is not None: + return + if not _is_trickle_channel_response(value): + if value is not None: + _LOG.warning("Ignoring malformed live runner O2R channel: %r", value) + return + channel = cast(LiveRunnerTrickleChannel, value) + url = channel.get("url", "").strip() + if not url: + return + self.o2r_channel = channel + reader = ChannelReader(url, start_seq=0, on_event=self._handle_o2r_message) + self._o2r_reader = reader + self._o2r_task = reader.callback_task() + + async def _handle_o2r_message(self, message: dict[str, Any]) -> None: + if message.get("keep") == "alive": + return + + event = message.get("event") + session_id = message.get("session") + if event not in ("reserved", "released") or not isinstance(session_id, str) or not session_id.strip(): + _LOG.warning("Ignoring unknown live runner O2R message: %r", message) + return + + session_id = session_id.strip() + typed_event = cast(Literal["reserved", "released"], event) + if typed_event == "reserved": + self._reserve_session_id(session_id) + else: + self._release_session_id(session_id) + + timestamp = message.get("timestamp") + callback = self._on_session_reserve if typed_event == "reserved" else self._on_session_release + if callback is None: + return + + session_event = LiveRunnerSessionEvent( + session_id=session_id, + event=typed_event, + timestamp=timestamp if isinstance(timestamp, str) else None, + raw=message, + ) + try: + result = callback(session_event) + if inspect.isawaitable(result): + await result + except Exception: + _LOG.exception("Live runner %s callback failed for session %s", typed_event, session_id) + + def _reserve_session_id(self, session_id: str) -> None: + if session_id not in self._active_session_ids: + self._active_session_ids.append(session_id) + + def _release_session_id(self, session_id: str) -> None: + try: + self._active_session_ids.remove(session_id) + except ValueError: + pass + + +async def register_runner( + orchestrator_url: str, + *, + secret: str, + runner_url: str, + app: str, + price_per_unit: int = 0, + pixels_per_unit: int = 1, + price_unit: str = "USD", + runner_id: str = "", + mode: str = "persistent", + label: str = "", + version: str = "", + status: str = "ready", + capacity: int = 1, + gpu: Optional[LiveRunnerGPU] = None, + auto_detect_gpu: bool = True, + timeout: float = 5.0, + heartbeat_interval_s: Optional[float] = None, + unregister_on_close: bool = True, + on_session_reserve: Optional[LiveRunnerSessionCallback] = None, + on_session_release: Optional[LiveRunnerSessionCallback] = None, +) -> LiveRunnerRegistration: + if gpu is None and auto_detect_gpu: + gpu = detect_process_gpu() + + registration = LiveRunnerRegistration( + orchestrator_url=orchestrator_url, + secret=secret, + runner_url=runner_url, + app=app, + price_info=LiveRunnerPriceInfo(price_per_unit, pixels_per_unit, price_unit), + runner_id=runner_id, + mode=mode, + label=label, + version=version, + status=status, + capacity=capacity, + gpu=gpu, + timeout=timeout, + heartbeat_interval_s=heartbeat_interval_s, + unregister_on_close=unregister_on_close, + on_session_reserve=on_session_reserve, + on_session_release=on_session_release, + ) + return await registration.start() + + +async def create_trickle_channels( + session: str | LiveRunnerSessionRequest, + channels: list[LiveRunnerTrickleChannelRequest], + *, + orchestrator_url: str = "", + runner_id: str = "", + session_token: str = "", + timeout: float = 5.0, +) -> list[LiveRunnerTrickleChannel]: + """Create trickle channels for a live runner app session.""" + runner, session_id, token, control_url = _resolve_session_credentials( + session, + runner_id=runner_id, + session_token=session_token, + ) + _validate_trickle_channel_requests(channels) + data = await post_json( + _trickle_channels_endpoint(orchestrator_url, runner, session_id, control_url), + {"channels": channels}, + headers={"Livepeer-Session-Token": token}, + timeout=timeout, + ) + response_channels = data.get("channels") + if not isinstance(response_channels, list) or not all( + _is_trickle_channel_response(channel) for channel in response_channels + ): + raise LivepeerGatewayError("Live runner trickle channel create response missing channels") + return cast(list[LiveRunnerTrickleChannel], response_channels) + + +async def remove_trickle_channels( + session: str | LiveRunnerSessionRequest, + channels: list[str], + *, + orchestrator_url: str = "", + runner_id: str = "", + session_token: str = "", + timeout: float = 5.0, +) -> list[str]: + """Remove trickle channels for a live runner app session.""" + runner, session_id, token, control_url = _resolve_session_credentials( + session, + runner_id=runner_id, + session_token=session_token, + ) + data = await request_json( + _trickle_channels_endpoint(orchestrator_url, runner, session_id, control_url), + method="DELETE", + payload={"channels": channels}, + headers={"Livepeer-Session-Token": token}, + timeout=timeout, + ) + if not isinstance(data, dict): + raise LivepeerGatewayError( + f"Live runner trickle channel remove expected JSON object, got {type(data).__name__}" + ) + deleted = data.get("deleted") + if not isinstance(deleted, list) or not all(isinstance(channel, str) for channel in deleted): + raise LivepeerGatewayError("Live runner trickle channel remove response missing deleted") + return deleted + + +async def call_runner( + runner_url: str = "", + *, + runner: Optional[LiveRunnerInstance] = None, + payload: Optional[dict[str, Any]] = None, + method: str = "POST", + signer_url: Optional[str] = None, + signer_headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, + max_payment_challenge_retries: int = 3, +) -> LiveRunnerCallResult: + runner_url = runner_url.strip() or (runner.url.strip() if runner is not None else "") + if not runner_url: + raise LivepeerGatewayError("Live runner call requires runner_url") + request_payload = payload or {} + payer_address = "" + if signer_url: + signer = await get_signer_info(signer_url, _freeze_headers(signer_headers)) + payer_address = cast(str, signer.address) + challenge: Optional[_RunnerPaymentChallenge] = None + attempts = (max(0, int(max_payment_challenge_retries)) + 1) * 2 + for attempt in range(attempts): + payment_session: Optional[LivePaymentSession] = None + session_id = "" + request_headers: dict[str, str] = {} + if signer_url: + request_headers[_LIVE_RUNNER_PAYER_ADDRESS_HEADER] = payer_address + # Pending challenge means payment is needed. + if challenge is not None: + try: + payment_session, payment = await _get_runner_payment( + challenge, + signer_url=signer_url or "", + signer_headers=signer_headers, + ) + except SignerRefreshRequired as e: + if attempt + 1 >= attempts: + raise + # Could happen if embedded payment params expire; just retry in this case. + _LOG.info( + "Live runner reservation payment challenge needs refresh; retrying with a fresh challenge: %s", + e, + ) + challenge = None + continue + request_headers["Livepeer-Payment"] = payment.payment + request_headers["Livepeer-Segment"] = payment.seg_creds or "" + session_id = challenge.manifest_id + + try: + request_kwargs: dict[str, Any] = {"timeout": timeout} + if request_headers: + request_kwargs["headers"] = request_headers + data = await request_json( + runner_url, + method=method, + payload=request_payload, + **request_kwargs, + ) + if not isinstance(data, dict): + raise LivepeerGatewayError( + f"Live runner call expected JSON object, got {type(data).__name__}" + ) + return LiveRunnerCallResult( + data, + runner_url=runner_url, + runner=runner, + session_id=( + session_id + or (data["session_id"].strip() if isinstance(data.get("session_id"), str) else "") + ), + payment_session=payment_session, + ) + except LivepeerHTTPError as e: + if e.status_code != 402: + raise + if not signer_url: + raise LivepeerGatewayError("Live runner paid call requires signer_url") from e + challenge = _parse_runner_payment_challenge(e) + continue + + raise LivepeerGatewayError("Live runner call exhausted payment challenge retries") + + +@dataclass(frozen=True) +class _RunnerPaymentChallenge: + payment_params: str + orchestrator_url: str + manifest_id: str + + +def _parse_runner_payment_challenge(error: LivepeerHTTPError) -> _RunnerPaymentChallenge: + try: + data = json.loads(error.body) + except json.JSONDecodeError as e: + raise LivepeerGatewayError("Live runner payment challenge response was not valid JSON") from e + if not isinstance(data, dict): + raise LivepeerGatewayError("Live runner payment challenge response must be a JSON object") + + payment_params = data.get("payment_params") + orchestrator_url = data.get("orchestrator") + manifest_id = data.get("manifest_id") + if not isinstance(payment_params, str) or not payment_params: + raise LivepeerGatewayError("Live runner payment challenge missing payment_params") + if not isinstance(orchestrator_url, str) or not orchestrator_url: + raise LivepeerGatewayError("Live runner payment challenge missing orchestrator") + if not isinstance(manifest_id, str) or not manifest_id: + raise LivepeerGatewayError("Live runner payment challenge missing manifest_id") + + return _RunnerPaymentChallenge( + payment_params=payment_params, + orchestrator_url=orchestrator_url, + manifest_id=manifest_id, + ) + + +async def _get_runner_payment( + challenge: _RunnerPaymentChallenge, + *, + signer_url: str, + signer_headers: Optional[dict[str, str]], +) -> tuple[LivePaymentSession, GetPaymentResponse]: + session = LivePaymentSession( + signer_url=signer_url, + signer_headers=signer_headers, + type="lv2v", + payment_params=challenge.payment_params, + manifest_id=challenge.manifest_id, + orchestrator_url=challenge.orchestrator_url, + ) + payment = await session.get_payment() + if not payment.payment: + raise LivepeerGatewayError("Live runner payment response missing payment") + if not payment.seg_creds: + raise LivepeerGatewayError("Live runner payment response missing segCreds") + return session, payment + + +def _live_runner_session_from_json( + data: dict[str, Any], + *, + runner_url: str, + runner: Optional[LiveRunnerInstance], +) -> LiveRunnerSession: + session_id = data.get("session_id") + app_url = data.get("app_url") + if not isinstance(session_id, str) or not session_id.strip(): + raise LivepeerGatewayError("Live runner session reserve response missing session_id") + if not isinstance(app_url, str) or not app_url.strip(): + raise LivepeerGatewayError("Live runner session reserve response missing app_url") + return LiveRunnerSession( + session_id=session_id.strip(), + app_url=app_url.strip(), + runner_url=runner_url, + runner=runner, + ) + +async def stop_runner_session( + session: LiveRunnerSession | LiveRunnerSessionRequest, + *, + timeout: float = 5.0, +) -> None: + if isinstance(session, LiveRunnerSession): + runner_url = session.runner_url.strip() + session_id = session.session_id.strip() + if not runner_url: + raise LivepeerGatewayError("Live runner session stop requires runner_url") + if not session_id: + raise LivepeerGatewayError("Live runner session stop requires session_id") + url = _join_endpoint(runner_url, f"/{quote(session_id, safe='')}/stop") + else: + headers = getattr(session, "headers", None) + get = getattr(headers, "get", None) + control_url = get("Livepeer-Session-Control", "") if callable(get) else "" + if not isinstance(control_url, str) or not control_url.strip(): + raise LivepeerGatewayError("Live runner session stop requires session_control") + url = _join_endpoint(control_url, "stop") + await _post_empty( + url, + {}, + timeout, + ) + + +def detect_process_gpu() -> Optional[LiveRunnerGPU]: + for detector in (_detect_gpu_pynvml, _detect_gpu_torch, _detect_gpu_nvidia_smi): + try: + gpu = detector() + except Exception: + _LOG.debug("GPU auto-discovery detector failed: %s", detector.__name__, exc_info=True) + continue + if gpu is not None: + return gpu + return None + + +def _normalize_http_base(url: str) -> str: + url = url.strip() + normalized = url if "://" in url else f"https://{url}" + parsed = urlparse(normalized) + if parsed.scheme not in ("http", "https") or not parsed.netloc: + raise LivepeerGatewayError(f"Invalid orchestrator URL: {url!r}") + path = parsed.path.rstrip("/") + return urlunparse((parsed.scheme, parsed.netloc, path, "", parsed.query, "")) + + +def _join_endpoint(base_url: str, suffix: str) -> str: + parsed = urlparse(_normalize_http_base(base_url)) + suffix_path = suffix if suffix.startswith("/") else f"/{suffix}" + path = f"{parsed.path.rstrip('/')}{suffix_path}" + return urlunparse((parsed.scheme, parsed.netloc, path, "", parsed.query, "")) + + +def _trickle_channels_endpoint( + orchestrator_url: str, + runner_id: str, + session_id: str, + control_url: str = "", +) -> str: + if control_url: + return _join_endpoint(control_url, "channels") + if not orchestrator_url: + raise LivepeerGatewayError("Live runner trickle channel request requires session_control") + if not runner_id: + raise LivepeerGatewayError("Live runner trickle channel request requires runner_id") + return _join_endpoint( + orchestrator_url, + ( + f"/runner/{quote(runner_id, safe='')}" + f"/session/{quote(session_id, safe='')}" + "/channels" + ), + ) + + +def _parse_go_duration_s(value: object, *, default: Optional[float]) -> Optional[float]: + if not isinstance(value, str) or not value.strip(): + return default + match = _DURATION_RE.match(value) + if not match: + return default + number = float(match.group("value")) + unit = match.group("unit") + scale = { + "ns": 1e-9, + "us": 1e-6, + "\u00b5s": 1e-6, + "ms": 1e-3, + "s": 1.0, + "m": 60.0, + "h": 3600.0, + }[unit] + return number * scale + + +def _normalize_runner_mode(mode: str) -> str: + normalized = mode.strip() + if normalized not in _LIVE_RUNNER_MODES: + raise ValueError(f"live runner mode must be one of {sorted(_LIVE_RUNNER_MODES)}") + return normalized + + +def _is_invalid_authorization_error(exc: LivepeerGatewayError) -> bool: + return isinstance(exc, LivepeerHTTPError) and exc.status_code == 401 + + +def _resolve_session_credentials( + session: str | LiveRunnerSessionRequest, + *, + runner_id: str = "", + session_token: str = "", +) -> tuple[str, str, str, str]: + runner = runner_id.strip() + session_id = "" + token = session_token.strip() + control_url = "" + + if isinstance(session, str): + session_id = session.strip() + else: + headers = getattr(session, "headers", None) + if headers is not None: + get = getattr(headers, "get", None) + if callable(get): + runner_value = get("Livepeer-Runner-Route", "") + session_id_value = get("Livepeer-Session-Id", "") + token_value = get("Livepeer-Session-Token", "") + control_value = get("Livepeer-Session-Control", "") + if not runner and isinstance(runner_value, str): + runner = runner_value.strip() + if isinstance(session_id_value, str): + session_id = session_id_value.strip() + if not token and isinstance(token_value, str): + token = token_value.strip() + if isinstance(control_value, str): + control_url = control_value.strip() + + if not session_id: + raise LivepeerGatewayError("Live runner trickle channel request requires session_id") + if not token: + raise LivepeerGatewayError("Live runner trickle channel request requires session_token") + return runner, session_id, token, control_url + + +def _validate_trickle_channel_requests(channels: list[LiveRunnerTrickleChannelRequest]) -> None: + for channel in channels: + if not isinstance(channel, dict): + raise TypeError(f"trickle channel must be dict, got {type(channel).__name__}") + if not isinstance(channel.get("name"), str): + raise TypeError("trickle channel name must be str") + if not isinstance(channel.get("mime_type"), str): + raise TypeError("trickle channel mime_type must be str") + + +def _is_trickle_channel_response(value: object) -> bool: + if not isinstance(value, dict): + return False + return all( + isinstance(value.get(key), str) + for key in ("name", "channel_name", "url", "mime_type") + ) + + +async def _post_empty(url: str, headers: dict[str, str], timeout: float) -> None: + try: + client_timeout = aiohttp.ClientTimeout(total=timeout) + connector = aiohttp.TCPConnector(ssl=False) + async with aiohttp.ClientSession(timeout=client_timeout, connector=connector) as session: + async with session.post(url, data=b"", headers=headers) as resp: + body = await resp.text() + if resp.status >= 400: + raise LivepeerGatewayError( + f"HTTP empty POST error: HTTP {resp.status}; body={body!r}" + ) + except LivepeerGatewayError: + raise + except getattr(aiohttp, "ClientConnectorError", ()) as e: + raise LivepeerGatewayError(f"HTTP empty POST error: {getattr(e, 'message', e)}") from e + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise LivepeerGatewayError(f"HTTP empty POST error: {getattr(e, 'message', e)}") from e + + +def _detect_gpu_pynvml() -> Optional[LiveRunnerGPU]: + try: + import pynvml # type: ignore[import-not-found] + except Exception: + return None + + pynvml.nvmlInit() + try: + index = _pynvml_process_device_index(pynvml) + if index is None: + index = _first_visible_cuda_index() + if index is None: + return None + handle = pynvml.nvmlDeviceGetHandleByIndex(index) + uuid = _decode_maybe_bytes(pynvml.nvmlDeviceGetUUID(handle)) + name = _decode_maybe_bytes(pynvml.nvmlDeviceGetName(handle)) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + return LiveRunnerGPU(id=uuid, name=name, vram_mb=int(getattr(mem, "total", 0)) // (1024 * 1024)) + finally: + try: + pynvml.nvmlShutdown() + except Exception: + pass + + +def _pynvml_process_device_index(pynvml: Any) -> Optional[int]: + pid = os.getpid() + count = int(pynvml.nvmlDeviceGetCount()) + for index in range(count): + handle = pynvml.nvmlDeviceGetHandleByIndex(index) + processes: list[Any] = [] + for name in ("nvmlDeviceGetComputeRunningProcesses_v2", "nvmlDeviceGetComputeRunningProcesses"): + fn = getattr(pynvml, name, None) + if fn is None: + continue + try: + processes = list(fn(handle)) + break + except Exception: + continue + if any(int(getattr(proc, "pid", -1)) == pid for proc in processes): + return index + return None + + +def _detect_gpu_torch() -> Optional[LiveRunnerGPU]: + try: + import torch # type: ignore[import-not-found] + except Exception: + return None + try: + if not torch.cuda.is_available(): + return None + index = int(torch.cuda.current_device()) + props = torch.cuda.get_device_properties(index) + name = str(getattr(props, "name", "") or torch.cuda.get_device_name(index)) + total = int(getattr(props, "total_memory", 0) or 0) + return LiveRunnerGPU(id=str(index), name=name, vram_mb=total // (1024 * 1024)) + except Exception: + _LOG.debug("torch.cuda GPU discovery failed", exc_info=True) + return None + + +def _detect_gpu_nvidia_smi() -> Optional[LiveRunnerGPU]: + if shutil.which("nvidia-smi") is None: + return None + uuid = _nvidia_smi_process_gpu_uuid() + rows = _nvidia_smi_gpu_rows() + if not rows: + return None + if uuid: + for row in rows: + if row.get("uuid") == uuid: + return _gpu_from_nvidia_smi_row(row) + index = _first_visible_cuda_index() + if index is not None: + for row in rows: + if row.get("index") == str(index): + return _gpu_from_nvidia_smi_row(row) + return _gpu_from_nvidia_smi_row(rows[0]) + + +def _nvidia_smi_process_gpu_uuid() -> str: + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-compute-apps=pid,gpu_uuid", + "--format=csv,noheader,nounits", + ], + text=True, + stderr=subprocess.DEVNULL, + timeout=2.0, + ) + except Exception: + return "" + pid = str(os.getpid()) + for line in output.splitlines(): + parts = [part.strip() for part in line.split(",")] + if len(parts) >= 2 and parts[0] == pid: + return parts[1] + return "" + + +def _nvidia_smi_gpu_rows() -> list[dict[str, str]]: + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=index,uuid,name,memory.total", + "--format=csv,noheader,nounits", + ], + text=True, + stderr=subprocess.DEVNULL, + timeout=2.0, + ) + except Exception: + return [] + rows = [] + for line in output.splitlines(): + parts = [part.strip() for part in line.split(",", maxsplit=3)] + if len(parts) != 4: + continue + rows.append({"index": parts[0], "uuid": parts[1], "name": parts[2], "vram_mb": parts[3]}) + return rows + + +def _gpu_from_nvidia_smi_row(row: dict[str, str]) -> LiveRunnerGPU: + try: + vram_mb = int(float(row.get("vram_mb", "0"))) + except ValueError: + vram_mb = 0 + return LiveRunnerGPU(id=row.get("uuid", ""), name=row.get("name", ""), vram_mb=vram_mb) + + +def _first_visible_cuda_index() -> Optional[int]: + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() + if not visible: + return 0 + first = visible.split(",")[0].strip() + if not first or first == "-1": + return None + if first.isdigit(): + return int(first) + return 0 + + +def _decode_maybe_bytes(value: object) -> str: + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace") + return str(value or "") diff --git a/src/livepeer_gateway/lv2v.py b/src/livepeer_gateway/lv2v.py index 5948b03..4e5388d 100644 --- a/src/livepeer_gateway/lv2v.py +++ b/src/livepeer_gateway/lv2v.py @@ -17,9 +17,14 @@ SkipPaymentCycle, ) from .events import Events -from .media_output import LagPolicy, MediaOutput +from .media_output import ( + LagPolicy, + MediaFrameCallback, + MediaOutput, + MediaPacketCallback, +) from .media_publish import MediaPublish, MediaPublishConfig -from .orchestrator import _http_origin, post_json +from .http import _http_origin, post_json_sync from .selection import orchestrator_selector from .remote_signer import PaymentSession from .token import parse_token @@ -116,6 +121,8 @@ def media_output( chunk_size: int = 64 * 1024, max_segments: int = 5, on_lag: LagPolicy = LagPolicy.LATEST, + on_frame: Optional[MediaFrameCallback] = None, + on_packet: Optional[MediaPacketCallback] = None, ) -> MediaOutput: """ Convenience helper to create a `MediaOutput` for this job. @@ -134,6 +141,8 @@ def media_output( chunk_size=chunk_size, max_segments=max_segments, on_lag=on_lag, + on_frame=on_frame, + on_packet=on_packet, ) @property @@ -364,7 +373,7 @@ def start_lv2v( base = _http_origin(info.transcoder) url = f"{base}/live-video-to-video" - data = post_json(url, req.to_json(), headers=headers, timeout=timeout) + data = post_json_sync(url, req.to_json(), headers=headers, timeout=timeout) job = LiveVideoToVideo.from_json( data, signer_url=resolved_signer_url, diff --git a/src/livepeer_gateway/media_output.py b/src/livepeer_gateway/media_output.py index a9107b8..1f820a8 100644 --- a/src/livepeer_gateway/media_output.py +++ b/src/livepeer_gateway/media_output.py @@ -1,16 +1,17 @@ -from __future__ import annotations - """ Helpers for consuming trickle media outputs as segments, bytes, or frames. """ +from __future__ import annotations + import asyncio from dataclasses import dataclass +import inspect import logging import time from enum import Enum from contextlib import suppress -from typing import AsyncIterator, Collection, Optional +from typing import AsyncIterator, Awaitable, Callable, Collection, Optional from .errors import LivepeerGatewayError from .media_decode import ( @@ -31,6 +32,14 @@ _LOG = logging.getLogger(__name__) _DEFAULT_ACCEPTED_CONTENT_TYPES = frozenset({"video/mp2t", "audio/mp2t"}) +MediaFrameCallback = Callable[ + [AudioDecodedMediaFrame | VideoDecodedMediaFrame], + None | Awaitable[None], +] +MediaPacketCallback = Callable[[DemuxedMediaPacket], None | Awaitable[None]] +MediaBytesCallback = Callable[[bytes], None | Awaitable[None]] + + class LagPolicy(Enum): """ Policy for handling consumers that fall behind the segment window. @@ -136,6 +145,9 @@ def __init__( max_segments: int = 5, on_lag: LagPolicy = LagPolicy.LATEST, accepted_content_types: Collection[str] = _DEFAULT_ACCEPTED_CONTENT_TYPES, + on_bytes: Optional[MediaBytesCallback] = None, + on_frame: Optional[MediaFrameCallback] = None, + on_packet: Optional[MediaPacketCallback] = None, ) -> None: if max_segments < 1: raise ValueError("max_segments must be >= 1") @@ -148,6 +160,9 @@ def __init__( self.max_segments = max_segments self.on_lag = on_lag self.accepted_content_types = _normalize_accepted_content_types(accepted_content_types) + self.on_bytes = on_bytes + self.on_frame = on_frame + self.on_packet = on_packet self._sub: Optional[TrickleSubscriber] = None self._segments: list[SegmentReader] = [] @@ -158,6 +173,10 @@ def __init__( self._started_at = time.time() self._processor: Optional[MpegTsDecoder | MpegTsPacketDemuxer] = None self._last_decoder_stats: Optional[DecoderQueueStats] = None + self._bytes_callback_task: Optional[asyncio.Task[None]] = None + self._frame_callback_task: Optional[asyncio.Task[None]] = None + self._packet_callback_task: Optional[asyncio.Task[None]] = None + self._callback_errors: list[BaseException] = [] self._stats: dict[str, int] = { "segments_consumed": 0, "bytes_read": 0, @@ -176,6 +195,114 @@ def __init__( "packet_errors": 0, "decode_errors": 0, } + if self.on_bytes is not None or self.on_frame is not None or self.on_packet is not None: + self.start_callbacks() + + def start_callbacks(self) -> list[asyncio.Task[None]]: + """ + Start configured frame/packet callback consumers. + + This is idempotent. If called without a running event loop, no tasks are + started and callers may retry later from async code. + """ + if self.on_bytes is None and self.on_frame is None and self.on_packet is None: + return [] + try: + loop = asyncio.get_running_loop() + except RuntimeError: + _LOG.warning( + "No running event loop; MediaOutput callbacks not started. " + "Call start_callbacks() from async code or use async with MediaOutput(...)." + ) + return [] + + started: list[asyncio.Task[None]] = [] + if self.on_bytes is not None and self._bytes_callback_task is None: + task = loop.create_task( + self._run_bytes_callback_loop(self.on_bytes), + name="MediaOutput.on_bytes", + ) + task.add_done_callback(self._record_callback_task_result) + self._bytes_callback_task = task + started.append(task) + if self.on_frame is not None and self._frame_callback_task is None: + task = loop.create_task( + self._run_frame_callback_loop(self.on_frame), + name="MediaOutput.on_frame", + ) + task.add_done_callback(self._record_callback_task_result) + self._frame_callback_task = task + started.append(task) + if self.on_packet is not None and self._packet_callback_task is None: + task = loop.create_task( + self._run_packet_callback_loop(self.on_packet), + name="MediaOutput.on_packet", + ) + task.add_done_callback(self._record_callback_task_result) + self._packet_callback_task = task + started.append(task) + return started + + def callback_tasks(self) -> tuple[asyncio.Task[None], ...]: + tasks = [] + if self._bytes_callback_task is not None: + tasks.append(self._bytes_callback_task) + if self._frame_callback_task is not None: + tasks.append(self._frame_callback_task) + if self._packet_callback_task is not None: + tasks.append(self._packet_callback_task) + return tuple(tasks) + + async def wait_callbacks(self, timeout: Optional[float] = None) -> tuple[object, ...]: + """ + Wait for configured callback consumers to finish. + + Raises the first callback error, matching close(). + """ + callback_tasks = self.callback_tasks() + if not callback_tasks: + return () + results = await asyncio.wait_for( + asyncio.gather(*callback_tasks, return_exceptions=True), + timeout=timeout, + ) + self._collect_callback_errors(results) + if self._callback_errors: + raise self._callback_errors[0] + return tuple(results) + + def _collect_callback_errors(self, results: Collection[object]) -> None: + for result in results: + if isinstance(result, BaseException) and not isinstance( + result, asyncio.CancelledError + ): + if not any(error is result for error in self._callback_errors): + self._callback_errors.append(result) + + async def _run_bytes_callback_loop(self, callback: MediaBytesCallback) -> None: + async for chunk in self.bytes(): + await _maybe_await(callback(chunk)) + + async def _run_frame_callback_loop(self, callback: MediaFrameCallback) -> None: + async for frame in self.frames(): + await _maybe_await(callback(frame)) + + async def _run_packet_callback_loop(self, callback: MediaPacketCallback) -> None: + async for packet in self.packets(): + await _maybe_await(callback(packet)) + + def _record_callback_task_result(self, task: asyncio.Task[None]) -> None: + try: + exc = task.exception() + except asyncio.CancelledError: + return + if exc is None: + return + self._callback_errors.append(exc) + _LOG.error( + "MediaOutput callback task failed", + exc_info=(type(exc), exc, exc.__traceback__), + ) def segments( self, @@ -419,13 +546,28 @@ async def _next_segment( return self._segments[relative] return None - async def close(self) -> None: + async def close(self, *, wait_callbacks: bool = True, timeout: Optional[float] = 10.0) -> None: + callback_tasks = self.callback_tasks() + if callback_tasks: + if wait_callbacks and (timeout is None or timeout > 0): + try: + await self.wait_callbacks(timeout=timeout) + except asyncio.TimeoutError: + pass + for task in callback_tasks: + if not task.done(): + task.cancel() + results = await asyncio.gather(*callback_tasks, return_exceptions=True) + self._collect_callback_errors(results) for segment in self._segments: await segment.close() if self._sub is not None: await self._sub.close() + if self._callback_errors: + raise self._callback_errors[0] async def __aenter__(self) -> "MediaOutput": + self.start_callbacks() return self async def __aexit__(self, exc_type, exc_value, traceback) -> None: @@ -483,3 +625,8 @@ def _require_content_type(value: Optional[str], accepted: frozenset[str]) -> Non raise LivepeerGatewayError( f"Expected Content-Type in {sorted(accepted)!r}, got {value!r}" ) + + +async def _maybe_await(value: None | Awaitable[None]) -> None: + if inspect.isawaitable(value): + await value diff --git a/src/livepeer_gateway/orch_info.py b/src/livepeer_gateway/orch_info.py index fe621e8..bca7721 100644 --- a/src/livepeer_gateway/orch_info.py +++ b/src/livepeer_gateway/orch_info.py @@ -15,7 +15,7 @@ from . import lp_rpc_pb2 from . import lp_rpc_pb2_grpc from .errors import LivepeerGatewayError -from .remote_signer import _freeze_headers, get_orch_info_sig +from .remote_signer import _freeze_headers, _hex_to_bytes, get_orch_info_sig _LOG = logging.getLogger(__name__) @@ -125,9 +125,19 @@ def get_orch_info( cause=e, ) from None + try: + address = _hex_to_bytes(signer.address, expected_len=20) if signer.address else b"" + sig = _hex_to_bytes(signer.sig) if signer.sig else b"" + except ValueError as e: + raise OrchestratorRpcError( + orch_url, + f"invalid signer material: {e}", + cause=e, + ) from None + request = lp_rpc_pb2.OrchestratorRequest( - address=signer.address, - sig=signer.sig, + address=address, + sig=sig, ignoreCapacityCheck=True, ) if capabilities is not None: diff --git a/src/livepeer_gateway/orchestrator.py b/src/livepeer_gateway/orchestrator.py index 844cd7b..0c5ae69 100644 --- a/src/livepeer_gateway/orchestrator.py +++ b/src/livepeer_gateway/orchestrator.py @@ -1,310 +1,48 @@ from __future__ import annotations -import json -import logging -import ssl -from functools import lru_cache -from typing import Any, Optional, Sequence -from urllib.parse import ParseResult, parse_qsl, quote, urlencode, urlparse, urlunparse -from urllib.error import URLError, HTTPError -from urllib.request import Request, urlopen - -from . import lp_rpc_pb2 -from .capabilities import capabilities_to_query - +from .discovery import _append_caps, discover_orchestrators from .errors import ( LivepeerGatewayError, SignerRefreshRequired, SkipPaymentCycle, ) -from .remote_signer import RemoteSignerError - -_LOG = logging.getLogger(__name__) - -def _truncate(s: str, max_len: int = 2000) -> str: - if len(s) <= max_len: - return s - return s[:max_len] + f"...(+{len(s) - max_len} chars)" - -def _http_error_body(e: HTTPError) -> str: - """ - Best-effort read of an HTTPError response body for debugging. - """ - try: - b = e.read() - if not b: - return "" - if isinstance(b, bytes): - return b.decode("utf-8", errors="replace") - return str(b) - except Exception: - return "" - -def _extract_error_message(e: HTTPError) -> str: - """ - Best-effort extraction of a useful error message from an HTTPError body. - - If the body is JSON and matches {"error": {"message": "..."}}, return that message. - Otherwise return the full body. - - Always truncates the returned value for readability. - """ - body = _http_error_body(e) - s = body.strip() - if not s: - return "" - - try: - data = json.loads(s) - except Exception: - return _truncate(body) - - if isinstance(data, dict): - err = data.get("error") - if isinstance(err, dict): - msg = err.get("message") - if isinstance(msg, str) and msg: - return _truncate(msg) - - return _truncate(body) - - -def request_json( - url: str, - *, - method: Optional[str] = None, - payload: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - timeout: float = 5.0, -) -> Any: - """ - Make a JSON HTTP request and parse the JSON response. - - If method is None, defaults to POST when payload is provided, otherwise GET. - - Raises LivepeerGatewayError on HTTP/network/JSON parsing errors. - """ - req_headers: dict[str, str] = { - "Accept": "application/json", - "User-Agent": "livepeer-python-gateway/0.1", - } - body: Optional[bytes] = None - if payload is not None: - req_headers["Content-Type"] = "application/json" - body = json.dumps(payload).encode("utf-8") - if headers: - req_headers.update(headers) - - resolved_method = method.upper() if method else ("POST" if payload is not None else "GET") - req = Request(url, data=body, headers=req_headers, method=resolved_method) - - # Always ignore HTTPS certificate validation (matches our gRPC behavior). - ssl_ctx = ssl._create_unverified_context() - - try: - with urlopen(req, timeout=timeout, context=ssl_ctx) as resp: - raw = resp.read().decode("utf-8") - data: Any = json.loads(raw) - except HTTPError as e: - body = _extract_error_message(e) - body_part = f"; body={body!r}" if body else "" - if e.code == 480: - raise SignerRefreshRequired( - f"Signer returned HTTP 480 (refresh session required) (url={url}){body_part}" - ) from e - if e.code == 482: - raise SkipPaymentCycle( - f"Signer returned HTTP 482 (skip payment cycle) (url={url}){body_part}" - ) from e - raise LivepeerGatewayError( - f"HTTP JSON error: HTTP {e.code} from endpoint (url={url}){body_part}" - ) from e - except ConnectionRefusedError as e: - raise LivepeerGatewayError( - f"HTTP JSON error: connection refused (is the server running? is the host/port correct?) (url={url})" - ) from e - except URLError as e: - raise LivepeerGatewayError( - f"HTTP JSON error: failed to reach endpoint: {getattr(e, 'reason', e)} (url={url})" - ) from e - except json.JSONDecodeError as e: - raise LivepeerGatewayError(f"HTTP JSON error: endpoint did not return valid JSON: {e} (url={url})") from e - except Exception as e: - raise LivepeerGatewayError( - f"HTTP JSON error: unexpected error: {e.__class__.__name__}: {e} (url={url})" - ) from e - - return data - - -def post_json( - url: str, - payload: dict[str, Any], - *, - headers: Optional[dict[str, str]] = None, - timeout: float = 5.0, -) -> dict[str, Any]: - """ - POST JSON to `url` and parse a JSON object response. - """ - data = request_json( - url, - payload=payload, - headers=headers, - timeout=timeout, - ) - if not isinstance(data, dict): - raise LivepeerGatewayError( - f"HTTP JSON error: expected JSON object, got {type(data).__name__} (url={url})" - ) - return data - - -def get_json( - url: str, - *, - headers: Optional[dict[str, str]] = None, - timeout: float = 5.0, -) -> Any: - """ - GET JSON from `url` and parse the response. - """ - return request_json(url, headers=headers, timeout=timeout) - -def _parse_http_url(url: str, *, context: str = "URL") -> ParseResult: - """ - Normalize a URL for HTTP(S) endpoints. - - Accepts: - - "host:port" (implicitly https://host:port) - - "http://host:port[/...]" - - "https://host:port[/...]" - """ - url = url.strip() - normalized = url if "://" in url else f"https://{url}" - parsed = urlparse(normalized) - if parsed.scheme not in ("http", "https"): - raise ValueError(f"Only http:// or https:// {context}s are supported (got {parsed.scheme!r})") - if not parsed.netloc: - raise ValueError(f"Invalid {context}: {url!r}") - return parsed - - -def _http_origin(url: str) -> str: - """ - Normalize a URL (possibly with a path) into a scheme:// origin (scheme + host:port). - - Accepts: - - "host:port" (implicitly https://host:port) - - "http://host:port[/...]" (path/query/fragment are ignored) - - "https://host:port[/...]" (path/query/fragment are ignored) - """ - parsed = _parse_http_url(url) - return f"{parsed.scheme}://{parsed.netloc}" - - -def _append_caps(url: str, capabilities: Optional[lp_rpc_pb2.Capabilities]) -> str: - """ - Append repeated `caps` query parameters to a URL. - - Existing query params are preserved. Capability values keep `/` unescaped. - - Example output: - https://example.com/discover-orchestrators?x=1&caps=live-video-to-video/streamdiffusion-sdxl-v2v&caps=text-to-image/sdxl - """ - if capabilities is None: - return url - - caps = capabilities_to_query(capabilities) - if not caps: - return url - - parsed = urlparse(url) - query_pairs = parse_qsl(parsed.query, keep_blank_values=True) - query_pairs.extend(("caps", cap) for cap in caps) - query = urlencode(query_pairs, doseq=True, quote_via=quote, safe="/") - return urlunparse(parsed._replace(query=query)) - - -def discover_orchestrators( - orchestrators: Optional[Sequence[str] | str] = None, - *, - signer_url: Optional[str] = None, - signer_headers: Optional[dict[str, str]] = None, - discovery_url: Optional[str] = None, - discovery_headers: Optional[dict[str, str]] = None, - capabilities: Optional[lp_rpc_pb2.Capabilities] = None, -) -> list[str]: - """ - Discover orchestrators and return a list of addresses. - - This discovery can happen via the following parameters in priority order (highest first): - - orchestrators: list or comma-delimited string - (empty/whitespace-only input falls through) - - discovery_url: use this discovery endpoint - - signer_url: use signer-provided discovery service - """ - if orchestrators is not None: - if isinstance(orchestrators, str): - orch_list = [orch.strip() for orch in orchestrators.split(",")] - else: - try: - orch_list = list(orchestrators) - except TypeError as e: - raise LivepeerGatewayError( - "discover_orchestrators requires a list of orchestrator URLs or a comma-delimited string" - ) from e - orch_list = [orch.strip() for orch in orch_list if isinstance(orch, str) and orch.strip()] - if orch_list: - return orch_list - - if discovery_url: - discovery_endpoint = _parse_http_url(discovery_url).geturl() - request_headers = discovery_headers - elif signer_url: - discovery_endpoint = f"{_http_origin(signer_url)}/discover-orchestrators" - request_headers = signer_headers - else: - _LOG.debug("discover_orchestrators failed: no discovery inputs") - raise LivepeerGatewayError("discover_orchestrators requires discovery_url or signer_url") - - if capabilities is not None: - discovery_endpoint = _append_caps(discovery_endpoint, capabilities) - - try: - _LOG.debug("discover_orchestrators running discovery: %s", discovery_endpoint) - data = get_json(discovery_endpoint, headers=request_headers) - except LivepeerGatewayError as e: - _LOG.debug("discover_orchestrators discovery failed: %s", e) - raise RemoteSignerError( - discovery_endpoint, - str(e), - cause=e.__cause__ or e, - ) from None - - if not isinstance(data, list): - _LOG.debug( - "discover_orchestrators discovery response not list: type=%s", - type(data).__name__, - ) - raise RemoteSignerError( - discovery_endpoint, - f"Discovery response must be a JSON list, got {type(data).__name__}", - cause=None, - ) from None - - _LOG.debug("discover_orchestrators discovery response: %s", data) - - orch_list = [] - for item in data: - if not isinstance(item, dict): - continue - address = item.get("address") - if isinstance(address, str) and address.strip(): - orch_list.append(address.strip()) - _LOG.debug("discover_orchestrators discovered %d orchestrators", len(orch_list)) - - return orch_list - - +from .http import ( + _extract_error_message, + _extract_error_message_from_body, + _http_error_body, + _http_origin, + _json_request_parts, + _parse_http_url, + _raise_http_json_error, + _truncate, + get_json_sync, + post_json_sync, + request_json_sync, +) +# Compatibility aliases for the original synchronous helpers. +request_json = request_json_sync +post_json = post_json_sync +get_json = get_json_sync + +__all__ = [ + "LivepeerGatewayError", + "SignerRefreshRequired", + "SkipPaymentCycle", + "_append_caps", + "_extract_error_message", + "_extract_error_message_from_body", + "_http_error_body", + "_http_origin", + "_json_request_parts", + "_parse_http_url", + "_raise_http_json_error", + "_truncate", + "discover_orchestrators", + "get_json", + "get_json_sync", + "post_json", + "post_json_sync", + "request_json", + "request_json_sync", +] diff --git a/src/livepeer_gateway/remote_signer.py b/src/livepeer_gateway/remote_signer.py index d8287c7..372338e 100644 --- a/src/livepeer_gateway/remote_signer.py +++ b/src/livepeer_gateway/remote_signer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import base64 import json import logging @@ -11,8 +12,11 @@ from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen +import aiohttp + from . import lp_rpc_pb2 -from .errors import LivepeerGatewayError, PaymentError, SignerRefreshRequired, SkipPaymentCycle +from .async_cache import async_lru_cache +from .errors import LivepeerGatewayError, PaymentError, SignerRefreshRequired _LOG = logging.getLogger(__name__) @dataclass(frozen=True) @@ -25,13 +29,11 @@ class GetPaymentResponse: class SignerMaterial: """ Material returned by the remote signer. - address: 20-byte broadcaster ETH address - sig: signature bytes (length depends on scheme; commonly 65 bytes for ECDSA) - address_hex: original hex string from signer (preserves EIP-55 checksum casing) + address: opaque broadcaster address string. + sig: opaque signature string. """ - address: bytes - sig: bytes - address_hex: str = "" + address: Optional[str] + sig: Optional[str] @dataclass @@ -70,6 +72,35 @@ def _hex_to_bytes(s: str, *, expected_len: Optional[int] = None) -> bytes: return b +def _signer_material_from_json( + data: dict[str, Any], + signer_url: str, +) -> SignerMaterial: + if "address" not in data or "signature" not in data: + raise RemoteSignerError( + signer_url, + f"Remote signer JSON must contain 'address' and 'signature': {data!r}", + cause=None, + ) from None + + address = data["address"] + sig = data["signature"] + if not isinstance(address, str) or not address: + raise RemoteSignerError( + signer_url, + f"Remote signer 'address' must be a non-empty string: {address!r}", + cause=None, + ) from None + if not isinstance(sig, str) or not sig: + raise RemoteSignerError( + signer_url, + f"Remote signer 'signature' must be a non-empty string: {sig!r}", + cause=None, + ) from None + + return SignerMaterial(address=address, sig=sig) + + @lru_cache(maxsize=None) def get_orch_info_sig( signer_url: str, @@ -80,7 +111,7 @@ def get_orch_info_sig( Fetch signer material exactly once per (signer_url, headers) combination for the lifetime of the process. Subsequent calls return cached data. """ - from .orchestrator import _extract_error_message, _http_origin, post_json + from .http import _extract_error_message, _http_origin, post_json_sync as post_json # check for offchain mode if not signer_url: @@ -95,23 +126,12 @@ def get_orch_info_sig( # Some signers accept/expect POST with an empty JSON object. data = post_json(signer_url, {}, headers=headers, timeout=5.0) - # Expected response shape (example): - # { - # "address": "0x0123...abcd", # 20-byte ETH address hex - # "signature": "0x..." # signature hex - # } - if "address" not in data or "signature" not in data: - raise RemoteSignerError( - signer_url, - f"Remote signer JSON must contain 'address' and 'signature': {data!r}", - cause=None, - ) from None - - address_hex_str = str(data["address"]) - address = _hex_to_bytes(address_hex_str, expected_len=20) - sig = _hex_to_bytes(str(data["signature"])) # signature length may vary + signer = _signer_material_from_json(data, signer_url) except LivepeerGatewayError as e: + if isinstance(e, RemoteSignerError): + raise + # post_json wraps the underlying exception as __cause__; convert back into # a signer-specific error message. cause = e.__cause__ or e @@ -152,7 +172,172 @@ def get_orch_info_sig( cause=cause if isinstance(cause, BaseException) else e, ) from None - return SignerMaterial(address=address, sig=sig, address_hex=address_hex_str) + return signer + + +@async_lru_cache(maxsize=128) +async def get_signer_info( + signer_url: str, + # frozenset instead of dict because cache keys require hashable arguments. + _signer_headers: Optional[frozenset[tuple[str, str]]] = None, +) -> SignerMaterial: + """ + Async-native version of get_orch_info_sig for callers that should not block + the event loop or use gRPC. + """ + from .http import _http_origin, post_json + + if not signer_url: + return SignerMaterial(address=None, sig=None) + + url = f"{_http_origin(signer_url)}/sign-orchestrator-info" + headers = dict(_signer_headers) if _signer_headers else None + data = await post_json(url, {}, headers=headers, timeout=5.0) + return _signer_material_from_json(data, url) + + +class LivePaymentSession: + def __init__( + self, + signer_url: Optional[str], + *, + signer_headers: Optional[dict[str, str]] = None, + type: str, + payment_params: str, + manifest_id: str, + orchestrator_url: Optional[str] = None, + max_refresh_retries: int = 3, + ) -> None: + self._signer_url = signer_url + self._signer_headers = _freeze_headers(signer_headers) + self._type = type + self._payment_params = payment_params + self._manifest_id = manifest_id + self._max_refresh_retries = max(0, int(max_refresh_retries)) + self._state: Optional[dict[str, Any]] = None + self._orchestrator_url = orchestrator_url + + async def get_payment(self) -> GetPaymentResponse: + if not self._signer_url: + return GetPaymentResponse(payment="", seg_creds=None) + + attempts = 0 + while True: + try: + return await self._payment_request() + except SignerRefreshRequired as e: + if attempts >= self._max_refresh_retries: + raise PaymentError( + f"Signer refresh required after {attempts} retries: {e}" + ) from e + orchestrator_url = e.orchestrator_url + if not orchestrator_url: + raise PaymentError( + "Signer refresh response missing Livepeer-Orchestrator-URL header" + ) from e + await self._refresh_payment_params(orchestrator_url) + attempts += 1 + + async def send_payment(self, orchestrator_url: Optional[str] = None) -> None: + if not self._signer_url: + return + + target = orchestrator_url or self._orchestrator_url + if not target: + raise PaymentError("orchestrator_url is required before sending payment") + + from .http import _extract_error_message_from_body, _http_origin + + payment = await self.get_payment() + url = f"{_http_origin(target)}/payment" + headers = { + "Livepeer-Payment": payment.payment, + "Livepeer-Segment": payment.seg_creds, + } + try: + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, data=b"", headers=headers) as resp: + body = await resp.text() + if resp.status >= 400: + message = _extract_error_message_from_body(body) + body_part = f"; body={message!r}" if message else "" + raise PaymentError( + f"HTTP payment error: HTTP {resp.status} from endpoint (url={url}){body_part}" + ) + except PaymentError: + raise + except getattr(aiohttp, "ClientConnectorError", ()) as e: + raise PaymentError( + f"HTTP payment error: failed to reach endpoint: {getattr(e, 'message', e)} (url={url})" + ) from e + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise PaymentError( + f"HTTP payment error: failed to reach endpoint: {getattr(e, 'message', e)} (url={url})" + ) from e + + async def _payment_request(self) -> GetPaymentResponse: + from .http import _http_origin, post_json + + url = f"{_http_origin(self._signer_url)}/generate-live-payment" + payload: dict[str, Any] = { + "orchestrator": self._payment_params, + "type": self._type, + "ManifestID": self._manifest_id, + } + if self._state is not None: + payload["state"] = self._state + + headers = dict(self._signer_headers) if self._signer_headers else None + data = await post_json(url, payload, headers=headers) + payment = data.get("payment") + if not isinstance(payment, str) or not payment: + raise PaymentError( + f"GetPayment error: missing/invalid 'payment' in response (url={url})" + ) + + seg_creds = data.get("segCreds") + if seg_creds is not None and not isinstance(seg_creds, str): + raise PaymentError( + f"GetPayment error: invalid 'segCreds' in response (url={url})" + ) + + state = data.get("state") + if not isinstance(state, dict): + raise PaymentError( + f"Remote signer response missing 'state' object (url={url})" + ) + + self._state = state + return GetPaymentResponse(payment=payment, seg_creds=seg_creds) + + async def _refresh_payment_params(self, orchestrator_url: str) -> None: + from .http import _http_origin, post_json + + signer = await get_signer_info(self._signer_url or "", self._signer_headers) + if not signer.address: + raise PaymentError("Cannot refresh payment without signer address") + + url = f"{_http_origin(orchestrator_url)}/refresh-payment" + data = await post_json( + url, + { + "sender": signer.address, + "manifest_id": self._manifest_id, + }, + ) + payment_params = data.get("payment_params") + if not isinstance(payment_params, str) or not payment_params: + raise PaymentError( + f"RefreshPayment error: missing/invalid 'payment_params' in response (url={url})" + ) + self._payment_params = payment_params + refreshed_orchestrator_url = data.get("orchestrator") + self._orchestrator_url = ( + refreshed_orchestrator_url + if isinstance(refreshed_orchestrator_url, str) and refreshed_orchestrator_url.strip() + else orchestrator_url + ) class PaymentSession: @@ -204,7 +389,7 @@ def get_payment(self) -> GetPaymentResponse: return GetPaymentResponse(seg_creds=seg, payment="") def _payment_request() -> GetPaymentResponse: - from .orchestrator import _http_origin, post_json + from .http import _http_origin, post_json_sync as post_json base = _http_origin(self._signer_url) url = f"{base}/generate-live-payment" @@ -275,7 +460,7 @@ def send_payment(self) -> None: Generate a payment (via get_payment) and forward it to the orchestrator via POST {orch}/payment. """ - from .orchestrator import _extract_error_message, _http_origin + from .http import _extract_error_message, _http_origin p = self.get_payment() if not self._info.transcoder: diff --git a/src/livepeer_gateway/scope.py b/src/livepeer_gateway/scope.py index 8a498fe..d8bbeea 100644 --- a/src/livepeer_gateway/scope.py +++ b/src/livepeer_gateway/scope.py @@ -3,68 +3,46 @@ import logging from typing import Any, Optional, Sequence -from .capabilities import CapabilityId, build_capabilities -from .control import ControlConfig, ControlMode -from .errors import LivepeerGatewayError, NoOrchestratorAvailableError, OrchestratorRejection +from .errors import LivepeerGatewayError, NoRunnerAvailableError +from .http import post_json from .lv2v import LiveVideoToVideo, StartJobRequest -from .orchestrator import _http_origin, post_json -from .remote_signer import PaymentSession -from .selection import orchestrator_selector +from .selection import runner_selector from .token import parse_token +_SCOPE_RUNNER_APP = "live-video-to-video/scope" _LOG = logging.getLogger(__name__) -def start_scope( +async def start_scope( orch_url: Optional[Sequence[str] | str], req: StartJobRequest, *, - start_payments: bool = True, token: Optional[str] = None, signer_url: Optional[str] = None, signer_headers: Optional[dict[str, str]] = None, discovery_url: Optional[str] = None, discovery_headers: Optional[dict[str, str]] = None, - control_config: Optional[ControlConfig] = None, - use_tofu: bool = True, timeout: float = 5.0, ) -> LiveVideoToVideo: """ - Start a scope job. + Start a Scope job through a live runner. - Selects an orchestrator with Scope capability and calls - POST {info.transcoder}/scope with JSON body. - - If ``start_payments`` is true and the call happens within a running - asyncio event loop, a background task is automatically started to - send per-segment payments. Otherwise a warning is logged and - payments can be started later via ``job.start_payment_sender()``. + Scope is treated as a single-shot live runner app. The request body is sent + to a discovered ``live-video-to-video/scope`` runner and any paid runner + challenge is handled by the live-runner payment flow. Optional ``token`` can be provided as a base64-encoded JSON object. Token values take precedence over explicit keyword arguments. Explicit keyword arguments are used only for fields missing in the token. - Orchestrator selection/discovery precedence (highest -> lowest): - 1) token ``orchestrators`` value - 2) explicit ``orch_url`` list + Runner discovery precedence (highest -> lowest): + 1) token ``orchestrators`` value, converted by appending ``/discovery`` + 2) explicit ``orch_url`` value, converted by appending ``/discovery`` 3) token ``discovery`` value 4) explicit ``discovery_url`` argument 5) remote signer discovery endpoint derived from the resolved signer URL - ``timeout`` controls only the initial HTTP POST to - ``/scope`` after an orchestrator has been selected. - Discovery and ``GetOrchestrator`` calls use their own timeouts. - - ``use_tofu`` controls TLS mode for ``GetOrchestrator``: - - True: trust-on-first-use certificate pinning - - False: default gRPC/system CA roots - - ``control_config`` controls control-channel behavior. Use - ``ControlConfig(mode=ControlMode.DISABLED)`` to disable keepalives. - - ``model_id`` is ignored for now; internally this is hard-coded to "scope". """ - token_data: Optional[dict[str, Any]] = None if token is not None: token_data = parse_token(token) @@ -89,67 +67,63 @@ def start_scope( if resolved_discovery_headers is None: resolved_discovery_headers = discovery_headers - capabilities = build_capabilities(CapabilityId.LIVE_VIDEO_TO_VIDEO, "scope") - # Orchestrator discovery precedence after token-first field resolution: - # token orchestrators -> explicit orch_url -> token discovery -> - # explicit discovery_url -> signer_url - cursor = orchestrator_selector( - resolved_orch_url, + body = req.to_json() + result = await _select_scope_runner( + body=body, signer_url=resolved_signer_url, signer_headers=resolved_signer_headers, discovery_url=resolved_discovery_url, discovery_headers=resolved_discovery_headers, - capabilities=capabilities, - use_tofu=use_tofu, + orch_url=resolved_orch_url, + timeout=timeout, ) - start_rejections: list[OrchestratorRejection] = [] - while True: - try: - selected_url, info = cursor.next() - except NoOrchestratorAvailableError as e: - all_rejections = list(e.rejections) + start_rejections - if all_rejections: - raise NoOrchestratorAvailableError( - f"All orchestrators failed ({len(all_rejections)} tried)", - rejections=all_rejections, - ) from None - raise - - try: - session = PaymentSession( - resolved_signer_url, - info, - signer_headers=resolved_signer_headers, - type="lv2v", - capabilities=capabilities, - use_tofu=use_tofu, - ) - p = session.get_payment() - headers: dict[str, str] = { - "Livepeer-Payment": p.payment, - "Livepeer-Segment": p.seg_creds, - } - - base = _http_origin(info.transcoder) - url = f"{base}/scope" - payload = req.to_json() - payload.setdefault("model_id", "scope") - data = post_json(url, payload, headers=headers, timeout=timeout) - job = LiveVideoToVideo.from_json( - data, - signer_url=resolved_signer_url, - orchestrator_info=info, - payment_session=session, - ) - if not job.manifest_id: - raise LivepeerGatewayError("LiveVideoToVideo response missing manifest_id") - session.set_manifest_id(job.manifest_id) - return job - except LivepeerGatewayError as e: - _LOG.debug( - "start_scope candidate failed, trying fallback if available: %s (%s)", - selected_url, - str(e), - ) - start_rejections.append(OrchestratorRejection(url=selected_url, reason=str(e))) + data = result.data + if not _is_serverless_runner(result.runner): + app_url = data.get("app_url") + if not isinstance(app_url, str) or not app_url.strip(): + raise LivepeerGatewayError("Scope runner response missing app_url") + data = await post_json(f"{app_url.strip().rstrip('/')}/scope", body, timeout=timeout) + + job = LiveVideoToVideo.from_json( + data, + signer_url=resolved_signer_url, + payment_session=result.payment_session, + ) + if not job.manifest_id: + raise LivepeerGatewayError("Scope response missing manifest_id") + return job + + +async def _select_scope_runner( + *, + body: dict[str, Any], + signer_url: Optional[str], + signer_headers: Optional[dict[str, str]], + discovery_url: Optional[str], + discovery_headers: Optional[dict[str, str]], + orch_url: Optional[Sequence[str] | str], + timeout: float, +): + cursor = await runner_selector( + body=body, + signer_url=signer_url, + signer_headers=signer_headers, + orchestrators=orch_url, + discovery_url=discovery_url, + discovery_headers=discovery_headers, + app=_SCOPE_RUNNER_APP, + timeout=timeout, + ) + try: + return await cursor.next() + except NoRunnerAvailableError as e: + for rejection in e.rejections: + _LOG.info("scope runner rejected: %s: %s", rejection.url, rejection.reason) + raise + + +def _is_serverless_runner(runner: object) -> bool: + raw = getattr(runner, "raw", None) + version = raw.get("version") if isinstance(raw, dict) else None + return isinstance(version, str) and version.startswith("serverless") diff --git a/src/livepeer_gateway/selection.py b/src/livepeer_gateway/selection.py index d3ab01c..dd87d4c 100644 --- a/src/livepeer_gateway/selection.py +++ b/src/livepeer_gateway/selection.py @@ -2,12 +2,24 @@ import logging from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional, Sequence, Tuple +from typing import Any, Optional, Sequence, Tuple from . import lp_rpc_pb2 -from .errors import NoOrchestratorAvailableError, OrchestratorRejection +from .discovery import ( + FilterValue, + discover_orchestrator_runners, + discover_orchestrators, + discover_runners, +) +from .errors import ( + LivepeerGatewayError, + NoOrchestratorAvailableError, + NoRunnerAvailableError, + OrchestratorRejection, + RunnerRejection, +) +from .live_runner import LiveRunnerCallResult, LiveRunnerInstance, LiveRunnerSession, call_runner from .orch_info import get_orch_info -from .orchestrator import discover_orchestrators _LOG = logging.getLogger(__name__) @@ -135,3 +147,185 @@ def orchestrator_selector( capabilities=capabilities, use_tofu=use_tofu, ) + + +class RunnerSelectionCursor: + """ + Stateful selector that advances through live runners sequentially. + + Runner attempts are intentionally not parallelized: selecting a persistent + runner reserves capacity, and selecting a single-shot runner may perform + the caller's actual app operation. + """ + + def __init__( + self, + candidates: Sequence[LiveRunnerInstance], + *, + body: Optional[dict[str, Any]] = None, + method: str = "POST", + signer_url: Optional[str] = None, + signer_headers: Optional[dict[str, str]] = None, + timeout: float = 5.0, + ) -> None: + self._candidates = list(candidates) + self._body = dict(body or {}) + self._method = method + self._signer_url = signer_url + self._signer_headers = signer_headers + self._timeout = timeout + self._next_index = 0 + self.rejections: list[RunnerRejection] = [] + + @property + def candidates(self) -> tuple[LiveRunnerInstance, ...]: + return tuple(self._candidates) + + async def next(self) -> LiveRunnerCallResult: + while self._next_index < len(self._candidates): + runner = self._candidates[self._next_index] + self._next_index += 1 + try: + kwargs: dict[str, Any] = { + "runner": runner, + "payload": self._body, + "method": self._method, + "timeout": self._timeout, + } + if self._signer_url is not None: + kwargs["signer_url"] = self._signer_url + if self._signer_headers is not None: + kwargs["signer_headers"] = self._signer_headers + result = await call_runner(**kwargs) + except Exception as e: + reason = str(e) + _LOG.debug( + "select_runner candidate failed: %s (%s)", + runner.url, + reason, + ) + self.rejections.append(RunnerRejection(url=runner.url, reason=reason)) + continue + + _LOG.debug("select_runner selected: %s", runner.url) + return result + + _LOG.debug( + "select_runner failed: all %d runners rejected", + len(self._candidates), + ) + raise NoRunnerAvailableError( + f"All runners failed ({len(self.rejections)} tried)", + rejections=list(self.rejections), + ) + + +async def runner_selector( + *, + body: Optional[dict[str, Any]] = None, + method: str = "POST", + orchestrators: Optional[Sequence[str] | str] = None, + signer_url: Optional[str] = None, + signer_headers: Optional[dict[str, str]] = None, + discovery_url: Optional[str] = None, + discovery_headers: Optional[dict[str, str]] = None, + app: Optional[FilterValue] = None, + gpu: Optional[FilterValue] = None, + timeout: float = 5.0, +) -> RunnerSelectionCursor: + if orchestrators is not None: + entries = await discover_orchestrator_runners( + orchestrators, + app=app, + gpu=gpu, + ) + else: + entries = await discover_runners( + signer_url=signer_url, + signer_headers=signer_headers, + discovery_url=discovery_url, + discovery_headers=discovery_headers, + app=app, + gpu=gpu, + ) + + candidates = _runner_candidates_from_discovery(entries) + + if not candidates: + _LOG.debug("select_runner failed: empty runner list") + raise NoRunnerAvailableError("No runners available to select") + + return RunnerSelectionCursor( + candidates, + body=body, + method=method, + signer_url=signer_url, + signer_headers=signer_headers, + timeout=timeout, + ) + + +async def reserve_session( + *, + signer_url: Optional[str] = None, + signer_headers: Optional[dict[str, str]] = None, + discovery_url: Optional[str] = None, + discovery_headers: Optional[dict[str, str]] = None, + app: Optional[FilterValue] = None, + gpu: Optional[FilterValue] = None, + timeout: float = 5.0, +) -> LiveRunnerSession: + cursor = await runner_selector( + signer_url=signer_url, + signer_headers=signer_headers, + discovery_url=discovery_url, + discovery_headers=discovery_headers, + app=app, + gpu=gpu, + timeout=timeout, + ) + result = await cursor.next() + session_id = result.data.get("session_id") + app_url = result.data.get("app_url") + if not isinstance(session_id, str) or not session_id.strip(): + raise LivepeerGatewayError("runner session response missing session_id") + if not isinstance(app_url, str) or not app_url.strip(): + raise LivepeerGatewayError("runner session response missing app_url") + return LiveRunnerSession( + session_id=session_id.strip(), + app_url=app_url.strip(), + runner_url=result.runner_url, + runner=result.runner, + ) + + +def _runner_candidates_from_discovery(entries: Sequence[dict[str, Any]]) -> list[LiveRunnerInstance]: + candidates: list[LiveRunnerInstance] = [] + for entry in entries: + orchestrator_url = _string_value(entry.get("address")) + runners = entry.get("runners") + if not isinstance(runners, list): + continue + + for runner in runners: + if not isinstance(runner, dict): + continue + url = _string_value(runner.get("url")) + app = _string_value(runner.get("app")) + if not url or not app: + continue + candidates.append( + LiveRunnerInstance( + url=url, + app=app, + runner_id=_string_value(runner.get("runner_id")), + mode=_string_value(runner.get("mode")), + orchestrator_url=orchestrator_url, + raw=dict(runner), + ) + ) + return candidates + + +def _string_value(value: object) -> str: + return value.strip() if isinstance(value, str) else ""