diff --git a/.gitignore b/.gitignore
index 3641c6c5a..b38375f3b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -224,6 +224,12 @@ controlnet_test_*
demo/realtime-img2img/uploads/
.cgw.conf
+# Local-only git workflow tooling (cgw scripts, hooks, example β never committed)
+scripts/git/
+hooks/cc-block-dangerous-git.sh
+.githooks/
+cgw.conf.example
+
# Local Claude / session state (per-user, never committed)
.claude/
@@ -261,3 +267,6 @@ SESSION_LOG.md
# Profiling/audit CSV exports (Nsight summaries, kernel stats β generated artifacts)
audit_reports/
+
+# Quality harness run outputs (generated; goldens/ is committed, outputs/ is not)
+tests/quality/outputs/
diff --git a/_plans/2026-05-17_controlnet-ipc-emitter-fix.md b/_plans/2026-05-17_controlnet-ipc-emitter-fix.md
new file mode 100644
index 000000000..2a2791f39
--- /dev/null
+++ b/_plans/2026-05-17_controlnet-ipc-emitter-fix.md
@@ -0,0 +1,89 @@
+# ControlNet CUDA IPC β CUDA Graph Capture Conflict (session 2026-05-17)
+
+> **RESOLVED 2026-05-18** β Hypothesis A confirmed. Fix applied and committed. See `_plans/2026-05-18_controlnet-ipc-stream-capture-fix.md`.
+
+> Continuation of `_plans/2026-05-17_controlnet-zero-copy.md`. Emitter fixed so activation survives stream restart. New error class observed: TRT CN engine fails with `cudaErrorStreamCaptureInvalidated (901)` when IPC import runs inside the graph-capture window.
+
+## π‘ Session state (2026-05-17 end of session)
+
+- β
Patches 1-5 to `StreamDiffusionTD/td_manager.py` intact
+- β
Emitter patch applied to `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` (CN block after `use_cuda_ipc_input` + `cuda_ipc_control_shm_name` inside `td_settings`)
+- β
Activation marker confirmed: `CUDA IPC control ready (zero-copy GPU): shm=StreamDiffusionTD_512-512_control_ipc`
+- β
CN importer auto-detected `(512, 512, 4) uint8` β correct for TD canny TOP
+- β TRT CN engine forward fails β `cudaErrorStreamCaptureInvalidated (901)`
+- β Emitter patch and td_config.yaml NOT committed (waiting on error resolution)
+
+## Error (23:42:28-29)
+
+```
+[E] IExecutionContext::enqueueV3: Error Code 1: Myelin (Platform Cuda error)
+
+streamdiffusion.modules.controlnet_module - ERROR - controlnet forward failed:
+ CUDA ERROR: cudaErrorStreamCaptureInvalidated (901)
+ call_summary: cond_shape=(2, 77, 2048), img_shape=(2, 3, 512, 512), scale=0.6, is_sdxl=True, is_trt=True
+
+Traceback:
+ controlnet_module.py:488 _unet_hook: down_samples, mid_sample = cn(...)
+ controlnet_engine.py:135 __call__: outputs = self.engine.infer(...)
+ utilities.py:1028 infer: self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream.ptr))
+RuntimeError: CUDA ERROR: cudaErrorStreamCaptureInvalidated (901)
+
+TouchDesignerManager - ERROR - Error updating parameters:
+ CUDA error: operation would make the legacy stream depend on a capturing blocking stream
+```
+
+## Root-cause hypothesis
+
+The TRT CN engine captures a CUDA graph on its own stream (`cudaStreamBeginCapture` β¦ `cudaStreamEndCapture` at `utilities.py:1028`). Our `_get_control_frame_cuda_ipc()` calls:
+
+```python
+gpu_frame = self._cuda_ipc_control_importer.get_frame(stream=torch.cuda.current_stream())
+```
+
+`get_frame()` issues `cudaStreamWaitEvent` against the IPC slot's event. This touches the stream **during or adjacent to TRT's capture window**, which:
+
+- Either drags the legacy/null stream into a dependency with the capturing stream (hypothesis A)
+- Or records an event on the IPC stream that the capturing stream can't reference (hypothesis B)
+- Or invalidates the capture from a previous call, and `cudaStreamEndCapture` returns 901 (hypothesis C)
+
+The **input** IPC importer uses the same code path and never errors β suggesting timing is the differentiator. Input is fetched before any TRT capture starts; CN frame is fetched after `update_control_image()` and inside the hook that triggers the CN engine capture.
+
+## Files to read at next session start
+
+| Order | File | Location | Why |
+|---|---|---|---|
+| 1 | `cuda_ipc_importer.py` | `src/streamdiffusion/_compat/cuda_ipc/` | `get_frame()` stream-wait implementation; any `cudaStreamIsCapturing` guard |
+| 2 | `utilities.py` | `src/streamdiffusion/acceleration/tensorrt/` | Lines 1000-1035: `infer()` capture begin/end, which stream |
+| 3 | `controlnet_engine.py` | `src/streamdiffusion/acceleration/tensorrt/runtime_engines/` | Lines 120-140: when capture begins relative to input setup |
+| 4 | `controlnet_module.py` | `src/streamdiffusion/modules/` | Lines 470-500: `_unet_hook` β timing of CN forward vs control-image update |
+| 5 | `td_manager.py` | `StreamDiffusionTD/` | Lines 875-921: `_process_controlnet_frame` β call ordering |
+
+## Candidate fixes (verify hypothesis before choosing)
+
+- **a) Dedicated import stream** β pass a non-`current_stream()` argument to `get_frame()`, one that is never captured. Sync to engine stream once after. Low risk if importer signature supports it.
+- **b) Capture-mode guard** β before `cudaStreamWaitEvent`, check `cudaStreamIsCapturing(stream)`. If capturing, use `cudaEventWaitExternal` flag or wait on a side-channel stream and pass result through an explicit event.
+- **c) Reorder fetch before capture window** β pull CN frame at the top of the per-frame loop (before the diffusion step), cache the tensor, hand it to the orchestrator. The `process_tensor` branch already accepts a pre-fetched CUDA tensor.
+- **d) Disable CUDA graph capture for CN engine only** β `CUDALINK_USE_GRAPHS=0` or per-engine flag in engine config. Temporary workaround; measure perf cost.
+
+Options (a) and (c) are the cleanest structural fixes.
+
+## Quick-revert if CN is needed immediately
+
+Set `Usecudaipccontrolnet` TD COMP par to `False` (if par exists on the COMP), or comment out the two emitter lines:
+
+```python
+# yaml_content += f'use_cuda_ipc_controlnet: {str(use_ipc_controlnet).lower()}\n'
+# yaml_content += f" cuda_ipc_control_shm_name: '{stream_name}_control_ipc'\n"
+```
+
+Note: reverting to legacy path also requires re-adding a legacy CN SHM Out TOP in the .toe (was removed when the CUDA-Link Sender was added).
+
+## Commit (deferred)
+
+After the stream-capture conflict is resolved and live verification passes:
+
+```powershell
+./scripts/git/commit_enhanced.sh --no-venv "feat: emit ControlNet CUDA IPC activation keys in stream-start YAML"
+```
+
+Branch: `feat/cuda-ipc-output`, PR target: `SDTD_031_dev`.
diff --git a/_plans/2026-05-17_controlnet-zero-copy.md b/_plans/2026-05-17_controlnet-zero-copy.md
new file mode 100644
index 000000000..7f9f76d6a
--- /dev/null
+++ b/_plans/2026-05-17_controlnet-zero-copy.md
@@ -0,0 +1,340 @@
+# True zero-copy GPU input for ControlNet β close the last per-frame CPU detour
+
+> **Hand-off from `2026-05-17_zero-copy-gpu-input.md` (PR'd as `02911e5`, both Phase 1 + Phase 2 stream-sync hardening landed cleanly).** Main input is fully zero-copy. ControlNet input still takes the CPU detour every frame it's active. Same recipe applies β different config wiring, different format target.
+
+## Context
+
+After commit `02911e5` (zero-copy GPU input v1 + Phase 2 stream-sync hardening), the main img2img input is true zero-copy end-to-end. But when ControlNet is enabled, every frame still pays the old CPU cost on a parallel SHM channel:
+
+| Direction | Channel | Transport | Payload handling |
+|---|---|---|---|
+| SD β TD (output) | `cuda_ipc_shm_name` | CUDAIPCExporter | **GPU end-to-end** (zero-copy) |
+| TD β SD (main input) | `cuda_ipc_input_shm_name` | CUDAIPCImporter (Phase 2 GPU-fence) | **GPU end-to-end** (zero-copy) |
+| TD β SD (controlnet) | `control_mem_name` = ` -cn` | **Legacy SharedMemory mmap** | **HWC uint8 β CPU float-cast β PIL roundtrip inside orchestrator** |
+| TD β SD (ipadapter) | `ipadapter_mem_name` = ` -ip` | Legacy SHM mmap | Out of scope (OSC-triggered, preprocessor PIL detour β see Out of scope) |
+| SD β TD (CN preprocessed) | `-cn-processed` | Legacy SHM mmap | Out of scope (separate refactor) |
+
+When ControlNet is active, `_process_controlnet_frame` runs **every frame** in the streaming loop (`td_manager.py:818`). For multi-CN setups the wrapper-side loop fires N updates per frame.
+
+### What's already in place (verified by MCP search + Reads)
+
+- **Tensor fast-path in orchestrator exists**: `PreprocessingOrchestrator.prepare_control_image` at `src/streamdiffusion/preprocessing/preprocessing_orchestrator.py:268-281` detects `isinstance(control_image, torch.Tensor)` and routes through `_process_tensor_input` (lines 636-656). For preprocessor-less inputs (passthrough), the path is: `unsqueeze(0) if dim==3 β .to(device=self.device, dtype=self.dtype)`. **GPU-only, no CPU roundtrip.** It expects **NCHW float [0,1]**.
+- **Single-index entry from TD**: `ControlNetModule.update_control_image_efficient` at `src/streamdiffusion/modules/controlnet_module.py:152-203` calls `process_sync(image, preprocessors, scales, W, H, index)` per CN. `td_manager.py:841-842` loops over CN indices passing the same `control_frame`. Building one GPU tensor and reusing it across that loop is correct.
+- **Importer API is identical to input direction**: `CUDAIPCImporter(shm_name=..., debug=False)` + `is_ready()` + `get_frame(stream=...)`. We already use it for input at `td_manager.py:677`. Same constructor, same lifecycle. Phase 2 stream-sync (`stream=torch.cuda.current_stream()`) applies the same way.
+- **TD-side resize already happens**: `td_manager.py:824-832` reads `(height, width, 3)` from config β TD must already downsample CN frames to model resolution before SHM write. The new CUDA-Link Sender comp inherits the same constraint β no new size negotiation needed.
+
+### What the current code does (the waste)
+
+`StreamDiffusionTD/td_manager.py:818-842` (`_process_controlnet_frame`, hot path):
+
+```python
+control_frame = np.ndarray((height, width, 3), dtype=np.uint8, buffer=self.control_memory.buf)
+if control_frame.dtype == np.uint8:
+ control_frame = control_frame.astype(np.float32) / 255.0 # β WASTE: CPU rescale
+# ... per-CN loop:
+for cn_idx in range(num_controlnets):
+ self.wrapper.update_control_image(cn_idx, control_frame) # numpy HWC float [0,1]
+```
+
+Then inside the orchestrator (current path for numpy input), `_convert_to_tensor` does:
+
+```python
+control_image = (control_image * 255).astype(np.uint8) # β WASTE: round-trip up
+control_image = Image.fromarray(control_image) # β WASTE: PIL allocation
+control_tensor = self._cached_transform(control_image).unsqueeze(0) # β WASTE: ToTensor() rescales + H2D
+control_tensor.to(device=self.device, dtype=self.dtype)
+```
+
+Net effect per CN per frame: `mmap read β CPU rescale β PIL allocation β ToTensor (rescale + H2D) β device cast`. With N CNs the orchestrator deduplicates input via `_last_input_frame is` identity check (`preprocessing_orchestrator.py:298-304`) so the H2D itself runs once, but the **per-frame CPU work** is unavoidable on the current path.
+
+## Approach
+
+Mirror the main-input plan exactly:
+- New SHM channel `cuda_ipc_control_shm_name` (default `_control_ipc`)
+- New gating flag `use_cuda_ipc_controlnet` (default `False` β fully backward compatible)
+- New importer field `self._cuda_ipc_control_importer`, lazy-initialized on first frame
+- New reader `_get_control_frame_cuda_ipc()` that returns **NCHW float32 [0,1]** GPU tensor (NOT [-1,1] like the main input β different format because the orchestrator's tensor passthrough path doesn't re-normalize)
+- `_process_controlnet_frame` tries the IPC path when gated, falls back to legacy SHM otherwise
+
+The TD-side .toe edits (adding a new CUDA-Link Sender comp publishing to `_control_ipc`) are manual work for the user β the plan calls out the requirement but doesn't deliver TD network changes (the .toe is a binary file, not Scripts/).
+
+### The GPU transform (single chained op)
+
+```python
+# gpu_frame: HWC float32 BGRA on GPU, range [0,1] from Importer
+# target: NCHW float32 RGB [0,1] on GPU (orchestrator handles dtype cast)
+nchw = (
+ gpu_frame[..., [2, 1, 0]] # HWC float32 RGB [0,1] (drop alpha + BGRβRGB)
+ .permute(2, 0, 1) # CHW float32 RGB [0,1]
+ .unsqueeze(0) # NCHW (N=1)
+ .contiguous() # contiguous strides
+)
+```
+
+**Key difference from main input**: no `mul(2).sub_(1)` scale to [-1,1], no `.to(dtype=...)`. The orchestrator's `_process_tensor_input` does the dtype cast itself at line 647/656 (`return ...to(device=self.device, dtype=self.dtype)`). Sending float32 keeps the contract simple β let the orchestrator decide when to downcast.
+
+### Capability gating risk (called out for verification)
+
+If any CN preprocessor lacks `process_tensor` (`preprocessing_orchestrator.py:641`), the fallback at line 664-665 forces `.cpu()` and a PIL roundtrip β **the zero-copy gain evaporates for that CN**. Most pure-passthrough CNs and several built-in preprocessors implement `process_tensor`; some image-analysis ones may not. Verification step (below) includes checking which preprocessors are configured.
+
+## Code changes
+
+### Patch 1 β `StreamDiffusionTD/td_manager.py:60-65` β config + importer field
+
+```python
+self.use_cuda_ipc_output = self.config.get('use_cuda_ipc_output', False)
+self.use_cuda_ipc_input = self.config.get('use_cuda_ipc_input', False)
+self.use_cuda_ipc_controlnet = self.config.get('use_cuda_ipc_controlnet', False) # NEW
+self.cuda_ipc_input_shm_name = self.td_settings.get('cuda_ipc_input_shm_name')
+self.cuda_ipc_control_shm_name = self.td_settings.get('cuda_ipc_control_shm_name') # NEW
+self._cuda_ipc_importer = None # lazy-init on first frame
+self._cuda_ipc_control_importer = None # NEW: lazy-init on first CN frame
+```
+
+### Patch 2 β `StreamDiffusionTD/td_manager.py:392-410` β cleanup
+
+Mirror the existing `_cuda_ipc_importer` cleanup block:
+
+```python
+if self._cuda_ipc_control_importer is not None:
+ try:
+ self._cuda_ipc_control_importer.cleanup()
+ except Exception:
+ pass
+ self._cuda_ipc_control_importer = None
+```
+
+### Patch 3 β `StreamDiffusionTD/td_manager.py:705` β new probe helper
+
+After existing `_probe_ipc_input_shm` (lines 705-717), add a sibling:
+
+```python
+def _probe_ipc_control_shm(self) -> bool:
+ """Return True iff TD has created the CN IPC SharedMemory."""
+ if not self.cuda_ipc_control_shm_name:
+ return False
+ try:
+ from multiprocessing.shared_memory import SharedMemory
+ shm = SharedMemory(name=self.cuda_ipc_control_shm_name)
+ shm.close()
+ return True
+ except (FileNotFoundError, Exception):
+ return False
+```
+
+### Patch 4 β `StreamDiffusionTD/td_manager.py` β new method `_get_control_frame_cuda_ipc`
+
+Insert as a peer to `_get_input_frame_cuda_ipc` (currently around line 667-703):
+
+```python
+def _get_control_frame_cuda_ipc(self) -> Optional["torch.Tensor"]:
+ """Read one CN frame from TD's CUDA IPC channel and return a GPU torch.Tensor
+ matching the orchestrator's tensor passthrough contract: NCHW float32 RGB [0,1] on CUDA.
+ Returns None if importer not ready (caller falls back to legacy SHM path).
+ """
+ if self._cuda_ipc_control_importer is None:
+ if not self._probe_ipc_control_shm():
+ return None
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter
+ try:
+ self._cuda_ipc_control_importer = CUDAIPCImporter(
+ shm_name=self.cuda_ipc_control_shm_name,
+ debug=False,
+ )
+ except Exception as e:
+ logger.warning(f"CUDAIPCImporter (control) init failed: {e}")
+ self._cuda_ipc_control_importer = None
+ return None
+ if not self._cuda_ipc_control_importer.is_ready():
+ self._cuda_ipc_control_importer = None
+ return None
+ logger.info(f"CUDA IPC control ready (zero-copy GPU): shm={self.cuda_ipc_control_shm_name}")
+
+ gpu_frame = self._cuda_ipc_control_importer.get_frame(stream=torch.cuda.current_stream())
+ if gpu_frame is None:
+ return None
+
+ # Orchestrator's _process_tensor_input handles dtype cast β we just normalize layout and channels.
+ return (
+ gpu_frame[..., [2, 1, 0]] # HWC float32 RGB [0,1] (drop alpha, BGRβRGB)
+ .permute(2, 0, 1) # CHW float32 RGB [0,1]
+ .unsqueeze(0) # NCHW (N=1)
+ .contiguous()
+ )
+```
+
+### Patch 5 β `StreamDiffusionTD/td_manager.py:818-842` β branch in `_process_controlnet_frame`
+
+Modify the body to try the IPC path when gated, fall back to legacy SHM otherwise:
+
+```python
+def _process_controlnet_frame(self) -> None:
+ """Process ControlNet frame data (per-frame updates)"""
+ if not self.config.get('use_controlnet', False):
+ return
+
+ control_frame = None
+
+ # Fast path: CUDA IPC (zero-copy GPU tensor) if gated and TD emitter is up
+ if self.use_cuda_ipc_controlnet:
+ control_frame = self._get_control_frame_cuda_ipc()
+
+ # Legacy fallback: SHM mmap β numpy HWC β CPU float-cast
+ if control_frame is None:
+ if not self.control_memory:
+ return
+ try:
+ width = self.config['width']
+ height = self.config['height']
+ control_frame = np.ndarray((height, width, 3), dtype=np.uint8, buffer=self.control_memory.buf)
+ control_frame = control_frame.astype(np.float32) / 255.0
+ except Exception as e:
+ logger.error(f"Error reading ControlNet SHM: {e}")
+ return
+
+ try:
+ # Update ControlNet image for all active CNs (each runs its own preprocessor)
+ cn_module = getattr(self.wrapper.stream, '_controlnet_module', None) if hasattr(self.wrapper, 'stream') else None
+ num_controlnets = len(cn_module.controlnets) if cn_module is not None else 1
+ for cn_idx in range(num_controlnets):
+ self.wrapper.update_control_image(cn_idx, control_frame)
+
+ # Send the processed image back to TD (unchanged β out of scope for this plan)
+ try:
+ if (hasattr(self.wrapper, 'stream') and
+ hasattr(self.wrapper.stream, '_controlnet_module') and
+ self.wrapper.stream._controlnet_module is not None):
+ controlnet_module = self.wrapper.stream._controlnet_module
+ if (hasattr(controlnet_module, 'controlnet_images') and
+ len(controlnet_module.controlnet_images) > 0 and
+ controlnet_module.controlnet_images[0] is not None):
+ processed_tensor = controlnet_module.controlnet_images[0]
+ self._send_processed_controlnet_frame(processed_tensor)
+ except Exception as processed_error:
+ logger.debug(f"Could not extract processed ControlNet image: {processed_error}")
+
+ except Exception as e:
+ logger.error(f"Error processing ControlNet frame: {e}")
+```
+
+### Patch 6 β `StreamDiffusionTD/td_config.yaml` β sample config (documentation)
+
+Add commented-out reference values for the user to copy when they wire up the TD-side Sender:
+
+```yaml
+# CUDA IPC: ControlNet (set to true once TD has a CUDA-Link Sender publishing to this name)
+use_cuda_ipc_controlnet: false
+td_settings:
+ cuda_ipc_input_shm_name: StreamDiffusionTD_512-512_input_ipc
+ # cuda_ipc_control_shm_name: StreamDiffusionTD_512-512_control_ipc # NEW (commented-out by default)
+```
+
+### TD-side manual work (called out, NOT in this PR)
+
+For `use_cuda_ipc_controlnet=true` to actually work, the user must:
+
+1. In the .toe network, add a second `CUDA-Link` Sender comp parallel to the existing input one
+2. Wire it to the CN preview TOP (the same source that currently feeds the `-cn` SHM)
+3. Set its shm name to match `cuda_ipc_control_shm_name` (default `StreamDiffusionTD_512-512_control_ipc`)
+4. Flip `use_cuda_ipc_controlnet: true` in `td_config.yaml`
+
+If the user skips these steps and flips the flag, the SD side will log "CUDAIPCImporter (control) init failed" and gracefully fall back to legacy SHM β no breakage.
+
+### What we explicitly do NOT touch
+
+- **IPAdapter path** β OSC-triggered (`td_manager.py:882` early-exit on `ipadapter_update_requested`), and `IPAdapterEmbeddingPreprocessor._process_tensor_core` (`processors/ipadapter_embedding.py:55-59`) forces a PIL roundtrip anyway. Transport-only zero-copy would buy ~nothing. Deferred until either cadence changes OR preprocessor is refactored.
+- **CN preprocessed return path** (`_send_processed_controlnet_frame` + `-cn-processed` SHM) β separate SDβTD direction, separate refactor, separate PR.
+- **`PreprocessingOrchestrator._process_tensor_input` PIL fallback** at line 664-665 β if a preprocessor lacks `process_tensor`, that branch defeats the win. Out of scope: changing the orchestrator. Mitigation: documented as a verification step below.
+- **Preprocessor `process_tensor` implementations** β adding GPU paths to preprocessors that lack one. Per-preprocessor refactor, separate work.
+- **`_compat/cuda_ipc/`** β no changes; the existing API is sufficient.
+- **`.toe` network edits** β manual user work, documented above but not part of this PR.
+
+## Verification
+
+After applying all six patches in the running SD venv (no rebuild needed β pure Python; Scripts/ edits live-reload per `[[project_scripts_dir_purpose]]`):
+
+### 1. Smoke test β import contract
+
+```powershell
+venv\Scripts\python -c "from StreamDiffusionTD.td_manager import TouchDesignerManager; print('OK')"
+```
+
+Must print `OK`. Any `SyntaxError`/`NameError`/`ImportError` means a patch is wrong β stop and re-read.
+
+### 2. Backward-compatibility test β legacy SHM still works
+
+With `use_cuda_ipc_controlnet: false` (default) and CN enabled, relaunch the .toe. Expect:
+- Legacy `-cn` SHM still serves frames (no regression vs baseline)
+- No new log markers
+- Same FPS / quality as before this PR
+
+### 3. CUDA IPC opt-in test (requires TD-side Sender setup)
+
+After user adds a CUDA-Link Sender comp publishing to `StreamDiffusionTD_512-512_control_ipc` and flips `use_cuda_ipc_controlnet: true`:
+- SD log shows `CUDA IPC control ready (zero-copy GPU): shm=StreamDiffusionTD_512-512_control_ipc`
+- ControlNet preview in TD shows correct colors (BGRβRGB shuffle on GPU instead of CPU)
+- No `_get_control_frame_cuda_ipc:` errors in log
+- If Sender comp missing or wrong name: clean fallback to legacy SHM with one-time `init failed` warning, no crash
+
+### 4. Preprocessor capability check
+
+Before claiming the win, verify the active CN preprocessors actually have `process_tensor`:
+
+```powershell
+venv\Scripts\python -c "
+from streamdiffusion.preprocessing.processors import REGISTRY
+for name, cls in REGISTRY.items():
+ has = hasattr(cls, 'process_tensor')
+ print(f'{name}: process_tensor={has}')
+"
+```
+
+For any preprocessor used in the user's config WITHOUT `process_tensor`, the zero-copy gain evaporates (CPU PIL fallback at `preprocessing_orchestrator.py:664-665`). That's a follow-up item, not a blocker.
+
+### 5. Performance verification
+
+With CN active and `use_cuda_ipc_controlnet=true`, compare against the legacy-SHM baseline:
+- **Steady-state `total_time`**: expected ~0.3-0.8ms lower (CN's CPU rescale + PIL roundtrip + H2D eliminated)
+- **`total_time` jitter**: should tighten when CN active (one fewer per-frame CPU detour)
+- **CN preview latency**: subjective TD-side check β visible reduction in N-CN-multi setups
+
+Optional `nsys` check: should show zero `cudaMemcpyAsync HtoD` calls between consecutive `cudaGraphLaunch`es originating from the CN code path.
+
+## Commit
+
+Per `[[feedback_pr_branch_convention]]`, branch stays at `feat/cuda-ipc-output` (current head: `02911e5`), PR target `SDTD_031_dev`.
+
+```powershell
+./scripts/git/commit_enhanced.sh --no-venv `
+ "feat: zero-copy GPU input for ControlNet via CUDA IPC (transport parity with main input)"
+```
+
+Then save the plan as a project file per `[[feedback_save_plans_as_project_files]]`:
+- Copy this file to `StreamDiffusion/_plans/2026-05-17_controlnet-zero-copy.md`
+
+Note: `StreamDiffusionTD/td_manager.py` is gitignored (lives in companion `dotsimulate/StreamDiffusionTD` repo). Only the plan file and `src/`-side touches (if any) are committable in this repo. The `td_manager.py` patches go through the companion repo per `[[project_td_release_flow]]`.
+
+## Critical files
+
+| File | Lines | Change |
+|---|---|---|
+| `StreamDiffusionTD/td_manager.py` | 60-65 | Patch 1 β add `use_cuda_ipc_controlnet`, `cuda_ipc_control_shm_name`, `_cuda_ipc_control_importer` |
+| `StreamDiffusionTD/td_manager.py` | 392-410 | Patch 2 β cleanup `_cuda_ipc_control_importer` |
+| `StreamDiffusionTD/td_manager.py` | 705-717 | Patch 3 β add `_probe_ipc_control_shm` helper |
+| `StreamDiffusionTD/td_manager.py` | (new method near 667-703) | Patch 4 β add `_get_control_frame_cuda_ipc` |
+| `StreamDiffusionTD/td_manager.py` | 818-842 | Patch 5 β branch `_process_controlnet_frame` on IPC vs SHM |
+| `StreamDiffusionTD/td_config.yaml` | (config schema) | Patch 6 β sample config keys (commented) |
+
+Reused unchanged (verified):
+
+- `src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:903` β `get_frame(stream=...)` already supports GPU-side fence
+- `src/streamdiffusion/preprocessing/preprocessing_orchestrator.py:268-281, 636-656` β tensor fast-path already exists
+- `src/streamdiffusion/modules/controlnet_module.py:152-203` β `update_control_image_efficient` accepts tensors via per-index passthrough
+- `src/streamdiffusion/wrapper.py:2390-2399` β `update_control_image` already forwards tensors unchanged
+
+## Out of scope (documented for future work)
+
+- **IPAdapter zero-copy** β see Approach. Two reasons it's deferred: (a) OSC-triggered cadence vs per-frame, so the optimization saves a few CPU ms per *user trigger* not per frame; (b) `IPAdapterEmbeddingPreprocessor._process_tensor_core` actively defeats the tensor fast path by converting back to PIL for the CLIP image processor. Worth doing only after the preprocessor is refactored to keep tensors on GPU.
+- **CN return path zero-copy** (`_send_processed_controlnet_frame` β `-cn-processed` SHM) β separate SDβTD direction; mirror the existing main-output `CUDAIPCExporter` pattern.
+- **`PreprocessingOrchestrator._process_tensor_input` PIL fallback** β when a preprocessor lacks `process_tensor`, line 664-665 falls back to `.cpu()`. Not blocking this PR, but the per-preprocessor `process_tensor` work is a real follow-up.
+- **TD-side .toe Sender comp wiring** β manual user work (the .toe is a binary file, edited in TouchDesigner not git).
diff --git a/_plans/2026-05-17_cuda-ipc-input-direction.md b/_plans/2026-05-17_cuda-ipc-input-direction.md
new file mode 100644
index 000000000..2433f1f73
--- /dev/null
+++ b/_plans/2026-05-17_cuda-ipc-input-direction.md
@@ -0,0 +1,361 @@
+# Wire CUDA IPC input direction (TD β SD), fix FileNotFoundError on legacy CPU SHM
+
+## Context
+
+After committing the output-direction IPC (commit `4c2a742` on `feat/cuda-ipc-output`), launching the .toe now crashes SD with:
+
+```
+FileNotFoundError: [WinError 2] The system cannot find the file specified: 'StreamDiffusionTD_512-512'
+ at td_manager.py:331 β self.input_memory = shared_memory.SharedMemory(name=self.input_mem_name)
+```
+
+The TD textport log shows TD has switched its **input** to a cuda-link **Sender**:
+
+- `[CUDAIPCExtension:Sender] Created new SharedMemory: StreamDiffusionTD_512-512_input_ipc (433 bytes)`
+- TD now writes input frames as zero-copy GPU IPC (3 slots, 4 MB each, 512Γ512 **float32 4ch**, ~313β624 Β΅s/frame)
+- TD no longer creates the legacy CPU SharedMemory `StreamDiffusionTD_512-512`, so SD's open call fails
+
+The deferred input direction is now required. SD must read input via `CUDAIPCImporter` (vendored at `_compat/cuda_ipc/cuda_ipc_importer.py`) when the toggle is on, and skip the legacy CPU SHM open.
+
+User's authoritative YAML already has the SHM-name reserved:
+
+```yaml
+td_settings:
+ input_mem_name: 'StreamDiffusionTD_512-512' # legacy (unused when IPC input on)
+ cuda_ipc_input_shm_name: 'StreamDiffusionTD_512-512_input_ipc' # TD Sender writes here
+```
+
+Missing: a top-level `use_cuda_ipc_input: true` toggle and the SD-side Importer wiring.
+
+## Target end state
+
+- New top-level YAML toggle `use_cuda_ipc_input: true|false` (parallel to `use_cuda_ipc_output`).
+- When `use_cuda_ipc_input: true`, SD skips the legacy CPU SHM input open and reads frames via a lazy-initialized `CUDAIPCImporter` bound to `td_settings.cuda_ipc_input_shm_name`.
+- First-connect noise (the `traceback.print_exc()` at `cuda_ipc_importer.py:810` when SHM is missing) is sidestepped by a cheap pre-probe via `multiprocessing.shared_memory.SharedMemory(name=...)` β Importer is only constructed once the probe confirms TD's Sender has created the SHM header.
+- `_get_input_frame` returns a numpy HWC uint8 RGB array compatible with the existing streaming-loop contract (L513β514 then does `astype(float32) / 255.0`). TD's wire is HWC float32 BGRA β convert on GPU (drop alpha, swap BβR, scale to [0,255] uint8) β `.cpu().numpy()`. Keeps the streaming-loop contract unchanged; defers full zero-copy GPU pipeline to a future round.
+- Crash is gone. With `use_cuda_ipc_input: true`, SD reads TD's GPU IPC frames and feeds them through the existing img2img path.
+- A second commit lands on `feat/cuda-ipc-output` adding only the SD-side input wiring (TD-side files remain gitignored).
+
+## Execution (in order)
+
+All paths relative to `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/`.
+
+### Step 1 β Add `use_cuda_ipc_input` toggle to YAML
+
+**`StreamDiffusionTD/td_config.yaml`** (gitignored β user must edit, or we set it):
+Add at top level, next to `use_cuda_ipc_output`:
+
+```yaml
+use_cuda_ipc_input: true
+```
+
+**`configs/td_config.yaml.example`** (tracked):
+Add the same key with `false` default, next to `use_cuda_ipc_output: false`:
+
+```yaml
+# CUDA IPC zero-copy GPU-to-GPU output (SDβTD via cuda-link)
+use_cuda_ipc_output: false
+cuda_ipc_shm_name: 'StreamDiffusionTD_512-512_output_ipc'
+cuda_ipc_num_slots: 3
+output_type: 'np'
+
+# CUDA IPC zero-copy GPU-to-GPU input (TDβSD via cuda-link)
+# When true, SD reads input frames from td_settings.cuda_ipc_input_shm_name
+# instead of the legacy CPU SharedMemory at td_settings.input_mem_name.
+use_cuda_ipc_input: false
+```
+
+### Step 2 β Wire toggle + Importer state in `td_manager.py.__init__`
+
+`StreamDiffusionTD/td_manager.py` (gitignored, runtime fix). Near the existing `self.use_cuda_ipc_output` at L62:
+
+```python
+self.use_cuda_ipc_output = self.config.get('use_cuda_ipc_output', False)
+self.use_cuda_ipc_input = self.config.get('use_cuda_ipc_input', False)
+self.cuda_ipc_input_shm_name = self.td_settings.get('cuda_ipc_input_shm_name')
+self._cuda_ipc_importer = None # lazy-init on first frame
+```
+
+### Step 3 β Skip legacy CPU SHM open when IPC input is on
+
+`_initialize_memory_interfaces` around L331 β wrap the input SHM open in the same guard pattern used for the output side:
+
+```python
+# Input memory (from TouchDesigner) β skip when CUDA IPC input is active
+if not self.use_cuda_ipc_input:
+ self.input_memory = shared_memory.SharedMemory(name=self.input_mem_name)
+ logger.debug(f"Connected to input SharedMemory: {self.input_mem_name}")
+else:
+ self.input_memory = None
+ logger.debug(f"CUDA IPC input active; legacy SharedMemory skipped (will read {self.cuda_ipc_input_shm_name})")
+```
+
+This single guard fixes the `FileNotFoundError` crash.
+
+### Step 4 β Add IPC-aware fast-path in `_get_input_frame`
+
+Replace the body of `_get_input_frame` (L628β644) with a branch-on-toggle:
+
+```python
+def _get_input_frame(self) -> Optional[np.ndarray]:
+ """Get input frame from TouchDesigner (platform-specific)"""
+ try:
+ if self.use_cuda_ipc_input:
+ return self._get_input_frame_cuda_ipc()
+ if self.is_macos and self.syphon_handler:
+ return self.syphon_handler.capture_input_frame()
+ if self.input_memory:
+ width = self.config['width']
+ height = self.config['height']
+ frame = np.ndarray((height, width, 3), dtype=np.uint8, buffer=self.input_memory.buf)
+ return frame.copy()
+ return None
+ except Exception as e:
+ if self.debug_mode:
+ logger.debug(f"_get_input_frame: {e}")
+ return None
+```
+
+Add the new IPC helper alongside (right after `_get_input_frame`):
+
+```python
+def _get_input_frame_cuda_ipc(self) -> Optional[np.ndarray]:
+ """Read one frame from TD's CUDAIPCExporter (Sender). Returns HWC uint8 RGB,
+ matching the legacy CPU SHM contract so the streaming loop is unchanged."""
+ # Lazy-construct the Importer once TD's Sender SHM exists.
+ if self._cuda_ipc_importer is None:
+ if not self._probe_ipc_input_shm():
+ return None # TD Sender not active yet β retry next tick
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter
+ try:
+ self._cuda_ipc_importer = CUDAIPCImporter(
+ shm_name=self.cuda_ipc_input_shm_name,
+ debug=False,
+ )
+ except Exception as e:
+ logger.warning(f"CUDAIPCImporter init failed: {e}")
+ self._cuda_ipc_importer = None
+ return None
+ if not self._cuda_ipc_importer.is_ready():
+ # init silently failed (e.g. magic mismatch); drop and retry next tick
+ self._cuda_ipc_importer = None
+ return None
+ logger.info(f"CUDA IPC input ready: shm={self.cuda_ipc_input_shm_name}")
+
+ # TD wire: HWC float32 BGRA on GPU. Convert to HWC uint8 RGB to match
+ # streaming-loop contract (L513β514 expects uint8 β float32 / 255.0).
+ gpu_frame = self._cuda_ipc_importer.get_frame() # zero-copy torch.Tensor on GPU
+ if gpu_frame is None:
+ return None
+ rgb = gpu_frame[..., [2, 1, 0]].contiguous() # BGRA β RGB (drop alpha)
+ rgb_u8 = (rgb.clamp(0, 1) * 255).to(torch.uint8) # float [0,1] β uint8 [0,255]
+ return rgb_u8.cpu().numpy() # D2H to match existing contract
+```
+
+### Step 5 β Add cheap SHM-existence probe
+
+Add as a sibling method (suppresses the noisy `traceback.print_exc()` at `cuda_ipc_importer.py:810` by only constructing the Importer when the SHM segment exists):
+
+```python
+def _probe_ipc_input_shm(self) -> bool:
+ """Return True iff TD has created the input IPC SharedMemory."""
+ if not self.cuda_ipc_input_shm_name:
+ return False
+ try:
+ from multiprocessing.shared_memory import SharedMemory
+ shm = SharedMemory(name=self.cuda_ipc_input_shm_name)
+ shm.close()
+ return True
+ except FileNotFoundError:
+ return False
+ except Exception:
+ return False
+```
+
+This is the proven solution to the first-connect noise that broke the abandoned input-perspective plan.
+
+### Step 6 β Importer cleanup
+
+In `_cleanup_memory_interfaces` (around L387, after the existing `cleanup_cuda_ipc` for the Exporter):
+
+```python
+if self._cuda_ipc_importer is not None:
+ try:
+ self._cuda_ipc_importer.cleanup()
+ except Exception:
+ pass
+ self._cuda_ipc_importer = None
+```
+
+### Step 7 β Runtime verification
+
+1. Set `use_cuda_ipc_input: true` (and `use_cuda_ipc_output: true`) in `StreamDiffusionTD/td_config.yaml`.
+2. Launch .toe. Confirm no `FileNotFoundError` on startup.
+3. Watch for `CUDA IPC input ready: shm=StreamDiffusionTD_512-512_input_ipc` log line on first frame.
+4. Confirm SD output appears in TD's Receiver COMP (round-trip works: TD Sender β SD Importer β wrapper β SD Exporter β TD Receiver).
+5. Toggle `use_cuda_ipc_input: false`, restart, confirm legacy CPU SHM path still works (regression β assumes TD COMP is switched back to non-Sender mode). If TD is still in Sender mode the legacy path will fail to open SHM β that's expected; document as "TD-side mode must match SD-side toggle."
+
+### Step 8 β Commit on `feat/cuda-ipc-output`
+
+```bash
+git add configs/td_config.yaml.example
+./scripts/git/commit_enhanced.sh --no-venv --skip-lint \
+ "feat: add CUDA IPC input direction via cuda-link (TD->SD zero-copy GPU transport)"
+```
+
+Only `configs/td_config.yaml.example` is tracked; the runtime files (`td_manager.py`, `td_config.yaml`) are gitignored. The commit is small by design β the IPC import wiring lives in the .tox binary (synced into the gitignored `StreamDiffusionTD/` dir per [[project_scripts_dir_purpose]]).
+
+## Critical files & key references
+
+**Tracked (committed)**:
+
+- `configs/td_config.yaml.example` β add `use_cuda_ipc_input: false` toggle alongside existing `use_cuda_ipc_output`
+
+**Gitignored (runtime fix only)**:
+
+- `StreamDiffusionTD/td_config.yaml` β add `use_cuda_ipc_input: true` for the user's session
+- `StreamDiffusionTD/td_manager.py`:
+ - L62-area: add `use_cuda_ipc_input`, `cuda_ipc_input_shm_name`, `_cuda_ipc_importer` slots
+ - L331: guard legacy CPU SHM input open with `if not self.use_cuda_ipc_input:`
+ - L387-area (`_cleanup_memory_interfaces`): tear down Importer
+ - L628-644 (`_get_input_frame`): IPC fast-path branch
+ - Add new methods: `_get_input_frame_cuda_ipc`, `_probe_ipc_input_shm`
+
+**Reused verbatim (no edits)**:
+
+- `src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py`:
+ - ctor L495-515 (`shm_name`, `shape=None`, `dtype=None`, `debug`, `timeout_ms=5000`, `device=0` β auto-detects from SHM metadata)
+ - `_initialize()` L783-811 (non-blocking; noisy `traceback.print_exc()` at L810 β sidestepped by Step 5 probe)
+ - `get_frame()` L903 (returns zero-copy torch.Tensor on GPU, HWC, dtype from metadata)
+ - `cleanup()` L1237-1266 (idempotent)
+ - `is_ready()` L1272 (post-init state check)
+- TD's wire format (from textport log): **HWC float32 4ch BGRA** at 512Γ512, 3 slots
+
+## MCP verification (against current index)
+
+Verified via `mcp__code-search__search_code` before finalizing this plan (per [[feedback_verify_plan_with_mcp]]):
+
+- `CUDAIPCImporter.__init__` β confirmed at `cuda_ipc_importer.py:495-563`. Signature: `(shm_name="cudalink_output_ipc", shape=None, dtype=None, debug=False, timeout_ms=5000.0, device=0)`. Auto-detects shape/dtype from SHM metadata, so Step 4 passing only `shm_name` + `debug` is sufficient.
+- `CUDAIPCImporter._initialize` β confirmed at `cuda_ipc_importer.py:783-811`. Non-blocking, single-shot. `traceback.print_exc()` at L810 is the noise the Step 5 probe sidesteps.
+- `CUDAIPCImporter._open_and_validate_shm` β confirmed at `cuda_ipc_importer.py:628-685`. Catches `FileNotFoundError` and re-raises after logging; the probe avoids triggering this path entirely.
+- `CUDAIPCImporter.get_frame` β confirmed at `cuda_ipc_importer.py:903-988`. Returns zero-copy `torch.Tensor` on GPU with shape/dtype from SHM metadata (matches Step 4's GPU-side BGRAβRGBβuint8 conversion).
+- `CUDAIPCImporter.cleanup` β confirmed at `cuda_ipc_importer.py:1237-1266`. Idempotent (Step 6 cleanup is safe to call unconditionally).
+- `CUDAIPCImporter.is_ready` β confirmed at `cuda_ipc_importer.py:1272`. Post-init state check (used in Step 4 to drop a half-initialized importer).
+- `_get_input_frame` β confirmed at `StreamDiffusionTD/td_manager.py:628-644` (authoritative target). Returns HWC uint8 RGB numpy from `self.input_memory.buf` β Step 4's replacement preserves this contract.
+
+**Prior art note** β `Scripts/streamdiffusionTD__Text__td_manager__td.py:717-738` contains an existing `_try_construct_ipc_importer` from earlier abandoned input-direction work, using an `_ipc_importer_cls` indirection + `is_ready()` guard. The Step 4/5 implementation borrows the same probe-then-construct + `is_ready` pattern but does not depend on Scripts/ (which may be stale per [[project_scripts_dir_purpose]]).
+
+## Round 2 β Emitter fix (TD COMP overwrites yaml on launch)
+
+### Symptom
+
+After the SD-side wiring committed at `72dc7cc`, launching the .toe still crashes with the same `FileNotFoundError: ... 'StreamDiffusionTD_512-512'` at `td_manager.py:336` (the guarded SHM open). The guard is in place β but `self.use_cuda_ipc_input` evaluates to `False`.
+
+### Root cause
+
+The TD .tox COMP **regenerates `StreamDiffusionTD/td_config.yaml` on every launch** via its YAML emitter at:
+
+```
+D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py:3740-3774
+```
+
+(Scripts/ is at the parent dir of the repo, per [[project_scripts_dir_purpose]] β edits sync into the running .tox immediately.) The emitter currently writes these IPC keys:
+
+- L3754 `use_cuda_ipc_output: {use_ipc}` β driven by `Usecudaipcoutput` par, default `True`
+- L3755 `cuda_ipc_shm_name: '{stream_name}_output_ipc'`
+- L3756 `cuda_ipc_num_slots: 3`
+- L3768 `cuda_ipc_input_shm_name: '{stream_name}_input_ipc'`
+
+Missing: **no `use_cuda_ipc_input` write**. The yaml I edited manually had the line stripped on .toe launch. `self.config.get('use_cuda_ipc_input', False)` β `False` β guard inactive β crash.
+
+The TD COMP is **hardwired to Sender mode** for the input direction (textport log confirms `[CUDAIPCExtension:Sender] Created new SharedMemory: StreamDiffusionTD_512-512_input_ipc`). So the emitter should mirror that and write `use_cuda_ipc_input: true`.
+
+### Fix
+
+One symmetric block added to the emitter, mirroring the existing output-direction pattern at L3747-3757. Insert immediately after L3757 (`output_type: 'np'`), before L3759 (`# TouchDesigner specific settings`):
+
+```python
+# Emit CUDA IPC INPUT setting β enable by default; override via Usecudaipcinput par if present
+use_ipc_input = True
+try:
+ use_ipc_input = bool(self.ownerComp.par.Usecudaipcinput.eval())
+except AttributeError:
+ pass
+yaml_content += '\n# CUDA IPC zero-copy GPU-to-GPU input (TDβSD via cuda-link)\n'
+yaml_content += f'use_cuda_ipc_input: {str(use_ipc_input).lower()}\n'
+```
+
+Defaults to `True` (matches the .tox's hardwired Sender). If the user later adds a `Usecudaipcinput` parameter to the COMP, it'll be respected β symmetric with `Usecudaipcoutput`.
+
+### Verification
+
+1. Launch .toe. Confirm `td_config.yaml` is rewritten and now contains `use_cuda_ipc_input: true` near the existing `use_cuda_ipc_output: true`.
+2. Confirm no `FileNotFoundError` on startup.
+3. Watch for `CUDA IPC input ready: shm=StreamDiffusionTD_512-512_input_ipc` log on first frame.
+4. Confirm round-trip: TD Sender β SD Importer β wrapper β SD Exporter β TD Receiver.
+
+### Out of scope (Round 2)
+
+- Independent TD parameter `Usecudaipcinput` β defer until user wants to toggle input direction separately from the hardwired Sender. Current default-`True` falls through if the parameter is absent, so adding it later is non-breaking.
+- Committing the emitter change β Scripts/ is outside the git repo (parent dir per [[project_scripts_dir_purpose]]), so this is a runtime sync only. No git artifact.
+
+### Critical files (Round 2)
+
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py:3757-3759` β insert the 6-line emit block
+
+## Out of scope (Round 1)
+
+- True zero-copy GPU input path (skip the `.cpu().numpy()` D2H by feeding `wrapper.img2img` a GPU tensor directly). Requires touching the streaming loop's uint8βfloat conversion at L513-514 and the wrapper's img2img preprocessing β defer to a follow-up.
+- Symphonous handling when TD toggles Sender mode at runtime (current behavior: SD detects via probe on next tick, lazily reinitializes; teardown on TD side currently triggers `SlotState.SHUTDOWN` in `_try_acquire`, importer auto-cleans).
+- Pushing the branch / opening a PR β user-driven per [[feedback_pr_branch_convention]].
+
+---
+
+## Next session β log review handoff
+
+**Status going in**: Both code changes are applied (SD-side committed at `72dc7cc` on branch `feat/cuda-ipc-output`; TD-side emitter edit at `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` post-L3757 is **uncommitted** since Scripts/ lives outside the git repo). The user is expected to relaunch the .toe between sessions, then return with two fresh logs.
+
+### What to ask the user for
+
+1. **SD cmd log** β full stdout from launching the .toe (the same channel that previously showed the `FileNotFoundError: ... 'StreamDiffusionTD_512-512'` crash).
+2. **TD textport log** β full TD console output (the `[CUDAIPCExtension:Sender]` lines and any new `[CUDAIPCExtension:Receiver]` lines for the output direction).
+
+### Success criteria β what the logs MUST show
+
+**SD cmd log (must see)**:
+
+- β
`td_config.yaml` printout near startup includes `use_cuda_ipc_input: true` (between `output_type: 'np'` and `# TouchDesigner specific settings`). If this line is missing, the emitter edit didn't sync into the .tox β check whether TD was restarted (the COMP re-emits on init, not on file change).
+- β
NO `FileNotFoundError: ... 'StreamDiffusionTD_512-512'` traceback.
+- β
One-shot `CUDA IPC input ready: shm=StreamDiffusionTD_512-512_input_ipc` log line on first frame received.
+- β
Streaming loop continues β frame timing logs, no repeated `_get_input_frame` exceptions.
+
+**TD textport log (must see)**:
+
+- β
Sender side already known-good: `Created new SharedMemory: StreamDiffusionTD_512-512_input_ipc` + `FIRST FRAME: ...` + steady-state `Frame N: slot X, ... GPU memcpy=...us`.
+- β
Receiver side activates: look for `[CUDAIPCExtension:Receiver] ...` lines confirming the output direction round-trips. (Round 1 already wired SDβTD output, so this should match commit `4c2a742`'s behavior.)
+
+### Common failure modes to triage (in order of likelihood)
+
+1. **Emitter didn't sync** β yaml printout still missing `use_cuda_ipc_input`. Confirm Scripts/ edit landed (`grep -n use_cuda_ipc_input D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py`). If present in Scripts/ but not in regenerated yaml, the .toe may need a full close+reopen (not just re-init) to pick up Scripts/ changes β verify per [[project_scripts_dir_purpose]].
+2. **Importer init fails** β see `CUDAIPCImporter init failed: ...` warn in SD log. Likely cause: SHM magic mismatch or shape mismatch between vendored `cuda_ipc_importer.py` (`PROTOCOL_MAGIC = 0x43495044` "CIPD") and the TD-side cuda-link v1.4.1 emitter. Verify the magic with `grep -n PROTOCOL_MAGIC D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py` and confirm TD Sender wrote the same.
+3. **Probe never succeeds** β SD silently returns None every frame, never logs "CUDA IPC input ready". Means `_probe_ipc_input_shm` keeps hitting `FileNotFoundError`. Could be a Windows SHM name-collision quirk (the legacy CPU SHM `StreamDiffusionTD_512-512` and IPC SHM `StreamDiffusionTD_512-512_input_ipc` are different names β should not collide, but worth verifying TD Sender textport log shows the `_input_ipc` SHM actually created).
+4. **BGRAβRGB conversion artifacts** β frames flow but output looks wrong (color-swapped or alpha leak). The Round 1 conversion at `td_manager.py:687-688` (`gpu_frame[..., [2, 1, 0]].contiguous()` + `clamp(0,1) * 255 β uint8`) assumes TD writes float32 BGRA in [0,1]. Textport log confirms `512x512 float32 4ch` β but if values are out of [0,1] range (TD's sRGB pipeline can produce >1 in HDR), the `clamp(0, 1)` would crush highlights. Worth visual-comparison test if user reports off colors.
+
+### Next agent's first move
+
+```
+1. Read the SD log + TD log the user pastes
+2. Grep yaml printout for `use_cuda_ipc_input` to confirm emitter fix landed
+3. Grep for `CUDA IPC input ready` + `FileNotFoundError`
+4. If happy path β commit the emitter change is NOT possible (Scripts/ is outside the repo); instead acknowledge round-trip success and ask user whether to merge `feat/cuda-ipc-output` β `SDTD_031_dev` per [[feedback_pr_branch_convention]]
+5. If failure β triage per "Common failure modes" above before touching code
+```
+
+### Reference: what's committed vs uncommitted
+
+| File | Tracked? | Committed? | Notes |
+|---|---|---|---|
+| `configs/td_config.yaml.example` | yes | `72dc7cc` | adds `use_cuda_ipc_input: false` default |
+| `StreamDiffusionTD/td_config.yaml` | gitignored | no | regenerated by TD on launch; manual edits clobbered |
+| `StreamDiffusionTD/td_manager.py` | gitignored | no | runtime fix: guard, helpers, cleanup, probe |
+| `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` | **outside repo** | n/a | Round 2 emitter fix (1 block post-L3757); no git artifact possible |
diff --git a/_plans/2026-05-17_diagnose-cudaruntimetypes-import.md b/_plans/2026-05-17_diagnose-cudaruntimetypes-import.md
new file mode 100644
index 000000000..6c311e36d
--- /dev/null
+++ b/_plans/2026-05-17_diagnose-cudaruntimetypes-import.md
@@ -0,0 +1,184 @@
+# Diagnose & fix `_get_input_frame: No module named 'CUDARuntimeTypes'`
+
+> **Hand-off from `cozy-snacking-wilkinson.md` Round 2.** Both Round 1 (SD-side) and Round 2 (TD emitter) landed. The `FileNotFoundError` is gone. A new failure replaced it: the SD-side Importer construction throws `ModuleNotFoundError: No module named 'CUDARuntimeTypes'` on every frame, so no input ever reaches the wrapper.
+
+## Context
+
+**What the user observed** (SD cmd log + TD textport, 2026-05-17 21:13):
+
+- β
Round 1 guard worked: `CUDA IPC input active; legacy SharedMemory skipped (will read StreamDiffusionTD_512-512_input_ipc)` β no `FileNotFoundError`.
+- β Round 2 success criterion missing: no `CUDA IPC input ready: shm=...` log line β Importer was never successfully constructed.
+- β 13 Γ `TouchDesignerManager - DEBUG - _get_input_frame: No module named 'CUDARuntimeTypes'` at startup, all at 21:13:03 (~77ms/attempt), then SD log goes silent while TD Sender continues writing frames at full rate (TD log shows Frame 97 β Frame 2716 timing).
+
+The error message format `_get_input_frame: ` matches the OUTER catch at `td_manager.py:_get_input_frame` (debug-only):
+
+```python
+except Exception as e:
+ if self.debug_mode:
+ logger.debug(f"_get_input_frame: {e}")
+ return None
+```
+
+That catch sees the exception raised by the `from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter` statement inside `_get_input_frame_cuda_ipc` (Step 4 of cozy-snacking) β it sits OUTSIDE the inner try/except, so it propagates up.
+
+## Root cause (verified)
+
+The vendored `_compat/cuda_ipc/` package has **two files with broken absolute imports** that worked in the upstream `cuda_link` package context (where `cuda_link` is pip-installed) AND in TouchDesigner's flat namespace (where `CUDARuntimeTypes` is a top-level module), but fail in SD's venv where **neither** is available:
+
+### File 1 β `src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_wrapper.py:24-54`
+
+```python
+try:
+ from cuda_link.cuda_runtime_types import ( # β NOT installed in SD venv
+ CUDAError, CUDAEvent_t, CUDAGraph_t, CUDAGraphExec_t, CUDAGraphNode_t,
+ CUDAStream_t, cudaIpcEventHandle_t, cudaIpcMemHandle_t,
+ cudaMemcpy3DParms, cudaPointerAttributes,
+ )
+except ImportError:
+ from CUDARuntimeTypes import ( # β TD-only top-level module β also missing
+ CUDAError, CUDAEvent_t, ...
+ )
+
+try:
+ from cuda_link.cuda_graphs import CUDAGraphsMixin # β same problem
+except ImportError:
+ from CUDAGraphs import CUDAGraphsMixin # β same problem
+```
+
+### File 2 β `src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py:18-41`
+
+Same try/except pattern: tries `cuda_link.cuda_runtime_types`, falls back to top-level `CUDARuntimeTypes`. Both fail in SD.
+
+### Import chain that triggers the failure
+
+1. `_get_input_frame_cuda_ipc` runs `from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter`.
+2. `_compat/cuda_ipc/__init__.py:11` does `from .cuda_ipc_wrapper import CUDARuntimeAPI, get_cuda_runtime`.
+3. `cuda_ipc_wrapper.py:24-49` hits the broken pair β raises `ModuleNotFoundError: No module named 'CUDARuntimeTypes'`.
+4. (Even if the wrapper import were skipped, `cuda_ipc_importer.py:108` does `from .cuda_ipc_wrapper import CUDARuntimeAPI, get_cuda_runtime` directly β same failure.)
+5. ImportError propagates β outer catch in `_get_input_frame` β DEBUG log β returns None β loop retries forever.
+
+### Why the sibling vendored files already exist
+
+`_compat/cuda_ipc/cuda_runtime_types.py` is present and exports every symbol the wrapper/graphs files need β verified via grep:
+
+- Classes: `cudaPos`, `cudaMemcpy3DParms`, `cudaIpcMemHandle_t`, `cudaIpcEventHandle_t`, `cudaPointerAttributes`, `CUDAError`, `cudaPitchedPtr`, `cudaExtent`
+- Aliases: `CUDAEvent_t`, `CUDAStream_t`, `CUDAGraph_t`, `CUDAGraphExec_t`, `CUDAGraphNode_t`, `CUDART_GRAPHS_MIN_VERSION`
+
+And `_compat/cuda_ipc/cuda_graphs.py` is present as a sibling for `CUDAGraphsMixin`.
+
+The sibling-relative import pattern is **already established in this package**:
+
+- `cuda_ipc_importer.py:108-109` β `from .cuda_ipc_wrapper import ...` + `from .cuda_runtime_types import ...`
+- `cuda_ipc_exporter.py:61-62` β same pattern
+
+So `cuda_ipc_wrapper.py` and `cuda_graphs.py` are the **only two stragglers** still using the broken `cuda_link.X` / `CUDARuntimeTypes` pattern.
+
+### Why cozy-snacking missed this
+
+That plan's "MCP verification" section read file contents (ctor signatures, line ranges). It did not execute the import chain. `from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter` was never actually attempted in SD's venv before the commit landed β the static-content check looked clean.
+
+> Process refinement: future plans relying on a `from import ` line should verify the chain by running `python -c "from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter"` in SD's venv during Phase 3 review β not just by reading line ranges. (Memory update candidate for `[[feedback_verify_plan_with_mcp]]`.)
+
+## Fix β replace broken imports with relative ones
+
+Match the convention already used by `cuda_ipc_importer.py:108-109` and `cuda_ipc_exporter.py:61-62`.
+
+### Patch 1 β `src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_wrapper.py`
+
+Replace **lines 24-54** with:
+
+```python
+from .cuda_runtime_types import ( # noqa: E402
+ CUDAError,
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaIpcEventHandle_t,
+ cudaIpcMemHandle_t,
+ cudaMemcpy3DParms,
+ cudaPointerAttributes,
+)
+from .cuda_graphs import CUDAGraphsMixin # noqa: E402
+```
+
+### Patch 2 β `src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py`
+
+Replace **lines 18-41** with:
+
+```python
+from .cuda_runtime_types import ( # noqa: E402
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaExtent,
+ cudaMemcpy3DParms,
+ cudaPitchedPtr,
+ cudaPos,
+)
+```
+
+**No symbol changes.** All names already exist in `_compat/cuda_ipc/cuda_runtime_types.py` (verified above). The diff is `-26 +13` lines total.
+
+### Why drop the try/except entirely (Option A) vs. add a third fallback (Option B)
+
+- **Option A (recommended)**: hard-replace with relative imports. The `_compat/cuda_ipc/` directory is a sealed in-tree copy used only from SD's venv β its sibling files are the canonical source of these symbols here. Matches what `cuda_ipc_importer.py` / `cuda_ipc_exporter.py` already do.
+- **Option B**: keep the try/except chain and add a third `from .cuda_runtime_types import ...` fallback. Preserves byte-similarity to upstream `cuda_link`, useful only if someone ever drops the upstream pip package into SD's venv. Adds 8 lines of dead defensive code for a path nobody uses today.
+
+Recommendation: **Option A**. If the user later wants to install `cuda_link` as a real pip dep, the `_compat` copy can be deleted entirely at that point β the relative-import version doesn't need to coexist with the pip version.
+
+## Verification
+
+After applying both patches in the running SD venv (no rebuild needed β pure Python):
+
+1. **Smoke-test the import chain** (Bash/PowerShell):
+
+ ```powershell
+ python -c "from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter; print('OK', CUDAIPCImporter)"
+ ```
+
+ Must print `OK `. If it prints any traceback, the patch is wrong β stop and re-read the broken file.
+
+2. **Relaunch the .toe.** Inspect SD cmd log for the Round 2 success criteria (from cozy-snacking lines 326-330):
+ - β
`td_config.yaml` printout contains `use_cuda_ipc_input: true`.
+ - β
NO `FileNotFoundError`.
+ - β
NO `_get_input_frame: No module named 'CUDARuntimeTypes'`.
+ - β
One-shot `CUDA IPC input ready: shm=StreamDiffusionTD_512-512_input_ipc` on the first frame.
+ - β
TD textport `[CUDAIPCExtension:Receiver]` lines continue to confirm the output direction round-trips (this was already working in Round 1 commit `4c2a742`).
+
+3. **Round-trip visual check**: TD's Receiver COMP should show the SD-processed frames moving (not a frozen first frame). If the BGRAβRGB conversion at `td_manager.py:687-688` is off, colors will be swapped β that's a separate Round-3 fix, not part of this diagnosis.
+
+## Commit
+
+Both files are tracked (`src/streamdiffusion/_compat/cuda_ipc/`), so this lands as a normal commit on `feat/cuda-ipc-output`, on top of `72dc7cc`.
+
+```powershell
+./scripts/git/commit_enhanced.sh --no-venv `
+ "fix: use relative imports in vendored _compat/cuda_ipc (CUDARuntimeTypes missing in SD venv)"
+```
+
+(Per `[[feedback_pr_branch_convention]]`, branch stays at `feat/cuda-ipc-output`; PR target is `SDTD_031_dev`.)
+
+## Critical files (this diagnosis only)
+
+| File | Lines | Change |
+|---|---|---|
+| `StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_wrapper.py` | 24-54 | replace 2Γ try/except with 2Γ relative import (`-26 +13`) |
+| `StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py` | 18-41 | replace 1Γ try/except with 1Γ relative import (`-23 +10`) |
+
+Reused verbatim (no edits):
+
+- `StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_runtime_types.py` β already exports all required symbols
+- `StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/__init__.py` β import order is fine
+- `StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py` β already uses `.cuda_ipc_wrapper` / `.cuda_runtime_types` correctly
+- `StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_exporter.py` β same
+
+## Out of scope
+
+- BGRAβRGB conversion correctness (`td_manager.py:687-688` `clamp(0,1)*255 β uint8`). Round-3 if user reports off colors.
+- True zero-copy GPU input (skip the `.cpu().numpy()` D2H). Deferred per cozy-snacking Round 1 "Out of scope".
+- Installing `cuda_link` as a real pip package and removing the vendored copy. Larger refactor; not needed for the fix.
+- Saving this plan into `StreamDiffusion/_plans/2026-05-17_diagnose-cudaruntimetypes-import.md` per `[[feedback_save_plans_as_project_files]]` β will copy on exit from plan mode (plan-mode editor only permits the assigned file in `~/.claude/plans/`).
diff --git a/_plans/2026-05-17_redo-cuda-ipc-output-direction.md b/_plans/2026-05-17_redo-cuda-ipc-output-direction.md
new file mode 100644
index 000000000..e030446ed
--- /dev/null
+++ b/_plans/2026-05-17_redo-cuda-ipc-output-direction.md
@@ -0,0 +1,269 @@
+# Redo CUDA IPC integration β rename SD-side to cuda-link vocabulary, wire output direction
+
+## Context
+
+Phase 1 (restore) is complete. The working tree is back to a clean post-Quantization state with `_compat/cuda_ipc/` and `_compat/td_exporter/` re-vendored as pristine copies from `F:\RD_PROJECTS\COMPONENTS\cuda-link` (commit `92989fc`, version `1.4.1`). No IPC glue remains in `wrapper.py`, `config.py`, `td_config.yaml.example`, `StreamDiffusionTD/td_manager.py`, or `StreamDiffusionTD/td_config.yaml`. The abandoned plan that mapped cuda-link SenderβSD "input" / ReceiverβSD "output" is archived at `_plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md`.
+
+Phase 2 (this plan) redoes the integration with cuda-link's vocabulary kept verbatim, renaming SD-side names to match. The previous attempt's input-direction (Importer) first-connect noise at `cuda_ipc_importer.py:808` is sidestepped for this round by scoping to the output direction only (SD's Exporter), which was the proven 16-25 FPS path in abandoned Phase 2.1.
+
+User decisions captured this session:
+- **Scope**: Rename + output IPC only. Input direction (Importer) deferred to a follow-up.
+- **Naming**: `td_export_shm_name` / `td_import_shm_name` (TD's perspective β mirrors what user types into TD's `Ipcmemname` param).
+- **Migration**: Hard rename, no backwards-compat shim. Users update their deployed `td_config.yaml` manually.
+- **SHM name format (TD-side, already in place)**: `parent.SDTD.par.Streamoutname + '_input_ipc'` for the SHM SD's Exporter writes into (TD reads as input); `parent.SDTD.par.Streamoutname + '_output_ipc'` for the SHM TD writes into (SD's Importer would read, out of scope this round). The TD StreamDiffusionExt computes these β SD just receives the final strings verbatim.
+
+## Target end state
+
+- Two new YAML keys: `td_settings.td_export_shm_name` (TD writes, SD reads β out-of-scope wiring) and `td_settings.td_import_shm_name` (SD writes, TD reads β CUDA IPC wired this round). The old `input_mem_name` / `output_mem_name` keys are gone from both `td_config.yaml` and `td_config.yaml.example`.
+- `TouchDesignerManager.__init__` takes `td_export_shm_name` / `td_import_shm_name` positional params. All internal SHM-name references renamed.
+- `StreamDiffusionWrapper` gains an opt-in CUDA IPC fast-path in `postprocess_image` (or its caller) for the SDβTD direction. Falls through to legacy CPU SHM behavior when IPC is disabled.
+- A new `use_cuda_ipc` YAML flag controls the IPC fast-path (default `false` for safe rollout).
+- `td_main.py` drops the `_{int(time.time())}` uniquifier on the output name β `Streamoutname`-based naming is already per-COMP unique.
+- StreamDiffusionExt YAML emitter inside the .toe binary writes the two new keys with the IPC-suffix values (user edits inside TouchDesigner β out of repo).
+- Smoke test passes: `from streamdiffusion.wrapper import StreamDiffusionWrapper` (assuming pre-existing `controlnet_aux` is installed, unrelated). End-to-end TD test: open `.toe`, configure `Streamoutname`, see frames flow at β₯16 FPS with `use_cuda_ipc: true` in `td_config.yaml`.
+
+## Execution (in order)
+
+All paths relative to `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/` unless absolute.
+
+### Step 1 β Rename SD-side YAML keys (mechanical)
+
+**`StreamDiffusionTD/td_config.yaml`** (current L73-74) and **`configs/td_config.yaml.example`** (current L95-96): replace
+
+```yaml
+ input_mem_name: 'StreamDiffusionTD_512-512'
+ output_mem_name: 'StreamDiffusionTD_512-512_out'
+```
+
+with
+
+```yaml
+ # cuda-link CUDA IPC shared-memory names (auto-emitted by TD StreamDiffusionExt)
+ # td_export_shm_name: TD writes here (Sender/ExportBuffer); SD's CUDAIPCImporter reads. Not wired this round.
+ # td_import_shm_name: SD writes here (CUDAIPCExporter); TD reads (Receiver/ImportBuffer).
+ td_export_shm_name: 'StreamDiffusionTD_512-512_output_ipc'
+ td_import_shm_name: 'StreamDiffusionTD_512-512_input_ipc'
+
+ # CUDA IPC output toggle (false = legacy CPU multiprocessing.shared_memory)
+ use_cuda_ipc: false
+```
+
+### Step 2 β Rename `TouchDesignerManager` ctor + internal refs
+
+Edit `StreamDiffusionTD/td_manager.py` (authoritative β copied directly from `.toe` binary into this dir per [[project_scripts_dir_purpose]]; gitignored). Scripts/ is the OLD stale sync dir β do **not** touch.
+
+- **L40**: signature change
+ - Before: `def __init__(self, config, input_mem_name: str, output_mem_name: str, debug_mode=False, osc_reporter=None)`
+ - After: `def __init__(self, config, td_export_shm_name: str, td_import_shm_name: str, debug_mode=False, osc_reporter=None)`
+- **L41-42**: `self.input_mem_name = input_mem_name` β `self.td_export_shm_name = td_export_shm_name`; same for `td_import_shm_name`.
+- **L98-100**: derived names β replace `self.input_mem_name` with `self.td_export_shm_name` (the suffixed `-cn`/`-cn-processed`/`-ip` names tag onto whatever-data-flows-into-SD, which is TD's export side).
+- **L314-322** (macOS Syphon block): `sender_name=self.td_import_shm_name, input_name=self.td_export_shm_name` (Syphon perspective is unaffected by the rename β it's just two names; updating identifiers).
+- **L330-346**: CPU SHM connect/create β `self.td_export_shm_name` replaces `self.input_mem_name`; `self.td_import_shm_name` replaces `self.output_mem_name`.
+- **L358**: `control_processed_mem_name = self.td_import_shm_name + '-cn-processed'`.
+- **L534**: `_send_output_frame(output_image)` call site β no change, only the buffer name's identifier changes.
+- **L643-740** (`_send_output_frame`): inspect for any `self.output_mem_name` references β rename to `self.td_import_shm_name`.
+- **L673** (`buffer=self.output_memory.buf`): no rename needed (attribute name `output_memory` is a Python attr, not a SHM identifier β keep as-is; it's the CPU SHM handle that may be `None` when IPC active).
+
+Verification: `grep -nE "(input_mem_name|output_mem_name)" StreamDiffusionTD/td_manager.py` returns 0 hits.
+
+### Step 3 β Drop timestamp uniquifier in `td_main.py`
+
+Edit `StreamDiffusionTD/td_main.py` L399-415:
+
+```python
+# Before
+input_mem = td_settings.get('input_mem_name', 'input_stream')
+base_output_name = td_settings.get('output_mem_name', 'sd_to_td')
+output_mem = f"{base_output_name}_{int(time.time())}"
+# ...
+self.manager = TouchDesignerManager(yaml_config, input_mem, output_mem, ...)
+
+# After
+td_export_shm_name = td_settings['td_export_shm_name'] # required, no fallback per hard-rename
+td_import_shm_name = td_settings['td_import_shm_name']
+# ...
+self.manager = TouchDesignerManager(
+ yaml_config,
+ td_export_shm_name,
+ td_import_shm_name,
+ debug_mode=debug_mode,
+ osc_reporter=osc_reporter,
+)
+```
+
+- **L450**: log line β `print(f"\033[38;5;80mMemory: \033[37m{td_export_shm_name} <- TD | SD -> {td_import_shm_name}\033[0m")` (direction arrows make the new naming readable).
+- **L462**: `self.osc_reporter.send_output_name(self.manager.td_import_shm_name)` (the OSC reporter announces the SHM name where TD should read).
+
+The timestamp uniquifier is dropped because `Streamoutname` (TD-side, e.g. `StreamDiffusionTD_512-512`) is already per-COMP unique. If a user has two COMPs at the same resolution they need different `Streamoutname` values β that's a TD-side configuration concern, not a Python-side uniqueness trick.
+
+### Step 4 β Wire SDβTD CUDA IPC output direction
+
+#### 4a β `src/streamdiffusion/wrapper.py`
+
+Add ctor kwargs (near existing ones around L82-160):
+
+```python
+use_cuda_ipc: bool = False,
+cuda_ipc_shm_name: str | None = None,
+cuda_ipc_num_slots: int = 2,
+```
+
+Add instance slots in `__init__` (alongside existing output-type init around L312):
+
+```python
+self.use_cuda_ipc = use_cuda_ipc
+self._cuda_ipc_shm_name = cuda_ipc_shm_name
+self._cuda_ipc_num_slots = cuda_ipc_num_slots
+self._cuda_ipc_exporter = None # lazy-init on first frame
+```
+
+Add a fast-path inside `postprocess_image` (currently at `wrapper.py:894-948`), early-exit when IPC is active. The function signature stays the same; the new branch happens before the existing `output_type == "pil"|"pt"|"np"|"latent"` dispatch:
+
+```python
+def postprocess_image(self, image_tensor, output_type="pil"):
+ # CUDA IPC fast-path: zero-copy BGRA export to TD via _compat.cuda_ipc.
+ # Skips D2H, CPU repack, and CPU SHM write. Returns None to signal "frame
+ # consumed by IPC" so the caller's CPU SHM write path is skipped.
+ if self.use_cuda_ipc and self._cuda_ipc_shm_name:
+ bgra = self._ipc_pack_rgba(image_tensor) # HWC uint8 BGRA on GPU
+ exporter = self._lazy_init_ipc_exporter(bgra.shape[0], bgra.shape[1])
+ exporter.export_frame(bgra.data_ptr(), bgra.numel())
+ return None
+
+ # ... existing dispatch unchanged ...
+```
+
+Add helpers:
+
+```python
+def _ipc_pack_rgba(self, image_tensor):
+ # Convert pipeline output to HWC uint8 BGRA on GPU. cuda-link expects BGRA
+ # per CUDAIPCExporter docstring (cuda_ipc_exporter.py:11-22).
+ # image_tensor is NCHW float [0,1] for SDXL pipelines β see uses at
+ # wrapper.py:787, 812, 853.
+ if image_tensor.dim() == 4:
+ image_tensor = image_tensor[0] # CHW
+ x = (image_tensor.clamp(0, 1) * 255).to(torch.uint8) # CHW uint8
+ rgb = x.permute(1, 2, 0).contiguous() # HWC RGB
+ # BGRA = swap RβB, append alpha=255
+ bgra = torch.cat([
+ rgb[..., 2:3], rgb[..., 1:2], rgb[..., 0:1],
+ torch.full(rgb.shape[:-1] + (1,), 255, dtype=torch.uint8, device=rgb.device),
+ ], dim=-1).contiguous()
+ return bgra
+
+def _lazy_init_ipc_exporter(self, height, width):
+ if self._cuda_ipc_exporter is not None:
+ return self._cuda_ipc_exporter
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCExporter
+ exporter = CUDAIPCExporter(
+ shm_name=self._cuda_ipc_shm_name,
+ height=height, width=width,
+ channels=4, dtype="uint8",
+ num_slots=self._cuda_ipc_num_slots,
+ debug=False,
+ )
+ exporter.initialize() # required per cuda_ipc_exporter.py:338
+ self._cuda_ipc_exporter = exporter
+ return exporter
+```
+
+Add cleanup in the existing `cleanup`/`__del__` path (wrapper has a teardown method around L2658 in the abandoned-plan diff β confirm exact location during implementation):
+
+```python
+if self._cuda_ipc_exporter is not None:
+ self._cuda_ipc_exporter.cleanup()
+ self._cuda_ipc_exporter = None
+```
+
+#### 4b β `src/streamdiffusion/config.py`
+
+Add three new `param_map` entries so YAML keys reach the wrapper ctor:
+
+```python
+"use_cuda_ipc": "use_cuda_ipc",
+"cuda_ipc_shm_name": "cuda_ipc_shm_name",
+"cuda_ipc_num_slots": "cuda_ipc_num_slots",
+```
+
+(Exact dict location TBD during implementation β search `param_map` in `config.py`.)
+
+#### 4c β `StreamDiffusionTD/td_manager.py` β pass YAML through
+
+In the `create_wrapper_from_config` call site (L72) the config dict already flows to the wrapper through `config.py`'s param_map. The only TD-side concern is:
+
+- Read `self.td_settings.get('use_cuda_ipc', False)` once and stash on `self.use_cuda_ipc`.
+- When `use_cuda_ipc` is true, inject `cuda_ipc_shm_name = self.td_import_shm_name` into the config dict before passing to `create_wrapper_from_config`. This keeps the wrapper agnostic of the TD-perspective YAML naming.
+- Skip the CPU SHM connect/create for the **output direction** (L335-346) when IPC is on β the SHM doesn't need to exist on the Python side. Input-side CPU SHM (L330) still opens.
+
+#### 4d β Output frame send path
+
+In `_send_output_frame` (L643-740), early-return when `self.use_cuda_ipc` is true:
+
+```python
+def _send_output_frame(self, output_image):
+ if self.use_cuda_ipc:
+ # Frame was already exported via wrapper.postprocess_image IPC fast-path;
+ # output_image is None.
+ return
+ # ... existing CPU SHM write path unchanged ...
+```
+
+This requires `postprocess_image` returning `None` to propagate up to `_send_output_frame`'s caller (currently the streaming loop at L534) without breaking. Check `wrapper.__call__` or whichever method calls `postprocess_image` for the chain.
+
+### Step 5 β StreamDiffusionExt YAML emitter (manual TD edit)
+
+Out of repo. Inside `StreamDiffusionTD_dev.toe`, the StreamDiffusionExt extension's YAML-emitter code (previously at extracted L3754-3768 in the old Scripts/-synced copy) must be updated to emit:
+
+```python
+yaml_lines.append(f" td_export_shm_name: '{parent.SDTD.par.Streamoutname}_output_ipc'")
+yaml_lines.append(f" td_import_shm_name: '{parent.SDTD.par.Streamoutname}_input_ipc'")
+yaml_lines.append(f" use_cuda_ipc: {bool(parent.SDTD.par.Usecudaipc)}") # new toggle param on SDTD COMP
+```
+
+The user has to make this edit directly inside TouchDesigner. The CUDAIPCExtension COMPs inside the .toe also need their `Ipcmemname` params bound to the same `Streamoutname + '_input_ipc'` (Receiver-mode COMP) so SD and TD agree on the SHM identity.
+
+### Step 6 β Verification
+
+```bash
+# Static rename completeness
+grep -rn "input_mem_name\|output_mem_name" StreamDiffusion/StreamDiffusionTD/ src/streamdiffusion/ configs/ 2>&1
+# β zero hits
+
+# Import smoke (pre-existing controlnet_aux missing is unrelated; wrapper module itself loads)
+cd StreamDiffusion && python -c "from streamdiffusion.wrapper import StreamDiffusionWrapper; print('import OK')"
+
+# IPC class loadability
+python -c "from streamdiffusion._compat.cuda_ipc import CUDAIPCExporter; print(CUDAIPCExporter.__init__.__doc__[:200])"
+```
+
+End-to-end TD test:
+1. Open `StreamDiffusionTD_dev.toe`. Set `parent.SDTD.par.Streamoutname = "StreamDiffusionTD_512-512"` and `Usecudaipc = True`.
+2. Update StreamDiffusionExt YAML emitter per Step 5; re-emit `td_config.yaml`.
+3. Restart Python pipeline. Verify console shows `Memory: <- TD | SD -> _input_ipc`.
+4. Trigger a frame from TD's input. Confirm SD's output appears in TD's Receiver COMP at β₯16 FPS (matching abandoned Phase 2.1's 16-25 FPS).
+5. Hot-disable IPC: set `use_cuda_ipc: false` in YAML, restart, confirm legacy CPU SHM path still works (regression check).
+
+## Critical files & key references
+
+**Modified (this plan)**:
+- `StreamDiffusionTD/td_config.yaml` L73-74 β YAML keys rename + IPC toggle add
+- `configs/td_config.yaml.example` L95-96 β same
+- `StreamDiffusionTD/td_manager.py` L40, L41-42, L98-100, L314-322, L330-346, L358, L643-740 β ctor + internal refs + output send early-return
+- `StreamDiffusionTD/td_main.py` L399-418, L450, L462 β drop timestamp uniquifier, update kwargs + log
+- `src/streamdiffusion/wrapper.py` L82-160 (ctor kwargs), L312 (instance slots), L894 (`postprocess_image` IPC fast-path), L2658-area (cleanup) β IPC wiring
+- `src/streamdiffusion/config.py` β three new `param_map` entries
+- StreamDiffusionExt YAML emitter (in `.toe` binary, out of repo) β user edits in TD
+
+**Reused verbatim (no edits)**:
+- `src/streamdiffusion/_compat/cuda_ipc/` β pristine cuda-link `92989fc`. Integration surface: `CUDAIPCExporter` at `cuda_ipc_exporter.py:191` (ctor L208-218, `initialize()` L338, `export_frame(gpu_ptr, size)` L705, `cleanup()` L957). BGRA HWC uint8 wire format per docstring L11-22.
+- `src/streamdiffusion/_compat/td_exporter/` β pristine cuda-link TD-side. `TDSender.py:70` (`_EXPORT_BUFFER_NAME = "ExportBuffer"`), `CUDAIPCExtension.py:98` (`Ipcmemname` default `cudalink_output_ipc`), `TDReceiver.py:171-190` (`RetryState` β relevant for Step 5 user-side TD config).
+
+**Out of scope (next round)**:
+- Input direction (TDβSD): `CUDAIPCImporter` wiring inside `_get_input_frame`. Requires solving the first-connect `traceback.print_exc()` noise at `cuda_ipc_importer.py:808` via SHM-existence probe before construction (the abandoned-Phase-2.4 approach). The YAML key `td_export_shm_name` is reserved for this.
+- Migration of deployed user `td_config.yaml` files. Per user choice, manual update.
+- `controlnet_aux` import error (pre-existing missing dependency unrelated to this work).
+
+## Memory note
+
+Per [[feedback_save_plans_as_project_files]], after ExitPlanMode and on user approval, copy this plan to `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/_plans/2026-05-17_redo-cuda-ipc-output-direction.md` as the project-tracked copy.
diff --git a/_plans/2026-05-17_restore-to-pre-cuda-ipc.md b/_plans/2026-05-17_restore-to-pre-cuda-ipc.md
new file mode 100644
index 000000000..a5bb2e2fe
--- /dev/null
+++ b/_plans/2026-05-17_restore-to-pre-cuda-ipc.md
@@ -0,0 +1,156 @@
+# Restore branch to "Quantization landed, CUDA IPC just started"
+
+## Context
+
+The current branch `feat/quantization-robustness` (HEAD `1a8065f`, 9 ahead of `origin/SDTD_031_dev`) carries an in-progress CUDA IPC integration in the working tree. The previous attempt (`_plans/drifting-twirling-tulip.md`) adopted a **StreamDiffusion-perspective** naming convention: cuda-link's `Sender`/`Receiver` modes were aliased to SD-side `input`/`output` directions, and YAML keys like `use_cuda_ipc_output` / `cuda_ipc_input_shm_name` were introduced. Phases 2.1β2.3 landed (PythonβTD output working at 16β25 FPS), but Phase 2.4 (input direction first-connect retry) ran into a `traceback.print_exc()` noise issue in upstream `cuda_ipc_importer.py:809` that the manager-side probe was working around.
+
+User has decided to **abandon this perspective** and redo the integration with cuda-link's vocabulary kept verbatim (canonical names per cuda-link 1.4.1 @ `92989fc`: `ExportBuffer`/`ImportBuffer` TOPs, `Sender`/`Receiver` modes, `CUDAIPCExporter`/`CUDAIPCImporter` classes, `export_frame`/`import_frame` methods, `Ipcmemname` param). The rename direction flips: instead of mapping cuda-link β SD-side names, rename SD-side names β cuda-link.
+
+This plan covers **only the restore**. The naming-flip redo is a separate task.
+
+## Target end state
+
+- HEAD unchanged at `1a8065f` (post-Quantization, robustness fixes preserved).
+- `git status` shows **clean working tree** for tracked files. Three diffs reverted: `configs/td_config.yaml.example`, `src/streamdiffusion/config.py`, `src/streamdiffusion/wrapper.py`.
+- Untracked vendored sources **preserved as-is** (per user choice "Keep vendored, drop only glue"):
+ - `src/streamdiffusion/_compat/cuda_ipc/` (10 files, cuda-link @ `92989fc`, VENDORED_VERSION.txt intact)
+ - `src/streamdiffusion/_compat/td_exporter/` (TD-side vendored, no `__init__.py` by design)
+- Scripts/ and StreamDiffusionTD/ files cleaned of all IPC integration glue (TouchDesignerManager.ipc_input_importer, `_try_construct_ipc_importer`, cuda_ipc_input_shm_name handling, ExtensionExt YAML emit lines).
+- Safety net in place: `git stash` entry + archive branch `archive/cuda-ipc-input-perspective` pointing at current HEAD with all WIP tree captured.
+- Abandoned plan archived under `_plans/archive/` with header note.
+
+## Execution (in order)
+
+All paths relative to `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/` unless absolute.
+
+### Step 1 β Safety net first
+
+```bash
+# (1a) Create archive branch tag at current HEAD (no checkout β stays on feat/quantization-robustness)
+git branch archive/cuda-ipc-input-perspective HEAD
+
+# (1b) Stash everything including untracked, with a descriptive message
+git stash push -u -m "cuda-ipc-input-perspective WIP 2026-05-17 β before naming-flip restore"
+
+# (1c) Verify stash captured all of: 3 tracked diffs + _compat/cuda_ipc/ + _compat/td_exporter/ + _plans/
+git stash show -u stash@{0} --stat
+```
+
+Result: stash entry exists, archive branch points at `1a8065f`, working tree is **already clean** for tracked files post-stash. The untracked dirs are also in the stash and will be gone from the working tree.
+
+### Step 2 β Restore vendored sources from the stash (cherry-pick the keepers)
+
+`git stash apply` would re-apply everything (including the glue we want gone). Instead, restore only the vendored sources directly from the stash:
+
+```bash
+# Restore the two vendored dirs from stash@{0} (untracked portion)
+git checkout stash@{0}^3 -- src/streamdiffusion/_compat/cuda_ipc/
+git checkout stash@{0}^3 -- src/streamdiffusion/_compat/td_exporter/
+```
+
+Note: `stash@{0}^3` is the untracked-files commit inside a `git stash -u` stash. Verify with `git log --oneline stash@{0}^3` before running β if the stash layout differs (some git versions), use `git stash show -u --name-only stash@{0}` and then `git restore --source=stash@{0} --staged --worktree -- ` instead.
+
+Verify VENDORED_VERSION.txt files still show `head_commit: 92989fc` and `vendored: 2026-05-17`.
+
+### Step 3 β Surgical removal of TD-side IPC glue
+
+Per user choice. Edit both `Scripts/` (source-of-truth, feeds .tox at runtime per [[project_scripts_dir_purpose]]) **and** `StreamDiffusionTD/` (gitignored mirror β same bytes, fed back into the .tox export).
+
+**Critical**: `diff -q` confirmed Scripts/ and StreamDiffusionTD/ copies are byte-identical (48713 bytes). Edit Scripts/ first, then copy to StreamDiffusionTD/.
+
+#### Edit set A β `Scripts/streamdiffusionTD__Text__td_manager__td.py` (48713 bytes)
+
+Remove these exact regions (line numbers as of current working tree):
+- **L84, L88**: instance-attr declarations `self.ipc_input_importer = None`, `self._ipc_importer_cls = None`, and any sibling pending-name attr (`_pending_ipc_input_name`, `_ipc_input_connected_logged`) β sweep the `__init__` body for any `ipc_`/`_ipc_` attr and remove.
+- **L336β359**: the entire `if ipc_input_name := self.td_settings.get('cuda_ipc_input_shm_name'):` block inside `_initialize_memory_interfaces` (the probe-construct-fallback logic). Restore the pre-existing CPU SHM fallback to be the only input path.
+- **L424β429**: the `if self.ipc_input_importer is not None: ... cleanup()` block in the cleanup path.
+- **L672β683**: the IPC retry probe + `get_frame_numpy()` path inside `_get_input_frame`. Restore plain CPU SHM read.
+- **L703β740**: the `_try_construct_ipc_importer` method definition + any `_probe_ipc_shm_exists` helper. Delete wholesale.
+
+Verification grep after edits: `grep -nE "(ipc_input|cuda_ipc_input|_try_construct_ipc|_probe_ipc_shm|CUDAIPCImporter)" Scripts/streamdiffusionTD__Text__td_manager__td.py` must return zero hits.
+
+#### Edit set B β `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py`
+
+Per Explore findings:
+- **L3754βL3756**: the three writes for `use_cuda_ipc_output`, `cuda_ipc_shm_name`, `cuda_ipc_num_slots` in the YAML emitter. Remove all three.
+- **L3768**: the `cuda_ipc_input_shm_name` write. Remove. (L3767 `input_mem_name` and L3769 `output_mem_name` are pre-existing β keep them.)
+
+Verification grep: `grep -nE "(use_cuda_ipc_output|cuda_ipc_shm_name|cuda_ipc_num_slots|cuda_ipc_input_shm_name)" Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` must return zero hits.
+
+#### Edit set C β Sync to StreamDiffusionTD/
+
+```bash
+# After Scripts/ edits are complete:
+cp Scripts/streamdiffusionTD__Text__td_manager__td.py StreamDiffusion/StreamDiffusionTD/td_manager.py
+cp Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py StreamDiffusion/StreamDiffusionTD/StreamDiffusionExt.py # verify exact filename
+# Path-from-parent for the cp source: D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/Scripts/
+# Path-from-parent for the cp dest: D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/StreamDiffusionTD/
+```
+
+Note: the StreamDiffusionTD/ filenames may differ from Scripts/ names (Scripts/ uses TD's `______td.py` convention; StreamDiffusionTD/ uses bare module names). Confirm exact mapping via `ls StreamDiffusion/StreamDiffusionTD/*.py` before cp.
+
+### Step 4 β Archive the abandoned plan
+
+```bash
+mkdir -p _plans/archive
+git mv _plans/drifting-twirling-tulip.md _plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md 2>/dev/null || mv _plans/drifting-twirling-tulip.md _plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md
+```
+
+Prepend this header to the archived file:
+
+```markdown
+> **ABANDONED 2026-05-17.** This plan used SD-perspective naming (cuda-link SenderβSD input, ReceiverβSD output, YAML `use_cuda_ipc_output` etc.). The integration is being redone with cuda-link's vocabulary kept verbatim and SD-side names renamed instead. Kept as historical reference for the working Phase 2.1 BGRA repack approach and the Phase 2.4 SHM-probe-before-construct trick.
+
+```
+
+### Step 5 β Final verification
+
+```bash
+git status # β only _plans/archive/ should appear (untracked)
+git diff --stat # β empty (no tracked-file diffs)
+ls src/streamdiffusion/_compat/cuda_ipc/ # β 10 files + VENDORED_VERSION.txt
+ls src/streamdiffusion/_compat/td_exporter/ # β vendored TD scripts still present
+grep -rn "cuda_ipc_input_shm_name\|ipc_input_importer\|_try_construct_ipc_importer\|use_cuda_ipc_output" Scripts/ StreamDiffusion/StreamDiffusionTD/ 2>&1
+# β zero hits
+
+# Confirm restored state is loadable
+cd StreamDiffusion && python -c "from streamdiffusion.wrapper import StreamDiffusionWrapper; print('import OK')"
+# β "import OK" (the wrapper.py revert dropped the IPC code paths cleanly)
+```
+
+Then a TD smoke test: open `StreamDiffusionTD_dev.toe` and confirm the legacy CPU SHM path (`input_mem_name`/`output_mem_name`) still drives a frame end-to-end. No CUDA IPC paths active.
+
+## Critical files & key references
+
+**Modified (revert via stash mechanism):**
+- `src/streamdiffusion/wrapper.py` β L134-136 ctor kwargs, L320-321 + L336-338 instance slots, L931-947 IPC fast-path inside `postprocess_image`, L995-1009 `_ipc_pack_rgba`, L1011-1040 `_lazy_init_ipc_exporter`, L2658-2666 cleanup. All from the WIP diff.
+- `src/streamdiffusion/config.py` β L160-162 three new `param_map` entries.
+- `configs/td_config.yaml.example` β L85-92 new YAML keys.
+
+**Kept as-is (vendored, preserved):**
+- `src/streamdiffusion/_compat/cuda_ipc/` β cuda-link Python sources, untouched upstream copy. The `CUDAIPCExporter` (`cuda_ipc_exporter.py:191-1163`) and `CUDAIPCImporter` (`cuda_ipc_importer.py:479-`) classes will be the integration surface in the redo.
+- `src/streamdiffusion/_compat/td_exporter/` β cuda-link TD-side vendored scripts. `TDSender.py:70` defines `_EXPORT_BUFFER_NAME = "ExportBuffer"`, `CUDAIPCExtension.py:169` defines `def import_frame(self, import_buffer: TOP)`. These names are canonical and will be adopted on the SD side in the redo.
+
+**Edited surgically (per Step 3):**
+- `Scripts/streamdiffusionTD__Text__td_manager__td.py` (== `StreamDiffusionTD/td_manager.py`, gitignored mirror)
+- `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` (== gitignored mirror)
+
+**Archived:**
+- `_plans/drifting-twirling-tulip.md` β `_plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md`
+
+## Next-task sketch (NOT executed in this plan β for continuity only)
+
+The redo will rename SD-side concepts to cuda-link vocabulary. Concrete mapping derived from MCP search of canonical names (cuda-link `92989fc`, README L61, `CUDAIPCExtension.py` L98, `TDSender.py:70`, `TDReceiver.py:313`):
+
+| Current SD-side name | Cuda-link canonical | Blast radius |
+|---|---|---|
+| `td_settings.input_mem_name` (YAML) | `td_settings.export_buffer_shm_name` (or just `Ipcmemname` per TD COMP convention) | HIGH β in every user's local `td_config.yaml` |
+| `td_settings.output_mem_name` | `td_settings.import_buffer_shm_name` | HIGH β same |
+| `TouchDesignerManager.__init__(input_mem_name, output_mem_name, ...)` | `TouchDesignerManager.__init__(export_shm_name, import_shm_name, ...)` | MED β only `td_main.py:399-402` calls positionally |
+| `image` param on `wrapper.__call__/img2img` | **KEEP** β `image` is a general PyTorch convention, not a TD-IPC concept | n/a |
+| `image_tensor` locals | **KEEP** β same reasoning | n/a |
+| `_process_skip_diffusion` `preprocessor_input/output` locals | **KEEP** β these are preprocessor I/O, not TD-IPC | n/a |
+
+The interesting rename surface is the **TDβSD transport boundary** (memory names, mode enum, helper class names) β not the internal PyTorch tensor flow. Most of the wrapper.py rename pressure dissolves once we see that `image`/`image_tensor` are PyTorch idiom, orthogonal to cuda-link.
+
+Memory note per [[feedback_save_plans_as_project_files]]: after ExitPlanMode and on user approval, copy this plan to `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/_plans/2026-05-17_restore-to-pre-cuda-ipc.md` as the project-tracked copy.
diff --git a/_plans/2026-05-17_zero-copy-gpu-input.md b/_plans/2026-05-17_zero-copy-gpu-input.md
new file mode 100644
index 000000000..6ca3f9048
--- /dev/null
+++ b/_plans/2026-05-17_zero-copy-gpu-input.md
@@ -0,0 +1,256 @@
+# True zero-copy GPU input β close the input/output asymmetry
+
+> **Hand-off from `wiggly-finding-puzzle.md` (CUDARuntimeTypes import fix, commit `eecb9f5`).** Input direction transport now works end-to-end (logs at 21:38:15 confirm `CUDA IPC input ready: shm=StreamDiffusionTD_512-512_input_ipc`). But the input *payload* still detours through CPU. Output direction is already true zero-copy. This plan brings input to parity.
+
+## Context
+
+After commit `eecb9f5` (CUDARuntimeTypes import fix), both directions of the TD β SD bridge use the same vendored `_compat/cuda_ipc/` package as IPC transport. But the data flow is asymmetric:
+
+| Direction | Transport | Payload handling | Notes |
+|---|---|---|---|
+| SD β TD (output) | `CUDAIPCExporter` | **Stays on GPU end-to-end** | `wrapper._ipc_pack_rgba` builds BGRA on GPU β `exporter.export_frame(data_ptr, numel)` |
+| TD β SD (input) | `CUDAIPCImporter` | **D2H β numpy β CPU preprocess β H2D** | `_get_input_frame_cuda_ipc` does `.cpu().numpy()`, then loop does CPU float-cast, then `VaeImageProcessor.preprocess` runs on CPU, then pinned-buffer H2D |
+
+The input direction throws away the zero-copy GPU tensor the Importer already hands out, then re-materializes it via a CPU round-trip β pure waste.
+
+### What's already in place (verified by Phase 1 Explore)
+
+- **`CUDAIPCImporter.get_frame()`** (`src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:903`) already returns a **persistent, zero-copy `torch.Tensor` on CUDA** built in `TorchBuffers.build` (lines 282-316) via `__cuda_array_interface__` + `torch.as_tensor`. Shape on the TD wire: **HWC float32 BGRA, range [0,1]**. Optional `stream=` kwarg adds `cudaStreamWaitEvent` so the consumer's stream waits GPU-side for the producer's IPC event β no CPU sync.
+- **Pipeline fast-path already exists**: `StreamDiffusion.__call__` at `src/streamdiffusion/pipeline.py:1024-1039` skips `image_processor.preprocess` and the H2D staging copy when `x` is a CUDA tensor with `dtype == self.dtype` and `shape[-2:] == (self.height, self.width)`. Hit those three conditions exactly and the pipeline runs zero-copy. Every example script (`examples/benchmark/single.py:107-119`, `examples/img2img/single.py`, `examples/screen/main.py`, `examples/vid2vid/main.py`) already uses this pattern: `image_tensor = stream.preprocess_image(...); stream(image=image_tensor)`.
+- **`wrapper.img2img`** (`src/streamdiffusion/wrapper.py:856-860`) already passes a `torch.Tensor` straight through to `self.stream(image)` β only `str`/`Image.Image` inputs go through `preprocess_image`. No wrapper-side changes needed.
+- **GPU BGRAβRGB indexer is already inline** at `td_manager.py:687`: `gpu_frame[..., [2, 1, 0]].contiguous()`. We keep that step; we just stop scaling/casting/D2H afterward.
+- **ControlNet & IPAdapter use independent SHM streams** (`self.control_memory`, `self.ipadapter_memory`) β they don't read the main input frame, so the input zero-copy plan has **no coupling** to those features. They keep their CPU paths unchanged.
+
+### What the current code does (the waste)
+
+`td_manager.py:684-689` (current `_get_input_frame_cuda_ipc` tail):
+
+```python
+gpu_frame = self._cuda_ipc_importer.get_frame() # zero-copy torch.Tensor on GPU [HWC float32 BGRA, [0,1]]
+if gpu_frame is None:
+ return None
+rgb = gpu_frame[..., [2, 1, 0]].contiguous() # BGRA β RGB (drop alpha) β KEEP this step on GPU
+rgb_u8 = (rgb.clamp(0, 1) * 255).to(torch.uint8) # β WASTE: rescale up to uint8 just to rescale back down
+return rgb_u8.cpu().numpy() # β WASTE: D2H + numpy (~0.5-1ms per frame, syncs the stream)
+```
+
+Followed by `td_manager.py:527-529`:
+
+```python
+if input_image.dtype == np.uint8:
+ input_image = input_image.astype(np.float32) / 255.0 # β WASTE: CPU rescale back to [0,1]
+```
+
+Followed by `pipeline.py:1034-1039`:
+
+```python
+_raw = self.image_processor.preprocess(x, self.height, self.width) # β WASTE: CPU resize/normalize on numpy
+self._input_staging.copy_(_raw) # β WASTE: CPUβpinned copy
+x = self._input_staging.to(device=self.device, non_blocking=True) # β WASTE: H2D copy of data we already had on GPU
+```
+
+Net effect: ~3 redundant rescales, 1 D2H, 1 H2D, 1 stream-blocking sync, per frame. The user's success logs already show occasional `total_time` jitter spikes (298β1133β709Β΅s) and 100Γ-outlier memcpy spikes (~612Β΅s at Frames 1746, 2231) which are consistent with the D2H sync interacting with WDDM scheduling β eliminating this path should shrink both the steady-state median and the tail.
+
+## Approach
+
+Replace the D2H tail of `_get_input_frame_cuda_ipc` with a small GPU-only transform that **lands the tensor exactly in the shape/dtype/range the pipeline fast-path expects**, then route it through `_streaming_loop` as a `torch.Tensor` instead of `np.ndarray`. Mirror exactly what the output direction already does.
+
+### The GPU transform (single chained op, all in-flight on the current CUDA stream)
+
+```python
+# gpu_frame: HWC float32 BGRA on GPU, range [0,1] β from Importer.get_frame()
+# target: NCHW self.dtype on GPU, range [-1,1] β pipeline fast-path expects this exactly
+nchw = (
+ gpu_frame[..., [2, 1, 0]] # HWC float32 RGB [0,1] (drop alpha, channel swap β free view+gather)
+ .mul(2.0).sub_(1.0) # HWC float32 RGB [-1,1] (matches VaeImageProcessor.normalize: img*2-1)
+ .permute(2, 0, 1) # CHW float32 RGB [-1,1] (free view)
+ .unsqueeze(0) # NCHW float32 RGB [-1,1] (N=1) (free view)
+ .to(dtype=self.wrapper.stream.dtype, non_blocking=True)
+ .contiguous() # required: permute leaves strides non-contiguous; pipeline downstream expects contiguous NCHW
+)
+```
+
+Cost: one channel-gather kernel + one elementwise scale + one dtype cast, all bandwidth-bound on a 512Γ512Γ3 tensor (~1.5MB at fp16) β **well under 100Β΅s** on the user's RTX 4090, vs the ~0.5-1ms the D2H currently burns. Plus we delete the `astype(float32)/255.0` CPU op and the entire `VaeImageProcessor.preprocess` + pinned-buffer H2D path.
+
+### Optional stream-sync hardening
+
+`CUDAIPCImporter.get_frame(stream=...)` accepts a CUDA stream and inserts `cudaStreamWaitEvent` so the consumer's stream waits GPU-side for the producer's IPC event. Pass `torch.cuda.current_stream()._as_parameter_` so the wait happens inside the pipeline's own stream β eliminates the small CPU-side wait the Importer currently does internally. **Defer to Phase 2** if the v1 path lands cleanly; v1 should work without it.
+
+## Code changes
+
+### Patch 1 β `StreamDiffusionTD/td_manager.py:661-689` β rewrite `_get_input_frame_cuda_ipc` tail
+
+Replace lines 661-689 with:
+
+```python
+def _get_input_frame_cuda_ipc(self) -> Optional["torch.Tensor"]:
+ """Read one frame from TD's CUDAIPCExporter and return a GPU torch.Tensor
+ matching the pipeline's fast-path: NCHW, self.stream.dtype, range [-1,1],
+ on CUDA, shape (1, 3, height, width). Bypasses image_processor.preprocess
+ and pinned-buffer H2D entirely."""
+ if self._cuda_ipc_importer is None:
+ if not self._probe_ipc_input_shm():
+ return None
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter
+ try:
+ self._cuda_ipc_importer = CUDAIPCImporter(
+ shm_name=self.cuda_ipc_input_shm_name,
+ debug=False,
+ )
+ except Exception as e:
+ logger.warning(f"CUDAIPCImporter init failed: {e}")
+ self._cuda_ipc_importer = None
+ return None
+ if not self._cuda_ipc_importer.is_ready():
+ self._cuda_ipc_importer = None
+ return None
+ logger.info(f"CUDA IPC input ready (zero-copy GPU): shm={self.cuda_ipc_input_shm_name}")
+
+ gpu_frame = self._cuda_ipc_importer.get_frame() # HWC float32 BGRA on GPU, [0,1]
+ if gpu_frame is None:
+ return None
+
+ target_dtype = self.wrapper.stream.dtype
+ nchw = (
+ gpu_frame[..., [2, 1, 0]] # HWC RGB float32 [0,1]
+ .mul(2.0).sub_(1.0) # HWC RGB float32 [-1,1]
+ .permute(2, 0, 1) # CHW RGB float32 [-1,1]
+ .unsqueeze(0) # NCHW (N=1)
+ .to(dtype=target_dtype, non_blocking=True)
+ .contiguous()
+ )
+ return nchw
+```
+
+Key shifts vs current code:
+- Return type: `Optional[np.ndarray]` β `Optional[torch.Tensor]`
+- Drops `clamp(0,1)*255 β uint8 β .cpu().numpy()` (3 ops, 1 stream-sync, 1 D2H)
+- Adds `mul(2).sub_(1) β permute β unsqueeze β to(dtype)` (4 ops, all in-flight on GPU)
+- Log line gains "(zero-copy GPU)" to make the success criterion trivially greppable
+
+### Patch 2 β `StreamDiffusionTD/td_manager.py:643-659` β update dispatcher return type
+
+The legacy SHM path still returns numpy uint8 RGB. The dispatcher needs to allow both. Change the return annotation only:
+
+```python
+def _get_input_frame(self) -> Optional[Union["np.ndarray", "torch.Tensor"]]:
+ """Get input frame from TouchDesigner (platform-specific).
+
+ Returns:
+ torch.Tensor (NCHW, self.stream.dtype, [-1,1] on CUDA) β CUDA IPC path
+ np.ndarray (HWC uint8 RGB) β legacy SHM / Syphon paths
+ """
+ # body unchanged
+```
+
+(Imports: `Union` is already in scope at top of file β verify in patch.)
+
+### Patch 3 β `StreamDiffusionTD/td_manager.py:521-549` β branch on input type in `_streaming_loop`
+
+Replace lines 527-529 (the uint8βfloat32 CPU cast) with a type guard so the CPU path runs only for numpy inputs:
+
+```python
+# img2img mode: get input frame and process
+input_image = self._get_input_frame()
+if input_image is None:
+ time.sleep(0.001)
+ continue
+
+# CUDA IPC fast-path returns a ready-to-consume GPU tensor; legacy SHM path returns
+# HWC uint8 RGB numpy which still needs the CPU float-cast.
+if isinstance(input_image, np.ndarray) and input_image.dtype == np.uint8:
+ input_image = input_image.astype(np.float32) / 255.0
+```
+
+No other changes in this block. `self.wrapper.img2img(input_image)` already passes `torch.Tensor` straight through to `self.stream(image)` (`wrapper.py:856-860`), which hits the pipeline fast-path because we constructed the tensor to satisfy all three checks at `pipeline.py:1028-1033`.
+
+### What we explicitly do NOT touch
+
+- `wrapper.img2img` β already correct; tensor passthrough is its existing behavior
+- `wrapper.preprocess_image` β only used by examples, not by the TD streaming loop
+- `pipeline.__call__` fast-path β already exists and is the contract we satisfy
+- `_input_staging` allocation in `pipeline.prepare()` β stays allocated but unused on the IPC path; ~1.5MB pinned, not worth conditional logic
+- `_process_controlnet_frame` / `_process_ipadapter_frame` β independent SHM streams, no coupling
+- `_send_output_frame` / `postprocess_image` β output direction already zero-copy
+- `_compat/cuda_ipc/` β no changes; the existing API is sufficient
+
+## Verification
+
+After applying all three patches in the running SD venv (no rebuild needed β pure Python; `Scripts/` edits live-reload per `[[project_scripts_dir_purpose]]`):
+
+### 1. Smoke test the type contract
+
+Before launching the .toe, run this in SD's venv to confirm the new return-type plumbing parses:
+
+```powershell
+venv\Scripts\python -c "from StreamDiffusionTD.td_manager import TouchDesignerManager; print('OK')"
+```
+
+Must print `OK`. Any `SyntaxError` / `NameError` / `ImportError` means a patch is wrong β stop and re-read the file.
+
+### 2. Functional test β relaunch the .toe
+
+Watch SD cmd log for:
+- β
`CUDA IPC input ready (zero-copy GPU): shm=StreamDiffusionTD_512-512_input_ipc` (new log marker proving Patch 1 took effect)
+- β
No `_get_input_frame:` debug exceptions
+- β
Frames stream at β₯ the prior baseline (~20-26 FPS in success logs from 2026-05-17 21:38)
+- β
Clean shutdown on OSC `/stop` (no leaked tensor refs / IPC handle errors)
+
+### 3. Visual round-trip check (TD)
+
+TD Receiver COMP should show:
+- Correct colors (BGRAβRGB shuffle still happens, just on GPU now β verify no R/B swap)
+- No tone shift or banding (`mul(2).sub_(1)` followed by VAE encode is the same arithmetic the CPU path performed via `astype(float32)/255 β VaeImageProcessor.normalize`, so output should be visually identical)
+- No flicker / dropped frames
+
+### 4. Performance verification
+
+In SD log, compare these metrics against the 2026-05-17 21:38 baseline:
+- **Steady-state `total_time`**: expected ~0.5-1ms lower (D2H + CPU rescale + H2D eliminated)
+- **`total_time` jitter** (max/min spread): expected meaningfully tighter, since the stream-blocking `.cpu()` sync is gone
+- **GPU memcpy spikes** (the ~612Β΅s Frame 1746 / 2231 outliers): may or may not disappear β those might be unrelated allocator/WDDM events, but at minimum we've removed one possible source
+
+Optional deeper check with `nsys`:
+
+```powershell
+nsys profile -o input_zerocopy_after --trace cuda,nvtx --capture-range cudaProfilerApi `
+ venv\Scripts\python -m StreamDiffusionTD.main_sdtd ...
+```
+
+Then `nsys analyze input_zerocopy_after.nsys-rep` β the input direction should show **zero `cudaMemcpyAsync DtoH` calls** between consecutive `cudaGraphLaunch` calls.
+
+## Commit
+
+Per `[[feedback_pr_branch_convention]]`, branch stays at `feat/cuda-ipc-output` (current head: `eecb9f5`), PR target `SDTD_031_dev`.
+
+```powershell
+./scripts/git/commit_enhanced.sh --no-venv `
+ "feat: true zero-copy GPU input via CUDA IPC (close input/output asymmetry)"
+```
+
+Then save the plan as a project file per `[[feedback_save_plans_as_project_files]]`:
+- Copy this file to `StreamDiffusion/_plans/2026-05-17_zero-copy-gpu-input.md`
+
+## Critical files
+
+| File | Lines | Change |
+|---|---|---|
+| `StreamDiffusionTD/td_manager.py` | 661-689 | Rewrite `_get_input_frame_cuda_ipc` body β drop D2H, add GPU NCHW transform |
+| `StreamDiffusionTD/td_manager.py` | 643-659 | Update `_get_input_frame` return-type annotation only |
+| `StreamDiffusionTD/td_manager.py` | 521-549 | Branch CPU float-cast on `isinstance(np.ndarray)` |
+
+Reused unchanged (verified by Phase 1 Explore):
+
+- `src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:903` β `get_frame()` already returns zero-copy GPU torch.Tensor
+- `src/streamdiffusion/wrapper.py:834-873` β `img2img` already passes `torch.Tensor` straight through
+- `src/streamdiffusion/pipeline.py:1024-1039` β fast-path already exists for tensors matching `is_cuda + dtype + (H,W)`
+- `src/streamdiffusion/wrapper.py:921-925, 967-975` β output direction reference pattern (`_ipc_pack_rgba` + `export_frame(data_ptr, β¦)`)
+
+## Out of scope
+
+- **Stream-sync hardening** via `get_frame(stream=current_stream)`. Defer to a v2 if v1 lands stable. The Importer's internal CPU wait is microseconds and not the dominant cost; we can come back for the last 5%.
+- **`_input_staging` deallocation**. Stays as ~1.5MB pinned dead weight on the IPC path; not worth a conditional in `pipeline.prepare()`.
+- **ControlNet / IPAdapter zero-copy**. Independent SHM streams, independent CPU rescales. Would be a similar 3-patch refactor per feature; not blocking the main img2img path.
+- **GPU memcpy spike investigation** (~612Β΅s at Frames 1746 / 2231). If they persist after this change, that's a separate diagnosis (likely WDDM preemption or allocator pressure, not input-path related).
+- **BGRA β RGB correctness audit**. The shuffle is byte-identical to the prior code; this plan only changes when (`.cpu()`) and how (`.mul().sub_().permute().unsqueeze().to(dtype)`) we use it.
+- **A shared `_ipc_unpack_input` utility** in `wrapper.py` to mirror `_ipc_pack_rgba`. Could DRY the two directions; defer until a second consumer appears.
diff --git a/_plans/2026-05-18_controlnet-ipc-stream-capture-fix.md b/_plans/2026-05-18_controlnet-ipc-stream-capture-fix.md
new file mode 100644
index 000000000..9dc3491e4
--- /dev/null
+++ b/_plans/2026-05-18_controlnet-ipc-stream-capture-fix.md
@@ -0,0 +1,56 @@
+# ControlNet CUDA IPC β TRT graph-capture conflict: fix + verification
+
+> **RESOLVED 2026-05-18** β v4 fix applied and verified (cold-start with CN scale=0.577).
+> Predecessor: `_plans/2026-05-17_controlnet-ipc-emitter-fix.md` (hypothesis A confirmed, v1 fix failed, v2/v3 partial, v4 final).
+
+## Root cause
+
+TRT's internal `genericReformat::copyPackedRunKernel` β invoked at the CN engine's input boundary to reformat the `controlnet_cond` tensor β submits work to the legacy/NULL CUDA stream during `execute_async_v3`. When the CN engine's polygraphy stream is in CUDA-graph capture mode (`cudaStreamBeginCapture β¦ cudaStreamEndCapture`), that legacy-stream submission violates the capture rules:
+
+> `operation would make the legacy stream depend on a capturing blocking stream`
+> `cudaErrorStreamCaptureInvalidated (901)`
+
+The polygraphy `Stream` class (venv `polygraphy/cuda/cuda.py:111`) is created via `cudaStreamCreate` with no flags β **blocking** by default β implicitly synchronizes with legacy. Any GPU op submitted to legacy during the capture window β whether from user code or TRT internals β invalidates the capture.
+
+## Why earlier fixes failed
+
+| Fix | What it did | Why it failed |
+|---|---|---|
+| v1 | dedicated non-blocking import stream + `wait_stream` bridge | `wait_stream` re-coupled legacy to the pending IPC event β same 901 |
+| v2 | `get_frame()` with no `stream=` arg β CPU `cudaEventQuery` poll | Fixes warm-activation (OSC enable mid-stream). Fails cold-start: IPC tensor transforms still queue on legacy before capture β but that's not the real problem |
+| Stage A | `CUDALINK_USE_GRAPHS=0` | Disproved; exporter graphs are irrelevant |
+| v3 | `torch.cuda.current_stream().synchronize()` before `cudaStreamBeginCapture` | Drains legacy pre-capture. Fails because `genericReformat::copyPackedRunKernel` runs **inside** the capture window β pre-capture drain cannot prevent it |
+
+## Fix (v4)
+
+**`wrapper.py:2208` β `use_cuda_graph=False` for ControlNet engines.**
+
+`use_cuda_graph=True` was hard-coded when constructing every CN TRT engine, regardless of input tensor format or graph-capture compatibility. Setting it to `False` keeps the CN engine in TRT-accelerated mode but skips the CUDA-graph wrapping:
+
+- No `cudaStreamBeginCapture` is ever called on the CN engine stream.
+- `genericReformat::copyPackedRunKernel`'s legacy-stream submission is harmless.
+- CN inference retains all TRT kernel/tactic optimizations.
+- Cost: CN loses the WDDM batch-submission savings (~hundreds of Β΅s per forward on Windows WDDM). Measured impact: steady-state FPS β 18-25 vs 19-28 with graph capture.
+
+### Changes committed
+
+| File | Change |
+|---|---|
+| `src/streamdiffusion/wrapper.py:2208` | `use_cuda_graph=True` β `use_cuda_graph=False` + inline comment |
+| `src/streamdiffusion/acceleration/tensorrt/utilities.py:1018-1022` | Defensive: `torch.cuda.current_stream().synchronize()` before `cudaStreamBeginCapture`, gated on first capture per engine. Addresses the broader polygraphy-blocking-stream structural issue for future TRT engines. |
+| `src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_exporter.py:593` | `mode=0 β mode=1` (ThreadLocal capture hardening) β committed `07045be` |
+| `src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py:46-47` | Docstring correction β committed `07045be` |
+| `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` | YAML emitter emits `use_cuda_ipc_controlnet` + `cuda_ipc_control_shm_name` β committed `07045be` |
+| `StreamDiffusionTD/td_manager.py` | v2 runtime fix: `get_frame()` with no `stream=` arg (gitignored; live via `Scripts/` sync) |
+
+## Verification (live, 2026-05-18)
+
+Cold-start `.toe` with `controlnet_scale: 0.577`, `use_cuda_ipc_controlnet: true`:
+- `CUDA IPC control ready (zero-copy GPU): shm=StreamDiffusionTD_512-512_control_ipc` β
+- No `[E] IExecutionContext::enqueueV3`, no `901`, no "legacy stream depend on capturing blocking stream" β
+- CN scale applies immediately from frame 1 β
+- Steady-state FPS sustained β
+
+## Deferred follow-up
+
+Investigate whether the CN engine's `controlnet_cond` input tensor can be produced in a format that avoids the `genericReformat` boundary (explicit `TensorIOFormat` constraints at build time, or providing the tensor already in CHW float32 on the engine stream). If so, `use_cuda_graph=True` for CN could be safely re-enabled.
diff --git a/_plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md b/_plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md
new file mode 100644
index 000000000..9e416353f
--- /dev/null
+++ b/_plans/archive/2026-05-17_cuda-ipc-input-perspective_ABANDONED.md
@@ -0,0 +1,728 @@
+> **ABANDONED 2026-05-17.** This plan used SD-perspective naming (cuda-link SenderβSD input, ReceiverβSD output, YAML `use_cuda_ipc_output` etc.). The integration is being redone with cuda-link's vocabulary kept verbatim and SD-side names renamed instead. Kept as historical reference for the working Phase 2.1 BGRA repack approach and the Phase 2.4 SHM-probe-before-construct trick.
+
+# Replace SDTD shmem* COMPs with upstream cuda-link `CUDAIPCLink_v1.4.1.tox`
+
+## Status (2026-05-17 16:22, Phase 2.4 round 2 β probe-before-construct)
+
+**Phase 2.1 OK** β PythonβTD output transport (cuda-link Receiver) confirmed working at 16-25 FPS.
+**Phase 2.2 OK** β BGRA byte-swap in `wrapper.py:_ipc_pack_rgba` landed; colors correct.
+**Phase 2.3 OK** β `consume_pending_resolution()` added to `Scripts/shmem__Text__output_callbacks__td.py`.
+**Phase 2.4 round 1** (logger silencing + retry) β landed but **insufficient**. The error log lines from `cuda_ipc_importer` are silenced correctly, BUT `cuda_ipc_importer.py:809` calls `traceback.print_exc()` which writes to `sys.stderr` directly, bypassing the `logging` module entirely. `logger.setLevel(CRITICAL)` cannot suppress raw stderr writes.
+
+Moving to **Phase 2.4 round 2: probe `SharedMemory` existence ourselves BEFORE invoking the importer's `_initialize()`** β that keeps the failing code path unreached during the normal startup race so `traceback.print_exc()` never fires.
+
+---
+
+## Phase 2.4 round 2 β Probe SHM existence before invoking importer
+
+### What round 1 achieved + what it missed
+
+Round 1 (logger silencing + 1/sec retry of `_initialize()`) landed and is **partially working**:
+
+- β
`cuda_ipc_importer - ERROR - SharedMemory ... not found` lines: **suppressed** (logger filter at CRITICAL works).
+- β
`CUDA IPC input configured (waiting for TD Sender): ...` log appears as planned.
+- β Raw Python traceback still printed on cold-start + on every retry tick.
+
+### Root cause of the leftover noise
+
+`StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:807-810`:
+
+```python
+except (OSError, RuntimeError, ValueError, struct.error, IndexError) as e:
+ logger.error("Initialization failed: %s", e)
+ traceback.print_exc() # β writes directly to sys.stderr, bypassing logging
+ return False
+```
+
+`traceback.print_exc()` is **not a logging call** β it writes formatted frames straight to `sys.stderr`. `logger.setLevel(CRITICAL)` filters records inside the `logging` module; it has zero effect on direct stderr writes. There is no public knob in `cuda_ipc_importer` that disables this behavior, and we don't modify `_compat/cuda_ipc/` (preserves upstream sync β same rationale documented earlier in this plan).
+
+### Asymmetric upstream β verified
+
+Searched `cuda_ipc_importer.py` for retry primitives:
+- `_wait_for_slot()` (line 851) β for waiting on frame slots **after** connect (CPU poll on `query_event`)
+- `_reinitialize()` (line 1163) β re-opens IPC handles, **requires `self._conn` non-None** (line 1166), unusable pre-connect
+- No `reconnect`, `request_immediate_reconnect`, `retry`, `wait_for_*` first-connect primitive.
+
+Upstream cuda-link's Python side appears designed primarily for Python-as-Sender; Python-as-Receiver (our case for TDβPython input) lacks first-connect retry. **The manager must drive retry itself.**
+
+### Design β probe-before-construct
+
+The importer's `__init__` eagerly calls `_initialize()` and the failing branch unconditionally calls `traceback.print_exc()`. So **we must avoid invoking the constructor (and `_initialize()`) until the SHM actually exists.**
+
+Probe cost: a single `multiprocessing.shared_memory.SharedMemory(name=...).close()` round-trip. On a missing SHM that raises `FileNotFoundError` immediately (`OpenFileMapping β ERROR_FILE_NOT_FOUND`, microseconds). Since the probe happens in our manager code, **we own the try/except** β no `traceback.print_exc()` ever runs.
+
+Lifecycle:
+
+```
+manager __init__: _pending_ipc_input_name = None
+ ipc_input_importer = None
+
+_initialize_memory_interfaces:
+ if td_settings has cuda_ipc_input_shm_name:
+ _pending_ipc_input_name = name
+ if _probe_ipc_shm_exists(name): β silent SHM open/close probe
+ _try_construct_ipc_importer() β only if probe says yes
+ else:
+ log "configured (waiting for TD Sender)"
+ if _pending_ipc_input_name is None:
+ open CPU SHM input fallback β gate flipped from `if importer is None`
+
+_get_input_frame:
+ if _pending_ipc_input_name and ipc_input_importer is None:
+ every β₯1s:
+ if _probe_ipc_shm_exists(name):
+ _try_construct_ipc_importer()
+ if importer ready: one-shot "connected" log
+ if ipc_input_importer: β get_frame_numpy() + alpha strip
+ elif input_memory: β CPU SHM read (legacy fallback path; only fires when IPC not configured)
+ return None
+```
+
+### Edit set β `Scripts/streamdiffusionTD__Text__td_manager__td.py`
+
+**Note:** Round 1 already applied three edits to this file. Round 2 **supersedes Edits 2 and 3** with the probe-first versions below. Edit 1 (slot declarations) stays; one slot is added.
+
+**Edit 1 β extend slot declarations (~line 83-86):**
+
+```python
+self.syphon_handler = None
+self.ipc_input_importer = None # CUDAIPCImporter when cuda_ipc_input_shm_name is in td_settings
+self._ipc_input_last_retry = 0.0 # monotonic timestamp of last reconnect attempt
+self._ipc_input_connected_logged = False # one-shot log when sender first comes online
+self._pending_ipc_input_name: Optional[str] = None # set when IPC configured but not yet connected
+self._ipc_importer_cls = None # cached CUDAIPCImporter class, populated lazily
+```
+
+**Edit 2 β replace round-1 IPC construction block in `_initialize_memory_interfaces()`:**
+
+Replace the current `if ipc_input_name:` block AND change the CPU SHM fallback gate two lines below from `if self.ipc_input_importer is None:` to `if self._pending_ipc_input_name is None:`.
+
+New block:
+
+```python
+ipc_input_name = self.td_settings.get('cuda_ipc_input_shm_name')
+if ipc_input_name:
+ try:
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter
+ self._ipc_importer_cls = CUDAIPCImporter
+ self._pending_ipc_input_name = ipc_input_name
+ # Probe SHM existence ourselves before letting the importer call _initialize().
+ # _initialize()'s OSError catch unconditionally runs traceback.print_exc() (writes
+ # to sys.stderr, bypassing the logging module) β so we must not invoke it on a
+ # missing SHM. _probe_ipc_shm_exists() catches FileNotFoundError silently in OUR code.
+ if self._probe_ipc_shm_exists(ipc_input_name):
+ self._try_construct_ipc_importer()
+ if self.ipc_input_importer is not None and self.ipc_input_importer.is_ready():
+ logger.info(f"CUDA IPC input connected: {ipc_input_name}")
+ self._ipc_input_connected_logged = True
+ if self.ipc_input_importer is None:
+ logger.info(
+ f"CUDA IPC input configured (waiting for TD Sender): {ipc_input_name}"
+ )
+ except Exception as e:
+ logger.warning(f"CUDA IPC input setup failed, falling back to CPU SHM: {e}")
+ self._pending_ipc_input_name = None
+ self.ipc_input_importer = None
+
+if self._pending_ipc_input_name is None:
+ # CPU SHM fallback β only when IPC is genuinely not configured
+ self.input_memory = shared_memory.SharedMemory(name=self.input_mem_name)
+ logger.debug(f"Connected to input SharedMemory: {self.input_mem_name}")
+```
+
+**Edit 3 β replace round-1 retry logic in `_get_input_frame()`:**
+
+```python
+# Lazy/throttled importer construction β fires when TD Sender comes online after Python startup
+if self._pending_ipc_input_name is not None and self.ipc_input_importer is None:
+ now = time.monotonic()
+ if now - self._ipc_input_last_retry >= 1.0:
+ self._ipc_input_last_retry = now
+ if self._probe_ipc_shm_exists(self._pending_ipc_input_name):
+ self._try_construct_ipc_importer()
+ if self.ipc_input_importer is not None and not self._ipc_input_connected_logged:
+ logger.info("CUDA IPC input connected (TD Sender came online)")
+ self._ipc_input_connected_logged = True
+
+if self.ipc_input_importer is not None:
+ frame = self.ipc_input_importer.get_frame_numpy()
+ if frame is not None and frame.ndim == 3 and frame.shape[2] == 4:
+ frame = frame[:, :, :3] # strip alpha; downstream pipeline expects RGB
+ return frame
+```
+
+If IPC is configured but not yet connected, this branch falls through; the existing `if self.input_memory:` check fires only when CPU SHM was opened (i.e., IPC was never configured). When IPC is configured but waiting for Sender, all three conditions are false β method returns `None`, pipeline pauses for that frame.
+
+**Edit 4 β add two helper methods on the manager (place near other internal helpers):**
+
+```python
+def _probe_ipc_shm_exists(self, name: str) -> bool:
+ """Return True iff the named Windows SharedMemory currently exists.
+
+ Cheap probe (~Β΅s on Windows). Used to gate CUDAIPCImporter construction:
+ its __init__ eagerly calls _initialize(), whose except-OSError branch unconditionally
+ invokes traceback.print_exc() β which writes to sys.stderr and cannot be silenced by
+ logger configuration. Probing here keeps that code path unreached during the normal
+ startup race with TD's Sender activation.
+ """
+ try:
+ shm = shared_memory.SharedMemory(name=name)
+ shm.close()
+ return True
+ except FileNotFoundError:
+ return False
+ except Exception:
+ # Any other failure (permission, etc.) β treat as not-yet-available; retry will pick up.
+ return False
+
+def _try_construct_ipc_importer(self) -> None:
+ """Construct CUDAIPCImporter (which calls _initialize() eagerly).
+
+ Caller must have verified SHM existence via _probe_ipc_shm_exists first to avoid the
+ traceback-printing failure path. On unexpected init failure (e.g., bad magic bytes,
+ version mismatch), null out the importer so retry can probe again.
+ """
+ try:
+ self.ipc_input_importer = self._ipc_importer_cls(
+ shm_name=self._pending_ipc_input_name,
+ device=torch.cuda.current_device(),
+ timeout_ms=500.0,
+ )
+ if not self.ipc_input_importer.is_ready():
+ logger.warning(
+ f"CUDA IPC input '{self._pending_ipc_input_name}' opened but importer "
+ f"not ready (protocol mismatch?). Will retry."
+ )
+ self.ipc_input_importer = None
+ except Exception as e:
+ logger.warning(f"CUDA IPC input construction failed: {e}")
+ self.ipc_input_importer = None
+```
+
+### Why not modify `_compat/cuda_ipc/cuda_ipc_importer.py`?
+
+Same rationale as round 1 β preserve upstream-sync compatibility. The cleanest upstream fix would be either (a) removing the `traceback.print_exc()` on line 809 (it's redundant with `logger.error` plus `logger.exception()` would do the right thing) or (b) gating it behind a debug flag. Both belong in an upstream PR, not as local fork divergence.
+
+### Why throttle to 1/sec?
+
+`shared_memory.SharedMemory(name=...)` on Windows is a single `OpenFileMapping` syscall β microseconds whether it succeeds or fails. 1/sec keeps the syscall rate trivially low while giving sub-second user-perceived activation latency (toggle Active=On in TD β frames flow within β€1s).
+
+### Verification
+
+1. **Cold start** (Python before TD Sender activated):
+ - Save edits in `Scripts/streamdiffusionTD__Text__td_manager__td.py`.
+ - Trigger TD **Writeconfigs** (propagates edits into runtime `StreamDiffusionTD/td_manager.py`).
+ - Restart Python with TD's `shmem_out` Sender **deactivated** (`Active = Off`).
+ - **Expected logs:** `CUDA IPC input configured (waiting for TD Sender): StreamDiffusionTD_512-512_input_ipc`. **No** `Traceback`, no `cuda_ipc_importer - ERROR/WARNING/INFO` lines on startup. Streaming loop proceeds; `_get_input_frame` returns None per frame.
+
+2. **Late-activation** (TD Sender comes online after Python startup):
+ - Toggle `shmem_out.par.Active = On` in TD.
+ - **Within β€1s expect:** `CUDA IPC input connected (TD Sender came online)` (single line). Frames flow.
+
+3. **Warm restart** (TD Sender already online when Python starts):
+ - With Active=On in TD, restart Python.
+ - **Expected:** `CUDA IPC input connected: StreamDiffusionTD_512-512_input_ipc` immediately at init. No "waiting" message. No traceback.
+
+4. **Bounce test** (TD Sender OffβOnβOffβOn while Python connected): Out of scope β Phase 2.5. The probe is only consulted when importer is None; once connected, no probe runs.
+
+### Rollback
+
+Revert Edits 2, 3, and 4 in `Scripts/streamdiffusionTD__Text__td_manager__td.py` (Edit 1's slot declarations are harmless if left). To restore round-1 behavior instead, re-apply round-1 Edit 2/3 from this plan's git history.
+
+### Open questions (defer to Phase 2.5)
+
+1. **Reconnect after mid-stream disconnect.** If TD's Sender deactivates while Python is connected, `is_ready()` may still report True until something in the IPC layer notices. Need to detect stale handles and trigger `cleanup()` + null-out so the retry loop re-probes.
+2. **Upstream contribution.** Removing `traceback.print_exc()` from `cuda_ipc_importer.py:809` (or gating it on a debug flag) would let downstreams use simple `_initialize()` retry without manager-side probing. File against upstream cuda-link.
+
+---
+
+## Phase 2.3 β Frozen frame on new shmem Receiver (current blocker)
+
+### Symptom
+
+After Phase 2.2 wrapper.py edit + Python restart: the new `shmem` Receiver's `output` Script TOP no longer animates. User saw the receiver path firing in Phase 2.1 (FPS, copyCUDAMemory logs); now it appears static.
+
+User hypothesis was "Execute DAT uses old Mode names 'Sender'/'Receiver'" β **rejected**: `'Sender'`/`'Receiver'` (capitalized) ARE the canonical cuda-link names (`CUDAIPCExtension._mode`, set by `_normalize_mode` in `CUDAIPCExtension.py`). Both `Scripts/shmem__Execute__execute__td.py:29,31` and `Scripts/shmem_out__Execute__execute__td.py:29,31` use these correctly.
+
+### Root cause: Scripts/ filename-collision overwrote new shmem's DATs
+
+When the user duplicated `shmem_out` β renamed to `shmem`, TD's Scripts/ filename-convention auto-sync (where DAT contents are paired with `Scripts/______td.py`) loaded the pre-existing OLD fork files for the `shmem__*` namespace into the new COMP's DATs. The new COMP did NOT keep the working content from `shmem_out__*`.
+
+Evidence:
+- `Scripts/shmem__Execute__execute__td.py` (4548 B, mtime May 17 09:34) β OLD fork content with DIAG-EXEC instrumentation, references `parent().par.Play.eval()` (line 89) and `parent().par.Mode.eval() == 'receive'` (line 67).
+- `Scripts/shmem_out__Execute__execute__td.py` (2163 B, mtime May 17 00:30) β Clean cuda-link variant, no DIAG, smaller.
+- `Scripts/shmem__Text__output_callbacks__td.py` (1142 B, mtime May 17 **12:41**) β modified TODAY, most recent. OLD fork content with DIAG-COOK instrumentation. **Missing `consume_pending_resolution()` call.**
+- `Scripts/shmem_out__Text__output_callbacks__td.py` (1142 B, mtime Apr 22) β unrelated/legacy CPU-SHM output callbacks, untouched.
+
+### Why "frozen frame" β the resolution-update bug
+
+The OLD `shmem__Text__output_callbacks__td.py:12-19` onCook:
+
+```python
+def onCook(scriptOp):
+ try:
+ cuda_ext = parent().ext.CUDAIPCExtension
+ if cuda_ext is not None and cuda_ext.mode == 'Receiver' and cuda_ext.is_active():
+ # NOTE: resolution updates moved to Execute DAT to avoid re-cook race.
+ cuda_ext.import_frame(scriptOp)
+ return
+ except Exception as e:
+ ... # dedup-log DIAG
+```
+
+**No `consume_pending_resolution()` call.** The comment "resolution updates moved to Execute DAT" assumes `modoutsidecook` is enabled on the Script TOP (TD 2025+). But the new `shmem`'s `output` Script TOP probably has `modoutsidecook` OFF (TD 2023 default).
+
+Cross-reference Execute DAT path (`Scripts/shmem__Execute__execute__td.py:34-38`):
+
+```python
+if hasattr(import_buffer.par, 'modoutsidecook') and import_buffer.par.modoutsidecook.eval():
+ cuda_ext.import_frame(import_buffer)
+ cuda_ext.update_receiver_resolution(import_buffer) # only fires when modoutsidecook=True
+else:
+ import_buffer.cook(force=True) # triggers onCook β but onCook NEVER updates resolution
+```
+
+When `modoutsidecook` is OFF:
+1. Execute DAT calls `import_buffer.cook(force=True)` β triggers `output.onCook`.
+2. `onCook` calls `import_frame(scriptOp)` directly. Skips resolution update.
+3. `TDReceiver.update_receiver_resolution` (`TDReceiver.py:523-548`) is **never** called.
+4. The `output` Script TOP stays at default resolution (typically 1Γ1 or whatever it was duplicated as).
+5. `import_frame` copies 512Γ512Γ4 GPU bytes into a TOP buffer that is **not** 512Γ512 β silent buffer-size mismatch, TD shows a stale frame.
+
+Canonical `script_top_callbacks.py:31-43` (in `_compat/td_exporter/`) handles this path correctly β it calls `consume_pending_resolution()` and writes `outputresolution=9, resolutionw/h` to the scriptTop BEFORE `import_frame`.
+
+### Edit
+
+**Single edit:** patch `Scripts/shmem__Text__output_callbacks__td.py` to call `consume_pending_resolution()` inside the Receiver branch BEFORE `import_frame`. Preserve the existing DIAG-COOK instrumentation.
+
+Replace the body of `onCook(scriptOp)` (lines 12-35) with:
+
+```python
+def onCook(scriptOp):
+ # CUDA IPC Receiver path: GPU-to-GPU frame import via copyCUDAMemory
+ try:
+ cuda_ext = parent().ext.CUDAIPCExtension
+ if cuda_ext is not None and cuda_ext.mode == 'Receiver' and cuda_ext.is_active():
+ # Apply pending resolution update (TD 2023 path β fires when modoutsidecook is OFF).
+ # When modoutsidecook is ON, Execute DAT already called update_receiver_resolution
+ # and this returns None β harmless.
+ pending = cuda_ext.consume_pending_resolution()
+ if pending is not None:
+ width, height = pending
+ try:
+ scriptOp.par.outputresolution = 9 # Custom Resolution
+ scriptOp.par.resolutionw = width
+ scriptOp.par.resolutionh = height
+ cuda_ext._log(
+ f"[shmem onCook] Set output resolution to {width}x{height}",
+ force=True,
+ )
+ except (AttributeError, RuntimeError) as e:
+ cuda_ext._log(f"[shmem onCook] Could not set resolution: {e}", force=True)
+ cuda_ext.import_frame(scriptOp)
+ return
+ except Exception as e:
+ # DIAG (Round-6, temporary): dedup-log exceptions from import_frame.
+ key = type(e).__name__ + ':' + str(e)
+ if not hasattr(onCook, '_diag_seen_errors'):
+ onCook._diag_seen_errors = {}
+ seen = onCook._diag_seen_errors.get(key, 0)
+ onCook._diag_seen_errors[key] = seen + 1
+ if seen == 0 or seen == 60 or seen == 600:
+ try:
+ parent().ext.CUDAIPCExtension._log(
+ f"[DIAG-COOK] onCook raised ({seen + 1}x): {key}",
+ force=True,
+ )
+ except Exception:
+ print(f"[DIAG-COOK] onCook raised ({seen + 1}x): {key} (ext unavailable)")
+```
+
+API references verified in `CUDAIPCExtension.py:284-291` β `consume_pending_resolution()` returns `(width, height)` tuple if pending update is set, `None` otherwise. Safe to call repeatedly.
+
+### Alternative considered (rejected)
+
+**Enable `modoutsidecook` on the new shmem's `output` Script TOP** β would let the Execute DAT's `update_receiver_resolution` path drive resolution. Rejected because:
+- TD 2023 doesn't have `modoutsidecook`. Forces a TD-version requirement.
+- Existing `shmem_in_cn_processed` and other Receiver siblings in Phase 2.4+ may share the same TOP topology and would each need the same param flip. Fixing onCook covers all of them via the Scripts/ shared file.
+- The canonical `_compat/td_exporter/script_top_callbacks.py` template uses `consume_pending_resolution()` in onCook β aligning with upstream is preferable.
+
+### Why not also rewrite Execute DAT?
+
+The current `Scripts/shmem__Execute__execute__td.py` has dead-code SHM references (`parent().par.Play.eval()`, `parent().par.Mode.eval()`). These would throw `AttributeError` on the new cuda-link COMP if `Play`/`Mode` parameters don't exist β BUT they're inside a `try/except: pass` block (lines 51-87), AND the user's Phase 2.1 logs confirmed the Receiver branch DID fire (`copyCUDAMemory=0.95ms`), so the Execute DAT IS reaching its Receiver path.
+
+Deferred to Phase 2.5 (cleanup) β the dead-code SHM references are noise but not load-bearing for current symptoms. Fixing them is part of `Phase 3 (cleanup)`.
+
+### Verification (Phase 2.3)
+
+1. Save edit. Edit is to `Scripts/shmem__Text__output_callbacks__td.py` only.
+2. TD picks up the change automatically (Scripts/ sync). No Writeconfigs needed (output_callbacks is a Text DAT, not the runtime `td_manager.py`).
+3. Toggle the new `shmem.par.Active = Off β On` (forces Receiver re-init β `needs_resolution_update = True`).
+4. **Textport log to look for:** `[shmem onCook] Set output resolution to 512x512` (single fire, first frame after Activate).
+5. Visual check: frame updates (no longer static); colors look correct (red is red, blue is blue) β this validates Phase 2.2 BGRA fix simultaneously.
+6. `Receiver's output Script TOP info` shows resolution `512Γ512` (was likely 1Γ1 or duplicated-source size before).
+
+### Rollback
+
+Revert `Scripts/shmem__Text__output_callbacks__td.py` to the prior 1142-byte version (current mtime May 17 12:41). No state to clean β `consume_pending_resolution` is idempotent.
+
+### Open question (defer)
+
+Whether the Scripts/ filename-collision also overrode `Scripts/shmem__Execute__execute__td.py` with OLD content vs the cuda-link variant. mtimes suggest yes (May 17 09:34 β recent). But the functional cuda-link branch in lines 51-72 of that file is intact, so it's currently working. Leave the file as-is for Phase 2.3 β revisit during Phase 3 cleanup.
+
+---
+
+## Phase 2.2 β Fix output channel order (RGBA β BGRA)
+
+### Root cause
+
+Three pieces of evidence converge on RβB swap, not dtype:
+
+1. `wrapper.py:995-1002` β `_ipc_pack_rgba` copies `rgb_nhwc` (R,G,B from `_denormalize_on_gpu` β uint8) verbatim into channels `[0,1,2]` of the IPC buffer. Byte order on the wire: `R,G,B,A`.
+2. `cuda_ipc_exporter.py:19` (docstring) β "output_tensor: (H, W, 4) uint8 **BGRA** on GPU". The exporter's contract is BGRA.
+3. `TDReceiver.py:800-801, 483-487` β receiver builds a uint8 RGBA8 TOP via `copyCUDAMemory(addr, size, CUDAMemoryShape(dtype=uint8, ...))`. TD interprets the bytes positionally: byte[0]βR-display, byte[1]βG-display, byte[2]βB-display. Wrapper writes B-source data into byte[2], so blue from Python is shown as red, and vice versa.
+
+The legacy fork SHM path went through `copyNumpyArray()` (CPU SHM read by `script_top_callbacks.py`), which TD's TOP also displays as RGBA β but the legacy producer was sending the SAME RGB byte order. So why did the OLD path look right? **Because the legacy producer was `Scripts/shmem__Text__SharedMemEXT__td.py:280-297`'s `sendData()`, which TD called from the TD-side; TD's input TOP gave numpyArray in TD's native RGBA byte order to begin with.** The new path bypasses that β it sends from the Python wrapper, which assumes RGB throughout.
+
+### Edit
+
+Single edit to `StreamDiffusion/src/streamdiffusion/wrapper.py:995-1002` (`_ipc_pack_rgba`):
+
+**Before:**
+```python
+def _ipc_pack_rgba(self, rgb_nhwc: torch.Tensor) -> torch.Tensor:
+ """Pad (B,H,W,3) uint8 β (B,H,W,4) uint8 with opaque alpha on GPU, reusing a persistent buffer."""
+ B, H, W, _ = rgb_nhwc.shape
+ if self._ipc_rgba_buf is None or self._ipc_rgba_buf.shape != (B, H, W, 4):
+ self._ipc_rgba_buf = torch.empty((B, H, W, 4), dtype=torch.uint8, device=rgb_nhwc.device)
+ self._ipc_rgba_buf[..., 3] = 255 # alpha channel set once; reused across frames
+ self._ipc_rgba_buf[..., :3].copy_(rgb_nhwc)
+ return self._ipc_rgba_buf
+```
+
+**After:**
+```python
+def _ipc_pack_rgba(self, rgb_nhwc: torch.Tensor) -> torch.Tensor:
+ """Pad (B,H,W,3) uint8 RGB β (B,H,W,4) uint8 BGRA on GPU for cuda-link IPC transport.
+
+ cuda-link's wire contract is BGRA (cuda_ipc_exporter.py:19). TD interprets the raw GPU
+ bytes positionally as RGBA8 in the Script TOP, so we must swap RβB at pack time to
+ keep colors correct in TD.
+ """
+ B, H, W, _ = rgb_nhwc.shape
+ if self._ipc_rgba_buf is None or self._ipc_rgba_buf.shape != (B, H, W, 4):
+ self._ipc_rgba_buf = torch.empty((B, H, W, 4), dtype=torch.uint8, device=rgb_nhwc.device)
+ self._ipc_rgba_buf[..., 3] = 255 # alpha channel set once; reused across frames
+ # BGRA byte order: byte[0]=B, byte[1]=G, byte[2]=R, byte[3]=A
+ self._ipc_rgba_buf[..., 0].copy_(rgb_nhwc[..., 2]) # B β source channel 2 (B)
+ self._ipc_rgba_buf[..., 1].copy_(rgb_nhwc[..., 1]) # G β source channel 1 (G)
+ self._ipc_rgba_buf[..., 2].copy_(rgb_nhwc[..., 0]) # R β source channel 0 (R)
+ return self._ipc_rgba_buf
+```
+
+Three separate `.copy_()` calls instead of a single `[..., :3].copy_()` β each is a contiguous channel-wise GPU memcpy and adds negligible overhead (single-digit Β΅s at 512Γ512). No new allocations, persistent buffer preserved.
+
+### Why not swap on the TD side (Receiver) instead?
+
+The wrapper's output is the producer's responsibility β the IPC contract (per `cuda_ipc_exporter.py` docstring) is BGRA. Fixing the producer aligns the code with the documented contract and avoids touching the vendored upstream `td_exporter/` code. Receiver-side fix would mean editing TouchDesigner Script TOP callbacks per-component.
+
+### Input-direction symmetry concern (Phase 1 β note, don't fix here)
+
+Phase 1's input direction (TDβPython via `shmem_out` Sender) also goes through a uint8 RGBA byte stream, but the alpha-strip path in `Scripts/streamdiffusionTD__Text__td_manager__td.py:_get_input_frame` does `frame[:, :, :3]` β it takes channels `[0,1,2]` as RGB. If TD's cuda-link Sender writes those bytes as BGRA, the SD pipeline receives BGR-as-RGB β input is silently color-swapped too.
+
+Two possibilities:
+1. The OLD fork SHM input path was equally swapped, and SDXL-Turbo's prompt conditioning is tolerant enough that nobody noticed; the regression doesn't manifest visually.
+2. The TD-side Sender's TOPβGPU export does an implicit swizzle, so the byte stream reaching Python is already RGB.
+
+Defer investigation until after Phase 2.2 output fix lands. If input also looks color-shifted after the fix, mirror the swap in `_get_input_frame` (swap channels 0β2 before returning).
+
+### Validation (Phase 2.2)
+
+1. Save edit in `wrapper.py`. No TD-side change needed (Receiver stays as-is).
+2. Restart Python process. Confirm exporter still logs `dtype=uint8 (kind=1 bits=8 flags=0x0000)` β metadata unchanged.
+3. Visual check: blue objects in the SD output now appear blue in the TD viewer (not orange/red). Skin tones look correct.
+4. Sanity-check perf: `Frame N: avg cudaMemory` line shows same ~120 Β΅s as before β three split copies should not slow the per-frame budget.
+5. OffβOnβOffβOn cycle parity preserved.
+
+### Rollback
+
+Revert the 3-line copy back to `self._ipc_rgba_buf[..., :3].copy_(rgb_nhwc)`. No state to clean up; persistent buffer is dtype/shape-compatible.
+
+---
+
+## Phase 2.1 β Swap `shmem` (PythonβTD output direction)
+
+### Approach: duplicate the working `shmem_out`
+
+Cleanest path β preserves the extension wiring, internal Text DAT sync to `_compat/td_exporter/`, and all parameter overrides.
+
+**Procedure (in TD):**
+
+1. **Archive the old fork `shmem`**: rename existing `shmem` β `shmem_old` (don't delete yet β keep as rollback).
+2. **Duplicate `shmem_out`**: copy-paste β rename copy to `shmem`.
+3. **Flip parameters** on the new `shmem`:
+ - `Mode`: `Sender` β `Receiver`
+ - `Ipcmemname`: `parent.SDTD.par.Streamoutname + '_input_ipc'` β `parent.SDTD.par.Streamoutname + '_output_ipc'`
+ - Resolves to `StreamDiffusionTD_512-512_output_ipc` β symmetric with `_input_ipc` and matches the wrapper's canonical default `"streamdiffusion_output_ipc"` (`wrapper.py:135`).
+ - **Replaces the old fork expression** `op('stream_osc_data')['output-name',1]` (which resolved to a randomized per-process name like `StreamDiffusionTD_512-512_out_1779045216`). That randomized scheme is dead with the cuda-link swap β the new IPC name is deterministic.
+ - `Numslots`: keep 3.
+4. **Re-wire I/O**:
+ - Sender's input wire is gone (no source TOP feeds a Receiver).
+ - New Receiver exposes its output via internal Script TOP `output`. Re-wire whatever currently consumes the old `shmem` output to consume the new `shmem`'s output instead.
+5. **Activate**: toggle new `shmem.par.Active = On`. Watch Textport for `[CUDAIPCExtension:Receiver]` init log, then `import_frame` lines.
+
+**One Python-side edit** β rename the exporter SHM to `_output_ipc` for symmetry with `_input_ipc`:
+
+Edit `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py:3755`:
+```python
+# Before:
+yaml_content += f"cuda_ipc_shm_name: '{stream_name}_ipc'\n"
+# After:
+yaml_content += f"cuda_ipc_shm_name: '{stream_name}_output_ipc'\n"
+```
+
+Propagation:
+1. TD Writeconfigs writes new YAML value `cuda_ipc_shm_name: 'StreamDiffusionTD_512-512_output_ipc'`
+2. Python loads it via `config.py:161` and `wrapper.py:135,337`
+3. `CUDAIPCExporter` at `wrapper.py:1011` now writes to `..._output_ipc`
+4. New `shmem` Receiver reads from `..._output_ipc` β names match.
+
+No edit to `wrapper.py` or `config.py` needed β they already use `cuda_ipc_shm_name` as a config-driven string. The wrapper's hardcoded default `"streamdiffusion_output_ipc"` (`wrapper.py:135`) is overridden by the YAML-driven value.
+
+### `ImageChanged()` β pre-existing regression (note, don't fix here)
+
+Research finding (Phase 1 explore agent): `ImageChanged()` is only fired by the dead-code SHM path in `Scripts/shmem__Text__SharedMemEXT__td.py:412-420` (`_trigger_change_callback`). The active CUDA IPC path **never** fires it. This means feedback-safe mode (`StreamDiffusionExt:372-374`, `is_feedback_safe and is_stream_active`) has been silently broken since the Round-3 SHM-path removal.
+
+**Phase 2.1 does NOT make this worse** β it just continues not firing it. If feedback-safe is in use, the fix is to add a 5-line onCook hook to the new `shmem`'s `output` Script TOP that calls `parent.SDTD.ImageChanged()` (optionally with a hash-detect guard to match the original change-detected firing model, not per-frame). Deferred to a separate ticket.
+
+### Validation (Step 2.1.5)
+
+1. **YAML emit propagated**: after Writeconfigs, `StreamDiffusion/StreamDiffusionTD/td_config.yaml:72` shows `cuda_ipc_shm_name: 'StreamDiffusionTD_512-512_output_ipc'` (line number may shift; key value is what matters).
+2. **Python exporter uses new name**: on `start_streaming()`, Textport / Python logs show `CUDAIPCExporter` initialized with `StreamDiffusionTD_512-512_output_ipc`.
+3. Toggle `shmem.par.Active = On` after Python is streaming. Textport shows `[CUDAIPCExtension:Receiver] Receiver initialized` and `import_frame` lines with non-zero `copyCUDAMemory` time (already observed in Phase 1 logs from the fork Receiver β confirm same with the upstream version).
+4. Visual check: the downstream node consuming `shmem`'s output displays the StreamDiffusion frames in real time.
+5. Run 10+ min β verify no growth in receiver slot count, no re-init storms, no error 201 on the Receiver side.
+6. OffβOnβOffβOn cycle parity with `shmem_out` β clean teardown.
+
+### Rollback
+
+If validation fails: disable new `shmem.par.Active`, rename it `shmem_new`, rename `shmem_old` β `shmem`, re-wire downstream to the old COMP's output. Python side is unaffected β it'll keep writing to the same `_ipc` SHM and the fork Receiver picks it up.
+
+---
+
+## Phase 1 β historical record
+
+**TD side working.** Latest run confirms the new `shmem_out` (cuda-link v1.4.1 Sender) initializes cleanly:
+
+```
+[CUDAIPCExtension:Sender] Created new SharedMemory: StreamDiffusionTD_512-512_input_ipc (433 bytes)
+[CUDAIPCExtension:Sender] Wrote all IPC handles v1 to SharedMemory (433 bytes total)
+[CUDAIPCExtension:Sender] Wrote metadata: 512x512x4, kind=1 bits=8 flags=0x0000, size=1048576B
+[CUDAIPCExtension:Sender] Initialization complete - ready for zero-copy GPU transfer
+```
+
+- **Phase 1 Steps 1.1, 1.2, 1.3 β DONE.** COMP swapped, source TOP wired, `Ipcmemname = parent.SDTD.par.Streamoutname + '_input_ipc'` β `StreamDiffusionTD_512-512_input_ipc`.
+- **YAML emit DONE.** `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py:3768` already emits `cuda_ipc_input_shm_name: '{stream_name}_input_ipc'`. Verified live in `StreamDiffusionTD/td_config.yaml:83`.
+- **Vendored cuda-link td_exporter/ in place** β `_compat/td_exporter/` synced to v1.4.1 (commit `92989fc`).
+
+**Python side still on CPU SHM β current blocker.** Run fails with:
+
+```
+TouchDesignerManager - ERROR - Failed to initialize SharedMemory:
+ [WinError 2] The system cannot find the file specified: 'StreamDiffusionTD_512-512'
+```
+
+The Python process tries to open the **old** CPU SHM name (`StreamDiffusionTD_512-512`), which the new cuda-link COMP no longer creates. Only `StreamDiffusionTD_512-512_input_ipc` exists now (the IPC control-packet SHM, 433 bytes β far smaller than the expected RGB frame buffer).
+
+**Root cause of earlier wasted edit:** I previously edited `StreamDiffusion/StreamDiffusionTD/td_manager.py` directly. That file is a **build target** β TD's Writeconfigs (`StreamDiffusionExt:2602-2632`) overwrites it from the Text DAT `streamdiffusionTD/td_manager`. Canonical source is `Scripts/streamdiffusionTD__Text__td_manager__td.py`.
+
+---
+
+## Context
+
+**Why this change.** Previous debugging (Round 1-9 in `structured-purring-kurzweil.md`) chased a sticky `CUDA error 201 INVALID_CONTEXT` in the SDTD-bundled CUDA IPC sender at `TDSender.py:843`. The SDTD code is a fork of upstream `cuda-link`. Upstream ships a packaged `CUDAIPCLink_v1.4.1.tox` that speaks the same wire protocol (`PROTOCOL_MAGIC = 0x43495044`). Swapping the TD-side .tox lets us stop debugging the fork and inherit upstream fixes.
+
+**Scope:** Proof-of-concept first β swap `shmem_out` only (TDβPython direction). Keep Python-side `_compat/cuda_ipc/` fork (wire-protocol compatible). Eventually swap remaining 4 `shmem*` siblings, each in its own phase.
+
+---
+
+## Phase 1 remaining work β finish Python integration
+
+### Edit location convention (MUST FOLLOW)
+
+| File modified by user | Runtime file (auto-generated) | Direction |
+|---|---|---|
+| `Scripts/streamdiffusionTD__Text__td_manager__td.py` | `StreamDiffusion/StreamDiffusionTD/td_manager.py` | Edit Scripts/, TD's Writeconfigs writes runtime file |
+| `Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py` | (lives inside .toe β synced to/from Text DAT) | Edit Scripts/, instantly synced |
+
+**Never edit `StreamDiffusion/StreamDiffusionTD/td_manager.py` directly** β it will be clobbered on next Writeconfigs.
+
+### Edit set β `Scripts/streamdiffusionTD__Text__td_manager__td.py`
+
+Apply four small edits, each minimal:
+
+**Edit 1 β declare the importer slot (after line 83):**
+
+```python
+self.syphon_handler = None
+self.ipc_input_importer = None # CUDAIPCImporter when cuda_ipc_input_shm_name is in td_settings
+```
+
+**Edit 2 β `_initialize_memory_interfaces()` (line 327-332):** wrap CPU SHM open behind an IPC-importer-first check.
+
+Replace:
+
+```python
+else:
+ # Initialize SharedMemory (same pattern as your current version)
+ try:
+ # Input memory (from TouchDesigner)
+ self.input_memory = shared_memory.SharedMemory(name=self.input_mem_name)
+ logger.debug(f"Connected to input SharedMemory: {self.input_mem_name}")
+```
+
+With:
+
+```python
+else:
+ # Initialize SharedMemory (same pattern as your current version)
+ try:
+ # Input: prefer CUDA IPC (cuda-link shmem_out COMP) over CPU SHM
+ ipc_input_name = self.td_settings.get('cuda_ipc_input_shm_name')
+ if ipc_input_name:
+ try:
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCImporter
+ self.ipc_input_importer = CUDAIPCImporter(
+ shm_name=ipc_input_name,
+ device=torch.cuda.current_device(),
+ timeout_ms=500.0,
+ )
+ logger.info(f"CUDA IPC input configured: {ipc_input_name}")
+ except Exception as e:
+ logger.warning(f"CUDA IPC input init failed, falling back to CPU SHM: {e}")
+
+ if self.ipc_input_importer is None:
+ # Input memory (from TouchDesigner)
+ self.input_memory = shared_memory.SharedMemory(name=self.input_mem_name)
+ logger.debug(f"Connected to input SharedMemory: {self.input_mem_name}")
+```
+
+Rest of the function (output_memory, control_memory, ipadapter_memory) is unchanged β they remain on CPU SHM in Phase 1.
+
+**Edit 3 β `_get_input_frame()` (line 627-653):** try the IPC importer first.
+
+Replace whole body:
+
+```python
+def _get_input_frame(self) -> Optional[np.ndarray]:
+ """Get input frame from TouchDesigner (platform-specific)"""
+ try:
+ if self.is_macos and self.syphon_handler:
+ return self.syphon_handler.capture_input_frame()
+
+ if self.ipc_input_importer is not None:
+ frame = self.ipc_input_importer.get_frame_numpy()
+ if frame is not None and frame.ndim == 3 and frame.shape[2] == 4:
+ frame = frame[:, :, :3] # strip alpha; downstream pipeline expects RGB
+ return frame
+
+ if self.input_memory:
+ width = self.config['width']
+ height = self.config['height']
+ frame = np.ndarray((height, width, 3), dtype=np.uint8, buffer=self.input_memory.buf)
+ return frame.copy()
+
+ return None
+ except Exception as e:
+ logger.error(f"Error getting input frame: {e}")
+ return None
+```
+
+**Edit 4 β `_cleanup_memory_interfaces()` (after the syphon_handler block, before `if self.input_memory:`):**
+
+```python
+if self.ipc_input_importer is not None:
+ try:
+ self.ipc_input_importer.cleanup()
+ except Exception:
+ pass
+ self.ipc_input_importer = None
+```
+
+### Reused infrastructure (no new code needed)
+
+- `streamdiffusion._compat.cuda_ipc.CUDAIPCImporter` β already exported from `__init__.py:11` and `:27`. Constructor signature `(shm_name, shape=None, dtype=None, debug=False, timeout_ms=5000.0, device=0)`. Lazy init on first `get_frame_numpy()` call. Methods: `get_frame_numpy() -> np.ndarray | None`, `cleanup() -> None`.
+- `torch.cuda.current_device()` for the device arg β `torch` is already imported at `td_manager.py:17`.
+- `self.td_settings` β already populated at `td_manager.py:61` from YAML's `td_settings` block.
+
+### After edits β propagation flow
+
+1. Save `Scripts/streamdiffusionTD__Text__td_manager__td.py`. TD picks up the change in its Text DAT immediately.
+2. Trigger **Writeconfigs** in TD (the action that runs `StreamDiffusionExt:2602-2632`). This rewrites `StreamDiffusion/StreamDiffusionTD/td_manager.py` from the Text DAT.
+3. Restart the StreamDiffusion Python process. On `start_streaming()`, `_initialize_memory_interfaces()` reads `cuda_ipc_input_shm_name` from `td_settings`, opens the IPC SHM `StreamDiffusionTD_512-512_input_ipc`, and the per-frame loop calls `ipc_input_importer.get_frame_numpy()`.
+
+---
+
+## Validation (Step 1.5)
+
+After Writeconfigs + Python restart, verify in this order:
+
+1. **Init log line present:** `TouchDesignerManager - INFO - CUDA IPC input configured: StreamDiffusionTD_512-512_input_ipc`. No `[WinError 2]` for the old name.
+2. **First-frame open:** `CUDAIPCImporter` log line shows it opened the IPC handles for slots 0/1/2 (matches TD's `Wrote slot 0/1/2 handles` lines).
+3. **Frame flow:** StreamDiffusion produces output frames (Textport shows non-zero output FPS; not stuck at "waiting for sender").
+4. **Shape sanity:** TD writes 512Γ512Γ4 (RGBA uint8, 1,048,576 B) per metadata line β `get_frame_numpy()` returns `(512, 512, 4)`, alpha is stripped to `(512, 512, 3)` by the Edit 3 guard, matches the pipeline's expected RGB shape.
+5. **Stress test:** Run 10+ min at 512Γ512. Watch GPU mem (should stay flat β ring buffer is 3 Γ 1 MB), no re-init storms, no error 201.
+6. **Rollback path:** If anything fails, set `shmem_out`'s `Active = Off` in TD. The Python side will log the IPC init failure and fall through to CPU SHM (which will still fail with WinError 2 because nothing creates it now β so the practical rollback is restoring the previous `.toe` from the pre-swap backup).
+
+---
+
+## Phase 2 (deferred) β swap remaining shmem* siblings
+
+| Phase | COMP | Mode | Notes |
+|---|---|---|---|
+| 2.1 | `shmem` | Receiver | PythonβTD output; `Ipcmemname = {Streamoutname}_output_ipc` β symmetric with `_input_ipc`, requires one-line YAML emit edit |
+| 2.2 | `shmem_out_cn` | Sender | TDβPython ControlNet input; `Ipcmemname = {Streamoutname}_cn_ipc` |
+| 2.3 | `shmem_out_out_ip` | Sender | TDβPython IPAdapter input; `Ipcmemname = {Streamoutname}_ip_ipc` |
+| 2.4 | `shmem_in_cn_processed` | Receiver | PythonβTD ControlNet processed output |
+
+Each requires a parallel slot/init/cleanup edit in `td_manager.py` for its Python-side counterpart, plus a YAML emit key.
+
+---
+
+## Phase 3 (deferred) β cleanup
+
+After all 5 siblings swapped:
+
+- Delete `Scripts/shmem*__Text__SharedMemEXT__td.py` (5 files)
+- Delete `Scripts/shmem*__Execute__execute__td.py` (5 files)
+- Delete `Scripts/shmem*__Text__dot_lop_utils__td.py`, `__Text__dot_chat_util__td.py`, `__Text__output_callbacks__td.py`, `__ParExecute__*__td.py`
+- Leave `_compat/cuda_ipc/` and `_compat/td_exporter/` untouched
+
+---
+
+## Critical files
+
+**Edit (Scripts/ β canonical source):**
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/Scripts/streamdiffusionTD__Text__td_manager__td.py` β 4 edits listed above
+
+**Already edited / verified correct:**
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/Scripts/StreamDiffusionTD__Text__StreamDiffusionExt__td.py:3768` β emits `cuda_ipc_input_shm_name`
+- TD COMP `shmem_out` β `Ipcmemname = parent.SDTD.par.Streamoutname + '_input_ipc'`
+
+**Read-only references:**
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:494` β `CUDAIPCImporter.__init__`
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:989` β `get_frame_numpy()`
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py:1236` β `cleanup()`
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/src/streamdiffusion/_compat/cuda_ipc/__init__.py:11,27` β `CUDAIPCImporter` re-export
+
+**Do NOT edit (runtime targets β clobbered by TD on Writeconfigs):**
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/StreamDiffusionTD/td_manager.py`
+- `D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/StreamDiffusionTD/td_config.yaml`
+
+---
+
+## Save plan to project repo
+
+Per memory `feedback_save_plans_as_project_files`, copy this plan to:
+`D:/dev/SD_3_0_1/test_Install_dev/StreamDiffusion/StreamDiffusion/_plans/drifting-twirling-tulip.md`
+(after exiting plan mode β copy not possible in plan mode).
diff --git a/configs/td_config.yaml.example b/configs/td_config.yaml.example
index 48a620a26..fb82cd785 100644
--- a/configs/td_config.yaml.example
+++ b/configs/td_config.yaml.example
@@ -85,14 +85,27 @@ use_ipadapter: false
+# CUDA IPC zero-copy GPU-to-GPU output (SDβTD via cuda-link)
+use_cuda_ipc_output: false # set true to enable
+cuda_ipc_shm_name: 'StreamDiffusionTD_512-512_output_ipc'
+cuda_ipc_num_slots: 3
+output_type: 'np'
+
+# CUDA IPC zero-copy GPU-to-GPU input (TDβSD via cuda-link)
+# When true, SD reads input frames from td_settings.cuda_ipc_input_shm_name
+# instead of the legacy CPU SharedMemory at td_settings.input_mem_name.
+use_cuda_ipc_input: false
+
# TouchDesigner specific settings
td_settings:
# OSC communication
osc_receive_port: 8576
osc_transmit_port: 8588
- # Memory interface
+ # Legacy CPU SharedMemory names (used when use_cuda_ipc_output is false)
input_mem_name: 'StreamDiffusionTD_512-512'
+ # Reserved for future TDβSD IPC input direction (not wired yet)
+ cuda_ipc_input_shm_name: 'StreamDiffusionTD_512-512_input_ipc'
output_mem_name: 'StreamDiffusionTD_512-512_out'
# Debug settings
diff --git a/demo/realtime-img2img/app_config.py b/demo/realtime-img2img/app_config.py
index 9c21e58ce..252a992a5 100644
--- a/demo/realtime-img2img/app_config.py
+++ b/demo/realtime-img2img/app_config.py
@@ -1,47 +1,52 @@
"""
Application configuration and settings for realtime-img2img
"""
-import yaml
+
import logging
from pathlib import Path
+import yaml
+
+
def load_controlnet_registry():
"""Load ControlNet registry from config file"""
try:
registry_path = Path(__file__).parent / "controlnet_registry.yaml"
- with open(registry_path, 'r') as f:
+ with open(registry_path, "r") as f:
config_data = yaml.safe_load(f)
-
+
# Extract the available_controlnets section
- return config_data.get('available_controlnets', {})
+ return config_data.get("available_controlnets", {})
except Exception as e:
logging.exception(f"load_controlnet_registry: Failed to load ControlNet registry: {e}")
# Fallback to empty registry
return {}
+
def load_default_settings():
"""Load default settings from YAML config file"""
try:
registry_path = Path(__file__).parent / "controlnet_registry.yaml"
- with open(registry_path, 'r') as f:
+ with open(registry_path, "r") as f:
config_data = yaml.safe_load(f)
-
- return config_data.get('defaults', {})
+
+ return config_data.get("defaults", {})
except Exception as e:
logging.exception(f"load_default_settings: Failed to load default settings: {e}")
# Fallback to hardcoded defaults
return {
- 'guidance_scale': 1.1,
- 'delta': 0.7,
- 'num_inference_steps': 50,
- 'seed': 2,
- 't_index_list': [35, 45],
- 'ipadapter_scale': 1.0,
- 'normalize_prompt_weights': True,
- 'normalize_seed_weights': True,
- 'prompt': "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece"
+ "guidance_scale": 1.1,
+ "delta": 0.7,
+ "num_inference_steps": 50,
+ "seed": 2,
+ "t_index_list": [35, 45],
+ "ipadapter_scale": 1.0,
+ "normalize_prompt_weights": True,
+ "normalize_seed_weights": True,
+ "prompt": "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece",
}
+
# Load configuration at module level
AVAILABLE_CONTROLNETS = load_controlnet_registry()
DEFAULT_SETTINGS = load_default_settings()
diff --git a/demo/realtime-img2img/config.py b/demo/realtime-img2img/config.py
index 56d774048..48e31e18f 100644
--- a/demo/realtime-img2img/config.py
+++ b/demo/realtime-img2img/config.py
@@ -1,6 +1,6 @@
-from typing import NamedTuple
import argparse
import os
+from typing import NamedTuple
class Args(NamedTuple):
@@ -45,9 +45,7 @@ def pretty_print(self):
parser.add_argument("--host", type=str, default=default_host, help="Host address")
parser.add_argument("--port", type=int, default=default_port, help="Port number")
parser.add_argument("--reload", action="store_true", help="Reload code on change")
-parser.add_argument(
- "--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img"
-)
+parser.add_argument("--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img")
parser.add_argument(
"--max-queue-size",
dest="max_queue_size",
diff --git a/demo/realtime-img2img/connection_manager.py b/demo/realtime-img2img/connection_manager.py
index ae2072198..e35c09df7 100644
--- a/demo/realtime-img2img/connection_manager.py
+++ b/demo/realtime-img2img/connection_manager.py
@@ -1,10 +1,12 @@
+import asyncio
+import logging
+from types import SimpleNamespace
from typing import Dict, Union
from uuid import UUID
-import asyncio
+
from fastapi import WebSocket
from starlette.websockets import WebSocketState
-import logging
-from types import SimpleNamespace
+
Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
@@ -20,9 +22,7 @@ def __init__(self):
self.active_connections: Connections = {}
self.latest_data: Dict[UUID, SimpleNamespace] = {} # Store latest parameters for HTTP streaming
- async def connect(
- self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
- ):
+ async def connect(self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0):
await websocket.accept()
user_count = self.get_user_count()
print(f"User count: {user_count}")
@@ -61,7 +61,7 @@ async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
return await queue.get()
except asyncio.QueueEmpty:
return None
-
+
def get_latest_data_sync(self, user_id: UUID) -> SimpleNamespace:
"""Get the latest data without consuming it from the queue (for HTTP streaming)"""
return self.latest_data.get(user_id)
diff --git a/demo/realtime-img2img/img2img.py b/demo/realtime-img2img/img2img.py
index 06067b342..a1e6ada8c 100644
--- a/demo/realtime-img2img/img2img.py
+++ b/demo/realtime-img2img/img2img.py
@@ -1,6 +1,7 @@
-import sys
-import os
import logging
+import os
+import sys
+
sys.path.append(
os.path.join(
@@ -12,10 +13,11 @@
# Config system functions are now used only in main.py
+
import torch
-from pydantic import BaseModel, Field
from PIL import Image
-from typing import Optional
+from pydantic import BaseModel, Field
+
# Default values for pipeline parameters
default_negative_prompt = "black and white, blurry, low resolution, pixelated, pixel art, low quality, low fidelity"
@@ -78,26 +80,22 @@ class InputParams(BaseModel):
"768x512 (3:2)",
"896x640 (7:5)",
"1024x768 (4:3)",
- "1024x576 (16:9)"
- ]
- )
- width: int = Field(
- 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
- )
- height: int = Field(
- 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
+ "1024x576 (16:9)",
+ ],
)
+ width: int = Field(512, min=2, max=15, title="Width", disabled=True, hide=True, id="width")
+ height: int = Field(512, min=2, max=15, title="Height", disabled=True, hide=True, id="height")
-#TODO update naming convention to reflect the controlnet agnostic nature of the config system (pipeline_config instead of controlnet_config for example)
+ # TODO update naming convention to reflect the controlnet agnostic nature of the config system (pipeline_config instead of controlnet_config for example)
def __init__(self, wrapper, config):
"""
Initialize Pipeline with pre-created wrapper and config.
-
+
Args:
wrapper: Pre-created StreamDiffusionWrapper instance
config: Configuration dictionary used to create the wrapper
"""
-
+
# IPAdapter state tracking for optimization
self._last_ipadapter_source_type = None
self._last_ipadapter_source_data = None
@@ -106,16 +104,16 @@ def __init__(self, wrapper, config):
self.stream = wrapper
self.config = config
self.use_config = True
-
+
# Extract pipeline configuration from config
- self.pipeline_mode = self.config.get('mode', 'img2img')
- self.has_controlnet = 'controlnets' in self.config and len(self.config['controlnets']) > 0
- self.has_ipadapter = 'ipadapters' in self.config and len(self.config['ipadapters']) > 0
-
+ self.pipeline_mode = self.config.get("mode", "img2img")
+ self.has_controlnet = "controlnets" in self.config and len(self.config["controlnets"]) > 0
+ self.has_ipadapter = "ipadapters" in self.config and len(self.config["ipadapters"]) > 0
+
# Store config values for later use
- self.negative_prompt = self.config.get('negative_prompt', default_negative_prompt)
- self.guidance_scale = self.config.get('guidance_scale', 1.2)
- self.num_inference_steps = self.config.get('num_inference_steps', 50)
+ self.negative_prompt = self.config.get("negative_prompt", default_negative_prompt)
+ self.guidance_scale = self.config.get("guidance_scale", 1.2)
+ self.num_inference_steps = self.config.get("num_inference_steps", 50)
# Update input_mode based on pipeline mode
self.info = self.Info()
@@ -129,23 +127,23 @@ def __init__(self, wrapper, config):
self.guidance_scale = 1.1
self.num_inference_steps = 50
self.negative_prompt = default_negative_prompt
-
+
# Store output type for frame conversion - always force "pt" for optimal performance
self.output_type = "pt"
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
# Get input manager if available (passed from websocket handler)
- input_manager = getattr(params, 'input_manager', None)
-
+ input_manager = getattr(params, "input_manager", None)
+
# Handle different modes
if self.pipeline_mode == "txt2img":
# Text-to-image mode
-
+
# Handle ControlNet updates if enabled
if self.has_controlnet:
try:
stream_state = self.stream.get_stream_state()
- current_cfg = stream_state.get('controlnet_config', [])
+ current_cfg = stream_state.get("controlnet_config", [])
except Exception:
current_cfg = []
if current_cfg:
@@ -154,11 +152,11 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
control_image = self._get_controlnet_input(input_manager, i, params.image)
if control_image is not None:
self.stream.update_control_image(index=i, image=control_image)
-
+
# Handle IPAdapter updates if enabled
if self.has_ipadapter:
self._update_ipadapter_style_image(input_manager)
-
+
# Generate output based on what's enabled
if self.has_controlnet and not self.has_ipadapter:
# ControlNet only: use base input for generation
@@ -176,12 +174,12 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
output_image = self.stream()
else:
# Image-to-image mode: use original logic
-
+
# Handle ControlNet updates if enabled
if self.has_controlnet:
try:
stream_state = self.stream.get_stream_state()
- current_cfg = stream_state.get('controlnet_config', [])
+ current_cfg = stream_state.get("controlnet_config", [])
except Exception:
current_cfg = []
if current_cfg:
@@ -190,11 +188,11 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
control_image = self._get_controlnet_input(input_manager, i, params.image)
if control_image is not None:
self.stream.update_control_image(index=i, image=control_image)
-
+
# Handle IPAdapter updates if enabled
if self.has_ipadapter:
self._update_ipadapter_style_image(input_manager)
-
+
# Generate output based on what's enabled
if self.has_controlnet or self.has_ipadapter:
# ControlNet and/or IPAdapter: use base input for img2img
@@ -216,150 +214,153 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
def _get_controlnet_input(self, input_manager, index: int, fallback_image):
"""
Get input image for a specific ControlNet index.
-
+
Args:
input_manager: InputSourceManager instance (can be None)
index: ControlNet index
fallback_image: Fallback image if no specific source is configured
-
+
Returns:
Input image for the ControlNet or fallback
"""
if input_manager:
- frame = input_manager.get_frame('controlnet', index)
+ frame = input_manager.get_frame("controlnet", index)
if frame is not None:
return frame
-
+
# Fallback to main image input
return fallback_image
-
+
def _get_ipadapter_input(self, input_manager):
"""
Get input image for IPAdapter.
-
+
Args:
input_manager: InputSourceManager instance (can be None)
-
+
Returns:
Input image for IPAdapter or None
"""
if input_manager:
- return input_manager.get_frame('ipadapter')
+ return input_manager.get_frame("ipadapter")
return None
-
+
def _update_ipadapter_style_image(self, input_manager):
"""
Update IPAdapter style image from InputSourceManager.
Only updates when source actually changes to avoid unnecessary processing.
-
+
Args:
input_manager: InputSourceManager instance (can be None)
"""
if not input_manager or not self.has_ipadapter:
return
-
+
try:
# Get current source info to check if it changed
- source_info = input_manager.get_source_info('ipadapter')
- current_source_type = source_info.get('source_type')
- current_source_data = source_info.get('source_data')
- is_stream = source_info.get('is_stream', False)
-
+ source_info = input_manager.get_source_info("ipadapter")
+ current_source_type = source_info.get("source_type")
+ current_source_data = source_info.get("source_data")
+ is_stream = source_info.get("is_stream", False)
+
# Check if source changed (for static images, only update when source changes)
source_changed = (
- current_source_type != self._last_ipadapter_source_type or
- current_source_data != self._last_ipadapter_source_data
+ current_source_type != self._last_ipadapter_source_type
+ or current_source_data != self._last_ipadapter_source_data
)
-
+
# For streaming sources (webcam/video), always get fresh frame
# For static sources (uploaded image), only update when source changes
should_update = is_stream or source_changed
-
+
if not should_update:
return # No update needed - static source unchanged
-
+
# Get IPAdapter style image from input source manager
- ipadapter_frame = input_manager.get_frame('ipadapter')
-
+ ipadapter_frame = input_manager.get_frame("ipadapter")
+
if ipadapter_frame is not None:
import torch
-
+
# Use tensor directly - update_style_image expects torch tensor
if isinstance(ipadapter_frame, torch.Tensor):
try:
# Update IPAdapter with tensor and stream configuration
self.stream.update_style_image(ipadapter_frame, is_stream=is_stream)
- self.stream.update_stream_params(ipadapter_config={'is_stream': is_stream})
-
+ self.stream.update_stream_params(ipadapter_config={"is_stream": is_stream})
+
# Force prompt re-encoding to apply new style image embeddings
# This is critical because IPAdapter embedding hook only runs during prompt encoding
try:
state = self.stream.get_stream_state()
- current_prompts = state.get('prompt_list', [])
+ current_prompts = state.get("prompt_list", [])
if current_prompts:
self.stream.update_prompt(current_prompts, prompt_interpolation_method="slerp")
except Exception as e:
- logging.exception(f"_update_ipadapter_style_image: Failed to force prompt re-encoding: {e}")
-
-
+ logging.exception(
+ f"_update_ipadapter_style_image: Failed to force prompt re-encoding: {e}"
+ )
+
# Update tracking variables only on successful update
self._last_ipadapter_source_type = current_source_type
self._last_ipadapter_source_data = current_source_data
-
+
except Exception as e:
logging.exception(f"_update_ipadapter_style_image: Failed to update IPAdapter: {e}")
else:
- logging.warning("_update_ipadapter_style_image: IPAdapter frame is not a tensor, skipping style image update")
+ logging.warning(
+ "_update_ipadapter_style_image: IPAdapter frame is not a tensor, skipping style image update"
+ )
except Exception as e:
logging.exception(f"_update_ipadapter_style_image: Error updating IPAdapter style image: {e}")
-
+
def _get_base_input(self, input_manager, fallback_image):
"""
Get input image for base pipeline.
-
+
Args:
input_manager: InputSourceManager instance (can be None)
fallback_image: Fallback image if no specific source is configured
-
+
Returns:
Input image for base pipeline or fallback
"""
if input_manager:
- frame = input_manager.get_frame('base')
+ frame = input_manager.get_frame("base")
if frame is not None:
return frame
-
+
# Fallback to main image input
return fallback_image
def update_ipadapter_config(self, scale: float = None, style_image: Image.Image = None) -> bool:
"""
Update IPAdapter configuration in real-time using direct methods
-
+
Args:
scale: New IPAdapter scale value (optional)
style_image: New style image (PIL Image, optional)
-
+
Returns:
bool: True if successful, False otherwise
"""
if not self.has_ipadapter:
return False
-
+
if scale is None and style_image is None:
return False # Nothing to update
-
+
try:
# Update scale via unified config system (no direct method needed)
if scale is not None:
- self.stream.update_stream_params(ipadapter_config={'scale': scale})
-
+ self.stream.update_stream_params(ipadapter_config={"scale": scale})
+
# Update style image via direct method
if style_image is not None:
self.stream.update_style_image(style_image)
-
+
return True
- except Exception as e:
+ except Exception:
return False
def update_ipadapter_scale(self, scale: float) -> bool:
@@ -374,21 +375,21 @@ def update_ipadapter_weight_type(self, weight_type: str) -> bool:
"""Update IPAdapter weight type in real-time"""
if not self.has_ipadapter:
return False
-
+
try:
# Use unified updater on wrapper
- if hasattr(self.stream, 'update_stream_params'):
- self.stream.update_stream_params(ipadapter_config={ 'weight_type': weight_type })
+ if hasattr(self.stream, "update_stream_params"):
+ self.stream.update_stream_params(ipadapter_config={"weight_type": weight_type})
return True
# Should not reach here in normal operation
return False
- except Exception as e:
+ except Exception:
return False
def get_ipadapter_info(self) -> dict:
"""
Get current IPAdapter information
-
+
Returns:
dict: IPAdapter information including scale, model info, etc.
"""
@@ -397,35 +398,37 @@ def get_ipadapter_info(self) -> dict:
"scale": 1.0,
"weight_type": "linear",
"model_path": None,
- "style_image_set": False
+ "style_image_set": False,
}
-
- if self.has_ipadapter and self.config and 'ipadapters' in self.config:
+
+ if self.has_ipadapter and self.config and "ipadapters" in self.config:
# Get info from first IPAdapter config
- if len(self.config['ipadapters']) > 0:
- ipadapter_config = self.config['ipadapters'][0]
- info["scale"] = ipadapter_config.get('scale', 1.0)
- info["weight_type"] = ipadapter_config.get('weight_type', 'linear')
- info["model_path"] = ipadapter_config.get('ipadapter_model_path')
- info["style_image_set"] = 'style_image' in ipadapter_config
-
+ if len(self.config["ipadapters"]) > 0:
+ ipadapter_config = self.config["ipadapters"][0]
+ info["scale"] = ipadapter_config.get("scale", 1.0)
+ info["weight_type"] = ipadapter_config.get("weight_type", "linear")
+ info["model_path"] = ipadapter_config.get("ipadapter_model_path")
+ info["style_image_set"] = "style_image" in ipadapter_config
+
# Get current runtime state from wrapper's public API
try:
- if hasattr(self.stream, 'get_stream_state'):
+ if hasattr(self.stream, "get_stream_state"):
stream_state = self.stream.get_stream_state()
- ipadapter_runtime_config = stream_state.get('ipadapter_config', {})
+ ipadapter_runtime_config = stream_state.get("ipadapter_config", {})
if ipadapter_runtime_config:
- info["scale"] = ipadapter_runtime_config.get('scale', info.get("scale", 1.0))
- info["weight_type"] = ipadapter_runtime_config.get('weight_type', info.get("weight_type", 'linear'))
+ info["scale"] = ipadapter_runtime_config.get("scale", info.get("scale", 1.0))
+ info["weight_type"] = ipadapter_runtime_config.get(
+ "weight_type", info.get("weight_type", "linear")
+ )
except Exception:
pass # Use defaults from config if wrapper method fails
-
+
return info
def update_stream_params(self, **kwargs):
"""
Update streaming parameters using the consolidated API
-
+
Args:
**kwargs: All parameters supported by StreamDiffusionWrapper.update_stream_params()
including controlnet_config, guidance_scale, delta, etc.
diff --git a/demo/realtime-img2img/input_control.py b/demo/realtime-img2img/input_control.py
index c4f41359c..be1e3fe03 100644
--- a/demo/realtime-img2img/input_control.py
+++ b/demo/realtime-img2img/input_control.py
@@ -1,48 +1,48 @@
-from abc import ABC, abstractmethod
-from typing import Dict, Any, Callable, Optional
import asyncio
+import logging
import threading
import time
-import logging
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, Optional
class InputControl(ABC):
"""Generic interface for input controls that can modify parameters"""
-
+
def __init__(self, parameter_name: str, min_value: float = 0.0, max_value: float = 1.0):
self.parameter_name = parameter_name
self.min_value = min_value
self.max_value = max_value
self.is_active = False
self.update_callback: Optional[Callable[[str, float], None]] = None
-
+
@abstractmethod
async def start(self) -> None:
"""Start the input control"""
pass
-
+
@abstractmethod
async def stop(self) -> None:
"""Stop the input control"""
pass
-
+
@abstractmethod
def get_current_value(self) -> float:
"""Get the current normalized value (0.0 to 1.0)"""
pass
-
+
def set_update_callback(self, callback: Callable[[str, float], None]) -> None:
"""Set callback for parameter updates"""
self.update_callback = callback
-
+
def normalize_value(self, raw_value: float) -> float:
"""Normalize raw input value to 0.0-1.0 range"""
return max(0.0, min(1.0, raw_value))
-
+
def scale_to_parameter(self, normalized_value: float) -> float:
"""Scale normalized value to parameter range"""
return self.min_value + (normalized_value * (self.max_value - self.min_value))
-
+
def _trigger_update(self, normalized_value: float) -> None:
"""Trigger parameter update if callback is set"""
if self.update_callback:
@@ -52,9 +52,16 @@ def _trigger_update(self, normalized_value: float) -> None:
class GamepadInput(InputControl):
"""Gamepad input control for parameter modification"""
-
- def __init__(self, parameter_name: str, min_value: float = 0.0, max_value: float = 1.0,
- gamepad_index: int = 0, axis_index: int = 0, deadzone: float = 0.1):
+
+ def __init__(
+ self,
+ parameter_name: str,
+ min_value: float = 0.0,
+ max_value: float = 1.0,
+ gamepad_index: int = 0,
+ axis_index: int = 0,
+ deadzone: float = 0.1,
+ ):
super().__init__(parameter_name, min_value, max_value)
self.gamepad_index = gamepad_index
self.axis_index = axis_index
@@ -62,78 +69,78 @@ def __init__(self, parameter_name: str, min_value: float = 0.0, max_value: float
self.current_value = 0.0
self._stop_event = threading.Event()
self._thread = None
-
+
async def start(self) -> None:
"""Start gamepad monitoring"""
if self.is_active:
return
-
+
self.is_active = True
self._stop_event.clear()
self._thread = threading.Thread(target=self._monitor_gamepad, daemon=True)
self._thread.start()
logging.info(f"GamepadInput: Started monitoring gamepad {self.gamepad_index}, axis {self.axis_index}")
-
+
async def stop(self) -> None:
"""Stop gamepad monitoring"""
if not self.is_active:
return
-
+
self.is_active = False
self._stop_event.set()
-
+
if self._thread and self._thread.is_alive():
self._thread.join(timeout=1.0)
-
+
logging.info(f"GamepadInput: Stopped monitoring gamepad {self.gamepad_index}, axis {self.axis_index}")
-
+
def get_current_value(self) -> float:
"""Get current normalized value"""
return self.current_value
-
+
def _monitor_gamepad(self) -> None:
"""Monitor gamepad input in background thread"""
try:
import pygame
-
+
# Initialize pygame for gamepad support
pygame.init()
pygame.joystick.init()
-
+
# Check if gamepad is available
if pygame.joystick.get_count() <= self.gamepad_index:
logging.error(f"GamepadInput: Gamepad {self.gamepad_index} not found")
return
-
+
# Initialize the gamepad
joystick = pygame.joystick.Joystick(self.gamepad_index)
joystick.init()
-
+
logging.info(f"GamepadInput: Connected to {joystick.get_name()}")
-
+
# Monitor gamepad input
while not self._stop_event.is_set():
pygame.event.pump()
-
+
# Get axis value
if self.axis_index < joystick.get_numaxes():
raw_value = joystick.get_axis(self.axis_index)
-
+
# Apply deadzone
if abs(raw_value) < self.deadzone:
raw_value = 0.0
-
+
# Convert from [-1, 1] to [0, 1] range
normalized_value = (raw_value + 1.0) / 2.0
-
+
# Update current value
self.current_value = normalized_value
-
+
# Trigger update callback
self._trigger_update(normalized_value)
-
+
time.sleep(0.016) # ~60 FPS polling
-
+
except ImportError:
logging.error("GamepadInput: pygame not installed. Install with: pip install pygame")
except Exception as e:
@@ -150,48 +157,48 @@ def _monitor_gamepad(self) -> None:
class InputManager:
"""Manages multiple input controls"""
-
+
def __init__(self):
self.inputs: Dict[str, InputControl] = {}
self.parameter_update_callback: Optional[Callable[[str, float], None]] = None
-
+
def add_input(self, input_id: str, input_control: InputControl) -> None:
"""Add an input control"""
input_control.set_update_callback(self._handle_parameter_update)
self.inputs[input_id] = input_control
logging.info(f"InputManager: Added input control {input_id} for parameter {input_control.parameter_name}")
-
+
def remove_input(self, input_id: str) -> None:
"""Remove an input control"""
if input_id in self.inputs:
asyncio.create_task(self.inputs[input_id].stop())
del self.inputs[input_id]
logging.info(f"InputManager: Removed input control {input_id}")
-
+
async def start_input(self, input_id: str) -> None:
"""Start a specific input control"""
if input_id in self.inputs:
await self.inputs[input_id].start()
-
+
async def stop_input(self, input_id: str) -> None:
"""Stop a specific input control"""
if input_id in self.inputs:
await self.inputs[input_id].stop()
-
+
async def start_all(self) -> None:
"""Start all input controls"""
for input_control in self.inputs.values():
await input_control.start()
-
+
async def stop_all(self) -> None:
"""Stop all input controls"""
for input_control in self.inputs.values():
await input_control.stop()
-
+
def set_parameter_update_callback(self, callback: Callable[[str, float], None]) -> None:
"""Set callback for parameter updates from any input"""
self.parameter_update_callback = callback
-
+
def get_input_status(self) -> Dict[str, Dict[str, Any]]:
"""Get status of all inputs"""
status = {}
@@ -201,11 +208,11 @@ def get_input_status(self) -> Dict[str, Dict[str, Any]]:
"is_active": input_control.is_active,
"current_value": input_control.get_current_value(),
"min_value": input_control.min_value,
- "max_value": input_control.max_value
+ "max_value": input_control.max_value,
}
return status
-
+
def _handle_parameter_update(self, parameter_name: str, value: float) -> None:
"""Handle parameter update from input controls"""
if self.parameter_update_callback:
- self.parameter_update_callback(parameter_name, value)
\ No newline at end of file
+ self.parameter_update_callback(parameter_name, value)
diff --git a/demo/realtime-img2img/input_sources.py b/demo/realtime-img2img/input_sources.py
index d539845a7..b50a55b2a 100644
--- a/demo/realtime-img2img/input_sources.py
+++ b/demo/realtime-img2img/input_sources.py
@@ -8,19 +8,19 @@
import logging
from enum import Enum
-from typing import Dict, Optional, Union, Any
from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import numpy as np
import torch
from PIL import Image
-import cv2
-import numpy as np
-
from util import bytes_to_pt
from utils.video_utils import VideoFrameExtractor
class InputSourceType(Enum):
"""Types of input sources available."""
+
WEBCAM = "webcam"
UPLOADED_IMAGE = "uploaded_image"
UPLOADED_VIDEO = "uploaded_video"
@@ -29,15 +29,15 @@ class InputSourceType(Enum):
class InputSource:
"""
Represents an input source for a component.
-
+
Handles different types of inputs (webcam, image, video) and provides
a unified interface to get the current frame as a tensor.
"""
-
+
def __init__(self, source_type: InputSourceType, source_data: Any = None):
"""
Initialize an input source.
-
+
Args:
source_type: Type of input source
source_data: Data for the source (PIL Image, video path, or None for webcam)
@@ -48,11 +48,11 @@ def __init__(self, source_type: InputSourceType, source_data: Any = None):
self._current_frame = None
self._video_extractor = None
self._logger = logging.getLogger(f"InputSource.{source_type.value}")
-
+
# Initialize video extractor if needed
if source_type == InputSourceType.UPLOADED_VIDEO and source_data:
self._init_video_extractor()
-
+
def _init_video_extractor(self):
"""Initialize video extractor for video input sources."""
if self.source_data and Path(self.source_data).exists():
@@ -64,11 +64,11 @@ def _init_video_extractor(self):
self._video_extractor = None
else:
self._logger.error(f"Video file not found: {self.source_data}")
-
+
def get_frame(self) -> Optional[torch.Tensor]:
"""
Get the current frame as a PyTorch tensor.
-
+
Returns:
torch.Tensor: Current frame with shape (C, H, W), values in [0, 1], dtype float32
None: If no frame is available
@@ -77,7 +77,7 @@ def get_frame(self) -> Optional[torch.Tensor]:
if self.source_type == InputSourceType.WEBCAM:
# For webcam, return cached frame (will be updated externally)
return self._current_frame
-
+
elif self.source_type == InputSourceType.UPLOADED_IMAGE:
# For static image, convert to tensor if not already done
if self._current_frame is None and self.source_data:
@@ -91,35 +91,35 @@ def get_frame(self) -> Optional[torch.Tensor]:
elif isinstance(self.source_data, bytes):
# Convert bytes to tensor using existing utility
self._current_frame = bytes_to_pt(self.source_data)
-
+
return self._current_frame
-
+
elif self.source_type == InputSourceType.UPLOADED_VIDEO:
# For video, get next frame
return self._get_video_frame()
-
+
except Exception as e:
self._logger.error(f"Error getting frame from {self.source_type.value}: {e}")
-
+
return None
-
+
def _get_video_frame(self) -> Optional[torch.Tensor]:
"""Get the next frame from video source."""
if not self._video_extractor:
return None
-
+
return self._video_extractor.get_frame()
-
+
def update_webcam_frame(self, frame_data: Union[bytes, torch.Tensor]):
"""
Update the current frame for webcam sources.
-
+
Args:
frame_data: Frame data as bytes or tensor
"""
if self.source_type != InputSourceType.WEBCAM:
return
-
+
try:
if isinstance(frame_data, bytes):
self._current_frame = bytes_to_pt(frame_data)
@@ -127,7 +127,7 @@ def update_webcam_frame(self, frame_data: Union[bytes, torch.Tensor]):
self._current_frame = frame_data
except Exception as e:
self._logger.error(f"Error updating webcam frame: {e}")
-
+
def cleanup(self):
"""Clean up resources."""
if self._video_extractor:
@@ -138,211 +138,224 @@ def cleanup(self):
class InputSourceManager:
"""
Manages input sources for different components in the pipeline.
-
+
Provides a centralized way to set and get input sources for:
- ControlNet instances (indexed)
- - IPAdapter
+ - IPAdapter
- Base pipeline
"""
-
+
def __init__(self):
"""Initialize the input source manager."""
self.sources = {
- 'controlnet': {}, # {index: InputSource}
- 'ipadapter': None, # Single InputSource
- 'base': None # Single InputSource for main pipeline
+ "controlnet": {}, # {index: InputSource}
+ "ipadapter": None, # Single InputSource
+ "base": None, # Single InputSource for main pipeline
}
self._logger = logging.getLogger("InputSourceManager")
-
+
# Default to webcam for base pipeline
- self.sources['base'] = InputSource(InputSourceType.WEBCAM)
-
+ self.sources["base"] = InputSource(InputSourceType.WEBCAM)
+
# Default IPAdapter to uploaded_image with default image
self._init_default_ipadapter_source()
-
+
def set_source(self, component: str, source: InputSource, index: Optional[int] = None):
"""
Set input source for a component.
-
+
Args:
component: Component name ('controlnet', 'ipadapter', 'base')
source: InputSource instance
index: Index for ControlNet instances (required for 'controlnet')
"""
try:
- if component == 'controlnet':
+ if component == "controlnet":
if index is None:
raise ValueError("Index is required for ControlNet components")
-
+
# Clean up existing source if any
- if index in self.sources['controlnet']:
- self.sources['controlnet'][index].cleanup()
-
- self.sources['controlnet'][index] = source
+ if index in self.sources["controlnet"]:
+ self.sources["controlnet"][index].cleanup()
+
+ self.sources["controlnet"][index] = source
self._logger.info(f"Set ControlNet {index} input source to {source.source_type.value}")
-
- elif component in ['ipadapter', 'base']:
+
+ elif component in ["ipadapter", "base"]:
# Clean up existing source if any
if self.sources[component]:
self.sources[component].cleanup()
-
+
self.sources[component] = source
self._logger.info(f"Set {component} input source to {source.source_type.value}")
-
+
else:
raise ValueError(f"Unknown component: {component}")
-
+
except Exception as e:
self._logger.error(f"Error setting source for {component}: {e}")
-
+
def get_frame(self, component: str, index: Optional[int] = None) -> Optional[torch.Tensor]:
"""
Get current frame for a component.
-
+
Args:
component: Component name ('controlnet', 'ipadapter', 'base')
index: Index for ControlNet instances (required for 'controlnet')
-
+
Returns:
torch.Tensor: Current frame or None if not available
"""
try:
- if component == 'controlnet':
+ if component == "controlnet":
if index is None:
raise ValueError("Index is required for ControlNet components")
-
+
# Ensure ControlNet is initialized with default webcam source
self._ensure_controlnet_initialized(index)
- source = self.sources['controlnet'][index]
-
+ source = self.sources["controlnet"][index]
+
frame = source.get_frame()
if frame is not None:
return frame
-
+
# If webcam source has no frame yet, fallback to base pipeline input
self._logger.debug(f"ControlNet {index} webcam has no frame yet, falling back to base")
return self._get_fallback_frame()
-
- elif component in ['ipadapter', 'base']:
+
+ elif component in ["ipadapter", "base"]:
source = self.sources[component]
if source:
frame = source.get_frame()
if frame is not None:
return frame
-
+
# Fallback to base pipeline input if not base itself
- if component != 'base':
+ if component != "base":
self._logger.debug(f"{component} has no input, falling back to base")
return self._get_fallback_frame()
-
+
except Exception as e:
self._logger.error(f"Error getting frame for {component}: {e}")
-
+
return None
-
+
def _get_fallback_frame(self) -> Optional[torch.Tensor]:
"""Get frame from base pipeline as fallback."""
- base_source = self.sources['base']
+ base_source = self.sources["base"]
if base_source:
return base_source.get_frame()
return None
-
+
def update_webcam_frame(self, frame_data: Union[bytes, torch.Tensor]):
"""
Update webcam frame for all webcam sources.
-
+
Args:
frame_data: Frame data as bytes or tensor
"""
# Update base pipeline if it's webcam
- if (self.sources['base'] and
- self.sources['base'].source_type == InputSourceType.WEBCAM):
- self.sources['base'].update_webcam_frame(frame_data)
-
+ if self.sources["base"] and self.sources["base"].source_type == InputSourceType.WEBCAM:
+ self.sources["base"].update_webcam_frame(frame_data)
+
# Update ControlNet webcam sources
- for source in self.sources['controlnet'].values():
+ for source in self.sources["controlnet"].values():
if source.source_type == InputSourceType.WEBCAM:
source.update_webcam_frame(frame_data)
-
+
# Update IPAdapter if it's webcam
- if (self.sources['ipadapter'] and
- self.sources['ipadapter'].source_type == InputSourceType.WEBCAM):
- self.sources['ipadapter'].update_webcam_frame(frame_data)
-
+ if self.sources["ipadapter"] and self.sources["ipadapter"].source_type == InputSourceType.WEBCAM:
+ self.sources["ipadapter"].update_webcam_frame(frame_data)
+
def _ensure_controlnet_initialized(self, index: int):
"""
Ensure a ControlNet has a default webcam source if not already set.
-
+
Args:
index: ControlNet index
"""
- if index not in self.sources['controlnet']:
- self.sources['controlnet'][index] = InputSource(InputSourceType.WEBCAM)
+ if index not in self.sources["controlnet"]:
+ self.sources["controlnet"][index] = InputSource(InputSourceType.WEBCAM)
self._logger.info(f"_ensure_controlnet_initialized: Initialized ControlNet {index} with webcam source")
def get_source_info(self, component: str, index: Optional[int] = None) -> Dict[str, Any]:
"""
Get information about a component's input source.
-
+
Returns:
Dictionary with source type and metadata
"""
try:
- if component == 'controlnet':
+ if component == "controlnet":
if index is None:
- return {'source_type': 'error', 'source_data': 'index_required', 'is_stream': False, 'has_data': False}
-
+ return {
+ "source_type": "error",
+ "source_data": "index_required",
+ "is_stream": False,
+ "has_data": False,
+ }
+
# Ensure ControlNet is initialized with default webcam source
self._ensure_controlnet_initialized(index)
- source = self.sources['controlnet'][index]
-
- elif component in ['ipadapter', 'base']:
+ source = self.sources["controlnet"][index]
+
+ elif component in ["ipadapter", "base"]:
source = self.sources[component]
if not source:
- return {'source_type': 'none', 'source_data': None, 'is_stream': False, 'has_data': False}
+ return {"source_type": "none", "source_data": None, "is_stream": False, "has_data": False}
else:
- return {'source_type': 'unknown', 'source_data': None, 'is_stream': False, 'has_data': False}
-
+ return {"source_type": "unknown", "source_data": None, "is_stream": False, "has_data": False}
+
return {
- 'source_type': source.source_type.value,
- 'source_data': source.source_data,
- 'is_stream': source.is_stream,
- 'has_data': source.source_data is not None
+ "source_type": source.source_type.value,
+ "source_data": source.source_data,
+ "is_stream": source.is_stream,
+ "has_data": source.source_data is not None,
}
-
+
except Exception as e:
self._logger.error(f"Error getting source info for {component}: {e}")
- return {'source_type': 'error', 'source_data': None, 'is_stream': False, 'has_data': False, 'error': str(e)}
-
+ return {
+ "source_type": "error",
+ "source_data": None,
+ "is_stream": False,
+ "has_data": False,
+ "error": str(e),
+ }
+
def _init_default_ipadapter_source(self):
"""Initialize IPAdapter with default image source."""
try:
import os
+
from PIL import Image
-
+
# Try to load default image
default_image_path = os.path.join(os.path.dirname(__file__), "..", "..", "images", "inputs", "input.png")
if os.path.exists(default_image_path):
default_image = Image.open(default_image_path).convert("RGB")
- self.sources['ipadapter'] = InputSource(InputSourceType.UPLOADED_IMAGE, default_image)
+ self.sources["ipadapter"] = InputSource(InputSourceType.UPLOADED_IMAGE, default_image)
self._logger.info("_init_default_ipadapter_source: Initialized IPAdapter with default image")
else:
- self._logger.warning("_init_default_ipadapter_source: Default image not found, IPAdapter will have no source")
+ self._logger.warning(
+ "_init_default_ipadapter_source: Default image not found, IPAdapter will have no source"
+ )
except Exception as e:
self._logger.error(f"_init_default_ipadapter_source: Error loading default image: {e}")
-
+
def load_config_style_image(self, style_image_path: str, base_config_path: str = None):
"""
Load IPAdapter style image from config file path.
-
+
Args:
style_image_path: Path to style image (can be relative)
base_config_path: Base path for resolving relative paths
"""
try:
import os
+
from PIL import Image
-
+
# Handle relative paths
if not os.path.isabs(style_image_path):
if base_config_path:
@@ -355,17 +368,19 @@ def load_config_style_image(self, style_image_path: str, base_config_path: str =
if not os.path.exists(style_image_path):
self._logger.warning(f"load_config_style_image: Style image not found: {style_image_path}")
return
-
+
if os.path.exists(style_image_path):
style_image = Image.open(style_image_path).convert("RGB")
input_source = InputSource(InputSourceType.UPLOADED_IMAGE, style_image)
- self.set_source('ipadapter', input_source)
- self._logger.info(f"load_config_style_image: Loaded IPAdapter style image from config: {style_image_path}")
+ self.set_source("ipadapter", input_source)
+ self._logger.info(
+ f"load_config_style_image: Loaded IPAdapter style image from config: {style_image_path}"
+ )
else:
self._logger.warning(f"load_config_style_image: IPAdapter style image not found: {style_image_path}")
except Exception as e:
self._logger.exception(f"load_config_style_image: Error loading config style image: {e}")
-
+
def reset_to_defaults(self):
"""
Reset all input sources to their default states.
@@ -374,30 +389,30 @@ def reset_to_defaults(self):
try:
# Clean up existing sources first
self.cleanup()
-
+
# Reset to default states
self.sources = {
- 'controlnet': {}, # Empty - ControlNets will use fallback to base
- 'ipadapter': None, # Will be re-initialized
- 'base': None # Will be re-initialized
+ "controlnet": {}, # Empty - ControlNets will use fallback to base
+ "ipadapter": None, # Will be re-initialized
+ "base": None, # Will be re-initialized
}
-
+
# Re-initialize defaults
- self.sources['base'] = InputSource(InputSourceType.WEBCAM)
+ self.sources["base"] = InputSource(InputSourceType.WEBCAM)
self._init_default_ipadapter_source()
-
+
self._logger.info("reset_to_defaults: Reset all input sources to defaults")
-
+
except Exception as e:
self._logger.error(f"reset_to_defaults: Error resetting input sources: {e}")
-
+
def cleanup(self):
"""Clean up all sources."""
- for source in self.sources['controlnet'].values():
+ for source in self.sources["controlnet"].values():
source.cleanup()
-
- if self.sources['ipadapter']:
- self.sources['ipadapter'].cleanup()
-
- if self.sources['base']:
- self.sources['base'].cleanup()
+
+ if self.sources["ipadapter"]:
+ self.sources["ipadapter"].cleanup()
+
+ if self.sources["base"]:
+ self.sources["base"].cleanup()
diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py
index b1c609461..aa5fe5873 100644
--- a/demo/realtime-img2img/main.py
+++ b/demo/realtime-img2img/main.py
@@ -1,29 +1,16 @@
-from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, UploadFile, File, Response
-from fastapi.responses import StreamingResponse, JSONResponse
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.staticfiles import StaticFiles
-from fastapi import Request
-
-import markdown2
-
import logging
-import uuid
-import time
-from types import SimpleNamespace
-import asyncio
-import os
-import time
import mimetypes
-import torch
-import tempfile
-from pathlib import Path
-import yaml
+import time
-from config import config, Args
-from util import pil_to_frame, pt_to_frame, bytes_to_pil, bytes_to_pt
-from connection_manager import ConnectionManager, ServerFullException
+import torch
+from config import Args, config
+from connection_manager import ConnectionManager
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.staticfiles import StaticFiles
from img2img import Pipeline
-from input_control import InputManager, GamepadInput
+from input_control import InputManager
+
# fix mime error on windows
mimetypes.add_type("application/javascript", ".js")
@@ -33,87 +20,79 @@
# Import configuration from separate file to avoid circular imports
from app_config import AVAILABLE_CONTROLNETS, DEFAULT_SETTINGS
+
# Configure logging
def setup_logging(log_level: str = "INFO"):
"""Setup logging configuration for the application"""
# Convert string to logging level
numeric_level = getattr(logging, log_level.upper(), logging.INFO)
-
+
# Configure root logger
logging.basicConfig(
- level=numeric_level,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
- datefmt='%Y-%m-%d %H:%M:%S'
+ level=numeric_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
-
+
# Set up logger for streamdiffusion modules
- streamdiffusion_logger = logging.getLogger('streamdiffusion')
+ streamdiffusion_logger = logging.getLogger("streamdiffusion")
streamdiffusion_logger.setLevel(numeric_level)
-
+
# Set up logger for this application
- app_logger = logging.getLogger('realtime_img2img')
+ app_logger = logging.getLogger("realtime_img2img")
app_logger.setLevel(numeric_level)
-
+
return app_logger
+
# Initialize logger
logger = setup_logging(config.log_level)
# Suppress uvicorn INFO messages
if config.quiet:
- uvicorn_logger = logging.getLogger('uvicorn')
+ uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.setLevel(logging.WARNING)
- uvicorn_access_logger = logging.getLogger('uvicorn.access')
+ uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.setLevel(logging.WARNING)
class AppState:
"""Centralized application state management - SINGLE SOURCE OF TRUTH"""
-
+
def __init__(self):
# Pipeline state
self.pipeline_lifecycle = "stopped" # stopped, starting, running, error
self.pipeline_active = False
-
- # Configuration state
+
+ # Configuration state
self.uploaded_config = None # Raw uploaded config
- self.runtime_config = None # Runtime modifications to config
+ self.runtime_config = None # Runtime modifications to config
self.config_needs_reload = False
-
+
# Resolution state
self.current_resolution = {"width": 512, "height": 512}
-
+
# Parameter state (consolidates scattered vars from frontend)
self.pipeline_params = {}
-
+
# ControlNet state - AUTHORITATIVE SOURCE
- self.controlnet_info = {
- "enabled": False,
- "controlnets": []
- }
-
- # IPAdapter state - AUTHORITATIVE SOURCE
- self.ipadapter_info = {
- "enabled": False,
- "has_style_image": False,
- "scale": 1.0,
- "weight_type": "linear"
- }
-
+ self.controlnet_info = {"enabled": False, "controlnets": []}
+
+ # IPAdapter state - AUTHORITATIVE SOURCE
+ self.ipadapter_info = {"enabled": False, "has_style_image": False, "scale": 1.0, "weight_type": "linear"}
+
# Pipeline hooks state - AUTHORITATIVE SOURCE
self.pipeline_hooks = {
"image_preprocessing": {"enabled": False, "processors": []},
"image_postprocessing": {"enabled": False, "processors": []},
"latent_preprocessing": {"enabled": False, "processors": []},
- "latent_postprocessing": {"enabled": False, "processors": []}
+ "latent_postprocessing": {"enabled": False, "processors": []},
}
-
+
# Blending configurations
self.prompt_blending = None
self.seed_blending = None
self.normalize_prompt_weights = True
self.normalize_seed_weights = True
-
+
# Core pipeline parameters
self.guidance_scale = 1.1
self.delta = 0.7
@@ -122,98 +101,92 @@ def __init__(self):
self.t_index_list = [35, 45]
self.negative_prompt = ""
self.skip_diffusion = False
-
+
# UI state
self.fps = 0
self.queue_size = 0
self.model_id = ""
self.page_content = ""
-
+
# Input source state
self.input_sources = {
- 'controlnet': {}, # {index: source_info}
- 'ipadapter': None,
- 'base': None
+ "controlnet": {}, # {index: source_info}
+ "ipadapter": None,
+ "base": None,
}
-
+
# Debug mode state
self.debug_mode = False
self.debug_pending_frame = False # True when a frame step is requested
-
+
def populate_from_config(self, config_data):
"""Populate AppState from uploaded config - SINGLE SOURCE OF TRUTH"""
if not config_data:
return
-
+
logger.info("populate_from_config: Populating AppState from config as single source of truth")
-
+
# Store the complete uploaded config to preserve ALL parameters
self.uploaded_config = config_data
-
+
# Core parameters
- self.guidance_scale = config_data.get('guidance_scale', self.guidance_scale)
- self.delta = config_data.get('delta', self.delta)
- self.num_inference_steps = config_data.get('num_inference_steps', self.num_inference_steps)
- self.seed = config_data.get('seed', self.seed)
- self.t_index_list = config_data.get('t_index_list', self.t_index_list)
- self.negative_prompt = config_data.get('negative_prompt', self.negative_prompt)
- self.skip_diffusion = config_data.get('skip_diffusion', self.skip_diffusion)
- self.model_id = config_data.get('model_id_or_path', self.model_id)
-
+ self.guidance_scale = config_data.get("guidance_scale", self.guidance_scale)
+ self.delta = config_data.get("delta", self.delta)
+ self.num_inference_steps = config_data.get("num_inference_steps", self.num_inference_steps)
+ self.seed = config_data.get("seed", self.seed)
+ self.t_index_list = config_data.get("t_index_list", self.t_index_list)
+ self.negative_prompt = config_data.get("negative_prompt", self.negative_prompt)
+ self.skip_diffusion = config_data.get("skip_diffusion", self.skip_diffusion)
+ self.model_id = config_data.get("model_id_or_path", self.model_id)
+
# Resolution parameters
- if 'width' in config_data or 'height' in config_data:
+ if "width" in config_data or "height" in config_data:
self.current_resolution = {
- "width": config_data.get('width', self.current_resolution["width"]),
- "height": config_data.get('height', self.current_resolution["height"])
+ "width": config_data.get("width", self.current_resolution["width"]),
+ "height": config_data.get("height", self.current_resolution["height"]),
}
-
+
# Normalization settings
- self.normalize_prompt_weights = config_data.get('normalize_weights', self.normalize_prompt_weights)
- self.normalize_seed_weights = config_data.get('normalize_weights', self.normalize_seed_weights)
-
+ self.normalize_prompt_weights = config_data.get("normalize_weights", self.normalize_prompt_weights)
+ self.normalize_seed_weights = config_data.get("normalize_weights", self.normalize_seed_weights)
+
# ControlNet configuration
- if 'controlnets' in config_data:
- self.controlnet_info = {
- "enabled": True,
- "controlnets": []
- }
- for i, controlnet in enumerate(config_data['controlnets']):
+ if "controlnets" in config_data:
+ self.controlnet_info = {"enabled": True, "controlnets": []}
+ for i, controlnet in enumerate(config_data["controlnets"]):
processed = dict(controlnet)
- processed['index'] = i
- processed['name'] = controlnet.get('model_id', '')
- processed['strength'] = controlnet.get('conditioning_scale', 0.0)
+ processed["index"] = i
+ processed["name"] = controlnet.get("model_id", "")
+ processed["strength"] = controlnet.get("conditioning_scale", 0.0)
self.controlnet_info["controlnets"].append(processed)
else:
self.controlnet_info = {"enabled": False, "controlnets": []}
-
+
# IPAdapter configuration
- if config_data.get('use_ipadapter', False):
+ if config_data.get("use_ipadapter", False):
self.ipadapter_info["enabled"] = True
- ipadapters = config_data.get('ipadapters', [])
+ ipadapters = config_data.get("ipadapters", [])
if ipadapters:
first = ipadapters[0]
- self.ipadapter_info["scale"] = first.get('scale', 1.0)
- self.ipadapter_info["weight_type"] = first.get('weight_type', 'linear')
+ self.ipadapter_info["scale"] = first.get("scale", 1.0)
+ self.ipadapter_info["weight_type"] = first.get("weight_type", "linear")
# Store required model paths
- self.ipadapter_info["ipadapter_model_path"] = first.get('ipadapter_model_path')
- self.ipadapter_info["image_encoder_path"] = first.get('image_encoder_path')
- self.ipadapter_info["type"] = first.get('type', 'regular')
- self.ipadapter_info["insightface_model_name"] = first.get('insightface_model_name')
- if first.get('style_image'):
+ self.ipadapter_info["ipadapter_model_path"] = first.get("ipadapter_model_path")
+ self.ipadapter_info["image_encoder_path"] = first.get("image_encoder_path")
+ self.ipadapter_info["type"] = first.get("type", "regular")
+ self.ipadapter_info["insightface_model_name"] = first.get("insightface_model_name")
+ if first.get("style_image"):
self.ipadapter_info["has_style_image"] = True
- self.ipadapter_info["style_image_path"] = first['style_image']
+ self.ipadapter_info["style_image_path"] = first["style_image"]
else:
self.ipadapter_info = {"enabled": False, "has_style_image": False, "scale": 1.0, "weight_type": "linear"}
-
+
# Pipeline hooks configuration
for hook_type in self.pipeline_hooks.keys():
if hook_type in config_data:
hook_config = config_data[hook_type]
if isinstance(hook_config, dict):
- self.pipeline_hooks[hook_type] = {
- "enabled": hook_config.get("enabled", False),
- "processors": []
- }
+ self.pipeline_hooks[hook_type] = {"enabled": hook_config.get("enabled", False), "processors": []}
# Process processors with proper indexing
for index, processor in enumerate(hook_config.get("processors", [])):
if isinstance(processor, dict):
@@ -223,74 +196,74 @@ def populate_from_config(self, config_data):
"type": processor.get("type", "unknown"),
"enabled": processor.get("enabled", False),
"order": processor.get("order", index + 1),
- "params": processor.get("params", {})
+ "params": processor.get("params", {}),
}
self.pipeline_hooks[hook_type]["processors"].append(processed_processor)
else:
self.pipeline_hooks[hook_type] = {"enabled": False, "processors": []}
-
+
# Blending configurations
self.prompt_blending = self._normalize_prompt_config(config_data)
self.seed_blending = self._normalize_seed_config(config_data)
-
+
logger.info("populate_from_config: AppState populated successfully from config")
def _normalize_prompt_config(self, config_data):
"""Normalize prompt configuration to always return a list format"""
if not config_data:
return None
-
+
# Check for explicit prompt_blending first
- if 'prompt_blending' in config_data:
- prompt_blending = config_data['prompt_blending']
- if isinstance(prompt_blending, dict) and 'prompt_list' in prompt_blending:
- prompt_list = prompt_blending['prompt_list']
+ if "prompt_blending" in config_data:
+ prompt_blending = config_data["prompt_blending"]
+ if isinstance(prompt_blending, dict) and "prompt_list" in prompt_blending:
+ prompt_list = prompt_blending["prompt_list"]
if isinstance(prompt_list, list) and len(prompt_list) > 0:
return prompt_list
elif isinstance(prompt_blending, list) and len(prompt_blending) > 0:
return prompt_blending
-
- # Check for direct prompt_list key
- if 'prompt_list' in config_data:
- prompt_list = config_data['prompt_list']
+
+ # Check for direct prompt_list key
+ if "prompt_list" in config_data:
+ prompt_list = config_data["prompt_list"]
if isinstance(prompt_list, list) and len(prompt_list) > 0:
return prompt_list
-
+
# Check for simple prompt key and convert to list format
- if 'prompt' in config_data:
- prompt = config_data['prompt']
+ if "prompt" in config_data:
+ prompt = config_data["prompt"]
if prompt and isinstance(prompt, str):
return [(prompt, 1.0)]
-
+
return None
def _normalize_seed_config(self, config_data):
"""Normalize seed configuration to always return a list format"""
if not config_data:
return None
-
+
# Check for explicit seed_blending first
- if 'seed_blending' in config_data:
- seed_blending = config_data['seed_blending']
- if isinstance(seed_blending, dict) and 'seed_list' in seed_blending:
- seed_list = seed_blending['seed_list']
+ if "seed_blending" in config_data:
+ seed_blending = config_data["seed_blending"]
+ if isinstance(seed_blending, dict) and "seed_list" in seed_blending:
+ seed_list = seed_blending["seed_list"]
if isinstance(seed_list, list) and len(seed_list) > 0:
return seed_list
elif isinstance(seed_blending, list) and len(seed_blending) > 0:
return seed_blending
-
- # Check for direct seed_list key
- if 'seed_list' in config_data:
- seed_list = config_data['seed_list']
+
+ # Check for direct seed_list key
+ if "seed_list" in config_data:
+ seed_list = config_data["seed_list"]
if isinstance(seed_list, list) and len(seed_list) > 0:
return seed_list
-
+
# Check for simple seed key and convert to list format
- if 'seed' in config_data:
- seed = config_data['seed']
+ if "seed" in config_data:
+ seed = config_data["seed"]
if seed is not None and isinstance(seed, (int, float)):
return [(int(seed), 1.0)]
-
+
return None
def get_complete_state(self):
@@ -299,13 +272,10 @@ def get_complete_state(self):
# Pipeline state
"pipeline_active": self.pipeline_active,
"pipeline_lifecycle": self.pipeline_lifecycle,
-
# Configuration
"config_needs_reload": self.config_needs_reload,
-
# Resolution
"current_resolution": self.current_resolution,
-
# Parameters
"pipeline_params": self.pipeline_params,
"controlnet": self.controlnet_info,
@@ -314,7 +284,6 @@ def get_complete_state(self):
"seed_blending": self.seed_blending,
"normalize_prompt_weights": self.normalize_prompt_weights,
"normalize_seed_weights": self.normalize_seed_weights,
-
# Core parameters
"guidance_scale": self.guidance_scale,
"delta": self.delta,
@@ -323,27 +292,23 @@ def get_complete_state(self):
"t_index_list": self.t_index_list,
"negative_prompt": self.negative_prompt,
"skip_diffusion": self.skip_diffusion,
-
# UI state
"fps": self.fps,
"queue_size": self.queue_size,
"model_id": self.model_id,
"page_content": self.page_content,
-
# Input sources
"input_sources": self.input_sources,
-
# Debug mode state
"debug_mode": self.debug_mode,
"debug_pending_frame": self.debug_pending_frame,
-
# Pipeline hooks - AUTHORITATIVE SOURCE
"image_preprocessing": self.pipeline_hooks["image_preprocessing"],
"image_postprocessing": self.pipeline_hooks["image_postprocessing"],
"latent_preprocessing": self.pipeline_hooks["latent_preprocessing"],
"latent_postprocessing": self.pipeline_hooks["latent_postprocessing"],
}
-
+
def update_controlnet_strength(self, index: int, strength: float):
"""Update ControlNet strength in AppState - SINGLE SOURCE OF TRUTH"""
if index < len(self.controlnet_info["controlnets"]):
@@ -352,32 +317,32 @@ def update_controlnet_strength(self, index: int, strength: float):
logger.debug(f"update_controlnet_strength: Updated ControlNet {index} strength to {strength}")
else:
logger.warning(f"update_controlnet_strength: ControlNet index {index} out of range")
-
+
def add_controlnet(self, controlnet_config: dict):
"""Add ControlNet to AppState - SINGLE SOURCE OF TRUTH"""
index = len(self.controlnet_info["controlnets"])
processed = dict(controlnet_config)
- processed['index'] = index
- processed['name'] = controlnet_config.get('model_id', '')
- processed['strength'] = controlnet_config.get('conditioning_scale', 0.0)
-
+ processed["index"] = index
+ processed["name"] = controlnet_config.get("model_id", "")
+ processed["strength"] = controlnet_config.get("conditioning_scale", 0.0)
+
self.controlnet_info["controlnets"].append(processed)
self.controlnet_info["enabled"] = True
logger.debug(f"add_controlnet: Added ControlNet at index {index}")
-
+
def remove_controlnet(self, index: int):
"""Remove ControlNet from AppState - SINGLE SOURCE OF TRUTH"""
if index < len(self.controlnet_info["controlnets"]):
removed = self.controlnet_info["controlnets"].pop(index)
# Re-index remaining controlnets
for i, controlnet in enumerate(self.controlnet_info["controlnets"]):
- controlnet['index'] = i
+ controlnet["index"] = i
if not self.controlnet_info["controlnets"]:
self.controlnet_info["enabled"] = False
logger.debug(f"remove_controlnet: Removed ControlNet at index {index}")
else:
logger.warning(f"remove_controlnet: ControlNet index {index} out of range")
-
+
def update_hook_processor(self, hook_type: str, processor_index: int, updates: dict):
"""Update pipeline hook processor in AppState - SINGLE SOURCE OF TRUTH"""
if hook_type in self.pipeline_hooks:
@@ -386,10 +351,12 @@ def update_hook_processor(self, hook_type: str, processor_index: int, updates: d
processors[processor_index].update(updates)
logger.debug(f"update_hook_processor: Updated {hook_type} processor {processor_index}")
else:
- logger.warning(f"update_hook_processor: Processor index {processor_index} out of range for {hook_type}")
+ logger.warning(
+ f"update_hook_processor: Processor index {processor_index} out of range for {hook_type}"
+ )
else:
logger.warning(f"update_hook_processor: Unknown hook type {hook_type}")
-
+
def add_hook_processor(self, hook_type: str, processor_config: dict):
"""Add pipeline hook processor to AppState - SINGLE SOURCE OF TRUTH"""
if hook_type in self.pipeline_hooks:
@@ -400,14 +367,14 @@ def add_hook_processor(self, hook_type: str, processor_config: dict):
"type": processor_config.get("type", "unknown"),
"enabled": processor_config.get("enabled", True),
"order": processor_config.get("order", index + 1),
- "params": processor_config.get("params", {})
+ "params": processor_config.get("params", {}),
}
self.pipeline_hooks[hook_type]["processors"].append(processed)
self.pipeline_hooks[hook_type]["enabled"] = True
logger.debug(f"add_hook_processor: Added {hook_type} processor at index {index}")
else:
logger.warning(f"add_hook_processor: Unknown hook type {hook_type}")
-
+
def remove_hook_processor(self, hook_type: str, processor_index: int):
"""Remove pipeline hook processor from AppState - SINGLE SOURCE OF TRUTH"""
if hook_type in self.pipeline_hooks:
@@ -416,191 +383,204 @@ def remove_hook_processor(self, hook_type: str, processor_index: int):
removed = processors.pop(processor_index)
# Re-index remaining processors
for i, processor in enumerate(processors):
- processor['index'] = i
+ processor["index"] = i
if not processors:
self.pipeline_hooks[hook_type]["enabled"] = False
logger.debug(f"remove_hook_processor: Removed {hook_type} processor at index {processor_index}")
else:
- logger.warning(f"remove_hook_processor: Processor index {processor_index} out of range for {hook_type}")
+ logger.warning(
+ f"remove_hook_processor: Processor index {processor_index} out of range for {hook_type}"
+ )
else:
logger.warning(f"remove_hook_processor: Unknown hook type {hook_type}")
def update_parameter(self, parameter_name: str, value: float):
"""Update a single parameter in AppState - UNIFIED PARAMETER UPDATE"""
logger.debug(f"update_parameter: Updating {parameter_name} = {value}")
-
+
# Core pipeline parameters
- if parameter_name == 'guidance_scale':
+ if parameter_name == "guidance_scale":
self.guidance_scale = float(value)
- elif parameter_name == 'delta':
+ elif parameter_name == "delta":
self.delta = float(value)
- elif parameter_name == 'num_inference_steps':
+ elif parameter_name == "num_inference_steps":
self.num_inference_steps = int(value)
- elif parameter_name == 'seed':
+ elif parameter_name == "seed":
self.seed = int(value)
- elif parameter_name == 'negative_prompt':
+ elif parameter_name == "negative_prompt":
self.negative_prompt = str(value)
- elif parameter_name == 'skip_diffusion':
+ elif parameter_name == "skip_diffusion":
self.skip_diffusion = bool(value)
- elif parameter_name == 't_index_list':
+ elif parameter_name == "t_index_list":
if isinstance(value, list):
self.t_index_list = value
else:
logger.warning(f"update_parameter: t_index_list must be a list, got {type(value)}")
-
+
# IPAdapter parameters
- elif parameter_name == 'ipadapter_scale':
+ elif parameter_name == "ipadapter_scale":
self.ipadapter_info["scale"] = float(value)
- elif parameter_name == 'ipadapter_weight_type':
+ elif parameter_name == "ipadapter_weight_type":
# Convert numeric value to weight type string
- weight_types = ["linear", "ease in", "ease out", "ease in-out", "reverse in-out",
- "weak input", "weak output", "weak middle", "strong middle",
- "style transfer", "composition", "strong style transfer",
- "style and composition", "style transfer precise", "composition precise"]
+ weight_types = [
+ "linear",
+ "ease in",
+ "ease out",
+ "ease in-out",
+ "reverse in-out",
+ "weak input",
+ "weak output",
+ "weak middle",
+ "strong middle",
+ "style transfer",
+ "composition",
+ "strong style transfer",
+ "style and composition",
+ "style transfer precise",
+ "composition precise",
+ ]
index = int(value) % len(weight_types)
self.ipadapter_info["weight_type"] = weight_types[index]
-
+
# ControlNet strength parameters
- elif parameter_name.startswith('controlnet_') and parameter_name.endswith('_strength'):
+ elif parameter_name.startswith("controlnet_") and parameter_name.endswith("_strength"):
import re
- match = re.match(r'controlnet_(\d+)_strength', parameter_name)
+
+ match = re.match(r"controlnet_(\d+)_strength", parameter_name)
if match:
index = int(match.group(1))
self.update_controlnet_strength(index, float(value))
-
+
# ControlNet preprocessor parameters
- elif parameter_name.startswith('controlnet_') and '_preprocessor_' in parameter_name:
+ elif parameter_name.startswith("controlnet_") and "_preprocessor_" in parameter_name:
import re
- match = re.match(r'controlnet_(\d+)_preprocessor_(.+)', parameter_name)
+
+ match = re.match(r"controlnet_(\d+)_preprocessor_(.+)", parameter_name)
if match:
controlnet_index = int(match.group(1))
param_name = match.group(2)
if controlnet_index < len(self.controlnet_info["controlnets"]):
controlnet = self.controlnet_info["controlnets"][controlnet_index]
- if 'preprocessor_params' not in controlnet:
- controlnet['preprocessor_params'] = {}
- controlnet['preprocessor_params'][param_name] = value
-
+ if "preprocessor_params" not in controlnet:
+ controlnet["preprocessor_params"] = {}
+ controlnet["preprocessor_params"][param_name] = value
+
# Prompt blending weights
- elif parameter_name.startswith('prompt_weight_'):
+ elif parameter_name.startswith("prompt_weight_"):
import re
- match = re.match(r'prompt_weight_(\d+)', parameter_name)
+
+ match = re.match(r"prompt_weight_(\d+)", parameter_name)
if match:
index = int(match.group(1))
if self.prompt_blending and index < len(self.prompt_blending):
# Update weight in prompt blending list
prompt_text = self.prompt_blending[index][0]
self.prompt_blending[index] = (prompt_text, float(value))
-
+
# Seed blending weights
- elif parameter_name.startswith('seed_weight_'):
+ elif parameter_name.startswith("seed_weight_"):
import re
- match = re.match(r'seed_weight_(\d+)', parameter_name)
+
+ match = re.match(r"seed_weight_(\d+)", parameter_name)
if match:
index = int(match.group(1))
if self.seed_blending and index < len(self.seed_blending):
# Update weight in seed blending list
seed_value = self.seed_blending[index][0]
self.seed_blending[index] = (seed_value, float(value))
-
+
else:
logger.warning(f"update_parameter: Unknown parameter {parameter_name}")
return
-
+
logger.debug(f"update_parameter: Successfully updated {parameter_name} in AppState")
def generate_pipeline_config(self):
"""Generate pipeline configuration from AppState - PRESERVES ALL ORIGINAL CONFIG"""
- logger.info("generate_pipeline_config: Generating pipeline config from AppState, preserving all original config")
-
+ logger.info(
+ "generate_pipeline_config: Generating pipeline config from AppState, preserving all original config"
+ )
+
# Start with complete original config to preserve ALL parameters
config = {}
if self.uploaded_config:
config = dict(self.uploaded_config)
-
+
# Only override runtime-changeable parameters from AppState
- config.update({
- 'guidance_scale': self.guidance_scale,
- 'delta': self.delta,
- 'num_inference_steps': self.num_inference_steps,
- 'seed': self.seed,
- 't_index_list': self.t_index_list,
- 'negative_prompt': self.negative_prompt,
- 'skip_diffusion': self.skip_diffusion,
- 'width': self.current_resolution["width"],
- 'height': self.current_resolution["height"],
- 'output_type': 'pt', # Force optimal tensor performance
- })
-
+ config.update(
+ {
+ "guidance_scale": self.guidance_scale,
+ "delta": self.delta,
+ "num_inference_steps": self.num_inference_steps,
+ "seed": self.seed,
+ "t_index_list": self.t_index_list,
+ "negative_prompt": self.negative_prompt,
+ "skip_diffusion": self.skip_diffusion,
+ "width": self.current_resolution["width"],
+ "height": self.current_resolution["height"],
+ "output_type": "pt", # Force optimal tensor performance
+ }
+ )
+
# Update ControlNet configurations with current AppState values
if self.controlnet_info["enabled"] and self.controlnet_info["controlnets"]:
- config['controlnets'] = []
+ config["controlnets"] = []
for controlnet in self.controlnet_info["controlnets"]:
cn_config = dict(controlnet)
# Ensure conditioning_scale reflects current strength
- cn_config['conditioning_scale'] = controlnet.get('strength', controlnet.get('conditioning_scale', 0.0))
- config['controlnets'].append(cn_config)
- elif 'controlnets' in config:
+ cn_config["conditioning_scale"] = controlnet.get("strength", controlnet.get("conditioning_scale", 0.0))
+ config["controlnets"].append(cn_config)
+ elif "controlnets" in config:
# Remove controlnets if disabled
- del config['controlnets']
-
+ del config["controlnets"]
+
# Update IPAdapter configurations with current AppState values
if self.ipadapter_info["enabled"]:
- config['use_ipadapter'] = True
+ config["use_ipadapter"] = True
# Preserve original ipadapters config but update runtime values
- if 'ipadapters' in config and config['ipadapters']:
+ if "ipadapters" in config and config["ipadapters"]:
# Update existing config with current values
- config['ipadapters'][0].update({
- 'scale': self.ipadapter_info["scale"],
- 'weight_type': self.ipadapter_info["weight_type"]
- })
+ config["ipadapters"][0].update(
+ {"scale": self.ipadapter_info["scale"], "weight_type": self.ipadapter_info["weight_type"]}
+ )
# Add style image if available
if self.ipadapter_info.get("has_style_image") and self.ipadapter_info.get("style_image_path"):
- config['ipadapters'][0]['style_image'] = self.ipadapter_info["style_image_path"]
- elif 'use_ipadapter' in config:
+ config["ipadapters"][0]["style_image"] = self.ipadapter_info["style_image_path"]
+ elif "use_ipadapter" in config:
# Disable IPAdapter if not enabled in AppState
- config['use_ipadapter'] = False
-
+ config["use_ipadapter"] = False
+
# Update pipeline hooks with current AppState values
for hook_type, hook_config in self.pipeline_hooks.items():
if hook_config["enabled"] and hook_config["processors"]:
- config[hook_type] = {
- "enabled": True,
- "processors": []
- }
+ config[hook_type] = {"enabled": True, "processors": []}
for processor in hook_config["processors"]:
proc_config = {
"type": processor["type"],
"enabled": processor["enabled"],
"order": processor["order"],
- "params": processor["params"]
+ "params": processor["params"],
}
config[hook_type]["processors"].append(proc_config)
elif hook_type in config:
# Disable hook if not enabled in AppState
config[hook_type] = {"enabled": False, "processors": []}
-
+
# Update blending configurations with current AppState values
if self.prompt_blending:
- config['prompt_blending'] = {
- 'prompt_list': self.prompt_blending,
- 'interpolation_method': 'slerp'
- }
- config['normalize_weights'] = self.normalize_prompt_weights
- elif 'prompt_blending' in config:
- del config['prompt_blending']
-
+ config["prompt_blending"] = {"prompt_list": self.prompt_blending, "interpolation_method": "slerp"}
+ config["normalize_weights"] = self.normalize_prompt_weights
+ elif "prompt_blending" in config:
+ del config["prompt_blending"]
+
if self.seed_blending:
- config['seed_blending'] = {
- 'seed_list': self.seed_blending,
- 'interpolation_method': 'linear'
- }
+ config["seed_blending"] = {"seed_list": self.seed_blending, "interpolation_method": "linear"}
# Note: seed normalization uses same normalize_weights key
if not self.prompt_blending: # Only set if not already set by prompt blending
- config['normalize_weights'] = self.normalize_seed_weights
- elif 'seed_blending' in config:
- del config['seed_blending']
-
+ config["normalize_weights"] = self.normalize_seed_weights
+ elif "seed_blending" in config:
+ del config["seed_blending"]
+
logger.info("generate_pipeline_config: Generated pipeline config preserving all original parameters")
return config
@@ -612,7 +592,6 @@ def update_state(self, updates):
logger.debug(f"AppState update_state: Updated {key} = {value}")
else:
logger.warning(f"AppState update_state: Unknown state key: {key}")
-
class App:
@@ -623,46 +602,50 @@ def __init__(self, config: Args):
self.conn_manager = ConnectionManager()
self.fps_counter = []
self.last_fps_update = time.time()
-
+
# Centralized state management
self.app_state = AppState()
-
+
# Initialize input manager for controller support
self.input_manager = InputManager()
# Initialize input source manager for modular input routing
from input_sources import InputSourceManager
+
self.input_source_manager = InputSourceManager()
-
+
# Preemptively initialize input sources to avoid config upload delay
self._preload_input_sources()
-
+
self.init_app()
def _preload_input_sources(self):
"""Preemptively initialize input sources and preprocessors to avoid delays during config upload"""
try:
- logger.info("_preload_input_sources: Preemptively initializing input sources and preprocessors to avoid config upload delay")
-
+ logger.info(
+ "_preload_input_sources: Preemptively initializing input sources and preprocessors to avoid config upload delay"
+ )
+
# Preload base input source
- self.input_source_manager.get_source_info('base')
-
+ self.input_source_manager.get_source_info("base")
+
# Preload IPAdapter input source
- self.input_source_manager.get_source_info('ipadapter')
-
+ self.input_source_manager.get_source_info("ipadapter")
+
# Preload potential ControlNet input sources (up to 5)
for i in range(5):
- self.input_source_manager.get_source_info('controlnet', index=i)
-
+ self.input_source_manager.get_source_info("controlnet", index=i)
+
# Preload preprocessors to trigger controlnet_aux imports
# This is what causes the delay - the first time a preprocessor is accessed,
# all the controlnet_aux modules get imported
logger.info("_preload_input_sources: Triggering preprocessor imports...")
try:
- from streamdiffusion.preprocessing.processors import list_preprocessors, get_preprocessor_class
+ from streamdiffusion.preprocessing.processors import get_preprocessor_class, list_preprocessors
+
# List all available preprocessors - this triggers the lazy imports
available = list_preprocessors()
logger.info(f"_preload_input_sources: Found {len(available)} preprocessors, loading metadata...")
-
+
# Access at least one preprocessor class to ensure all imports complete
if available:
for processor_name in available[:3]: # Load first 3 to trigger most imports
@@ -670,15 +653,15 @@ def _preload_input_sources(self):
_ = get_preprocessor_class(processor_name)
except Exception as e:
logger.debug(f"_preload_input_sources: Could not load {processor_name}: {e}")
-
+
logger.info("_preload_input_sources: Preprocessor imports completed")
except Exception as prep_error:
logger.warning(f"_preload_input_sources: Could not preload preprocessors: {prep_error}")
-
+
logger.info("_preload_input_sources: Input sources and preprocessors preloaded successfully")
except Exception as e:
logger.error(f"_preload_input_sources: Error during preload: {e}")
-
+
def cleanup(self):
"""Cleanup resources when app is shutting down"""
logger.info("App cleanup: Starting application cleanup...")
@@ -687,7 +670,7 @@ def cleanup(self):
self._cleanup_pipeline(self.pipeline)
self.pipeline = None
self.app_state.pipeline_lifecycle = "stopped"
- if hasattr(self, 'input_source_manager'):
+ if hasattr(self, "input_source_manager"):
self.input_source_manager.cleanup()
self._cleanup_temp_files()
logger.info("App cleanup: Completed application cleanup")
@@ -696,15 +679,17 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N
"""Handle parameter updates from input controls - UNIFIED THROUGH APPSTATE"""
try:
logger.debug(f"_handle_input_parameter_update: Updating {parameter_name} = {value} via AppState")
-
+
# Update AppState as single source of truth
self.app_state.update_parameter(parameter_name, value)
-
+
# Sync to pipeline if active (for real-time updates)
- if self.pipeline and hasattr(self.pipeline, 'stream'):
+ if self.pipeline and hasattr(self.pipeline, "stream"):
self._sync_appstate_to_pipeline()
else:
- logger.debug(f"_handle_input_parameter_update: No active pipeline, parameter stored in AppState for next pipeline creation")
+ logger.debug(
+ "_handle_input_parameter_update: No active pipeline, parameter stored in AppState for next pipeline creation"
+ )
except Exception as e:
logger.exception(f"_handle_input_parameter_update: Failed to update {parameter_name}: {e}")
@@ -712,36 +697,36 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N
def _update_resolution(self, width: int, height: int) -> None:
"""Update resolution by recreating pipeline with new dimensions"""
logger.info(f"_update_resolution: Updating resolution to {width}x{height}")
-
+
# Update AppState first
self.app_state.current_resolution = {"width": width, "height": height}
-
+
# If no pipeline exists, just update state (will be used when pipeline is created)
if not self.pipeline:
logger.info("_update_resolution: No pipeline exists, resolution will apply on next pipeline creation")
return
-
+
# Set pipeline lifecycle state
self.app_state.pipeline_lifecycle = "restarting"
-
+
# Store reference to old pipeline for cleanup
old_pipeline = self.pipeline
-
+
# Clear current pipeline reference before cleanup
self.pipeline = None
-
+
# Cleanup old pipeline and free VRAM
if old_pipeline:
self._cleanup_pipeline(old_pipeline)
old_pipeline = None
-
+
# Create new pipeline with new resolution
# No state restoration needed - _create_pipeline() uses AppState as single source of truth
try:
self.pipeline = self._create_pipeline()
self.app_state.pipeline_lifecycle = "running"
logger.info(f"_update_resolution: Pipeline successfully recreated with resolution {width}x{height}")
-
+
except Exception as e:
self.app_state.pipeline_lifecycle = "error"
logger.error(f"_update_resolution: Failed to recreate pipeline: {e}")
@@ -750,9 +735,9 @@ def _update_resolution(self, width: int, height: int) -> None:
def _sync_appstate_to_pipeline(self):
"""Sync AppState parameters to active pipeline for real-time updates"""
try:
- if not self.pipeline or not hasattr(self.pipeline, 'stream'):
+ if not self.pipeline or not hasattr(self.pipeline, "stream"):
return
-
+
# Core parameters
self.pipeline.update_stream_params(
guidance_scale=self.app_state.guidance_scale,
@@ -760,64 +745,57 @@ def _sync_appstate_to_pipeline(self):
num_inference_steps=self.app_state.num_inference_steps,
seed=self.app_state.seed,
negative_prompt=self.app_state.negative_prompt,
- t_index_list=self.app_state.t_index_list
+ t_index_list=self.app_state.t_index_list,
)
-
+
# IPAdapter parameters
if self.app_state.ipadapter_info["enabled"]:
- self.pipeline.update_stream_params(ipadapter_config={
- 'scale': self.app_state.ipadapter_info["scale"]
- })
- if hasattr(self.pipeline, 'update_ipadapter_weight_type'):
+ self.pipeline.update_stream_params(ipadapter_config={"scale": self.app_state.ipadapter_info["scale"]})
+ if hasattr(self.pipeline, "update_ipadapter_weight_type"):
self.pipeline.update_ipadapter_weight_type(self.app_state.ipadapter_info["weight_type"])
-
+
# ControlNet parameters
if self.app_state.controlnet_info["enabled"] and self.app_state.controlnet_info["controlnets"]:
controlnet_config = []
for cn in self.app_state.controlnet_info["controlnets"]:
config_entry = dict(cn)
- config_entry['conditioning_scale'] = cn['strength']
+ config_entry["conditioning_scale"] = cn["strength"]
controlnet_config.append(config_entry)
self.pipeline.update_stream_params(controlnet_config=controlnet_config)
-
+
# Prompt blending
if self.app_state.prompt_blending:
self.pipeline.update_stream_params(prompt_list=self.app_state.prompt_blending)
-
+
# Seed blending
if self.app_state.seed_blending:
self.pipeline.update_stream_params(seed_list=self.app_state.seed_blending)
-
+
logger.debug("_sync_appstate_to_pipeline: Successfully synced AppState to pipeline")
-
+
except Exception as e:
logger.exception(f"_sync_appstate_to_pipeline: Failed to sync AppState to pipeline: {e}")
-
-
-
-
def _get_controlnet_pipeline(self):
"""Get the ControlNet pipeline from the main pipeline structure"""
if not self.pipeline:
return None
-
+
stream = self.pipeline.stream
-
+
# Module-aware: module installs expose preprocessors on stream
- if hasattr(stream, 'preprocessors'):
+ if hasattr(stream, "preprocessors"):
return stream
-
+
# Check if stream has nested stream (IPAdapter wrapper)
- if hasattr(stream, 'stream') and hasattr(stream.stream, 'preprocessors'):
+ if hasattr(stream, "stream") and hasattr(stream.stream, "preprocessors"):
return stream.stream
-
+
# New module path on stream
- if hasattr(stream, '_controlnet_module'):
+ if hasattr(stream, "_controlnet_module"):
return stream._controlnet_module
return None
-
def init_app(self):
# Enhanced CORS for API-only development mode
if self.args.api_only:
@@ -844,74 +822,95 @@ def init_app(self):
# Register route modules
self._register_routes()
-
+
def _register_routes(self):
"""Register all route modules with dependency injection"""
- from routes import parameters, controlnet, ipadapter, inference, pipeline_hooks, websocket, input_sources, debug
- from routes.common.dependencies import get_app_instance as shared_get_app_instance, get_pipeline_class as shared_get_pipeline_class, get_default_settings as shared_get_default_settings, get_available_controlnets as shared_get_available_controlnets
-
+ from routes import (
+ controlnet,
+ debug,
+ inference,
+ input_sources,
+ ipadapter,
+ parameters,
+ pipeline_hooks,
+ websocket,
+ )
+ from routes.common.dependencies import get_app_instance as shared_get_app_instance
+ from routes.common.dependencies import get_available_controlnets as shared_get_available_controlnets
+ from routes.common.dependencies import get_default_settings as shared_get_default_settings
+ from routes.common.dependencies import get_pipeline_class as shared_get_pipeline_class
+
# Create dependency overrides to inject app instance and other dependencies
def get_app_instance():
return self
-
+
def get_pipeline_class():
return Pipeline
-
+
def get_default_settings():
return DEFAULT_SETTINGS
-
+
def get_available_controlnets():
return AVAILABLE_CONTROLNETS
-
+
# Include routers and set up dependency overrides on the main app
- for router_module in [parameters, controlnet, ipadapter, inference, pipeline_hooks, websocket, input_sources, debug]:
+ for router_module in [
+ parameters,
+ controlnet,
+ ipadapter,
+ inference,
+ pipeline_hooks,
+ websocket,
+ input_sources,
+ debug,
+ ]:
# Include the router
self.app.include_router(router_module.router)
-
+
# Set up dependency overrides on the main app (not individual routers)
self.app.dependency_overrides[shared_get_app_instance] = get_app_instance
self.app.dependency_overrides[shared_get_pipeline_class] = get_pipeline_class
self.app.dependency_overrides[shared_get_default_settings] = get_default_settings
self.app.dependency_overrides[shared_get_available_controlnets] = get_available_controlnets
-
+
# Set up static files if not in API-only mode
if not self.args.api_only:
self.app.mount("/", StaticFiles(directory="frontend/public", html=True), name="public")
-
def _create_pipeline(self):
"""Create pipeline using AppState as single source of truth"""
logger.info("_create_pipeline: Creating pipeline using AppState as single source of truth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16
-
+
# Generate pipeline config from AppState - SINGLE SOURCE OF TRUTH
pipeline_config = self.app_state.generate_pipeline_config()
-
+
# Load config style images into InputSourceManager before creating pipeline
self._load_config_style_images()
-
+
# Create wrapper using the unified config - THIS IS NOW THE SINGLE PLACE WHERE WRAPPER IS CREATED
from src.streamdiffusion.config import create_wrapper_from_config
-
+
# Create wrapper using the unified config
wrapper = create_wrapper_from_config(pipeline_config)
-
+
# Update args with config values before passing to Pipeline
from config import Args
+
args_dict = self.args._asdict()
- if 'acceleration' in pipeline_config:
- args_dict['acceleration'] = pipeline_config['acceleration']
- if 'engine_dir' in pipeline_config:
- args_dict['engine_dir'] = pipeline_config['engine_dir']
- if 'use_safety_checker' in pipeline_config:
- args_dict['safety_checker'] = pipeline_config['use_safety_checker']
-
+ if "acceleration" in pipeline_config:
+ args_dict["acceleration"] = pipeline_config["acceleration"]
+ if "engine_dir" in pipeline_config:
+ args_dict["engine_dir"] = pipeline_config["engine_dir"]
+ if "use_safety_checker" in pipeline_config:
+ args_dict["safety_checker"] = pipeline_config["use_safety_checker"]
+
updated_args = Args(**args_dict)
-
+
# Create Pipeline instance with the pre-created wrapper and config
pipeline = Pipeline(wrapper=wrapper, config=pipeline_config)
-
+
logger.info("_create_pipeline: Pipeline created successfully with pre-created wrapper")
return pipeline
@@ -919,24 +918,25 @@ def _load_config_style_images(self):
"""Load style images from config into InputSourceManager"""
if not self.app_state.uploaded_config:
return
-
+
try:
# Load IPAdapter style images from config
- ipadapters = self.app_state.uploaded_config.get('ipadapters', [])
+ ipadapters = self.app_state.uploaded_config.get("ipadapters", [])
if ipadapters:
first_ipadapter = ipadapters[0]
- style_image_path = first_ipadapter.get('style_image')
+ style_image_path = first_ipadapter.get("style_image")
if style_image_path:
# Use the config file path as base for relative paths
- base_config_path = getattr(self.args, 'controlnet_config', None)
+ base_config_path = getattr(self.args, "controlnet_config", None)
self.input_source_manager.load_config_style_image(style_image_path, base_config_path)
except Exception as e:
logging.exception(f"_load_config_style_images: Error loading config style images: {e}")
def _cleanup_temp_files(self):
"""Clean up any temporary config files"""
- if hasattr(self, '_temp_config_files'):
+ if hasattr(self, "_temp_config_files"):
import os
+
for temp_path in self._temp_config_files:
try:
if os.path.exists(temp_path):
@@ -945,26 +945,24 @@ def _cleanup_temp_files(self):
pass
self._temp_config_files.clear()
-
def _calculate_aspect_ratio(self, width: int, height: int) -> str:
"""Calculate and return aspect ratio as a string"""
- import math
-
+
def gcd(a, b):
while b:
a, b = b, a % b
return a
-
+
ratio_gcd = gcd(width, height)
- return f"{width//ratio_gcd}:{height//ratio_gcd}"
+ return f"{width // ratio_gcd}:{height // ratio_gcd}"
def _cleanup_pipeline(self, pipeline):
"""Properly cleanup a pipeline and free VRAM"""
if pipeline is None:
return
-
+
try:
- if hasattr(pipeline, 'cleanup'):
+ if hasattr(pipeline, "cleanup"):
pipeline.cleanup()
del pipeline
torch.cuda.empty_cache()
@@ -984,4 +982,4 @@ def _cleanup_pipeline(self, pipeline):
reload=config.reload,
ssl_certfile=config.ssl_certfile,
ssl_keyfile=config.ssl_keyfile,
- )
\ No newline at end of file
+ )
diff --git a/demo/realtime-img2img/routes/__init__.py b/demo/realtime-img2img/routes/__init__.py
index 713519c91..cd671cfb6 100644
--- a/demo/realtime-img2img/routes/__init__.py
+++ b/demo/realtime-img2img/routes/__init__.py
@@ -1,4 +1,3 @@
"""
Routes package for realtime-img2img API endpoints
"""
-
diff --git a/demo/realtime-img2img/routes/common/api_utils.py b/demo/realtime-img2img/routes/common/api_utils.py
index a1ade9fff..0a5425774 100644
--- a/demo/realtime-img2img/routes/common/api_utils.py
+++ b/demo/realtime-img2img/routes/common/api_utils.py
@@ -3,46 +3,43 @@
"""
import logging
+from typing import Any, Dict, Optional
+
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
-from typing import Any, Dict, Optional
async def handle_api_request(
- request: Request,
- operation_name: str,
- required_params: list = None,
- pipeline_required: bool = True
+ request: Request, operation_name: str, required_params: list = None, pipeline_required: bool = True
) -> Dict[str, Any]:
"""
Standard request handler for API endpoints
-
+
Args:
request: FastAPI request object
operation_name: Name of the operation for logging
required_params: List of required parameter names
pipeline_required: Whether an active pipeline is required
-
+
Returns:
Parsed JSON data from request
-
+
Raises:
HTTPException: For validation errors
"""
try:
data = await request.json()
-
+
# Check required parameters
if required_params:
missing_params = [param for param in required_params if param not in data]
if missing_params:
raise HTTPException(
- status_code=400,
- detail=f"Missing required parameters: {', '.join(missing_params)}"
+ status_code=400, detail=f"Missing required parameters: {', '.join(missing_params)}"
)
-
+
return data
-
+
except Exception as e:
logging.exception(f"{operation_name}: Failed to parse request: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
@@ -51,18 +48,15 @@ async def handle_api_request(
def create_success_response(message: str, **extra_data) -> JSONResponse:
"""
Create a standardized success response
-
+
Args:
message: Success message
**extra_data: Additional data to include in response
-
+
Returns:
JSONResponse with standardized format
"""
- response_data = {
- "status": "success",
- "message": message
- }
+ response_data = {"status": "success", "message": message}
response_data.update(extra_data)
return JSONResponse(response_data)
@@ -70,84 +64,71 @@ def create_success_response(message: str, **extra_data) -> JSONResponse:
def handle_api_error(error: Exception, operation_name: str, status_code: int = 500) -> HTTPException:
"""
Standard error handler for API endpoints
-
+
Args:
error: The caught exception
operation_name: Name of the operation for logging
status_code: HTTP status code to return
-
+
Returns:
HTTPException with standardized error message
"""
logging.error(f"{operation_name}: Failed: {error}")
- return HTTPException(
- status_code=status_code,
- detail=f"Failed to {operation_name.lower()}: {str(error)}"
- )
+ return HTTPException(status_code=status_code, detail=f"Failed to {operation_name.lower()}: {str(error)}")
def validate_pipeline(pipeline: Any, operation_name: str) -> None:
"""
Validate that pipeline exists and is initialized
-
+
Args:
pipeline: Pipeline object to validate
operation_name: Name of the operation for error messages
-
+
Raises:
HTTPException: If pipeline is not valid
"""
logging.info(f"validate_pipeline: {operation_name} - pipeline is: {pipeline is not None}")
if not pipeline:
logging.error(f"validate_pipeline: {operation_name} - Pipeline is not initialized")
- raise HTTPException(
- status_code=400,
- detail="Pipeline is not initialized"
- )
+ raise HTTPException(status_code=400, detail="Pipeline is not initialized")
def validate_feature_enabled(pipeline: Any, feature_name: str, feature_check_attr: str) -> None:
"""
Validate that a specific feature is enabled in the pipeline
-
+
Args:
pipeline: Pipeline object
feature_name: Human-readable feature name (e.g., "ControlNet", "IPAdapter")
feature_check_attr: Attribute name to check (e.g., "has_controlnet", "has_ipadapter")
-
+
Raises:
HTTPException: If feature is not enabled
"""
if not getattr(pipeline, feature_check_attr, False):
- raise HTTPException(
- status_code=400,
- detail=f"{feature_name} is not enabled"
- )
+ raise HTTPException(status_code=400, detail=f"{feature_name} is not enabled")
def validate_config_mode(pipeline: Any, config_check: Optional[str] = None) -> None:
"""
Validate that pipeline is using config mode
-
+
Args:
pipeline: Pipeline object
config_check: Optional specific config key to check for existence
-
+
Raises:
HTTPException: If not in config mode or config key missing
"""
- logging.info(f"validate_config_mode: use_config={getattr(pipeline, 'use_config', None)}, config exists={getattr(pipeline, 'config', None) is not None}")
+ logging.info(
+ f"validate_config_mode: use_config={getattr(pipeline, 'use_config', None)}, config exists={getattr(pipeline, 'config', None) is not None}"
+ )
if not (pipeline.use_config and pipeline.config):
- logging.error(f"validate_config_mode: Pipeline is not using configuration mode")
- raise HTTPException(
- status_code=400,
- detail="Pipeline is not using configuration mode"
- )
-
+ logging.error("validate_config_mode: Pipeline is not using configuration mode")
+ raise HTTPException(status_code=400, detail="Pipeline is not using configuration mode")
+
if config_check and config_check not in pipeline.config:
logging.error(f"validate_config_mode: Configuration key '{config_check}' not found in pipeline config")
logging.info(f"validate_config_mode: Available config keys: {list(pipeline.config.keys())}")
- raise HTTPException(
- status_code=400,
- detail=f"Configuration missing required section: {config_check}"
- )
\ No newline at end of file
+ raise HTTPException(status_code=400, detail=f"Configuration missing required section: {config_check}")
diff --git a/demo/realtime-img2img/routes/controlnet.py b/demo/realtime-img2img/routes/controlnet.py
index e25cbcbfc..9571d613f 100644
--- a/demo/realtime-img2img/routes/controlnet.py
+++ b/demo/realtime-img2img/routes/controlnet.py
@@ -1,19 +1,24 @@
"""
ControlNet-related endpoints for realtime-img2img
"""
-from fastapi import APIRouter, Request, HTTPException, Depends, UploadFile, File
-from fastapi.responses import JSONResponse
+
+import copy
import logging
+
import yaml
-import tempfile
-from pathlib import Path
-import copy
+from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
+from fastapi.responses import JSONResponse
-from .common.api_utils import handle_api_request, create_success_response, handle_api_error, validate_feature_enabled, validate_config_mode
+from .common.api_utils import (
+ create_success_response,
+ handle_api_error,
+)
from .common.dependencies import get_app_instance, get_available_controlnets
+
router = APIRouter(prefix="/api", tags=["controlnet"])
+
def _ensure_runtime_controlnet_config(app_instance):
"""Ensure runtime controlnet config is initialized from uploaded config or create minimal config"""
if app_instance.app_state.runtime_config is None:
@@ -22,45 +27,45 @@ def _ensure_runtime_controlnet_config(app_instance):
app_instance.app_state.runtime_config = copy.deepcopy(app_instance.app_state.uploaded_config)
else:
# Create minimal config if no YAML exists
- app_instance.app_state.runtime_config = {'controlnets': []}
-
+ app_instance.app_state.runtime_config = {"controlnets": []}
+
# Ensure controlnets key exists in runtime config
- if 'controlnets' not in app_instance.app_state.runtime_config:
- app_instance.app_state.runtime_config['controlnets'] = []
+ if "controlnets" not in app_instance.app_state.runtime_config:
+ app_instance.app_state.runtime_config["controlnets"] = []
@router.post("/controlnet/upload-config")
async def upload_controlnet_config(file: UploadFile = File(...), app_instance=Depends(get_app_instance)):
"""Upload and load a new ControlNet YAML configuration"""
try:
- if not file.filename.endswith(('.yaml', '.yml')):
+ if not file.filename.endswith((".yaml", ".yml")):
raise HTTPException(status_code=400, detail="File must be a YAML file")
-
+
# Save uploaded file temporarily
content = await file.read()
-
+
# Parse YAML content
try:
- config_data = yaml.safe_load(content.decode('utf-8'))
+ config_data = yaml.safe_load(content.decode("utf-8"))
except yaml.YAMLError as e:
raise HTTPException(status_code=400, detail=f"Invalid YAML format: {str(e)}")
-
+
# YAML is source of truth - completely replace any runtime modifications
app_instance.app_state.uploaded_config = config_data
app_instance.app_state.runtime_config = None
app_instance.app_state.config_needs_reload = True
-
+
# SINGLE SOURCE OF TRUTH: Populate AppState from config
app_instance.app_state.populate_from_config(config_data)
-
+
# RESET ALL INPUT SOURCES TO DEFAULTS WHEN NEW CONFIG IS UPLOADED
- if hasattr(app_instance, 'input_source_manager'):
+ if hasattr(app_instance, "input_source_manager"):
try:
app_instance.input_source_manager.reset_to_defaults()
logging.info("upload_controlnet_config: Reset all input sources to defaults")
except Exception as e:
logging.exception(f"upload_controlnet_config: Failed to reset input sources: {e}")
-
+
# FORCE DESTROY ACTIVE PIPELINE TO MAKE CONFIG THE SOURCE OF TRUTH
if app_instance.pipeline:
logging.info("upload_controlnet_config: Destroying active pipeline to force config as source of truth")
@@ -69,43 +74,43 @@ async def upload_controlnet_config(file: UploadFile = File(...), app_instance=De
app_instance.pipeline = None
app_instance._cleanup_pipeline(old_pipeline)
app_instance.app_state.pipeline_lifecycle = "stopped"
-
+
# Get config prompt if available
- config_prompt = config_data.get('prompt', None)
+ config_prompt = config_data.get("prompt", None)
# Get negative prompt if available
- config_negative_prompt = config_data.get('negative_prompt', None)
-
+ config_negative_prompt = config_data.get("negative_prompt", None)
+
# Get t_index_list from config if available
from app_config import DEFAULT_SETTINGS
- t_index_list = config_data.get('t_index_list', DEFAULT_SETTINGS.get('t_index_list', [35, 45]))
-
+
+ t_index_list = config_data.get("t_index_list", DEFAULT_SETTINGS.get("t_index_list", [35, 45]))
+
# Get acceleration from config if available
- config_acceleration = config_data.get('acceleration', app_instance.args.acceleration)
-
+ config_acceleration = config_data.get("acceleration", app_instance.args.acceleration)
+
# Get width and height from config if available
- config_width = config_data.get('width', None)
- config_height = config_data.get('height', None)
-
+ config_width = config_data.get("width", None)
+ config_height = config_data.get("height", None)
+
# Update resolution if width/height are specified in config
if config_width is not None and config_height is not None:
try:
# Validate resolution
if config_width % 64 != 0 or config_height % 64 != 0:
raise HTTPException(status_code=400, detail="Resolution must be multiples of 64")
-
+
if not (384 <= config_width <= 1024) or not (384 <= config_height <= 1024):
raise HTTPException(status_code=400, detail="Resolution must be between 384 and 1024")
-
- app_instance.app_state.current_resolution = {
- "width": int(config_width),
- "height": int(config_height)
- }
-
- logging.info(f"upload_controlnet_config: Updated resolution from config to {config_width}x{config_height}")
-
+
+ app_instance.app_state.current_resolution = {"width": int(config_width), "height": int(config_height)}
+
+ logging.info(
+ f"upload_controlnet_config: Updated resolution from config to {config_width}x{config_height}"
+ )
+
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="Invalid width/height values in config")
-
+
# Build current resolution string
current_resolution = None
if config_width and config_height:
@@ -114,13 +119,13 @@ async def upload_controlnet_config(file: UploadFile = File(...), app_instance=De
aspect_ratio = app_instance._calculate_aspect_ratio(config_width, config_height)
if aspect_ratio:
current_resolution += f" ({aspect_ratio})"
-
+
# Build config_values for other parameters that frontend may expect
config_values = {}
for key in [
- 'use_taesd',
- 'cfg_type',
- 'safety_checker',
+ "use_taesd",
+ "cfg_type",
+ "safety_checker",
]:
if key in config_data:
config_values[key] = config_data[key]
@@ -155,46 +160,57 @@ async def upload_controlnet_config(file: UploadFile = File(...), app_instance=De
"latent_preprocessing": app_instance.app_state.pipeline_hooks["latent_preprocessing"],
"latent_postprocessing": app_instance.app_state.pipeline_hooks["latent_postprocessing"],
}
-
+
return JSONResponse(response_data)
-
+
except Exception as e:
logging.exception(f"upload_controlnet_config: Failed to upload configuration: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload configuration: {str(e)}")
+
@router.get("/controlnet/info")
async def get_controlnet_info(app_instance=Depends(get_app_instance)):
"""Get current ControlNet configuration info - SINGLE SOURCE OF TRUTH"""
return JSONResponse({"controlnet": app_instance.app_state.controlnet_info})
+
@router.get("/blending/current")
async def get_current_blending_config(app_instance=Depends(get_app_instance)):
"""Get current prompt and seed blending configurations"""
try:
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream') and hasattr(app_instance.pipeline.stream, 'get_stream_state'):
+ if (
+ app_instance.pipeline
+ and hasattr(app_instance.pipeline, "stream")
+ and hasattr(app_instance.pipeline.stream, "get_stream_state")
+ ):
state = app_instance.pipeline.stream.get_stream_state(include_caches=False)
- return JSONResponse({
- "prompt_blending": state.get("prompt_list", []),
- "seed_blending": state.get("seed_list", []),
- "normalize_prompt_weights": state.get("normalize_prompt_weights", True),
- "normalize_seed_weights": state.get("normalize_seed_weights", True),
- "has_config": app_instance.app_state.uploaded_config is not None,
- "pipeline_active": True
- })
+ return JSONResponse(
+ {
+ "prompt_blending": state.get("prompt_list", []),
+ "seed_blending": state.get("seed_list", []),
+ "normalize_prompt_weights": state.get("normalize_prompt_weights", True),
+ "normalize_seed_weights": state.get("normalize_seed_weights", True),
+ "has_config": app_instance.app_state.uploaded_config is not None,
+ "pipeline_active": True,
+ }
+ )
# Fallback to AppState when pipeline not initialized - SINGLE SOURCE OF TRUTH
- return JSONResponse({
- "prompt_blending": app_instance.app_state.prompt_blending,
- "seed_blending": app_instance.app_state.seed_blending,
- "normalize_prompt_weights": app_instance.app_state.normalize_prompt_weights,
- "normalize_seed_weights": app_instance.app_state.normalize_seed_weights,
- "has_config": app_instance.app_state.uploaded_config is not None,
- "pipeline_active": False
- })
-
+ return JSONResponse(
+ {
+ "prompt_blending": app_instance.app_state.prompt_blending,
+ "seed_blending": app_instance.app_state.seed_blending,
+ "normalize_prompt_weights": app_instance.app_state.normalize_prompt_weights,
+ "normalize_seed_weights": app_instance.app_state.normalize_seed_weights,
+ "has_config": app_instance.app_state.uploaded_config is not None,
+ "pipeline_active": False,
+ }
+ )
+
except Exception as e:
raise handle_api_error(e, "get_current_blending_config")
+
@router.post("/controlnet/update-strength")
async def update_controlnet_strength(request: Request, app_instance=Depends(get_app_instance)):
"""Update ControlNet strength in real-time"""
@@ -202,13 +218,13 @@ async def update_controlnet_strength(request: Request, app_instance=Depends(get_
data = await request.json()
controlnet_index = data.get("index")
strength = data.get("strength")
-
+
if controlnet_index is None or strength is None:
raise HTTPException(status_code=400, detail="Missing index or strength parameter")
-
+
# Update ControlNet strength in AppState - SINGLE SOURCE OF TRUTH
app_instance.app_state.update_controlnet_strength(controlnet_index, float(strength))
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -216,114 +232,122 @@ async def update_controlnet_strength(request: Request, app_instance=Depends(get_
controlnet_config = []
for cn in app_instance.app_state.controlnet_info["controlnets"]:
config_entry = dict(cn)
- config_entry['conditioning_scale'] = cn['strength'] # Map strength back to conditioning_scale
+ config_entry["conditioning_scale"] = cn["strength"] # Map strength back to conditioning_scale
controlnet_config.append(config_entry)
app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config)
except Exception as e:
logging.exception(f"update_controlnet_strength: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
+
return create_success_response(f"Updated ControlNet {controlnet_index} strength to {strength}")
-
+
except Exception as e:
raise handle_api_error(e, "update_controlnet_strength")
+
@router.get("/controlnet/available")
-async def get_available_controlnets_endpoint(app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets)):
+async def get_available_controlnets_endpoint(
+ app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets)
+):
"""Get list of available ControlNets that can be added"""
try:
# Debug the dependency injection
-
+
# Detect current model architecture to filter appropriate ControlNets
model_type = "sd15" # Default fallback
-
+
# Try to determine model type from pipeline config or uploaded config
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'config') and app_instance.pipeline.config:
- model_id = app_instance.pipeline.config.get('model_id', '')
- if 'sdxl' in model_id.lower() or 'xl' in model_id.lower():
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "config") and app_instance.pipeline.config:
+ model_id = app_instance.pipeline.config.get("model_id", "")
+ if "sdxl" in model_id.lower() or "xl" in model_id.lower():
model_type = "sdxl"
elif app_instance.app_state.uploaded_config:
# If no pipeline yet, try to get model type from uploaded config
- model_id = app_instance.app_state.uploaded_config.get('model_id_or_path', '')
- if 'sdxl' in model_id.lower() or 'xl' in model_id.lower():
+ model_id = app_instance.app_state.uploaded_config.get("model_id_or_path", "")
+ if "sdxl" in model_id.lower() or "xl" in model_id.lower():
model_type = "sdxl"
-
+
# Handle case where available_controlnets dependency returns None
if available_controlnets is None:
logging.warning("get_available_controlnets: available_controlnets dependency returned None")
available = []
else:
available = available_controlnets.get(model_type, [])
-
+
# Filter out already active ControlNets
current_controlnets = []
# Check runtime config first, then fall back to uploaded config
- if app_instance.app_state.runtime_config and 'controlnets' in app_instance.app_state.runtime_config:
- current_controlnets = [cn.get('model_id', '') for cn in app_instance.app_state.runtime_config['controlnets']]
- elif app_instance.app_state.uploaded_config and 'controlnets' in app_instance.app_state.uploaded_config:
- current_controlnets = [cn.get('model_id', '') for cn in app_instance.app_state.uploaded_config['controlnets']]
-
+ if app_instance.app_state.runtime_config and "controlnets" in app_instance.app_state.runtime_config:
+ current_controlnets = [
+ cn.get("model_id", "") for cn in app_instance.app_state.runtime_config["controlnets"]
+ ]
+ elif app_instance.app_state.uploaded_config and "controlnets" in app_instance.app_state.uploaded_config:
+ current_controlnets = [
+ cn.get("model_id", "") for cn in app_instance.app_state.uploaded_config["controlnets"]
+ ]
+
filtered_available = []
for cn in available:
- if cn['model_id'] not in current_controlnets:
+ if cn["model_id"] not in current_controlnets:
filtered_available.append(cn)
-
- return JSONResponse({
- "status": "success",
- "available_controlnets": filtered_available,
- "model_type": model_type
- })
-
+
+ return JSONResponse(
+ {"status": "success", "available_controlnets": filtered_available, "model_type": model_type}
+ )
+
except Exception as e:
raise handle_api_error(e, "get_available_controlnets_endpoint")
+
@router.post("/controlnet/add")
-async def add_controlnet(request: Request, app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets)):
+async def add_controlnet(
+ request: Request, app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets)
+):
"""Add a ControlNet from the predefined list"""
try:
data = await request.json()
controlnet_id = data.get("controlnet_id")
conditioning_scale = data.get("conditioning_scale", None)
-
+
if not controlnet_id:
raise HTTPException(status_code=400, detail="Missing controlnet_id parameter")
-
+
# Find the ControlNet definition
controlnet_def = None
for model_type, controlnets in available_controlnets.items():
for cn in controlnets:
- if cn['id'] == controlnet_id:
+ if cn["id"] == controlnet_id:
controlnet_def = cn
break
if controlnet_def:
break
-
+
if not controlnet_def:
raise HTTPException(status_code=400, detail=f"ControlNet {controlnet_id} not found in registry")
-
+
# Use provided scale or default
if conditioning_scale is None:
- conditioning_scale = controlnet_def['default_scale']
-
+ conditioning_scale = controlnet_def["default_scale"]
+
# Initialize runtime config from YAML if not already done
_ensure_runtime_controlnet_config(app_instance)
-
+
# Create new ControlNet entry
new_controlnet = {
- 'model_id': controlnet_def['model_id'],
- 'conditioning_scale': conditioning_scale,
- 'preprocessor': controlnet_def['default_preprocessor'],
- 'preprocessor_params': controlnet_def.get('preprocessor_params', {}),
- 'enabled': True
+ "model_id": controlnet_def["model_id"],
+ "conditioning_scale": conditioning_scale,
+ "preprocessor": controlnet_def["default_preprocessor"],
+ "preprocessor_params": controlnet_def.get("preprocessor_params", {}),
+ "enabled": True,
}
-
+
# Add to runtime config (not YAML)
- app_instance.app_state.runtime_config['controlnets'].append(new_controlnet)
-
+ app_instance.app_state.runtime_config["controlnets"].append(new_controlnet)
+
# Add to AppState - SINGLE SOURCE OF TRUTH
app_instance.app_state.add_controlnet(new_controlnet)
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -331,79 +355,84 @@ async def add_controlnet(request: Request, app_instance=Depends(get_app_instance
controlnet_config = []
for cn in app_instance.app_state.controlnet_info["controlnets"]:
config_entry = dict(cn)
- config_entry['conditioning_scale'] = cn['strength'] # Map strength back to conditioning_scale
+ config_entry["conditioning_scale"] = cn["strength"] # Map strength back to conditioning_scale
controlnet_config.append(config_entry)
app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config)
except Exception as e:
logging.exception(f"add_controlnet: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
-
+
# Return updated ControlNet info immediately - SINGLE SOURCE OF TRUTH
- added_index = len(app_instance.app_state.runtime_config['controlnets']) - 1
-
- return JSONResponse({
- "status": "success",
- "message": f"Added {controlnet_def['name']}",
- "controlnet_index": added_index,
- "controlnet_info": app_instance.app_state.controlnet_info
- })
-
+ added_index = len(app_instance.app_state.runtime_config["controlnets"]) - 1
+
+ return JSONResponse(
+ {
+ "status": "success",
+ "message": f"Added {controlnet_def['name']}",
+ "controlnet_index": added_index,
+ "controlnet_info": app_instance.app_state.controlnet_info,
+ }
+ )
+
except Exception as e:
raise handle_api_error(e, "add_controlnet")
+
@router.get("/controlnet/status")
async def get_controlnet_status(app_instance=Depends(get_app_instance)):
"""Get the status of ControlNet configuration"""
try:
controlnet_pipeline = app_instance._get_controlnet_pipeline()
-
+
if not controlnet_pipeline:
- return JSONResponse({
- "status": "no_pipeline",
- "message": "No ControlNet pipeline available",
- "controlnet_count": 0
- })
-
+ return JSONResponse(
+ {"status": "no_pipeline", "message": "No ControlNet pipeline available", "controlnet_count": 0}
+ )
+
# Use AppState - SINGLE SOURCE OF TRUTH
controlnet_count = len(app_instance.app_state.controlnet_info["controlnets"])
-
- return JSONResponse({
- "status": "ready",
- "controlnet_count": controlnet_count,
- "message": f"{controlnet_count} ControlNet(s) configured" if controlnet_count > 0 else "No ControlNets configured"
- })
-
+
+ return JSONResponse(
+ {
+ "status": "ready",
+ "controlnet_count": controlnet_count,
+ "message": f"{controlnet_count} ControlNet(s) configured"
+ if controlnet_count > 0
+ else "No ControlNets configured",
+ }
+ )
+
except Exception as e:
raise handle_api_error(e, "get_controlnet_status")
+
@router.post("/controlnet/remove")
async def remove_controlnet(request: Request, app_instance=Depends(get_app_instance)):
"""Remove a ControlNet by index"""
try:
data = await request.json()
index = data.get("index")
-
+
if index is None:
raise HTTPException(status_code=400, detail="Missing index parameter")
-
+
# Initialize runtime config from YAML if not already done
_ensure_runtime_controlnet_config(app_instance)
-
- if 'controlnets' not in app_instance.app_state.runtime_config:
+
+ if "controlnets" not in app_instance.app_state.runtime_config:
raise HTTPException(status_code=400, detail="No ControlNet configuration found")
-
- controlnets = app_instance.app_state.runtime_config['controlnets']
-
+
+ controlnets = app_instance.app_state.runtime_config["controlnets"]
+
if index < 0 or index >= len(controlnets):
raise HTTPException(status_code=400, detail=f"ControlNet index {index} out of range")
-
+
removed_controlnet = controlnets.pop(index)
-
+
# Remove from AppState - SINGLE SOURCE OF TRUTH
app_instance.app_state.remove_controlnet(index)
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -411,65 +440,66 @@ async def remove_controlnet(request: Request, app_instance=Depends(get_app_insta
controlnet_config = []
for cn in app_instance.app_state.controlnet_info["controlnets"]:
config_entry = dict(cn)
- config_entry['conditioning_scale'] = cn['strength'] # Map strength back to conditioning_scale
+ config_entry["conditioning_scale"] = cn["strength"] # Map strength back to conditioning_scale
controlnet_config.append(config_entry)
app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config)
except Exception as e:
logging.exception(f"remove_controlnet: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
-
+
# Return updated ControlNet info immediately - SINGLE SOURCE OF TRUTH
- return create_success_response(f"Removed ControlNet at index {index}", controlnet_info=app_instance.app_state.controlnet_info)
-
+ return create_success_response(
+ f"Removed ControlNet at index {index}", controlnet_info=app_instance.app_state.controlnet_info
+ )
+
except Exception as e:
raise handle_api_error(e, "remove_controlnet")
+
# Preprocessor endpoints (closely related to ControlNet)
@router.get("/preprocessors/info")
async def get_preprocessors_info(app_instance=Depends(get_app_instance)):
"""Get preprocessor information using metadata from preprocessor classes"""
try:
# Use the same processor registry as pipeline hooks
- from streamdiffusion.preprocessing.processors import list_preprocessors, get_preprocessor_class
-
+ from streamdiffusion.preprocessing.processors import get_preprocessor_class, list_preprocessors
+
available_processors = list_preprocessors()
processors_info = {}
-
+
for processor_name in available_processors:
try:
processor_class = get_preprocessor_class(processor_name)
- if hasattr(processor_class, 'get_preprocessor_metadata'):
+ if hasattr(processor_class, "get_preprocessor_metadata"):
metadata = processor_class.get_preprocessor_metadata()
processors_info[processor_name] = {
"name": metadata.get("name", processor_name),
"description": metadata.get("description", ""),
- "parameters": metadata.get("parameters", {})
+ "parameters": metadata.get("parameters", {}),
}
else:
processors_info[processor_name] = {
"name": processor_name,
"description": f"{processor_name} processor",
- "parameters": {}
+ "parameters": {},
}
except Exception as e:
logging.warning(f"get_preprocessors_info: Failed to load metadata for {processor_name}: {e}")
processors_info[processor_name] = {
"name": processor_name,
"description": f"{processor_name} processor",
- "parameters": {}
+ "parameters": {},
}
-
- return JSONResponse({
- "status": "success",
- "available": list(processors_info.keys()),
- "preprocessors": processors_info
- })
-
+
+ return JSONResponse(
+ {"status": "success", "available": list(processors_info.keys()), "preprocessors": processors_info}
+ )
+
except Exception as e:
raise handle_api_error(e, "get_preprocessors_info")
+
@router.post("/preprocessors/switch")
async def switch_preprocessor(request: Request, app_instance=Depends(get_app_instance)):
"""Switch preprocessor for a specific ControlNet"""
@@ -478,50 +508,59 @@ async def switch_preprocessor(request: Request, app_instance=Depends(get_app_ins
# Support both parameter naming conventions for compatibility
controlnet_index = data.get("controlnet_index") or data.get("processor_index")
preprocessor_name = data.get("preprocessor") or data.get("processor")
-
+
if controlnet_index is None or not preprocessor_name:
- raise HTTPException(status_code=400, detail="Missing controlnet_index/processor_index or preprocessor/processor parameter")
-
+ raise HTTPException(
+ status_code=400, detail="Missing controlnet_index/processor_index or preprocessor/processor parameter"
+ )
+
# Validate AppState has ControlNet configuration (pipeline not required)
- if not app_instance.app_state.controlnet_info["enabled"] or not app_instance.app_state.controlnet_info["controlnets"]:
- raise HTTPException(status_code=400, detail="No ControlNet configuration available. Please upload a config first.")
-
+ if (
+ not app_instance.app_state.controlnet_info["enabled"]
+ or not app_instance.app_state.controlnet_info["controlnets"]
+ ):
+ raise HTTPException(
+ status_code=400, detail="No ControlNet configuration available. Please upload a config first."
+ )
+
# Update AppState - SINGLE SOURCE OF TRUTH (works before pipeline creation)
if controlnet_index >= len(app_instance.app_state.controlnet_info["controlnets"]):
raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range")
-
+
# Update the preprocessor in AppState
controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index]
- old_preprocessor = controlnet.get('preprocessor', 'unknown')
- controlnet['preprocessor'] = preprocessor_name
- controlnet['preprocessor_params'] = {} # Reset parameters when switching
-
+ old_preprocessor = controlnet.get("preprocessor", "unknown")
+ controlnet["preprocessor"] = preprocessor_name
+ controlnet["preprocessor_params"] = {} # Reset parameters when switching
+
# Update runtime config to keep in sync
- if app_instance.app_state.runtime_config and 'controlnets' in app_instance.app_state.runtime_config:
- if controlnet_index < len(app_instance.app_state.runtime_config['controlnets']):
- app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor'] = preprocessor_name
- app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor_params'] = {}
-
+ if app_instance.app_state.runtime_config and "controlnets" in app_instance.app_state.runtime_config:
+ if controlnet_index < len(app_instance.app_state.runtime_config["controlnets"]):
+ app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor"] = (
+ preprocessor_name
+ )
+ app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor_params"] = {}
+
# Update pipeline if active
if app_instance.pipeline:
try:
controlnet_config = []
for cn in app_instance.app_state.controlnet_info["controlnets"]:
config_entry = dict(cn)
- config_entry['conditioning_scale'] = cn['strength']
+ config_entry["conditioning_scale"] = cn["strength"]
controlnet_config.append(config_entry)
app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config)
except Exception as e:
logging.exception(f"switch_preprocessor: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
-
+
return create_success_response(f"Switched ControlNet {controlnet_index} preprocessor to {preprocessor_name}")
-
+
except Exception as e:
raise handle_api_error(e, "switch_preprocessor")
+
@router.post("/preprocessors/update-params")
async def update_preprocessor_params(request: Request, app_instance=Depends(get_app_instance)):
"""Update preprocessor parameters for a specific ControlNet"""
@@ -532,107 +571,128 @@ async def update_preprocessor_params(request: Request, app_instance=Depends(get_
except Exception as json_error:
logging.error(f"update_preprocessor_params: JSON parsing failed: {json_error}")
raise HTTPException(status_code=400, detail=f"Invalid JSON: {json_error}")
-
+
controlnet_index = data.get("controlnet_index")
params = data.get("params", {})
-
+
if controlnet_index is None:
- logging.error(f"update_preprocessor_params: Missing controlnet_index parameter")
+ logging.error("update_preprocessor_params: Missing controlnet_index parameter")
raise HTTPException(status_code=400, detail="Missing controlnet_index parameter")
-
+
# Validate AppState has ControlNet configuration (pipeline not required)
- if not app_instance.app_state.controlnet_info["enabled"] or not app_instance.app_state.controlnet_info["controlnets"]:
- logging.error(f"update_preprocessor_params: No ControlNet configuration available in AppState")
- raise HTTPException(status_code=400, detail="No ControlNet configuration available. Please upload a config first.")
-
+ if (
+ not app_instance.app_state.controlnet_info["enabled"]
+ or not app_instance.app_state.controlnet_info["controlnets"]
+ ):
+ logging.error("update_preprocessor_params: No ControlNet configuration available in AppState")
+ raise HTTPException(
+ status_code=400, detail="No ControlNet configuration available. Please upload a config first."
+ )
+
# Update AppState - SINGLE SOURCE OF TRUTH (works before pipeline creation)
if controlnet_index >= len(app_instance.app_state.controlnet_info["controlnets"]):
- logging.error(f"update_preprocessor_params: ControlNet index {controlnet_index} out of range (max: {len(app_instance.app_state.controlnet_info['controlnets'])-1})")
+ logging.error(
+ f"update_preprocessor_params: ControlNet index {controlnet_index} out of range (max: {len(app_instance.app_state.controlnet_info['controlnets']) - 1})"
+ )
raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range")
-
+
# Update preprocessor parameters in AppState
controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index]
- if 'preprocessor_params' not in controlnet:
- controlnet['preprocessor_params'] = {}
- controlnet['preprocessor_params'].update(params)
-
+ if "preprocessor_params" not in controlnet:
+ controlnet["preprocessor_params"] = {}
+ controlnet["preprocessor_params"].update(params)
+
# Update runtime config to keep in sync
- if app_instance.app_state.runtime_config and 'controlnets' in app_instance.app_state.runtime_config:
- if controlnet_index < len(app_instance.app_state.runtime_config['controlnets']):
- if 'preprocessor_params' not in app_instance.app_state.runtime_config['controlnets'][controlnet_index]:
- app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor_params'] = {}
- app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor_params'].update(params)
-
+ if app_instance.app_state.runtime_config and "controlnets" in app_instance.app_state.runtime_config:
+ if controlnet_index < len(app_instance.app_state.runtime_config["controlnets"]):
+ if "preprocessor_params" not in app_instance.app_state.runtime_config["controlnets"][controlnet_index]:
+ app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor_params"] = {}
+ app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor_params"].update(
+ params
+ )
+
# Update pipeline if active
if app_instance.pipeline:
try:
controlnet_config = []
for cn in app_instance.app_state.controlnet_info["controlnets"]:
config_entry = dict(cn)
- config_entry['conditioning_scale'] = cn['strength']
+ config_entry["conditioning_scale"] = cn["strength"]
controlnet_config.append(config_entry)
app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config)
except Exception as e:
logging.exception(f"update_preprocessor_params: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
- logging.debug(f"update_preprocessor_params: Updated ControlNet {controlnet_index} preprocessor params: {params}")
-
- return create_success_response(f"Updated ControlNet {controlnet_index} preprocessor parameters", updated_params=params)
-
+
+ logging.debug(
+ f"update_preprocessor_params: Updated ControlNet {controlnet_index} preprocessor params: {params}"
+ )
+
+ return create_success_response(
+ f"Updated ControlNet {controlnet_index} preprocessor parameters", updated_params=params
+ )
+
except Exception as e:
logging.exception(f"update_preprocessor_params: Exception occurred: {str(e)}")
raise handle_api_error(e, "update_preprocessor_params")
+
@router.get("/preprocessors/current-params/{controlnet_index}")
async def get_current_preprocessor_params(controlnet_index: int, app_instance=Depends(get_app_instance)):
"""Get current parameter values for a specific ControlNet preprocessor"""
try:
# First try to get from uploaded config if no pipeline
if not app_instance.pipeline and app_instance.app_state.uploaded_config:
- controlnets = app_instance.app_state.uploaded_config.get('controlnets', [])
+ controlnets = app_instance.app_state.uploaded_config.get("controlnets", [])
if controlnet_index < len(controlnets):
controlnet = controlnets[controlnet_index]
- return JSONResponse({
- "status": "success",
- "controlnet_index": controlnet_index,
- "preprocessor": controlnet.get('preprocessor', 'unknown'),
- "parameters": controlnet.get('preprocessor_params', {}),
- "note": "From uploaded config"
- })
-
+ return JSONResponse(
+ {
+ "status": "success",
+ "controlnet_index": controlnet_index,
+ "preprocessor": controlnet.get("preprocessor", "unknown"),
+ "parameters": controlnet.get("preprocessor_params", {}),
+ "note": "From uploaded config",
+ }
+ )
+
# Return empty/default response if no config available
if not app_instance.pipeline:
- return JSONResponse({
- "status": "success",
- "controlnet_index": controlnet_index,
- "preprocessor": "unknown",
- "parameters": {},
- "note": "Pipeline not initialized - no config available"
- })
-
+ return JSONResponse(
+ {
+ "status": "success",
+ "controlnet_index": controlnet_index,
+ "preprocessor": "unknown",
+ "parameters": {},
+ "note": "Pipeline not initialized - no config available",
+ }
+ )
+
# Use AppState - SINGLE SOURCE OF TRUTH
if controlnet_index >= len(app_instance.app_state.controlnet_info["controlnets"]):
- return JSONResponse({
+ return JSONResponse(
+ {
+ "status": "success",
+ "controlnet_index": controlnet_index,
+ "preprocessor": "unknown",
+ "parameters": {},
+ "note": f"ControlNet index {controlnet_index} out of range",
+ }
+ )
+
+ controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index]
+ preprocessor = controlnet.get("preprocessor", "unknown")
+ preprocessor_params = controlnet.get("preprocessor_params", {})
+
+ return JSONResponse(
+ {
"status": "success",
"controlnet_index": controlnet_index,
- "preprocessor": "unknown",
- "parameters": {},
- "note": f"ControlNet index {controlnet_index} out of range"
- })
-
- controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index]
- preprocessor = controlnet.get('preprocessor', 'unknown')
- preprocessor_params = controlnet.get('preprocessor_params', {})
-
- return JSONResponse({
- "status": "success",
- "controlnet_index": controlnet_index,
- "preprocessor": preprocessor,
- "parameters": preprocessor_params
- })
-
+ "preprocessor": preprocessor,
+ "parameters": preprocessor_params,
+ }
+ )
+
except Exception as e:
raise handle_api_error(e, "get_current_preprocessor_params")
-
diff --git a/demo/realtime-img2img/routes/debug.py b/demo/realtime-img2img/routes/debug.py
index 5015ce378..4fcd9e3b7 100644
--- a/demo/realtime-img2img/routes/debug.py
+++ b/demo/realtime-img2img/routes/debug.py
@@ -1,75 +1,82 @@
"""
Debug mode API endpoints for realtime-img2img
"""
-from fastapi import APIRouter, HTTPException, Depends
-from pydantic import BaseModel
+
import logging
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel
+
from .common.dependencies import get_app_instance
+
router = APIRouter(prefix="/api/debug", tags=["debug"])
+
class DebugResponse(BaseModel):
success: bool
message: str
debug_mode: bool
debug_pending_frame: bool = False
+
@router.post("/enable", response_model=DebugResponse)
async def enable_debug_mode(app_instance=Depends(get_app_instance)):
"""Enable debug mode - pauses automatic frame processing"""
try:
app_instance.app_state.debug_mode = True
app_instance.app_state.debug_pending_frame = False
-
+
logging.info("enable_debug_mode: Debug mode enabled")
-
+
return DebugResponse(
success=True,
message="Debug mode enabled. Frame processing is now paused.",
debug_mode=True,
- debug_pending_frame=False
+ debug_pending_frame=False,
)
except Exception as e:
logging.exception(f"enable_debug_mode: Failed to enable debug mode: {e}")
raise HTTPException(status_code=500, detail=f"Failed to enable debug mode: {str(e)}")
+
@router.post("/disable", response_model=DebugResponse)
async def disable_debug_mode(app_instance=Depends(get_app_instance)):
"""Disable debug mode - resumes automatic frame processing"""
try:
app_instance.app_state.debug_mode = False
app_instance.app_state.debug_pending_frame = False
-
+
logging.info("disable_debug_mode: Debug mode disabled")
-
+
return DebugResponse(
success=True,
message="Debug mode disabled. Automatic frame processing resumed.",
debug_mode=False,
- debug_pending_frame=False
+ debug_pending_frame=False,
)
except Exception as e:
logging.exception(f"disable_debug_mode: Failed to disable debug mode: {e}")
raise HTTPException(status_code=500, detail=f"Failed to disable debug mode: {str(e)}")
+
@router.post("/step", response_model=DebugResponse)
async def step_frame(app_instance=Depends(get_app_instance)):
"""Process exactly one frame when in debug mode"""
try:
if not app_instance.app_state.debug_mode:
raise HTTPException(status_code=400, detail="Debug mode is not enabled")
-
+
# Set pending frame flag to allow one frame to be processed
app_instance.app_state.debug_pending_frame = True
-
+
logging.info("step_frame: Frame step requested")
-
+
return DebugResponse(
success=True,
message="Frame step requested. Next frame will be processed.",
debug_mode=True,
- debug_pending_frame=True
+ debug_pending_frame=True,
)
except HTTPException:
raise
@@ -77,6 +84,7 @@ async def step_frame(app_instance=Depends(get_app_instance)):
logging.exception(f"step_frame: Failed to step frame: {e}")
raise HTTPException(status_code=500, detail=f"Failed to step frame: {str(e)}")
+
@router.get("/status", response_model=DebugResponse)
async def get_debug_status(app_instance=Depends(get_app_instance)):
"""Get current debug mode status"""
@@ -85,7 +93,7 @@ async def get_debug_status(app_instance=Depends(get_app_instance)):
success=True,
message="Debug status retrieved",
debug_mode=app_instance.app_state.debug_mode,
- debug_pending_frame=app_instance.app_state.debug_pending_frame
+ debug_pending_frame=app_instance.app_state.debug_pending_frame,
)
except Exception as e:
logging.exception(f"get_debug_status: Failed to get debug status: {e}")
diff --git a/demo/realtime-img2img/routes/inference.py b/demo/realtime-img2img/routes/inference.py
index f9bce553e..f1c81dd4a 100644
--- a/demo/realtime-img2img/routes/inference.py
+++ b/demo/realtime-img2img/routes/inference.py
@@ -1,23 +1,27 @@
"""
Inference and system status endpoints for realtime-img2img
"""
-from fastapi import APIRouter, Request, HTTPException, Depends
-from fastapi.responses import JSONResponse, StreamingResponse
+
import logging
import uuid
+
import markdown2
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import JSONResponse, StreamingResponse
+
+from .common.dependencies import get_app_instance, get_default_settings, get_pipeline_class
-from .common.api_utils import handle_api_request, create_success_response, handle_api_error
-from .common.dependencies import get_app_instance, get_pipeline_class, get_default_settings
router = APIRouter(prefix="/api", tags=["inference"])
+
@router.get("/queue")
async def get_queue_size(app_instance=Depends(get_app_instance)):
"""Get current queue size"""
queue_size = app_instance.conn_manager.get_user_count()
return JSONResponse({"queue_size": queue_size})
+
@router.get("/stream/{user_id}")
async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_app_instance)):
"""Main streaming endpoint for inference"""
@@ -29,27 +33,38 @@ async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_
app_instance.pipeline = app_instance._create_pipeline()
app_instance.app_state.pipeline_lifecycle = "running"
logging.info("stream: Pipeline created successfully")
-
+
# Recreate pipeline if config changed (but not resolution - that's handled separately)
- elif app_instance.app_state.config_needs_reload or (app_instance.app_state.uploaded_config and not (app_instance.pipeline.use_config and app_instance.pipeline.config and 'controlnets' in app_instance.pipeline.config)) or (app_instance.app_state.uploaded_config and not app_instance.pipeline.use_config):
+ elif (
+ app_instance.app_state.config_needs_reload
+ or (
+ app_instance.app_state.uploaded_config
+ and not (
+ app_instance.pipeline.use_config
+ and app_instance.pipeline.config
+ and "controlnets" in app_instance.pipeline.config
+ )
+ )
+ or (app_instance.app_state.uploaded_config and not app_instance.pipeline.use_config)
+ ):
if app_instance.app_state.config_needs_reload:
logging.info("stream: Recreating pipeline with new ControlNet config...")
else:
logging.info("stream: Upgrading to ControlNet pipeline...")
-
+
app_instance.app_state.pipeline_lifecycle = "restarting"
-
+
# Properly cleanup the old pipeline before creating new one
old_pipeline = app_instance.pipeline
app_instance.pipeline = None
-
+
if old_pipeline:
app_instance._cleanup_pipeline(old_pipeline)
old_pipeline = None
-
+
# Create new pipeline
app_instance.pipeline = app_instance._create_pipeline()
-
+
app_instance.app_state.config_needs_reload = False
app_instance.app_state.pipeline_lifecycle = "running"
logging.info("stream: Pipeline recreated successfully")
@@ -59,25 +74,31 @@ async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_
# Check for acceleration changes (requires pipeline recreation)
acceleration_changed = False
- if hasattr(app_instance, 'new_acceleration') and app_instance.new_acceleration != app_instance.args.acceleration:
- logging.info(f"stream: Acceleration change detected: {app_instance.args.acceleration} -> {app_instance.new_acceleration}")
-
+ if (
+ hasattr(app_instance, "new_acceleration")
+ and app_instance.new_acceleration != app_instance.args.acceleration
+ ):
+ logging.info(
+ f"stream: Acceleration change detected: {app_instance.args.acceleration} -> {app_instance.new_acceleration}"
+ )
+
# Create new Args object with updated acceleration (NamedTuple is immutable)
from config import Args
+
args_dict = app_instance.args._asdict()
- args_dict['acceleration'] = app_instance.new_acceleration
+ args_dict["acceleration"] = app_instance.new_acceleration
app_instance.args = Args(**args_dict)
- delattr(app_instance, 'new_acceleration')
-
+ delattr(app_instance, "new_acceleration")
+
# Recreate pipeline with new acceleration
old_pipeline = app_instance.pipeline
app_instance.pipeline = None
if old_pipeline:
app_instance._cleanup_pipeline(old_pipeline)
-
+
app_instance.pipeline = app_instance._create_pipeline()
acceleration_changed = True
- logging.info(f"stream: Pipeline recreated with new acceleration")
+ logging.info("stream: Pipeline recreated with new acceleration")
# IPAdapter style images are now handled dynamically in pipeline.predict()
# No static application needed here
@@ -91,6 +112,7 @@ async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_
# Generate and stream frames using pipeline.predict() in a loop (like original)
try:
+
async def generate_frames():
try:
while True:
@@ -105,219 +127,241 @@ async def generate_frames():
else:
# Wait in debug mode without requesting frames
import asyncio
+
await asyncio.sleep(0.1) # Small delay to prevent busy waiting
continue
else:
# Normal mode - request new frame automatically
await app_instance.conn_manager.send_json(user_id, {"status": "send_frame"})
-
+
# Get the latest parameters from the WebSocket connection manager
# This consumes data from the queue after requesting a new frame
# Get latest data from the queue (blocks until new data arrives)
params = await app_instance.conn_manager.get_latest_data(user_id)
if params is None:
continue
-
+
# Attach InputSourceManager to params for modular input routing
- if hasattr(app_instance, 'input_source_manager'):
+ if hasattr(app_instance, "input_source_manager"):
params.input_manager = app_instance.input_source_manager
-
+
# Generate frame using pipeline.predict()
image = app_instance.pipeline.predict(params)
if image is None:
logging.error("stream: predict returned None image; skipping frame")
continue
-
+
# Update FPS counter
import time
+
current_time = time.time()
- if hasattr(app_instance, 'last_frame_time'):
+ if hasattr(app_instance, "last_frame_time"):
frame_time = current_time - app_instance.last_frame_time
app_instance.fps_counter.append(frame_time)
if len(app_instance.fps_counter) > 30: # Keep last 30 frames
app_instance.fps_counter.pop(0)
app_instance.last_frame_time = current_time
-
+
# Convert image to frame format for streaming
# Use appropriate frame conversion based on output type
if app_instance.pipeline.output_type == "pt":
from util import pt_to_frame
+
frame = pt_to_frame(image)
else:
from util import pil_to_frame
+
frame = pil_to_frame(image)
yield frame
-
+
except Exception as e:
logging.exception(f"stream: Error in frame generation: {e}")
return StreamingResponse(
generate_frames(),
media_type="multipart/x-mixed-replace; boundary=frame",
- headers={"Cache-Control": "no-cache, no-store, must-revalidate"}
+ headers={"Cache-Control": "no-cache, no-store, must-revalidate"},
)
-
+
except Exception as e:
raise e
-
+
except Exception as e:
logging.exception(f"stream: Error in streaming endpoint: {e}")
raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}")
+
@router.get("/state")
-async def get_app_state(app_instance=Depends(get_app_instance), pipeline_class=Depends(get_pipeline_class), default_settings=Depends(get_default_settings)):
+async def get_app_state(
+ app_instance=Depends(get_app_instance),
+ pipeline_class=Depends(get_pipeline_class),
+ default_settings=Depends(get_default_settings),
+):
"""Get complete application state - replaces /api/settings with centralized state management"""
try:
# Update app_state with current dynamic values - SINGLE SOURCE OF TRUTH
app_instance.app_state.pipeline_active = app_instance.pipeline is not None
-
+
# Update FPS from fps_counter
if len(app_instance.fps_counter) > 0:
avg_frame_time = sum(app_instance.fps_counter) / len(app_instance.fps_counter)
app_instance.app_state.fps = round(1.0 / avg_frame_time if avg_frame_time > 0 else 0, 1)
else:
app_instance.app_state.fps = 0
-
+
# Update queue size
app_instance.app_state.queue_size = app_instance.conn_manager.get_user_count()
-
+
# Update pipeline parameters schema
app_instance.app_state.pipeline_params = pipeline_class.InputParams.schema()
-
+
# Update page content
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'info'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "info"):
info = app_instance.pipeline.info
else:
info = pipeline_class.Info()
-
+
if info.page_content:
app_instance.app_state.page_content = markdown2.markdown(info.page_content)
-
+
# Get complete state
state_data = app_instance.app_state.get_complete_state()
-
+
# Add additional fields expected by frontend for backward compatibility
- state_data.update({
- "info": pipeline_class.Info.schema(),
- "input_params": app_instance.app_state.pipeline_params,
- "max_queue_size": app_instance.args.max_queue_size,
- "acceleration": app_instance.args.acceleration,
- # Add config prompt for backward compatibility
- "config_prompt": app_instance.app_state.uploaded_config.get('prompt') if app_instance.app_state.uploaded_config else None,
- # Add resolution in expected format
- "resolution": f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}",
- })
-
+ state_data.update(
+ {
+ "info": pipeline_class.Info.schema(),
+ "input_params": app_instance.app_state.pipeline_params,
+ "max_queue_size": app_instance.args.max_queue_size,
+ "acceleration": app_instance.args.acceleration,
+ # Add config prompt for backward compatibility
+ "config_prompt": app_instance.app_state.uploaded_config.get("prompt")
+ if app_instance.app_state.uploaded_config
+ else None,
+ # Add resolution in expected format
+ "resolution": f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}",
+ }
+ )
+
return JSONResponse(state_data)
-
+
except Exception as e:
logging.error(f"get_app_state: Error getting application state: {e}")
raise HTTPException(status_code=500, detail=f"Failed to get application state: {str(e)}")
+
@router.get("/settings")
-async def settings(app_instance=Depends(get_app_instance), pipeline_class=Depends(get_pipeline_class), default_settings=Depends(get_default_settings)):
+async def settings(
+ app_instance=Depends(get_app_instance),
+ pipeline_class=Depends(get_pipeline_class),
+ default_settings=Depends(get_default_settings),
+):
"""Get pipeline settings and configuration info"""
# Use Pipeline class directly for schema info (doesn't require instance)
info_schema = pipeline_class.Info.schema()
-
+
# Get info from pipeline instance if available to get correct input_mode
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'info'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "info"):
info = app_instance.pipeline.info
else:
info = pipeline_class.Info()
-
+
page_content = ""
if info.page_content:
page_content = markdown2.markdown(info.page_content)
input_params = pipeline_class.InputParams.schema()
-
+
# Add ControlNet information - SINGLE SOURCE OF TRUTH
controlnet_info = app_instance.app_state.controlnet_info
-
+
# Add IPAdapter information - SINGLE SOURCE OF TRUTH
ipadapter_info = app_instance.app_state.ipadapter_info
-
+
# Include config prompt if available, otherwise use default
config_prompt = None
- if app_instance.app_state.uploaded_config and 'prompt' in app_instance.app_state.uploaded_config:
- config_prompt = app_instance.app_state.uploaded_config['prompt']
+ if app_instance.app_state.uploaded_config and "prompt" in app_instance.app_state.uploaded_config:
+ config_prompt = app_instance.app_state.uploaded_config["prompt"]
elif not config_prompt:
- config_prompt = default_settings.get('prompt')
-
+ config_prompt = default_settings.get("prompt")
+
# Get current t_index_list from pipeline or config
current_t_index_list = None
- if app_instance.pipeline and hasattr(app_instance.pipeline.stream, 't_list'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline.stream, "t_list"):
current_t_index_list = app_instance.pipeline.stream.t_list
- elif app_instance.app_state.uploaded_config and 't_index_list' in app_instance.app_state.uploaded_config:
- current_t_index_list = app_instance.app_state.uploaded_config['t_index_list']
+ elif app_instance.app_state.uploaded_config and "t_index_list" in app_instance.app_state.uploaded_config:
+ current_t_index_list = app_instance.app_state.uploaded_config["t_index_list"]
else:
# Default values
- current_t_index_list = default_settings.get('t_index_list', [35, 45])
-
+ current_t_index_list = default_settings.get("t_index_list", [35, 45])
+
# Get current acceleration setting
current_acceleration = app_instance.args.acceleration
-
+
# Get current resolution
- current_resolution = f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}"
+ current_resolution = (
+ f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}"
+ )
# Add aspect ratio for display
- aspect_ratio = app_instance._calculate_aspect_ratio(app_instance.app_state.current_resolution['width'], app_instance.app_state.current_resolution['height'])
+ aspect_ratio = app_instance._calculate_aspect_ratio(
+ app_instance.app_state.current_resolution["width"], app_instance.app_state.current_resolution["height"]
+ )
if aspect_ratio:
current_resolution += f" ({aspect_ratio})"
- if app_instance.app_state.uploaded_config and 'acceleration' in app_instance.app_state.uploaded_config:
- current_acceleration = app_instance.app_state.uploaded_config['acceleration']
-
+ if app_instance.app_state.uploaded_config and "acceleration" in app_instance.app_state.uploaded_config:
+ current_acceleration = app_instance.app_state.uploaded_config["acceleration"]
+
# Get current streaming parameters (default values or from pipeline if available)
- current_guidance_scale = default_settings.get('guidance_scale', 1.1)
- current_delta = default_settings.get('delta', 0.7)
- current_num_inference_steps = default_settings.get('num_inference_steps', 50)
- current_seed = default_settings.get('seed', 2)
-
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ current_guidance_scale = default_settings.get("guidance_scale", 1.1)
+ current_delta = default_settings.get("delta", 0.7)
+ current_num_inference_steps = default_settings.get("num_inference_steps", 50)
+ current_seed = default_settings.get("seed", 2)
+
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
try:
state = app_instance.pipeline.stream.get_stream_state()
- current_guidance_scale = state.get('guidance_scale', current_guidance_scale)
- current_delta = state.get('delta', current_delta)
- current_num_inference_steps = state.get('num_inference_steps', current_num_inference_steps)
- current_seed = state.get('seed', current_seed)
+ current_guidance_scale = state.get("guidance_scale", current_guidance_scale)
+ current_delta = state.get("delta", current_delta)
+ current_num_inference_steps = state.get("num_inference_steps", current_num_inference_steps)
+ current_seed = state.get("seed", current_seed)
except Exception as e:
logging.warning(f"settings: Failed to get current stream parameters: {e}")
-
+
# Get negative prompt if available
- current_negative_prompt = default_settings.get('negative_prompt', '')
- if app_instance.app_state.uploaded_config and 'negative_prompt' in app_instance.app_state.uploaded_config:
- current_negative_prompt = app_instance.app_state.uploaded_config['negative_prompt']
- elif app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ current_negative_prompt = default_settings.get("negative_prompt", "")
+ if app_instance.app_state.uploaded_config and "negative_prompt" in app_instance.app_state.uploaded_config:
+ current_negative_prompt = app_instance.app_state.uploaded_config["negative_prompt"]
+ elif app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
try:
state = app_instance.pipeline.stream.get_stream_state()
- current_negative_prompt = state.get('negative_prompt', current_negative_prompt)
+ current_negative_prompt = state.get("negative_prompt", current_negative_prompt)
except Exception:
pass
-
+
# Get prompt and seed blending configuration - SINGLE SOURCE OF TRUTH
prompt_blending_config = app_instance.app_state.prompt_blending
seed_blending_config = app_instance.app_state.seed_blending
-
+
# Get normalization settings - SINGLE SOURCE OF TRUTH
normalize_prompt_weights = app_instance.app_state.normalize_prompt_weights
normalize_seed_weights = app_instance.app_state.normalize_seed_weights
-
+
# Get current skip_diffusion setting - SINGLE SOURCE OF TRUTH
current_skip_diffusion = app_instance.app_state.skip_diffusion
-
+
# Determine current model id for UI badge - SINGLE SOURCE OF TRUTH
model_id_for_ui = app_instance.app_state.model_id
-
+
# Check if pipeline is active
pipeline_active = app_instance.pipeline is not None
-
+
# Build config_values for other parameters that frontend may expect
config_values = {}
if app_instance.app_state.uploaded_config:
for key in [
- 'use_taesd',
- 'cfg_type',
- 'safety_checker',
+ "use_taesd",
+ "cfg_type",
+ "safety_checker",
]:
if key in app_instance.app_state.uploaded_config:
config_values[key] = app_instance.app_state.uploaded_config[key]
@@ -347,9 +391,10 @@ async def settings(app_instance=Depends(get_app_instance), pipeline_class=Depend
"model_id": model_id_for_ui,
"config_values": config_values,
}
-
+
return JSONResponse(response_data)
+
@router.get("/fps")
async def get_fps(app_instance=Depends(get_app_instance)):
"""Get current FPS"""
@@ -358,8 +403,5 @@ async def get_fps(app_instance=Depends(get_app_instance)):
fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
else:
fps = 0
-
- return JSONResponse({"fps": round(fps, 1)})
-
-
+ return JSONResponse({"fps": round(fps, 1)})
diff --git a/demo/realtime-img2img/routes/input_sources.py b/demo/realtime-img2img/routes/input_sources.py
index 0901c56ca..1029f2398 100644
--- a/demo/realtime-img2img/routes/input_sources.py
+++ b/demo/realtime-img2img/routes/input_sources.py
@@ -1,19 +1,21 @@
"""
Input Source Management API endpoints for realtime-img2img
"""
-from fastapi import APIRouter, Request, HTTPException, Depends, UploadFile, File
-from fastapi.responses import JSONResponse
+
+import io
import logging
-from pathlib import Path
-from typing import Optional, Any, Dict
import uuid
+from pathlib import Path
+from typing import Optional
+
+from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
+from input_sources import InputSource, InputSourceManager, InputSourceType
from PIL import Image
-import io
+from utils.video_utils import is_supported_video_format, validate_video_file
-from .common.api_utils import handle_api_request, create_success_response, handle_api_error
+from .common.api_utils import create_success_response, handle_api_error, handle_api_request
from .common.dependencies import get_app_instance
-from input_sources import InputSource, InputSourceType, InputSourceManager
-from utils.video_utils import validate_video_file, is_supported_video_format
+
router = APIRouter(prefix="/api/input-sources", tags=["input-sources"])
@@ -22,7 +24,7 @@
def _get_input_source_manager(app_instance) -> InputSourceManager:
"""Get or create the input source manager for the app instance."""
- if not hasattr(app_instance, 'input_source_manager'):
+ if not hasattr(app_instance, "input_source_manager"):
app_instance.input_source_manager = InputSourceManager()
return app_instance.input_source_manager
@@ -31,7 +33,7 @@ def _get_input_source_manager(app_instance) -> InputSourceManager:
async def set_input_source(request: Request, app_instance=Depends(get_app_instance)):
"""
Set input source for a component.
-
+
Body: {
component: str, # 'controlnet', 'ipadapter', 'base'
index?: int, # Required for controlnet
@@ -40,61 +42,60 @@ async def set_input_source(request: Request, app_instance=Depends(get_app_instan
}
"""
try:
- data = await handle_api_request(request, "set_input_source",
- required_params=['component', 'source_type'],
- pipeline_required=False)
-
- component = data['component']
- source_type_str = data['source_type']
- index = data.get('index')
- source_data = data.get('source_data')
-
+ data = await handle_api_request(
+ request, "set_input_source", required_params=["component", "source_type"], pipeline_required=False
+ )
+
+ component = data["component"]
+ source_type_str = data["source_type"]
+ index = data.get("index")
+ source_data = data.get("source_data")
+
# Validate component
- if component not in ['controlnet', 'ipadapter', 'base']:
+ if component not in ["controlnet", "ipadapter", "base"]:
raise HTTPException(status_code=400, detail=f"Invalid component: {component}")
-
+
# Validate source type
try:
source_type = InputSourceType(source_type_str)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid source type: {source_type_str}")
-
+
# Validate index for controlnet
- if component == 'controlnet' and index is None:
+ if component == "controlnet" and index is None:
raise HTTPException(status_code=400, detail="Index is required for ControlNet components")
-
+
# Get input source manager
manager = _get_input_source_manager(app_instance)
-
+
# Create input source
input_source = InputSource(source_type, source_data)
-
+
# Set the source
manager.set_source(component, input_source, index)
-
+
logger.info(f"set_input_source: Set {component} input source to {source_type_str}")
-
- return create_success_response({
- 'message': f'Input source set for {component}',
- 'component': component,
- 'source_type': source_type_str,
- 'index': index
- })
-
+
+ return create_success_response(
+ {
+ "message": f"Input source set for {component}",
+ "component": component,
+ "source_type": source_type_str,
+ "index": index,
+ }
+ )
+
except Exception as e:
return handle_api_error(e, "set_input_source")
@router.post("/upload-image/{component}")
async def upload_component_image(
- component: str,
- file: UploadFile = File(...),
- index: Optional[int] = None,
- app_instance=Depends(get_app_instance)
+ component: str, file: UploadFile = File(...), index: Optional[int] = None, app_instance=Depends(get_app_instance)
):
"""
Upload image for specific component.
-
+
Args:
component: Component name ('controlnet', 'ipadapter', 'base')
file: Image file to upload
@@ -102,69 +103,72 @@ async def upload_component_image(
"""
try:
# Validate component
- if component not in ['controlnet', 'ipadapter', 'base']:
+ if component not in ["controlnet", "ipadapter", "base"]:
raise HTTPException(status_code=400, detail=f"Invalid component: {component}")
-
+
# Validate index for controlnet
- if component == 'controlnet' and index is None:
+ if component == "controlnet" and index is None:
raise HTTPException(status_code=400, detail="Index is required for ControlNet components")
-
+
# Validate file type
- if not file.content_type.startswith('image/'):
+ if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
-
+
# Generate unique filename
file_id = str(uuid.uuid4())
- file_extension = Path(file.filename).suffix if file.filename else '.jpg'
- filename = f"{component}_{index}_{file_id}{file_extension}" if index is not None else f"{component}_{file_id}{file_extension}"
-
+ file_extension = Path(file.filename).suffix if file.filename else ".jpg"
+ filename = (
+ f"{component}_{index}_{file_id}{file_extension}"
+ if index is not None
+ else f"{component}_{file_id}{file_extension}"
+ )
+
# Save file
uploads_dir = Path("uploads/images")
uploads_dir.mkdir(parents=True, exist_ok=True)
file_path = uploads_dir / filename
-
+
content = await file.read()
- with open(file_path, 'wb') as f:
+ with open(file_path, "wb") as f:
f.write(content)
-
+
# Create PIL Image for input source
try:
image = Image.open(io.BytesIO(content))
- image = image.convert('RGB') # Ensure RGB format
+ image = image.convert("RGB") # Ensure RGB format
except Exception as e:
# Clean up file if image processing fails
file_path.unlink(missing_ok=True)
raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}")
-
+
# Get input source manager and set source
manager = _get_input_source_manager(app_instance)
input_source = InputSource(InputSourceType.UPLOADED_IMAGE, image)
manager.set_source(component, input_source, index)
-
+
logger.info(f"upload_component_image: Uploaded image for {component} (index: {index})")
-
- return create_success_response({
- 'message': f'Image uploaded for {component}',
- 'component': component,
- 'index': index,
- 'filename': filename,
- 'file_path': str(file_path)
- })
-
+
+ return create_success_response(
+ {
+ "message": f"Image uploaded for {component}",
+ "component": component,
+ "index": index,
+ "filename": filename,
+ "file_path": str(file_path),
+ }
+ )
+
except Exception as e:
return handle_api_error(e, "upload_component_image")
@router.post("/upload-video/{component}")
async def upload_component_video(
- component: str,
- file: UploadFile = File(...),
- index: Optional[int] = None,
- app_instance=Depends(get_app_instance)
+ component: str, file: UploadFile = File(...), index: Optional[int] = None, app_instance=Depends(get_app_instance)
):
"""
Upload video for specific component.
-
+
Args:
component: Component name ('controlnet', 'ipadapter', 'base')
file: Video file to upload
@@ -172,91 +176,95 @@ async def upload_component_video(
"""
try:
# Validate component
- if component not in ['controlnet', 'ipadapter', 'base']:
+ if component not in ["controlnet", "ipadapter", "base"]:
raise HTTPException(status_code=400, detail=f"Invalid component: {component}")
-
+
# Validate index for controlnet
- if component == 'controlnet' and index is None:
+ if component == "controlnet" and index is None:
raise HTTPException(status_code=400, detail="Index is required for ControlNet components")
-
+
# Validate file type
if not file.filename or not is_supported_video_format(file.filename):
raise HTTPException(status_code=400, detail="File must be a supported video format")
-
+
# Generate unique filename
file_id = str(uuid.uuid4())
file_extension = Path(file.filename).suffix
- filename = f"{component}_{index}_{file_id}{file_extension}" if index is not None else f"{component}_{file_id}{file_extension}"
-
+ filename = (
+ f"{component}_{index}_{file_id}{file_extension}"
+ if index is not None
+ else f"{component}_{file_id}{file_extension}"
+ )
+
# Save file
uploads_dir = Path("uploads/videos")
uploads_dir.mkdir(parents=True, exist_ok=True)
file_path = uploads_dir / filename
-
+
content = await file.read()
- with open(file_path, 'wb') as f:
+ with open(file_path, "wb") as f:
f.write(content)
-
+
# Validate video file
is_valid, error_msg = validate_video_file(str(file_path))
if not is_valid:
# Clean up file if validation fails
file_path.unlink(missing_ok=True)
raise HTTPException(status_code=400, detail=f"Invalid video file: {error_msg}")
-
+
# Get input source manager and set source
manager = _get_input_source_manager(app_instance)
input_source = InputSource(InputSourceType.UPLOADED_VIDEO, str(file_path))
manager.set_source(component, input_source, index)
-
+
logger.info(f"upload_component_video: Uploaded video for {component} (index: {index})")
-
- return create_success_response({
- 'message': f'Video uploaded for {component}',
- 'component': component,
- 'index': index,
- 'filename': filename,
- 'file_path': str(file_path)
- })
-
+
+ return create_success_response(
+ {
+ "message": f"Video uploaded for {component}",
+ "component": component,
+ "index": index,
+ "filename": filename,
+ "file_path": str(file_path),
+ }
+ )
+
except Exception as e:
return handle_api_error(e, "upload_component_video")
@router.get("/info/{component}")
async def get_component_source_info(
- component: str,
- index: Optional[int] = None,
- app_instance=Depends(get_app_instance)
+ component: str, index: Optional[int] = None, app_instance=Depends(get_app_instance)
):
"""
Get information about a component's input source.
-
+
Args:
component: Component name ('controlnet', 'ipadapter', 'base')
index: Index for ControlNet (required for controlnet component)
"""
try:
# Validate component
- if component not in ['controlnet', 'ipadapter', 'base']:
+ if component not in ["controlnet", "ipadapter", "base"]:
raise HTTPException(status_code=400, detail=f"Invalid component: {component}")
-
+
# Validate index for controlnet
- if component == 'controlnet' and index is None:
+ if component == "controlnet" and index is None:
raise HTTPException(status_code=400, detail="Index is required for ControlNet components")
-
+
# Get input source manager
manager = _get_input_source_manager(app_instance)
-
+
# Get source info
source_info = manager.get_source_info(component, index)
-
+
# Make sure source_data is JSON serializable (remove PIL Image objects)
- if source_info and 'source_data' in source_info:
- source_data = source_info['source_data']
+ if source_info and "source_data" in source_info:
+ source_data = source_info["source_data"]
# If source_data is a PIL Image, just indicate it's present rather than trying to serialize it
- if hasattr(source_data, '__class__') and source_data.__class__.__name__ == 'Image':
- source_info['source_data'] = 'image_present'
+ if hasattr(source_data, "__class__") and source_data.__class__.__name__ == "Image":
+ source_info["source_data"] = "image_present"
elif isinstance(source_data, str):
# Keep strings (file paths) as-is
pass
@@ -264,16 +272,13 @@ async def get_component_source_info(
# For other non-serializable objects, convert to string representation
try:
import json
+
json.dumps(source_data) # Test if it's serializable
except (TypeError, ValueError):
- source_info['source_data'] = str(type(source_data).__name__)
-
- return create_success_response({
- 'component': component,
- 'index': index,
- 'source_info': source_info
- })
-
+ source_info["source_data"] = str(type(source_data).__name__)
+
+ return create_success_response({"component": component, "index": index, "source_info": source_info})
+
except Exception as e:
return handle_api_error(e, "get_component_source_info")
@@ -284,64 +289,60 @@ async def list_all_source_info(app_instance=Depends(get_app_instance)):
try:
# Get input source manager
manager = _get_input_source_manager(app_instance)
-
+
# Collect all source information
all_sources = {
- 'base': manager.get_source_info('base'),
- 'ipadapter': manager.get_source_info('ipadapter'),
- 'controlnets': {}
+ "base": manager.get_source_info("base"),
+ "ipadapter": manager.get_source_info("ipadapter"),
+ "controlnets": {},
}
-
+
# Get all controlnet sources
- for index, source in manager.sources['controlnet'].items():
- all_sources['controlnets'][index] = manager.get_source_info('controlnet', index)
-
- return create_success_response({
- 'sources': all_sources
- })
-
+ for index, source in manager.sources["controlnet"].items():
+ all_sources["controlnets"][index] = manager.get_source_info("controlnet", index)
+
+ return create_success_response({"sources": all_sources})
+
except Exception as e:
return handle_api_error(e, "list_all_source_info")
@router.post("/reset/{component}")
-async def reset_component_source(
- component: str,
- index: Optional[int] = None,
- app_instance=Depends(get_app_instance)
-):
+async def reset_component_source(component: str, index: Optional[int] = None, app_instance=Depends(get_app_instance)):
"""
Reset a component's input source to webcam (default).
-
+
Args:
component: Component name ('controlnet', 'ipadapter', 'base')
index: Index for ControlNet (required for controlnet component)
"""
try:
# Validate component
- if component not in ['controlnet', 'ipadapter', 'base']:
+ if component not in ["controlnet", "ipadapter", "base"]:
raise HTTPException(status_code=400, detail=f"Invalid component: {component}")
-
+
# Validate index for controlnet
- if component == 'controlnet' and index is None:
+ if component == "controlnet" and index is None:
raise HTTPException(status_code=400, detail="Index is required for ControlNet components")
-
+
# Get input source manager
manager = _get_input_source_manager(app_instance)
-
+
# Create webcam input source
webcam_source = InputSource(InputSourceType.WEBCAM)
manager.set_source(component, webcam_source, index)
-
+
logger.info(f"reset_component_source: Reset {component} to webcam (index: {index})")
-
- return create_success_response({
- 'message': f'Input source reset to webcam for {component}',
- 'component': component,
- 'index': index,
- 'source_type': 'webcam'
- })
-
+
+ return create_success_response(
+ {
+ "message": f"Input source reset to webcam for {component}",
+ "component": component,
+ "index": index,
+ "source_type": "webcam",
+ }
+ )
+
except Exception as e:
return handle_api_error(e, "reset_component_source")
@@ -355,20 +356,22 @@ async def reset_all_input_sources(app_instance=Depends(get_app_instance)):
try:
# Get input source manager
manager = _get_input_source_manager(app_instance)
-
+
# Reset all sources to defaults
manager.reset_to_defaults()
-
+
logger.info("reset_all_input_sources: Reset all input sources to defaults")
-
- return create_success_response({
- 'message': 'All input sources reset to defaults',
- 'defaults': {
- 'base': 'webcam',
- 'ipadapter': 'uploaded_image (default image)',
- 'controlnet': 'fallback to base pipeline'
+
+ return create_success_response(
+ {
+ "message": "All input sources reset to defaults",
+ "defaults": {
+ "base": "webcam",
+ "ipadapter": "uploaded_image (default image)",
+ "controlnet": "fallback to base pipeline",
+ },
}
- })
-
+ )
+
except Exception as e:
return handle_api_error(e, "reset_all_input_sources")
diff --git a/demo/realtime-img2img/routes/ipadapter.py b/demo/realtime-img2img/routes/ipadapter.py
index 59a58b42a..edf7a4519 100644
--- a/demo/realtime-img2img/routes/ipadapter.py
+++ b/demo/realtime-img2img/routes/ipadapter.py
@@ -1,107 +1,122 @@
"""
IPAdapter-related endpoints for realtime-img2img
"""
-from fastapi import APIRouter, Request, HTTPException, Depends
-from fastapi.responses import JSONResponse, Response
+
import logging
import os
-from .common.api_utils import handle_api_request, create_success_response, handle_api_error, validate_pipeline, validate_feature_enabled, validate_config_mode
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import Response
+
+from .common.api_utils import (
+ create_success_response,
+ handle_api_error,
+ handle_api_request,
+ validate_config_mode,
+)
from .common.dependencies import get_app_instance
+
router = APIRouter(prefix="/api", tags=["ipadapter"])
# Legacy upload endpoint removed - use /api/input-sources/upload-image/ipadapter instead
# Legacy get uploaded image endpoint removed - use InputSourceManager instead
+
@router.get("/default-image")
async def get_default_image():
"""Get the default image (input.png)"""
try:
default_image_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "images", "inputs", "input.png")
-
+
if not os.path.exists(default_image_path):
raise HTTPException(status_code=404, detail="Default image not found")
-
+
# Read and return the default image file
with open(default_image_path, "rb") as image_file:
image_content = image_file.read()
-
- return Response(content=image_content, media_type="image/png", headers={"Cache-Control": "public, max-age=3600"})
-
+
+ return Response(
+ content=image_content, media_type="image/png", headers={"Cache-Control": "public, max-age=3600"}
+ )
+
except Exception as e:
raise handle_api_error(e, "get_default_image")
+
@router.post("/ipadapter/update-scale")
async def update_ipadapter_scale(request: Request, app_instance=Depends(get_app_instance)):
"""Update IPAdapter scale/strength in real-time"""
try:
data = await handle_api_request(request, "update_ipadapter_scale", ["scale"])
scale = data.get("scale")
-
+
# Validate AppState has IPAdapter configuration (pipeline not required)
if not app_instance.app_state.ipadapter_info["enabled"]:
- raise HTTPException(status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first.")
-
+ raise HTTPException(
+ status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first."
+ )
+
# Update AppState as single source of truth (works before pipeline creation)
app_instance.app_state.update_parameter("ipadapter_scale", float(scale))
-
+
# Sync to pipeline if active
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
+
return create_success_response(f"Updated IPAdapter scale to {scale}")
-
+
except Exception as e:
raise handle_api_error(e, "update_ipadapter_scale")
+
@router.post("/ipadapter/update-weight-type")
async def update_ipadapter_weight_type(request: Request, app_instance=Depends(get_app_instance)):
"""Update IPAdapter weight type in real-time"""
try:
data = await handle_api_request(request, "update_ipadapter_weight_type", ["weight_type"])
weight_type = data.get("weight_type")
-
+
# Validate AppState has IPAdapter configuration (pipeline not required)
if not app_instance.app_state.ipadapter_info["enabled"]:
- raise HTTPException(status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first.")
-
+ raise HTTPException(
+ status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first."
+ )
+
# Update AppState as single source of truth (works before pipeline creation)
app_instance.app_state.ipadapter_info["weight_type"] = weight_type
-
+
# Sync to pipeline if active
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
+
return create_success_response(f"Updated IPAdapter weight type to {weight_type}")
-
+
except Exception as e:
raise handle_api_error(e, "update_ipadapter_weight_type")
+
@router.post("/ipadapter/update-enabled")
async def update_ipadapter_enabled(request: Request, app_instance=Depends(get_app_instance)):
"""Enable or disable IPAdapter in real-time"""
try:
data = await handle_api_request(request, "update_ipadapter_enabled", ["enabled"])
enabled = data.get("enabled")
-
+
# Update AppState as single source of truth (works before pipeline creation)
app_instance.app_state.ipadapter_info["enabled"] = bool(enabled)
logging.info(f"update_ipadapter_enabled: Updated AppState ipadapter enabled to {enabled}")
-
+
# Sync to pipeline if active
if app_instance.pipeline:
validate_config_mode(app_instance.pipeline, "ipadapters")
-
+
# Update IPAdapter enabled state in the pipeline
- app_instance.pipeline.stream.update_stream_params(
- ipadapter_config={'enabled': bool(enabled)}
- )
- logging.info(f"update_ipadapter_enabled: Synced to active pipeline")
-
+ app_instance.pipeline.stream.update_stream_params(ipadapter_config={"enabled": bool(enabled)})
+ logging.info("update_ipadapter_enabled: Synced to active pipeline")
+
return create_success_response(f"IPAdapter {'enabled' if enabled else 'disabled'} successfully")
-
+
except Exception as e:
raise handle_api_error(e, "update_ipadapter_enabled")
-
diff --git a/demo/realtime-img2img/routes/parameters.py b/demo/realtime-img2img/routes/parameters.py
index 77f875592..27f4b0d06 100644
--- a/demo/realtime-img2img/routes/parameters.py
+++ b/demo/realtime-img2img/routes/parameters.py
@@ -1,15 +1,19 @@
"""
Parameter update endpoints for realtime-img2img
"""
-from fastapi import APIRouter, Request, HTTPException, Depends
-from fastapi.responses import JSONResponse
+
import logging
-from .common.api_utils import handle_api_request, create_success_response, handle_api_error
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import JSONResponse
+
+from .common.api_utils import create_success_response, handle_api_error, handle_api_request
from .common.dependencies import get_app_instance
+
router = APIRouter(prefix="/api", tags=["parameters"])
+
@router.post("/params")
async def update_params(request: Request, app_instance=Depends(get_app_instance)):
"""Update multiple streaming parameters in a single unified call"""
@@ -17,7 +21,7 @@ async def update_params(request: Request, app_instance=Depends(get_app_instance)
data = await request.json()
logging.info(f"update_params: Received data: {data}")
logging.info(f"update_params: Pipeline exists: {app_instance.pipeline is not None}")
-
+
# Allow updating resolution even when pipeline is not initialized.
# We save the new values so they take effect on the next stream start.
if "resolution" in data:
@@ -25,48 +29,44 @@ async def update_params(request: Request, app_instance=Depends(get_app_instance)
logging.info("update_params: No pipeline exists, updating resolution directly")
else:
logging.info("update_params: Pipeline exists, resolution update may be handled differently")
-
+
if "resolution" in data:
resolution = data["resolution"]
if isinstance(resolution, dict) and "width" in resolution and "height" in resolution:
width, height = int(resolution["width"]), int(resolution["height"])
-
+
# Call the proper pipeline recreation method
app_instance._update_resolution(width, height)
-
+
message = f"Resolution updated to {width}x{height} and pipeline recreated successfully"
logging.info(f"update_params: {message}")
- return JSONResponse({
- "status": "success",
- "message": message
- })
+ return JSONResponse({"status": "success", "message": message})
elif isinstance(resolution, str):
# Handle string format like "512x768 (2:3)" or "512x768"
- resolution_part = resolution.split(' ')[0]
+ resolution_part = resolution.split(" ")[0]
logging.info(f"update_params: Parsing resolution string: {resolution} -> {resolution_part}")
try:
- width, height = map(int, resolution_part.split('x'))
+ width, height = map(int, resolution_part.split("x"))
logging.info(f"update_params: Parsed width={width}, height={height}")
-
+
# Call the proper pipeline recreation method
app_instance._update_resolution(width, height)
-
+
message = f"Resolution updated to {width}x{height} and pipeline recreated successfully"
logging.info(f"update_params: {message}")
- return JSONResponse({
- "status": "success",
- "message": message
- })
+ return JSONResponse({"status": "success", "message": message})
except ValueError:
raise HTTPException(status_code=400, detail="Invalid resolution format")
else:
- raise HTTPException(status_code=400, detail="Resolution must be {width: int, height: int} or 'widthxheight' string")
+ raise HTTPException(
+ status_code=400, detail="Resolution must be {width: int, height: int} or 'widthxheight' string"
+ )
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
# Update parameters that exist in the data
params = {}
-
+
if "guidance_scale" in data:
params["guidance_scale"] = float(data["guidance_scale"])
if "delta" in data:
@@ -86,65 +86,60 @@ async def update_params(request: Request, app_instance=Depends(get_app_instance)
# Update AppState as single source of truth (works before pipeline creation)
for param_name, param_value in params.items():
app_instance.app_state.update_parameter(param_name, param_value)
-
+
# Sync to pipeline if active (for real-time updates)
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
- return JSONResponse({
- "status": "success",
- "message": f"Updated parameters: {list(params.keys())}",
- "updated_params": params
- })
+
+ return JSONResponse(
+ {
+ "status": "success",
+ "message": f"Updated parameters: {list(params.keys())}",
+ "updated_params": params,
+ }
+ )
else:
- return JSONResponse({
- "status": "success",
- "message": "No valid parameters provided to update"
- })
-
+ return JSONResponse({"status": "success", "message": "No valid parameters provided to update"})
+
except Exception as e:
logging.exception(f"update_params: Failed to update parameters: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update parameters: {str(e)}")
+
async def _update_single_parameter(
- request: Request,
- app_instance,
- parameter_name: str,
- value_converter: callable,
- operation_name: str
+ request: Request, app_instance, parameter_name: str, value_converter: callable, operation_name: str
):
"""Generic function to update a single parameter"""
try:
data = await handle_api_request(request, operation_name, [parameter_name])
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
value = value_converter(data[parameter_name])
-
+
# Update AppState as single source of truth (works before pipeline creation)
app_instance.app_state.update_parameter(parameter_name, value)
-
+
# Sync to pipeline if active (for real-time updates)
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
+
return create_success_response(f"Updated {parameter_name} to {value}", **{parameter_name: value})
-
+
except Exception as e:
raise handle_api_error(e, operation_name)
+
@router.post("/update-guidance-scale")
async def update_guidance_scale(request: Request, app_instance=Depends(get_app_instance)):
"""Update guidance scale parameter"""
- return await _update_single_parameter(
- request, app_instance, "guidance_scale", float, "update_guidance_scale"
- )
+ return await _update_single_parameter(request, app_instance, "guidance_scale", float, "update_guidance_scale")
+
@router.post("/update-delta")
async def update_delta(request: Request, app_instance=Depends(get_app_instance)):
"""Update delta parameter"""
- return await _update_single_parameter(
- request, app_instance, "delta", float, "update_delta"
- )
+ return await _update_single_parameter(request, app_instance, "delta", float, "update_delta")
+
@router.post("/update-num-inference-steps")
async def update_num_inference_steps(request: Request, app_instance=Depends(get_app_instance)):
@@ -153,32 +148,32 @@ async def update_num_inference_steps(request: Request, app_instance=Depends(get_
request, app_instance, "num_inference_steps", int, "update_num_inference_steps"
)
+
@router.post("/update-seed")
async def update_seed(request: Request, app_instance=Depends(get_app_instance)):
"""Update seed parameter"""
- return await _update_single_parameter(
- request, app_instance, "seed", int, "update_seed"
- )
+ return await _update_single_parameter(request, app_instance, "seed", int, "update_seed")
+
@router.post("/blending")
async def update_blending(request: Request, app_instance=Depends(get_app_instance)):
"""Update prompt and/or seed blending configuration in real-time"""
try:
data = await request.json()
-
+
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
params = {}
updated_types = []
-
+
# Handle prompt blending
if "prompt_list" in data:
prompt_list = data["prompt_list"]
interpolation_method = data.get("prompt_interpolation_method", "slerp")
-
+
if not isinstance(prompt_list, list):
raise HTTPException(status_code=400, detail="prompt_list must be a list")
-
+
# Validate and convert format
prompt_tuples = []
for item in prompt_list:
@@ -187,8 +182,11 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc
elif isinstance(item, dict) and "prompt" in item and "weight" in item:
prompt_tuples.append((str(item["prompt"]), float(item["weight"])))
else:
- raise HTTPException(status_code=400, detail="Each prompt item must be [prompt, weight] or {prompt: str, weight: float}")
-
+ raise HTTPException(
+ status_code=400,
+ detail="Each prompt item must be [prompt, weight] or {prompt: str, weight: float}",
+ )
+
params["prompt_list"] = prompt_tuples
params["prompt_interpolation_method"] = interpolation_method
updated_types.append("prompt")
@@ -197,10 +195,10 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc
if "seed_list" in data:
seed_list = data["seed_list"]
interpolation_method = data.get("seed_interpolation_method", "linear")
-
+
if not isinstance(seed_list, list):
raise HTTPException(status_code=400, detail="seed_list must be a list")
-
+
# Validate and convert format
seed_tuples = []
for item in seed_list:
@@ -209,8 +207,10 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc
elif isinstance(item, dict) and "seed" in item and "weight" in item:
seed_tuples.append((int(item["seed"]), float(item["weight"])))
else:
- raise HTTPException(status_code=400, detail="Each seed item must be [seed, weight] or {seed: int, weight: float}")
-
+ raise HTTPException(
+ status_code=400, detail="Each seed item must be [seed, weight] or {seed: int, weight: float}"
+ )
+
params["seed_list"] = seed_tuples
params["seed_interpolation_method"] = interpolation_method
updated_types.append("seed")
@@ -223,63 +223,64 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc
app_instance.app_state.prompt_blending = params["prompt_list"]
if "seed_list" in params:
app_instance.app_state.seed_blending = params["seed_list"]
-
+
# Sync to pipeline if active
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
+
return create_success_response(f"Updated {' and '.join(updated_types)} blending", updated_types=updated_types)
-
+
except Exception as e:
raise handle_api_error(e, "update_blending")
+
@router.post("/blending/update-prompt-weight")
async def update_prompt_weight(request: Request, app_instance=Depends(get_app_instance)):
"""Update a specific prompt weight in the current blending configuration"""
try:
data = await request.json()
- index = data.get('index')
- weight = data.get('weight')
-
+ index = data.get("index")
+ weight = data.get("weight")
+
if index is None or weight is None:
raise HTTPException(status_code=400, detail="Missing index or weight parameter")
-
+
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
# Update AppState as single source of truth
app_instance.app_state.update_parameter(f"prompt_weight_{index}", float(weight))
-
+
# Sync to pipeline if active
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
+
return create_success_response(f"Updated prompt weight {index} to {weight}")
-
+
except Exception as e:
raise handle_api_error(e, "update_prompt_weight")
-@router.post("/blending/update-seed-weight")
+
+@router.post("/blending/update-seed-weight")
async def update_seed_weight(request: Request, app_instance=Depends(get_app_instance)):
"""Update a specific seed weight in the current blending configuration"""
try:
data = await request.json()
- index = data.get('index')
- weight = data.get('weight')
-
+ index = data.get("index")
+ weight = data.get("weight")
+
if index is None or weight is None:
raise HTTPException(status_code=400, detail="Missing index or weight parameter")
-
+
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
# Update AppState as single source of truth
app_instance.app_state.update_parameter(f"seed_weight_{index}", float(weight))
-
+
# Sync to pipeline if active
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"):
app_instance._sync_appstate_to_pipeline()
-
+
return create_success_response(f"Updated seed weight {index} to {weight}")
-
+
except Exception as e:
raise handle_api_error(e, "update_seed_weight")
-
diff --git a/demo/realtime-img2img/routes/pipeline_hooks.py b/demo/realtime-img2img/routes/pipeline_hooks.py
index 2fd8f226f..0c7184ef6 100644
--- a/demo/realtime-img2img/routes/pipeline_hooks.py
+++ b/demo/realtime-img2img/routes/pipeline_hooks.py
@@ -1,14 +1,16 @@
-
"""
Pipeline hooks endpoints for realtime-img2img
"""
-from fastapi import APIRouter, Request, HTTPException, Depends
-from fastapi.responses import JSONResponse
+
import logging
-from .common.api_utils import handle_api_request, create_success_response, handle_api_error
+from fastapi import APIRouter, Depends, HTTPException, Request
+from fastapi.responses import JSONResponse
+
+from .common.api_utils import create_success_response, handle_api_error
from .common.dependencies import get_app_instance
+
router = APIRouter(prefix="/api", tags=["pipeline-hooks"])
@@ -22,12 +24,13 @@ async def get_pipeline_hooks_info_config(app_instance=Depends(get_app_instance))
"image_preprocessing": app_instance.app_state.pipeline_hooks["image_preprocessing"],
"image_postprocessing": app_instance.app_state.pipeline_hooks["image_postprocessing"],
"latent_preprocessing": app_instance.app_state.pipeline_hooks["latent_preprocessing"],
- "latent_postprocessing": app_instance.app_state.pipeline_hooks["latent_postprocessing"]
+ "latent_postprocessing": app_instance.app_state.pipeline_hooks["latent_postprocessing"],
}
return JSONResponse(hooks_info)
except Exception as e:
raise handle_api_error(e, "get_pipeline_hooks_info_config")
+
# Individual hook type endpoints that frontend expects
@router.get("/pipeline-hooks/image_preprocessing/info-config")
async def get_image_preprocessing_info_config(app_instance=Depends(get_app_instance)):
@@ -35,83 +38,95 @@ async def get_image_preprocessing_info_config(app_instance=Depends(get_app_insta
try:
hook_info = app_instance.app_state.pipeline_hooks["image_preprocessing"]
return JSONResponse({"image_preprocessing": hook_info})
- except Exception as e:
+ except Exception:
return JSONResponse({"image_preprocessing": None})
+
@router.get("/pipeline-hooks/image_postprocessing/info-config")
async def get_image_postprocessing_info_config(app_instance=Depends(get_app_instance)):
"""Get image postprocessing hook configuration info - SINGLE SOURCE OF TRUTH"""
try:
hook_info = app_instance.app_state.pipeline_hooks["image_postprocessing"]
return JSONResponse({"image_postprocessing": hook_info})
- except Exception as e:
+ except Exception:
return JSONResponse({"image_postprocessing": None})
+
@router.get("/pipeline-hooks/latent_preprocessing/info-config")
async def get_latent_preprocessing_info_config(app_instance=Depends(get_app_instance)):
"""Get latent preprocessing hook configuration info - SINGLE SOURCE OF TRUTH"""
try:
hook_info = app_instance.app_state.pipeline_hooks["latent_preprocessing"]
return JSONResponse({"latent_preprocessing": hook_info})
- except Exception as e:
+ except Exception:
return JSONResponse({"latent_preprocessing": None})
+
@router.get("/pipeline-hooks/latent_postprocessing/info-config")
async def get_latent_postprocessing_info_config(app_instance=Depends(get_app_instance)):
"""Get latent postprocessing hook configuration info - SINGLE SOURCE OF TRUTH"""
try:
hook_info = app_instance.app_state.pipeline_hooks["latent_postprocessing"]
return JSONResponse({"latent_postprocessing": hook_info})
- except Exception as e:
+ except Exception:
return JSONResponse({"latent_postprocessing": None})
+
@router.get("/pipeline-hooks/{hook_type}/info")
async def get_hook_processors_info(hook_type: str, app_instance=Depends(get_app_instance)):
"""Get available processors for a specific hook type"""
try:
- if hook_type not in ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"]:
+ if hook_type not in [
+ "image_preprocessing",
+ "image_postprocessing",
+ "latent_preprocessing",
+ "latent_postprocessing",
+ ]:
raise HTTPException(status_code=400, detail=f"Invalid hook type: {hook_type}")
-
+
# Use the same processor registry as ControlNet
- from streamdiffusion.preprocessing.processors import list_preprocessors, get_preprocessor_class
-
+ from streamdiffusion.preprocessing.processors import get_preprocessor_class, list_preprocessors
+
available_processors = list_preprocessors()
processors_info = {}
-
+
for processor_name in available_processors:
try:
processor_class = get_preprocessor_class(processor_name)
- if hasattr(processor_class, 'get_preprocessor_metadata'):
+ if hasattr(processor_class, "get_preprocessor_metadata"):
metadata = processor_class.get_preprocessor_metadata()
processors_info[processor_name] = {
"name": metadata.get("name", processor_name),
"description": metadata.get("description", ""),
- "parameters": metadata.get("parameters", {})
+ "parameters": metadata.get("parameters", {}),
}
else:
processors_info[processor_name] = {
"name": processor_name,
"description": f"{processor_name} processor",
- "parameters": {}
+ "parameters": {},
}
except Exception as e:
logging.warning(f"get_hook_processors_info: Failed to load metadata for {processor_name}: {e}")
processors_info[processor_name] = {
"name": processor_name,
"description": f"{processor_name} processor",
- "parameters": {}
+ "parameters": {},
}
-
- return JSONResponse({
- "status": "success",
- "hook_type": hook_type,
- "available": list(processors_info.keys()),
- "preprocessors": processors_info
- })
-
+
+ return JSONResponse(
+ {
+ "status": "success",
+ "hook_type": hook_type,
+ "available": list(processors_info.keys()),
+ "preprocessors": processors_info,
+ }
+ )
+
except Exception as e:
raise handle_api_error(e, "get_hook_processors_info")
+
@router.post("/pipeline-hooks/{hook_type}/add")
async def add_hook_processor(hook_type: str, request: Request, app_instance=Depends(get_app_instance)):
"""Add a new processor to a hook"""
@@ -119,27 +134,28 @@ async def add_hook_processor(hook_type: str, request: Request, app_instance=Depe
data = await request.json()
processor_type = data.get("processor_type")
processor_params = data.get("processor_params", {})
-
+
if not processor_type:
raise HTTPException(status_code=400, detail="Missing processor_type parameter")
-
+
# No pipeline validation needed - AppState updates work before pipeline creation
-
- if hook_type not in ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"]:
+
+ if hook_type not in [
+ "image_preprocessing",
+ "image_postprocessing",
+ "latent_preprocessing",
+ "latent_postprocessing",
+ ]:
raise HTTPException(status_code=400, detail=f"Invalid hook type: {hook_type}")
-
+
logging.debug(f"add_hook_processor: Adding {processor_type} to {hook_type}")
-
+
# Create processor config
- new_processor = {
- "type": processor_type,
- "params": processor_params,
- "enabled": True
- }
-
+ new_processor = {"type": processor_type, "params": processor_params, "enabled": True}
+
# Add to AppState - SINGLE SOURCE OF TRUTH
app_instance.app_state.add_hook_processor(hook_type, new_processor)
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -148,7 +164,7 @@ async def add_hook_processor(hook_type: str, request: Request, app_instance=Depe
config_entry = {
"type": processor["type"],
"params": processor["params"],
- "enabled": processor["enabled"]
+ "enabled": processor["enabled"],
}
hook_config.append(config_entry)
update_kwargs = {f"{hook_type}_config": hook_config}
@@ -157,25 +173,26 @@ async def add_hook_processor(hook_type: str, request: Request, app_instance=Depe
logging.exception(f"add_hook_processor: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
+
logging.info(f"add_hook_processor: Successfully added {processor_type} to {hook_type}")
-
+
return create_success_response(f"Added {processor_type} processor to {hook_type}")
-
+
except Exception as e:
raise handle_api_error(e, "add_hook_processor")
+
@router.delete("/pipeline-hooks/{hook_type}/remove/{processor_index}")
async def remove_hook_processor(hook_type: str, processor_index: int, app_instance=Depends(get_app_instance)):
"""Remove a processor from a hook"""
try:
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
logging.debug(f"remove_hook_processor: Removing processor {processor_index} from {hook_type}")
-
+
# Remove from AppState - SINGLE SOURCE OF TRUTH
app_instance.app_state.remove_hook_processor(hook_type, processor_index)
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -184,7 +201,7 @@ async def remove_hook_processor(hook_type: str, processor_index: int, app_instan
config_entry = {
"type": processor["type"],
"params": processor["params"],
- "enabled": processor["enabled"]
+ "enabled": processor["enabled"],
}
hook_config.append(config_entry)
update_kwargs = {f"{hook_type}_config": hook_config}
@@ -193,14 +210,15 @@ async def remove_hook_processor(hook_type: str, processor_index: int, app_instan
logging.exception(f"remove_hook_processor: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
+
logging.info(f"remove_hook_processor: Successfully removed processor {processor_index} from {hook_type}")
-
+
return create_success_response(f"Removed processor {processor_index} from {hook_type}")
-
+
except Exception as e:
raise handle_api_error(e, "remove_hook_processor")
+
@router.post("/pipeline-hooks/{hook_type}/toggle")
async def toggle_hook_processor(hook_type: str, request: Request, app_instance=Depends(get_app_instance)):
"""Toggle a processor enabled/disabled"""
@@ -208,17 +226,19 @@ async def toggle_hook_processor(hook_type: str, request: Request, app_instance=D
data = await request.json()
processor_index = data.get("processor_index")
enabled = data.get("enabled")
-
+
if processor_index is None or enabled is None:
raise HTTPException(status_code=400, detail="Missing processor_index or enabled parameter")
-
+
# No pipeline validation needed - AppState updates work before pipeline creation
-
- logging.debug(f"toggle_hook_processor: Toggling processor {processor_index} in {hook_type} to {'enabled' if enabled else 'disabled'}")
-
+
+ logging.debug(
+ f"toggle_hook_processor: Toggling processor {processor_index} in {hook_type} to {'enabled' if enabled else 'disabled'}"
+ )
+
# Update AppState - SINGLE SOURCE OF TRUTH
app_instance.app_state.update_hook_processor(hook_type, processor_index, {"enabled": bool(enabled)})
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -227,7 +247,7 @@ async def toggle_hook_processor(hook_type: str, request: Request, app_instance=D
config_entry = {
"type": processor["type"],
"params": processor["params"],
- "enabled": processor["enabled"]
+ "enabled": processor["enabled"],
}
hook_config.append(config_entry)
update_kwargs = {f"{hook_type}_config": hook_config}
@@ -236,14 +256,17 @@ async def toggle_hook_processor(hook_type: str, request: Request, app_instance=D
logging.exception(f"toggle_hook_processor: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
+
logging.info(f"toggle_hook_processor: Successfully toggled processor {processor_index} in {hook_type}")
-
- return create_success_response(f"Processor {processor_index} in {hook_type} {'enabled' if enabled else 'disabled'}")
-
+
+ return create_success_response(
+ f"Processor {processor_index} in {hook_type} {'enabled' if enabled else 'disabled'}"
+ )
+
except Exception as e:
raise handle_api_error(e, "toggle_hook_processor")
+
@router.post("/pipeline-hooks/{hook_type}/switch")
async def switch_hook_processor(hook_type: str, request: Request, app_instance=Depends(get_app_instance)):
"""Switch a processor to a different type"""
@@ -252,45 +275,53 @@ async def switch_hook_processor(hook_type: str, request: Request, app_instance=D
processor_index = data.get("processor_index")
# Support both parameter naming conventions for compatibility
new_processor_type = data.get("processor_type") or data.get("processor")
-
+
if processor_index is None or not new_processor_type:
- raise HTTPException(status_code=400, detail="Missing processor_index or processor_type/processor parameter")
-
+ raise HTTPException(
+ status_code=400, detail="Missing processor_index or processor_type/processor parameter"
+ )
+
# Handle config-only mode when no pipeline is active
if not app_instance.pipeline:
if not app_instance.app_state.uploaded_config:
raise HTTPException(status_code=400, detail="No pipeline active and no uploaded config available")
-
- logging.info(f"switch_hook_processor: Updating config for {hook_type} processor {processor_index} to {new_processor_type}")
-
+
+ logging.info(
+ f"switch_hook_processor: Updating config for {hook_type} processor {processor_index} to {new_processor_type}"
+ )
+
# Update the uploaded config directly
hook_config = app_instance.app_state.uploaded_config.get(hook_type, {"enabled": False, "processors": []})
if processor_index >= len(hook_config.get("processors", [])):
- raise HTTPException(status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}")
-
+ raise HTTPException(
+ status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}"
+ )
+
# Update processor type in config
hook_config["processors"][processor_index]["type"] = new_processor_type
hook_config["processors"][processor_index]["params"] = {}
app_instance.app_state.uploaded_config[hook_type] = hook_config
-
+
else:
# No pipeline validation needed - AppState updates work before pipeline creation
-
- logging.debug(f"switch_hook_processor: Switching processor {processor_index} in {hook_type} to {new_processor_type}")
-
+
+ logging.debug(
+ f"switch_hook_processor: Switching processor {processor_index} in {hook_type} to {new_processor_type}"
+ )
+
# Update AppState - SINGLE SOURCE OF TRUTH
processors = app_instance.app_state.pipeline_hooks[hook_type]["processors"]
-
+
if processor_index >= len(processors):
- raise HTTPException(status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}")
-
+ raise HTTPException(
+ status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}"
+ )
+
# Update the processor type and reset params in AppState
- app_instance.app_state.update_hook_processor(hook_type, processor_index, {
- "type": new_processor_type,
- "name": new_processor_type,
- "params": {}
- })
-
+ app_instance.app_state.update_hook_processor(
+ hook_type, processor_index, {"type": new_processor_type, "name": new_processor_type, "params": {}}
+ )
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -299,7 +330,7 @@ async def switch_hook_processor(hook_type: str, request: Request, app_instance=D
config_entry = {
"type": processor["type"],
"params": processor["params"],
- "enabled": processor["enabled"]
+ "enabled": processor["enabled"],
}
hook_config.append(config_entry)
update_kwargs = {f"{hook_type}_config": hook_config}
@@ -308,14 +339,17 @@ async def switch_hook_processor(hook_type: str, request: Request, app_instance=D
logging.exception(f"switch_hook_processor: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
- logging.info(f"switch_hook_processor: Successfully switched processor {processor_index} in {hook_type} to {new_processor_type}")
-
+
+ logging.info(
+ f"switch_hook_processor: Successfully switched processor {processor_index} in {hook_type} to {new_processor_type}"
+ )
+
return create_success_response(f"Switched processor {processor_index} in {hook_type} to {new_processor_type}")
-
+
except Exception as e:
raise handle_api_error(e, "switch_hook_processor")
+
@router.post("/pipeline-hooks/{hook_type}/update-params")
async def update_hook_processor_params(hook_type: str, request: Request, app_instance=Depends(get_app_instance)):
"""Update parameters for a specific processor"""
@@ -323,48 +357,58 @@ async def update_hook_processor_params(hook_type: str, request: Request, app_ins
logging.info(f"update_hook_processor_params: ===== STARTING {hook_type} REQUEST =====")
data = await request.json()
logging.info(f"update_hook_processor_params: Received data: {data}")
-
+
processor_index = data.get("processor_index")
processor_params = data.get("processor_params", {})
- logging.info(f"update_hook_processor_params: processor_index={processor_index}, processor_params={processor_params}")
-
+ logging.info(
+ f"update_hook_processor_params: processor_index={processor_index}, processor_params={processor_params}"
+ )
+
if processor_index is None:
- logging.error(f"update_hook_processor_params: Missing processor_index parameter")
+ logging.error("update_hook_processor_params: Missing processor_index parameter")
raise HTTPException(status_code=400, detail="Missing processor_index parameter")
-
+
# No pipeline validation needed - AppState updates work before pipeline creation
-
+
logging.debug(f"update_hook_processor_params: Updating params for processor {processor_index} in {hook_type}")
-
+
# Check if processors exist in AppState
processors = app_instance.app_state.pipeline_hooks[hook_type]["processors"]
if not processors:
logging.error(f"update_hook_processor_params: Hook type {hook_type} not found or empty")
- raise HTTPException(status_code=400, detail=f"No processors configured for {hook_type}. Add a processor first using the 'Add {hook_type.replace('_', ' ').title()} Processor' button.")
-
+ raise HTTPException(
+ status_code=400,
+ detail=f"No processors configured for {hook_type}. Add a processor first using the 'Add {hook_type.replace('_', ' ').title()} Processor' button.",
+ )
+
if processor_index >= len(processors):
- logging.error(f"update_hook_processor_params: Processor index {processor_index} out of range for {hook_type} (max: {len(processors)-1})")
- raise HTTPException(status_code=400, detail=f"Processor index {processor_index} not found. Only {len(processors)} processors are configured for {hook_type}.")
-
+ logging.error(
+ f"update_hook_processor_params: Processor index {processor_index} out of range for {hook_type} (max: {len(processors) - 1})"
+ )
+ raise HTTPException(
+ status_code=400,
+ detail=f"Processor index {processor_index} not found. Only {len(processors)} processors are configured for {hook_type}.",
+ )
+
# Update the processor parameters in AppState - SINGLE SOURCE OF TRUTH
logging.info(f"update_hook_processor_params: Current processor config: {processors[processor_index]}")
-
+
# Handle 'enabled' field separately as it's a top-level processor field, not a parameter
updates = {}
- if 'enabled' in processor_params:
- enabled_value = processor_params.pop('enabled') # Remove from params dict
- updates['enabled'] = bool(enabled_value)
+ if "enabled" in processor_params:
+ enabled_value = processor_params.pop("enabled") # Remove from params dict
+ updates["enabled"] = bool(enabled_value)
logging.info(f"update_hook_processor_params: Updated enabled field to: {enabled_value}")
-
+
# Update remaining parameters in the params field
if processor_params: # Only update if there are remaining params
- current_params = processors[processor_index].get('params', {})
+ current_params = processors[processor_index].get("params", {})
current_params.update(processor_params)
- updates['params'] = current_params
-
+ updates["params"] = current_params
+
# Apply updates to AppState
app_instance.app_state.update_hook_processor(hook_type, processor_index, updates)
-
+
# Update pipeline if active
if app_instance.pipeline:
try:
@@ -373,29 +417,36 @@ async def update_hook_processor_params(hook_type: str, request: Request, app_ins
config_entry = {
"type": processor["type"],
"params": processor["params"],
- "enabled": processor["enabled"]
+ "enabled": processor["enabled"],
}
hook_config.append(config_entry)
update_kwargs = {f"{hook_type}_config": hook_config}
logging.info(f"update_hook_processor_params: Calling update_stream_params with: {update_kwargs}")
app_instance.pipeline.update_stream_params(**update_kwargs)
- logging.info(f"update_hook_processor_params: update_stream_params completed successfully")
+ logging.info("update_hook_processor_params: update_stream_params completed successfully")
except Exception as e:
logging.exception(f"update_hook_processor_params: Failed to update pipeline: {e}")
# Mark for reload as fallback
app_instance.app_state.config_needs_reload = True
-
- logging.info(f"update_hook_processor_params: Successfully updated params for processor {processor_index} in {hook_type}")
-
- return create_success_response(f"Updated parameters for processor {processor_index} in {hook_type}", updated_params=processor_params)
-
+
+ logging.info(
+ f"update_hook_processor_params: Successfully updated params for processor {processor_index} in {hook_type}"
+ )
+
+ return create_success_response(
+ f"Updated parameters for processor {processor_index} in {hook_type}", updated_params=processor_params
+ )
+
except Exception as e:
logging.exception(f"update_hook_processor_params: Exception occurred: {str(e)}")
logging.error(f"update_hook_processor_params: Exception type: {type(e).__name__}")
raise handle_api_error(e, "update_hook_processor_params")
+
@router.get("/pipeline-hooks/{hook_type}/current-params/{processor_index}")
-async def get_current_hook_processor_params(hook_type: str, processor_index: int, app_instance=Depends(get_app_instance)):
+async def get_current_hook_processor_params(
+ hook_type: str, processor_index: int, app_instance=Depends(get_app_instance)
+):
"""Get current parameters for a specific processor"""
try:
# First try to get from uploaded config if no pipeline
@@ -404,44 +455,50 @@ async def get_current_hook_processor_params(hook_type: str, processor_index: int
processors = hook_config.get("processors", [])
if processor_index < len(processors):
processor = processors[processor_index]
- return JSONResponse({
+ return JSONResponse(
+ {
+ "status": "success",
+ "hook_type": hook_type,
+ "processor_index": processor_index,
+ "processor_type": processor.get("type", "unknown"),
+ "parameters": processor.get("params", {}),
+ "enabled": processor.get("enabled", True),
+ "note": "From uploaded config",
+ }
+ )
+
+ # Return empty if no config available
+ if not app_instance.pipeline:
+ return JSONResponse(
+ {
"status": "success",
"hook_type": hook_type,
"processor_index": processor_index,
- "processor_type": processor.get('type', 'unknown'),
- "parameters": processor.get('params', {}),
- "enabled": processor.get('enabled', True),
- "note": "From uploaded config"
- })
-
- # Return empty if no config available
- if not app_instance.pipeline:
- return JSONResponse({
- "status": "success",
- "hook_type": hook_type,
- "processor_index": processor_index,
- "processor_type": "unknown",
- "parameters": {},
- "enabled": False,
- "note": "Pipeline not initialized - no config available"
- })
-
+ "processor_type": "unknown",
+ "parameters": {},
+ "enabled": False,
+ "note": "Pipeline not initialized - no config available",
+ }
+ )
+
# Use AppState - SINGLE SOURCE OF TRUTH
processors = app_instance.app_state.pipeline_hooks[hook_type]["processors"]
-
+
if processor_index >= len(processors):
raise HTTPException(status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}")
-
+
processor = processors[processor_index]
-
- return JSONResponse({
- "status": "success",
- "hook_type": hook_type,
- "processor_index": processor_index,
- "processor_type": processor.get('type', 'unknown'),
- "parameters": processor.get('params', {}),
- "enabled": processor.get('enabled', True)
- })
-
+
+ return JSONResponse(
+ {
+ "status": "success",
+ "hook_type": hook_type,
+ "processor_index": processor_index,
+ "processor_type": processor.get("type", "unknown"),
+ "parameters": processor.get("params", {}),
+ "enabled": processor.get("enabled", True),
+ }
+ )
+
except Exception as e:
- raise handle_api_error(e, "get_current_hook_processor_params")
\ No newline at end of file
+ raise handle_api_error(e, "get_current_hook_processor_params")
diff --git a/demo/realtime-img2img/routes/websocket.py b/demo/realtime-img2img/routes/websocket.py
index 48f37a2d7..2a2c87a0a 100644
--- a/demo/realtime-img2img/routes/websocket.py
+++ b/demo/realtime-img2img/routes/websocket.py
@@ -1,33 +1,40 @@
"""
WebSocket endpoints for realtime-img2img
"""
-from fastapi import APIRouter, WebSocket, HTTPException, Depends
+
import logging
-import uuid
import time
+import uuid
from types import SimpleNamespace
-from util import bytes_to_pt
from connection_manager import ServerFullException
-from .common.dependencies import get_app_instance, get_pipeline_class
+from fastapi import APIRouter, Depends, HTTPException, WebSocket
from input_sources import InputSourceManager
+from util import bytes_to_pt
+
+from .common.dependencies import get_app_instance, get_pipeline_class
+
router = APIRouter(prefix="/api", tags=["websocket"])
def _get_input_source_manager(app_instance) -> InputSourceManager:
"""Get or create the input source manager for the app instance."""
- if not hasattr(app_instance, 'input_source_manager'):
+ if not hasattr(app_instance, "input_source_manager"):
app_instance.input_source_manager = InputSourceManager()
return app_instance.input_source_manager
+
@router.websocket("/ws/{user_id}")
-async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket, app_instance=Depends(get_app_instance), pipeline_class=Depends(get_pipeline_class)):
+async def websocket_endpoint(
+ user_id: uuid.UUID,
+ websocket: WebSocket,
+ app_instance=Depends(get_app_instance),
+ pipeline_class=Depends(get_pipeline_class),
+):
"""Main WebSocket endpoint for real-time communication"""
try:
- await app_instance.conn_manager.connect(
- user_id, websocket, app_instance.args.max_queue_size
- )
+ await app_instance.conn_manager.connect(user_id, websocket, app_instance.args.max_queue_size)
await handle_websocket_data(user_id, app_instance, pipeline_class)
except ServerFullException as e:
logging.exception(f"websocket_endpoint: Server Full: {e}")
@@ -35,6 +42,7 @@ async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket, app_insta
await app_instance.conn_manager.disconnect(user_id)
logging.info(f"websocket_endpoint: User disconnected: {user_id}")
+
async def handle_websocket_data(user_id: uuid.UUID, app_instance, pipeline_class):
"""Handle WebSocket data flow for a specific user"""
if not app_instance.conn_manager.check_user(user_id):
@@ -42,10 +50,7 @@ async def handle_websocket_data(user_id: uuid.UUID, app_instance, pipeline_class
last_time = time.time()
try:
while True:
- if (
- app_instance.args.timeout > 0
- and time.time() - last_time > app_instance.args.timeout
- ):
+ if app_instance.args.timeout > 0 and time.time() - last_time > app_instance.args.timeout:
await app_instance.conn_manager.send_json(
user_id,
{
@@ -62,45 +67,45 @@ async def handle_websocket_data(user_id: uuid.UUID, app_instance, pipeline_class
params = await app_instance.conn_manager.receive_json(user_id)
params = pipeline_class.InputParams(**params)
params = SimpleNamespace(**params.dict())
-
+
# Check if we need image data based on pipeline
need_image = True
- if app_instance.pipeline and hasattr(app_instance.pipeline, 'pipeline_mode'):
+ if app_instance.pipeline and hasattr(app_instance.pipeline, "pipeline_mode"):
# Need image for img2img OR for txt2img with ControlNets
- has_controlnets = app_instance.pipeline.use_config and app_instance.pipeline.config and 'controlnets' in app_instance.pipeline.config
+ has_controlnets = (
+ app_instance.pipeline.use_config
+ and app_instance.pipeline.config
+ and "controlnets" in app_instance.pipeline.config
+ )
need_image = app_instance.pipeline.pipeline_mode == "img2img" or has_controlnets
- elif app_instance.app_state.uploaded_config and 'mode' in app_instance.app_state.uploaded_config:
+ elif app_instance.app_state.uploaded_config and "mode" in app_instance.app_state.uploaded_config:
# Need image for img2img OR for txt2img with ControlNets
- has_controlnets = 'controlnets' in app_instance.app_state.uploaded_config
- need_image = app_instance.app_state.uploaded_config['mode'] == "img2img" or has_controlnets
-
+ has_controlnets = "controlnets" in app_instance.app_state.uploaded_config
+ need_image = app_instance.app_state.uploaded_config["mode"] == "img2img" or has_controlnets
+
# Get input source manager
input_manager = _get_input_source_manager(app_instance)
-
+
if need_image:
# Receive main webcam stream (fallback)
image_data = await app_instance.conn_manager.receive_bytes(user_id)
if len(image_data) == 0:
- await app_instance.conn_manager.send_json(
- user_id, {"status": "send_frame"}
- )
+ await app_instance.conn_manager.send_json(user_id, {"status": "send_frame"})
continue
-
+
# Update webcam frame in input manager for all webcam sources
input_manager.update_webcam_frame(image_data)
-
+
# Always use direct bytes-to-tensor conversion for efficiency
params.image = bytes_to_pt(image_data)
else:
params.image = None
-
+
# Store the input manager reference in params for later use by img2img.py
params.input_manager = input_manager
-
+
await app_instance.conn_manager.update_data(user_id, params)
except Exception as e:
logging.exception(f"handle_websocket_data: Websocket Error: {e}, {user_id} ")
await app_instance.conn_manager.disconnect(user_id)
-
-
diff --git a/demo/realtime-img2img/util.py b/demo/realtime-img2img/util.py
index 0ef21421b..b311c4faa 100644
--- a/demo/realtime-img2img/util.py
+++ b/demo/realtime-img2img/util.py
@@ -1,11 +1,10 @@
+import io
from importlib import import_module
from types import ModuleType
-from typing import Dict, Any
-from pydantic import BaseModel as PydanticBaseModel, Field
-from PIL import Image
-import io
+
import torch
-from torchvision.io import encode_jpeg, decode_jpeg
+from PIL import Image
+from torchvision.io import decode_jpeg, encode_jpeg
def get_pipeline_class(pipeline_name: str) -> ModuleType:
@@ -30,22 +29,22 @@ def bytes_to_pil(image_bytes: bytes) -> Image.Image:
def bytes_to_pt(image_bytes: bytes) -> torch.Tensor:
"""
Convert JPEG/PNG bytes directly to PyTorch tensor using torchvision
-
+
Args:
image_bytes: Raw image bytes (JPEG/PNG format)
-
+
Returns:
torch.Tensor: Image tensor with shape (C, H, W), values in [0, 1], dtype float32
"""
# Convert bytes to tensor for torchvision
byte_tensor = torch.frombuffer(image_bytes, dtype=torch.uint8)
-
+
# Decode JPEG/PNG directly to tensor (C, H, W) format, uint8 [0, 255]
image_tensor = decode_jpeg(byte_tensor)
-
+
# Convert to float32 and normalize to [0, 1]
image_tensor = image_tensor.float() / 255.0
-
+
return image_tensor
@@ -65,24 +64,24 @@ def pil_to_frame(image: Image.Image) -> bytes:
def pt_to_frame(tensor: torch.Tensor) -> bytes:
"""
Convert PyTorch tensor directly to JPEG frame bytes using torchvision
-
+
Args:
tensor: PyTorch tensor with shape (C, H, W) or (1, C, H, W), values in [0, 1]
-
+
Returns:
bytes: JPEG frame data for streaming
"""
# Handle batch dimension - take first image if batched
if tensor.dim() == 4:
tensor = tensor[0]
-
+
# Convert to uint8 format (0-255) and ensure correct shape (C, H, W)
tensor_uint8 = (tensor * 255).clamp(0, 255).to(torch.uint8)
-
+
# Encode directly to JPEG bytes using torchvision
jpeg_bytes = encode_jpeg(tensor_uint8, quality=90)
frame_data = jpeg_bytes.cpu().numpy().tobytes()
-
+
return (
b"--frame\r\n"
+ b"Content-Type: image/jpeg\r\n"
diff --git a/demo/realtime-img2img/utils/video_utils.py b/demo/realtime-img2img/utils/video_utils.py
index ad092415f..4ba9d4924 100644
--- a/demo/realtime-img2img/utils/video_utils.py
+++ b/demo/realtime-img2img/utils/video_utils.py
@@ -6,25 +6,26 @@
"""
import logging
+from pathlib import Path
+from typing import Optional, Tuple
+
import cv2
import numpy as np
import torch
-from pathlib import Path
-from typing import Optional, Tuple
class VideoFrameExtractor:
"""
Extracts frames from video files for use as input sources.
-
+
Handles video playback, looping, and frame extraction with automatic
conversion to PyTorch tensors.
"""
-
+
def __init__(self, video_path: str):
"""
Initialize the video frame extractor.
-
+
Args:
video_path: Path to the video file
"""
@@ -34,34 +35,35 @@ def __init__(self, video_path: str):
self.frame_count = 0
self.current_frame_idx = 0
self._logger = logging.getLogger(f"VideoFrameExtractor.{self.video_path.name}")
-
+
self._initialize_capture()
-
+
def _initialize_capture(self):
"""Initialize the video capture object."""
if not self.video_path.exists():
self._logger.error(f"Video file not found: {self.video_path}")
return
-
+
self.cap = cv2.VideoCapture(str(self.video_path))
-
+
if not self.cap.isOpened():
self._logger.error(f"Failed to open video file: {self.video_path}")
return
-
+
# Get video properties
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
- self._logger.info(f"Initialized video: {self.video_path.name}, "
- f"FPS: {self.fps:.2f}, Frames: {self.frame_count}")
-
+
+ self._logger.info(
+ f"Initialized video: {self.video_path.name}, FPS: {self.fps:.2f}, Frames: {self.frame_count}"
+ )
+
def get_frame(self) -> Optional[torch.Tensor]:
"""
Extract the current frame and advance to the next frame.
-
+
Automatically loops back to the beginning when reaching the end.
-
+
Returns:
torch.Tensor: Frame as tensor with shape (C, H, W), values in [0, 1], dtype float32
None: If frame extraction fails
@@ -69,101 +71,101 @@ def get_frame(self) -> Optional[torch.Tensor]:
if not self.cap or not self.cap.isOpened():
self._logger.error("Video capture not initialized or closed")
return None
-
+
ret, frame = self.cap.read()
-
+
if not ret:
# End of video, loop back to beginning
self._logger.debug("End of video reached, looping back to start")
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.current_frame_idx = 0
ret, frame = self.cap.read()
-
+
if not ret:
self._logger.error("Failed to read frame even after reset")
return None
-
+
self.current_frame_idx += 1
-
+
# Convert frame to tensor
return self._frame_to_tensor(frame)
-
+
def get_frame_at_time(self, timestamp: float) -> Optional[torch.Tensor]:
"""
Get frame at a specific timestamp.
-
+
Args:
timestamp: Time in seconds
-
+
Returns:
torch.Tensor: Frame at the specified time or None if failed
"""
if not self.cap or not self.cap.isOpened():
return None
-
+
# Convert timestamp to frame number
frame_number = int(timestamp * self.fps)
frame_number = max(0, min(frame_number, self.frame_count - 1))
-
+
# Seek to frame
self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
self.current_frame_idx = frame_number
-
+
ret, frame = self.cap.read()
if ret:
return self._frame_to_tensor(frame)
-
+
return None
-
+
def _frame_to_tensor(self, frame: np.ndarray) -> torch.Tensor:
"""
Convert OpenCV frame to PyTorch tensor.
-
+
Args:
frame: OpenCV frame in BGR format
-
+
Returns:
torch.Tensor: Frame tensor in RGB format with shape (C, H, W)
"""
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-
+
# Convert to tensor and normalize
frame_tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
-
+
return frame_tensor
-
+
def get_video_info(self) -> dict:
"""
Get information about the video.
-
+
Returns:
Dictionary with video metadata
"""
if not self.cap:
return {}
-
+
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
duration = self.frame_count / self.fps if self.fps > 0 else 0
-
+
return {
- 'path': str(self.video_path),
- 'fps': self.fps,
- 'frame_count': self.frame_count,
- 'width': width,
- 'height': height,
- 'duration': duration,
- 'current_frame': self.current_frame_idx
+ "path": str(self.video_path),
+ "fps": self.fps,
+ "frame_count": self.frame_count,
+ "width": width,
+ "height": height,
+ "duration": duration,
+ "current_frame": self.current_frame_idx,
}
-
+
def reset(self):
"""Reset video to beginning."""
if self.cap:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.current_frame_idx = 0
self._logger.debug("Video reset to beginning")
-
+
def cleanup(self):
"""Release video capture resources."""
if self.cap:
@@ -175,25 +177,25 @@ def cleanup(self):
def get_video_thumbnail(video_path: str, timestamp: float = 0.0) -> Optional[torch.Tensor]:
"""
Get a thumbnail frame from a video file.
-
+
Args:
video_path: Path to the video file
timestamp: Time in seconds to extract thumbnail from
-
+
Returns:
torch.Tensor: Thumbnail frame or None if failed
"""
try:
extractor = VideoFrameExtractor(video_path)
-
+
if timestamp > 0:
thumbnail = extractor.get_frame_at_time(timestamp)
else:
thumbnail = extractor.get_frame()
-
+
extractor.cleanup()
return thumbnail
-
+
except Exception as e:
logging.getLogger("video_utils").error(f"Failed to get thumbnail: {e}")
return None
@@ -202,52 +204,63 @@ def get_video_thumbnail(video_path: str, timestamp: float = 0.0) -> Optional[tor
def validate_video_file(video_path: str) -> Tuple[bool, str]:
"""
Validate if a file is a readable video.
-
+
Args:
video_path: Path to the video file
-
+
Returns:
Tuple of (is_valid, error_message)
"""
try:
path = Path(video_path)
-
+
if not path.exists():
return False, "Video file does not exist"
-
+
# Try to open with OpenCV
cap = cv2.VideoCapture(str(path))
-
+
if not cap.isOpened():
return False, "Cannot open video file"
-
+
# Try to read first frame
ret, frame = cap.read()
cap.release()
-
+
if not ret:
return False, "Cannot read frames from video"
-
+
return True, "Video file is valid"
-
+
except Exception as e:
return False, f"Video validation error: {str(e)}"
# Supported video formats
SUPPORTED_VIDEO_FORMATS = {
- '.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv',
- '.m4v', '.3gp', '.ogv', '.ts', '.m2ts', '.mts'
+ ".mp4",
+ ".avi",
+ ".mov",
+ ".mkv",
+ ".webm",
+ ".flv",
+ ".wmv",
+ ".m4v",
+ ".3gp",
+ ".ogv",
+ ".ts",
+ ".m2ts",
+ ".mts",
}
def is_supported_video_format(filename: str) -> bool:
"""
Check if a file has a supported video format.
-
+
Args:
filename: Name or path of the file
-
+
Returns:
bool: True if format is supported
"""
diff --git a/demo/realtime-txt2img/config.py b/demo/realtime-txt2img/config.py
index 354941485..f7bb12403 100644
--- a/demo/realtime-txt2img/config.py
+++ b/demo/realtime-txt2img/config.py
@@ -1,8 +1,9 @@
+import os
from dataclasses import dataclass, field
from typing import List, Literal
import torch
-import os
+
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "False") == "True"
diff --git a/demo/vid2vid/app.py b/demo/vid2vid/app.py
index d03da15e2..c7f4a37ff 100644
--- a/demo/vid2vid/app.py
+++ b/demo/vid2vid/app.py
@@ -1,18 +1,18 @@
-import gradio as gr
-
import os
import sys
-from typing import Literal, Dict, Optional
+from typing import Dict, Literal, Optional
-import fire
+import gradio as gr
import torch
from torchvision.io import read_video, write_video
from tqdm import tqdm
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from streamdiffusion import StreamDiffusionWrapper
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -28,7 +28,6 @@ def main(
enable_similar_image_filter: bool = True,
seed: int = 2,
):
-
"""
Process for generating images based on a prompt using a specified model.
@@ -106,9 +105,5 @@ def main(
return output
-demo = gr.Interface(
- main,
- gr.Video(sources=['upload', 'webcam']),
- "playable_video"
-)
+demo = gr.Interface(main, gr.Video(sources=["upload", "webcam"]), "playable_video")
demo.launch()
diff --git a/docs/pr_audit_2026-05-16.md b/docs/pr_audit_2026-05-16.md
new file mode 100644
index 000000000..4fcde6c3c
--- /dev/null
+++ b/docs/pr_audit_2026-05-16.md
@@ -0,0 +1,44 @@
+# StreamDiffusion PR Audit β 2026-05-16
+
+Cross-check of local dev branch vs dotsimulate/StreamDiffusion upstream to determine which open PRs are effectively landed.
+
+## PR status
+
+| PR | Title | GitHub state | Actual landed state |
+|---|---|---|---|
+| **#4** | perf: Inference performance & pipeline correctness | OPEN | **Effectively landed** β cherry-picked as `ef50c0d` into `dotsimulate/feat/trt10.16-fp8-perf`, then merged via PR #12 |
+| **#5** | feat: IP-Adapter auto-res, VRAM offload & dep updates | OPEN | **Effectively landed** β cherry-pick `18cf3f9`, then PR #12 |
+| **#6** | feat: FP8 quantization & TensorRT build infra | OPEN | **Effectively landed** β cherry-pick `85ca135` (rewritten as ONNX-level quantization), then PR #12 |
+| **#7** | perf: Tier 1 hot-path elimination & TE stutter fix | OPEN | **Effectively landed** β cherry-pick `c0819aa`, then PR #12 |
+| **#8** | perf: TRT engine builder β static shapes, CUDA graphs | OPEN | **Effectively landed** β cherry-pick `b4be27d`, then PR #12 |
+| **#9** | chore(installer): overhaul (superseded by #10) | CLOSED | n/a |
+| **#10** | chore(installer): overhaul install scripts | OPEN | **Effectively landed** β cherry-pick `ea4d3f1`, then PR #12 |
+| **#11** | fix(trt): missing build_all_tactics param | MERGED | Landed cleanly β `4d0cb2b` on SDTD_031_dev, 2026-04-22 |
+| **#12** | feat: TRT 10.16.1.11 + FP8 + CUDA hot-path | MERGED | Landed with conflict resolution β `70b0523` (2026-05-10) + cleanup `13b6651` (2026-05-13) |
+
+## Key refs
+
+- Squashed PR #12 source: `9610624` on `origin/dotsimulate/feat/trt10.16-fp8-perf`
+- PR #12 merge into SDTD_031_dev: `70b0523` (2026-05-10)
+- Formatter-noise cleanup: `13b6651` (2026-05-13)
+- PR #11 merge: `4d0cb2b` (2026-04-22)
+- Cherry-pick anchors (Apr 9): `ef50c0d`, `18cf3f9`, `85ca135`, `c0819aa`, `b4be27d`, `ea4d3f1`
+
+## Dev repo state
+
+- Checked-out branch: `dotsimulate/feat/trt10.16-fp8-perf` at `9610624`
+- Local `forkni/feat-fp8-torch-calibration-v2`: **35 commits ahead** of origin β unsquashed Phase 1 history; content likely fully covered by the squashed `9610624` (squash was created 8 min after last commit)
+- `origin/SDTD_031_dev` is stale at `4d0cb2b` β run `git fetch origin` to get `13b6651`
+
+## Recommended next steps
+
+1. `git fetch origin` in dev repo to update remote-tracking refs
+2. Close PRs #4, #5, #6, #7, #8, #10 on GitHub β comment referencing `70b0523` as superseding merge
+3. Before closing #6, confirm ONNX-level FP8 path (PR #12) is intended replacement for torch-mode approach in pr2 branch
+4. Verify local `forkni/feat-fp8-torch-calibration-v2` 35 ahead commits are fully covered by squash: `git diff forkni/feat-fp8-torch-calibration-v2 origin/dotsimulate/feat/trt10.16-fp8-perf --stat`
+5. FP8 Phase 2: open new branch off SDTD_031_dev at `13b6651`
+6. StreamDiffusion-installer: commit `8c8020a` in installer repo β Phase 1 blocker resolved, safe to revisit push (needs maintainer access, 403 on direct push)
+
+## Process lesson
+
+PR #12 merge required reverting 27 formatter-noise files and re-adding 4 real params afterward. Future PRs: run a pre-PR pass with `git diff --stat` to catch quote flips/whitespace/import reorder before submitting.
diff --git a/examples/benchmark/multi.py b/examples/benchmark/multi.py
index bfa971cf2..443835b99 100644
--- a/examples/benchmark/multi.py
+++ b/examples/benchmark/multi.py
@@ -3,7 +3,7 @@
import sys
import time
from multiprocessing import Process, Queue
-from typing import List, Literal, Optional, Dict
+from typing import Dict, List, Literal, Optional
import fire
import PIL.Image
@@ -13,6 +13,7 @@
from streamdiffusion.image_utils import postprocess_image
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from streamdiffusion import StreamDiffusionWrapper
diff --git a/examples/benchmark/single.py b/examples/benchmark/single.py
index 5e55fb633..133abc13b 100644
--- a/examples/benchmark/single.py
+++ b/examples/benchmark/single.py
@@ -1,7 +1,7 @@
import io
import os
import sys
-from typing import List, Literal, Optional, Dict
+from typing import Dict, List, Literal, Optional
import fire
import PIL.Image
@@ -101,9 +101,7 @@ def run(
delta=0.5,
)
- downloaded_image = download_image("https://github.com/ddpn08.png").resize(
- (width, height)
- )
+ downloaded_image = download_image("https://github.com/ddpn08.png").resize((width, height))
# warmup
for _ in range(warmup):
diff --git a/examples/config/config_ipadapter_stream_test.py b/examples/config/config_ipadapter_stream_test.py
index a66c01e19..b8141e033 100644
--- a/examples/config/config_ipadapter_stream_test.py
+++ b/examples/config/config_ipadapter_stream_test.py
@@ -13,205 +13,217 @@
- Tests the IPAdapter stream behavior fix
"""
-import cv2
-import torch
-import numpy as np
-from PIL import Image
import argparse
-from pathlib import Path
-import sys
-import time
+import json
import os
import shutil
-import json
-from collections import deque
+import sys
+import time
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
def tensor_to_opencv(tensor: torch.Tensor, target_width: int, target_height: int) -> np.ndarray:
"""
Convert a PyTorch tensor (output_type='pt') to OpenCV BGR format for video writing.
Uses efficient tensor operations similar to the realtime-img2img demo.
-
+
Args:
tensor: Tensor in range [0,1] with shape [B, C, H, W] or [C, H, W]
target_width: Target width for output
target_height: Target height for output
-
+
Returns:
BGR numpy array ready for OpenCV
"""
# Handle batch dimension - take first image if batched
if tensor.dim() == 4:
tensor = tensor[0]
-
+
# Convert to uint8 format (0-255) and ensure correct shape (C, H, W)
tensor_uint8 = (tensor * 255).clamp(0, 255).to(torch.uint8)
-
+
# Convert from [C, H, W] to [H, W, C] format
if tensor_uint8.dim() == 3:
image_np = tensor_uint8.permute(1, 2, 0).cpu().numpy()
else:
raise ValueError(f"tensor_to_opencv: Unexpected tensor shape: {tensor_uint8.shape}")
-
+
# Convert RGB to BGR for OpenCV
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
-
+
# Resize if needed
if image_bgr.shape[:2] != (target_height, target_width):
image_bgr = cv2.resize(image_bgr, (target_width, target_height))
-
+
return image_bgr
def process_video_ipadapter_stream(config_path, input_video, static_image, output_dir, engine_only=False):
"""Process video using IPAdapter as primary driving force with static base image"""
print(f"process_video_ipadapter_stream: Loading config from {config_path}")
-
+
# Import here to avoid loading at module level
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
- from streamdiffusion import load_config, create_wrapper_from_config
-
+ from streamdiffusion import create_wrapper_from_config, load_config
+
# Load configuration
config = load_config(config_path)
-
+
# Force tensor output for better performance
- config['output_type'] = 'pt'
-
+ config["output_type"] = "pt"
+
# Get width and height from config (with defaults)
- width = config.get('width', 512)
- height = config.get('height', 512)
-
+ width = config.get("width", 512)
+ height = config.get("height", 512)
+
print(f"process_video_ipadapter_stream: Using dimensions: {width}x{height}")
- print(f"process_video_ipadapter_stream: Using output_type='pt' for better performance")
-
+ print("process_video_ipadapter_stream: Using output_type='pt' for better performance")
+
# Create output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
-
+
# Copy config, input video, and static image to output directory
config_copy_path = output_dir / f"config_{Path(config_path).name}"
shutil.copy2(config_path, config_copy_path)
print(f"process_video_ipadapter_stream: Copied config to {config_copy_path}")
-
+
input_copy_path = output_dir / f"input_{Path(input_video).name}"
shutil.copy2(input_video, input_copy_path)
print(f"process_video_ipadapter_stream: Copied input video to {input_copy_path}")
-
+
static_copy_path = output_dir / f"static_{Path(static_image).name}"
shutil.copy2(static_image, static_copy_path)
print(f"process_video_ipadapter_stream: Copied static image to {static_copy_path}")
-
+
# Create wrapper using the built-in function
wrapper = create_wrapper_from_config(config)
-
+
if engine_only:
print("Engine-only mode: TensorRT engines have been built (if needed). Exiting.")
return None
-
+
# Load and prepare static image
static_img = Image.open(static_image)
static_img = static_img.resize((width, height), Image.Resampling.LANCZOS)
print(f"process_video_ipadapter_stream: Loaded static image: {static_image}")
-
+
# Open input video
cap = cv2.VideoCapture(str(input_video))
if not cap.isOpened():
raise ValueError(f"Could not open input video: {input_video}")
-
+
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
+
print(f"process_video_ipadapter_stream: Input video - {frame_count} frames at {fps} FPS")
-
+
# Setup output video writer (3-panel display: input, static, generated)
output_video_path = output_dir / "output_video.mp4"
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width * 3, height))
-
+
# Performance tracking
frame_times = []
total_start_time = time.time()
-
+
print("process_video_ipadapter_stream: Starting IPAdapter stream processing...")
print("process_video_ipadapter_stream: Using static image as base input, video frames for ControlNet + IPAdapter")
-
+
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
-
+
frame_start_time = time.time()
-
+
# Resize frame
frame_resized = cv2.resize(frame, (width, height))
-
+
# Convert frame to PIL
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
-
+
# Update ControlNet control images (structural guidance from video frames)
- if hasattr(wrapper.stream, '_controlnet_module') and wrapper.stream._controlnet_module:
+ if hasattr(wrapper.stream, "_controlnet_module") and wrapper.stream._controlnet_module:
controlnet_count = len(wrapper.stream._controlnet_module.controlnets)
- print(f"process_video_ipadapter_stream: Updating control image for {controlnet_count} ControlNet(s) on frame {frame_idx}")
+ print(
+ f"process_video_ipadapter_stream: Updating control image for {controlnet_count} ControlNet(s) on frame {frame_idx}"
+ )
for i in range(controlnet_count):
wrapper.update_control_image(i, frame_pil)
else:
print(f"process_video_ipadapter_stream: No ControlNet module found for frame {frame_idx}")
-
+
# Update IPAdapter style image (style/content guidance from video frames)
# This is the key part - using video frames as IPAdapter style images with is_stream=True
- if hasattr(wrapper.stream, '_ipadapter_module') and wrapper.stream._ipadapter_module:
- print(f"process_video_ipadapter_stream: Updating IPAdapter style image on frame {frame_idx} (is_stream=True)")
+ if hasattr(wrapper.stream, "_ipadapter_module") and wrapper.stream._ipadapter_module:
+ print(
+ f"process_video_ipadapter_stream: Updating IPAdapter style image on frame {frame_idx} (is_stream=True)"
+ )
# Update style image with is_stream=True for pipelined processing
wrapper.update_style_image(frame_pil, is_stream=True)
else:
print(f"process_video_ipadapter_stream: No IPAdapter module found for frame {frame_idx}")
-
+
# Process with static image as base input (this is the key difference)
# The static image provides the base structure, while ControlNet and IPAdapter
# provide the dynamic guidance from the video frames
output_tensor = wrapper(static_img)
-
+
# Convert tensor output to OpenCV BGR format
output_bgr = tensor_to_opencv(output_tensor, width, height)
-
+
# Convert static image to display format
static_array = np.array(static_img)
static_bgr = cv2.cvtColor(static_array, cv2.COLOR_RGB2BGR)
-
+
# Create 3-panel display: Input Video | Static Base | Generated Output
combined = np.hstack([frame_resized, static_bgr, output_bgr])
-
+
# Add labels
cv2.putText(combined, "Input Video", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
cv2.putText(combined, "Static Base", (width + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
- cv2.putText(combined, "Generated", (width * 2 + 10, 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
-
+ cv2.putText(combined, "Generated", (width * 2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
+
# Add frame info
- cv2.putText(combined, f"Frame: {frame_idx}/{frame_count}", (10, height - 20),
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
-
+ cv2.putText(
+ combined,
+ f"Frame: {frame_idx}/{frame_count}",
+ (10, height - 20),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.5,
+ (255, 255, 255),
+ 1,
+ )
+
# Write frame
out.write(combined)
-
+
# Track performance
frame_time = time.time() - frame_start_time
frame_times.append(frame_time)
-
+
frame_idx += 1
if frame_idx % 10 == 0:
avg_fps = len(frame_times) / sum(frame_times) if frame_times else 0
- print(f"process_video_ipadapter_stream: Processed {frame_idx}/{frame_count} frames (Avg FPS: {avg_fps:.2f})")
-
+ print(
+ f"process_video_ipadapter_stream: Processed {frame_idx}/{frame_count} frames (Avg FPS: {avg_fps:.2f})"
+ )
+
total_time = time.time() - total_start_time
-
+
# Cleanup
cap.release()
out.release()
-
+
# Calculate performance metrics
if frame_times:
avg_frame_time = sum(frame_times) / len(frame_times)
@@ -222,7 +234,7 @@ def process_video_ipadapter_stream(config_path, input_video, static_image, outpu
min_fps = 1.0 / max_frame_time
else:
avg_frame_time = avg_fps = min_frame_time = max_frame_time = max_fps = min_fps = 0
-
+
# Performance metrics
metrics = {
"input_video": str(input_video),
@@ -238,75 +250,94 @@ def process_video_ipadapter_stream(config_path, input_video, static_image, outpu
"avg_frame_time_seconds": avg_frame_time,
"min_frame_time_seconds": min_frame_time,
"max_frame_time_seconds": max_frame_time,
- "model_id": config['model_id'],
- "acceleration": config.get('acceleration', 'none'),
- "frame_buffer_size": config.get('frame_buffer_size', 1),
- "num_inference_steps": config.get('num_inference_steps', 50),
- "guidance_scale": config.get('guidance_scale', 1.1),
- "controlnets": [cn['model_id'] for cn in config.get('controlnets', [])],
- "ipadapter_configs": [ip['ipadapter_model_path'] for ip in config.get('ipadapter_config', [])],
+ "model_id": config["model_id"],
+ "acceleration": config.get("acceleration", "none"),
+ "frame_buffer_size": config.get("frame_buffer_size", 1),
+ "num_inference_steps": config.get("num_inference_steps", 50),
+ "guidance_scale": config.get("guidance_scale", 1.1),
+ "controlnets": [cn["model_id"] for cn in config.get("controlnets", [])],
+ "ipadapter_configs": [ip["ipadapter_model_path"] for ip in config.get("ipadapter_config", [])],
"test_type": "ipadapter_stream_test",
"is_stream_enabled": True,
"output_type": "pt",
- "description": "IPAdapter as primary driving force with static base image using tensor output for performance"
+ "description": "IPAdapter as primary driving force with static base image using tensor output for performance",
}
-
+
# Save metrics
metrics_path = output_dir / "performance_metrics.json"
- with open(metrics_path, 'w') as f:
+ with open(metrics_path, "w") as f:
json.dump(metrics, f, indent=2)
-
- print(f"process_video_ipadapter_stream: Processing completed!")
+
+ print("process_video_ipadapter_stream: Processing completed!")
print(f"process_video_ipadapter_stream: Output video saved to: {output_video_path}")
print(f"process_video_ipadapter_stream: Performance metrics saved to: {metrics_path}")
print(f"process_video_ipadapter_stream: Average FPS: {avg_fps:.2f}")
print(f"process_video_ipadapter_stream: Total time: {total_time:.2f} seconds")
- print(f"process_video_ipadapter_stream: Test completed - IPAdapter stream behavior verified")
-
+ print("process_video_ipadapter_stream: Test completed - IPAdapter stream behavior verified")
+
return metrics
def main():
- parser = argparse.ArgumentParser(description="IPAdapter Stream Test Demo - Tests IPAdapter as primary driving force")
-
- parser.add_argument("--config", type=str, required=True,
- help="Path to configuration file (must include both ControlNet and IPAdapter configs)")
- parser.add_argument("--input-video", type=str, required=True,
- help="Path to input video file (used for both ControlNet and IPAdapter guidance)")
- parser.add_argument("--static-image", type=str, required=True,
- help="Path to static image file (used as base input to StreamDiffusion)")
- parser.add_argument("--output-dir", type=str, default="output",
- help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.")
- parser.add_argument("--engine-only", action="store_true",
- help="Only build TensorRT engines and exit (no video processing)")
-
+ parser = argparse.ArgumentParser(
+ description="IPAdapter Stream Test Demo - Tests IPAdapter as primary driving force"
+ )
+
+ parser.add_argument(
+ "--config",
+ type=str,
+ required=True,
+ help="Path to configuration file (must include both ControlNet and IPAdapter configs)",
+ )
+ parser.add_argument(
+ "--input-video",
+ type=str,
+ required=True,
+ help="Path to input video file (used for both ControlNet and IPAdapter guidance)",
+ )
+ parser.add_argument(
+ "--static-image",
+ type=str,
+ required=True,
+ help="Path to static image file (used as base input to StreamDiffusion)",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="output",
+ help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.",
+ )
+ parser.add_argument(
+ "--engine-only", action="store_true", help="Only build TensorRT engines and exit (no video processing)"
+ )
+
args = parser.parse_args()
-
+
# Create timestamped subdirectory within the specified parent directory
timestamp = time.strftime("%Y%m%d_%H%M%S")
input_name = Path(args.input_video).stem
static_name = Path(args.static_image).stem
config_name = Path(args.config).stem
subdir_name = f"ipadapter_stream_test_{config_name}_{input_name}_{static_name}_{timestamp}"
-
+
# Combine parent directory with generated subdirectory name
final_output_dir = Path(args.output_dir) / subdir_name
args.output_dir = str(final_output_dir)
print(f"main: Using output directory: {args.output_dir}")
-
+
# Validate input files
if not Path(args.config).exists():
print(f"main: Error - Config file not found: {args.config}")
return 1
-
+
if not Path(args.input_video).exists():
print(f"main: Error - Input video not found: {args.input_video}")
return 1
-
+
if not Path(args.static_image).exists():
print(f"main: Error - Static image not found: {args.static_image}")
return 1
-
+
print("IPAdapter Stream Test Demo")
print("=" * 50)
print(f"main: Config: {args.config}")
@@ -321,14 +352,10 @@ def main():
print("- is_stream=True β High-throughput pipelined processing")
print("- Tests IPAdapter stream behavior fix")
print("=" * 50)
-
+
try:
metrics = process_video_ipadapter_stream(
- args.config,
- args.input_video,
- args.static_image,
- args.output_dir,
- engine_only=args.engine_only
+ args.config, args.input_video, args.static_image, args.output_dir, engine_only=args.engine_only
)
if args.engine_only:
print("main: Engine-only mode completed successfully!")
@@ -337,6 +364,7 @@ def main():
return 0
except Exception as e:
import traceback
+
print(f"main: Error during processing: {e}")
print(f"main: Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}")
return 1
diff --git a/examples/config/config_video_test.py b/examples/config/config_video_test.py
index 7ff0a0208..69a8d39e6 100644
--- a/examples/config/config_video_test.py
+++ b/examples/config/config_video_test.py
@@ -7,136 +7,136 @@
of the config and input video to an output directory.
"""
-import cv2
-import torch
-import numpy as np
-from PIL import Image
import argparse
-from pathlib import Path
-import sys
-import time
+import json
import os
import shutil
-import json
-from collections import deque
+import sys
+import time
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
def tensor_to_opencv(tensor: torch.Tensor, target_width: int, target_height: int) -> np.ndarray:
"""
Convert a PyTorch tensor (output_type='pt') to OpenCV BGR format for video writing.
Uses efficient tensor operations similar to the realtime-img2img demo.
-
+
Args:
tensor: Tensor in range [0,1] with shape [B, C, H, W] or [C, H, W]
target_width: Target width for output
target_height: Target height for output
-
+
Returns:
BGR numpy array ready for OpenCV
"""
# Handle batch dimension - take first image if batched
if tensor.dim() == 4:
tensor = tensor[0]
-
+
# Convert to uint8 format (0-255) and ensure correct shape (C, H, W)
tensor_uint8 = (tensor * 255).clamp(0, 255).to(torch.uint8)
-
+
# Convert from [C, H, W] to [H, W, C] format
if tensor_uint8.dim() == 3:
image_np = tensor_uint8.permute(1, 2, 0).cpu().numpy()
else:
raise ValueError(f"tensor_to_opencv: Unexpected tensor shape: {tensor_uint8.shape}")
-
+
# Convert RGB to BGR for OpenCV
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
-
+
# Resize if needed
if image_bgr.shape[:2] != (target_height, target_width):
image_bgr = cv2.resize(image_bgr, (target_width, target_height))
-
+
return image_bgr
def process_video(config_path, input_video, output_dir, engine_only=False):
"""Process video through ControlNet pipeline"""
print(f"process_video: Loading config from {config_path}")
-
+
# Import here to avoid loading at module level
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
- from streamdiffusion import load_config, create_wrapper_from_config
-
+ from streamdiffusion import create_wrapper_from_config, load_config
+
# Load configuration
config = load_config(config_path)
-
+
# Force tensor output for better performance
- config['output_type'] = 'pt'
-
+ config["output_type"] = "pt"
+
# Get width and height from config (with defaults)
- width = config.get('width', 512)
- height = config.get('height', 512)
-
+ width = config.get("width", 512)
+ height = config.get("height", 512)
+
print(f"process_video: Using dimensions: {width}x{height}")
- print(f"process_video: Using output_type='pt' for better performance")
-
+ print("process_video: Using output_type='pt' for better performance")
+
# Create output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
-
+
# Copy config and input video to output directory
config_copy_path = output_dir / f"config_{Path(config_path).name}"
shutil.copy2(config_path, config_copy_path)
print(f"process_video: Copied config to {config_copy_path}")
-
+
input_copy_path = output_dir / f"input_{Path(input_video).name}"
shutil.copy2(input_video, input_copy_path)
print(f"process_video: Copied input video to {input_copy_path}")
-
+
# Create wrapper using the built-in function (width/height from config)
wrapper = create_wrapper_from_config(config)
-
+
if engine_only:
print("Engine-only mode: TensorRT engines have been built (if needed). Exiting.")
return None
-
+
# Open input video
cap = cv2.VideoCapture(str(input_video))
if not cap.isOpened():
raise ValueError(f"Could not open input video: {input_video}")
-
+
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
+
print(f"process_video: Input video - {frame_count} frames at {fps} FPS")
-
+
# Setup output video writer
output_video_path = output_dir / "output_video.mp4"
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width + width, height))
-
+
# Performance tracking
frame_times = []
total_start_time = time.time()
-
+
print("process_video: Starting video processing...")
-
+
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
-
+
frame_start_time = time.time()
-
+
# Resize frame
frame_resized = cv2.resize(frame, (width, height))
-
+
# Convert frame to PIL
frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
frame_pil = Image.fromarray(frame_rgb)
-
+
# Update control image for all configured ControlNets
- if hasattr(wrapper.stream, '_controlnet_module') and wrapper.stream._controlnet_module:
+ if hasattr(wrapper.stream, "_controlnet_module") and wrapper.stream._controlnet_module:
controlnet_count = len(wrapper.stream._controlnet_module.controlnets)
print(f"process_video: Updating control image for {controlnet_count} ControlNet(s) on frame {frame_idx}")
for i in range(controlnet_count):
@@ -144,36 +144,35 @@ def process_video(config_path, input_video, output_dir, engine_only=False):
else:
print(f"process_video: No ControlNet module found for frame {frame_idx}")
output_tensor = wrapper(frame_pil)
-
+
# Convert tensor output to OpenCV BGR format
output_bgr = tensor_to_opencv(output_tensor, width, height)
-
+
# Create side-by-side display
combined = np.hstack([frame_resized, output_bgr])
-
+
# Add labels
cv2.putText(combined, "Input", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
- cv2.putText(combined, "Generated", (width + 10, 30),
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
-
+ cv2.putText(combined, "Generated", (width + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
+
# Write frame
out.write(combined)
-
+
# Track performance
frame_time = time.time() - frame_start_time
frame_times.append(frame_time)
-
+
frame_idx += 1
if frame_idx % 10 == 0:
avg_fps = len(frame_times) / sum(frame_times) if frame_times else 0
print(f"process_video: Processed {frame_idx}/{frame_count} frames (Avg FPS: {avg_fps:.2f})")
-
+
total_time = time.time() - total_start_time
-
+
# Cleanup
cap.release()
out.release()
-
+
# Calculate performance metrics
if frame_times:
avg_frame_time = sum(frame_times) / len(frame_times)
@@ -184,7 +183,7 @@ def process_video(config_path, input_video, output_dir, engine_only=False):
min_fps = 1.0 / max_frame_time
else:
avg_frame_time = avg_fps = min_frame_time = max_frame_time = max_fps = min_fps = 0
-
+
# Performance metrics
metrics = {
"input_video": str(input_video),
@@ -199,72 +198,77 @@ def process_video(config_path, input_video, output_dir, engine_only=False):
"avg_frame_time_seconds": avg_frame_time,
"min_frame_time_seconds": min_frame_time,
"max_frame_time_seconds": max_frame_time,
- "model_id": config['model_id'],
- "acceleration": config.get('acceleration', 'none'),
- "frame_buffer_size": config.get('frame_buffer_size', 1),
- "num_inference_steps": config.get('num_inference_steps', 50),
- "guidance_scale": config.get('guidance_scale', 1.1),
+ "model_id": config["model_id"],
+ "acceleration": config.get("acceleration", "none"),
+ "frame_buffer_size": config.get("frame_buffer_size", 1),
+ "num_inference_steps": config.get("num_inference_steps", 50),
+ "guidance_scale": config.get("guidance_scale", 1.1),
"output_type": "pt",
- "controlnets": [cn['model_id'] for cn in config.get('controlnets', [])],
+ "controlnets": [cn["model_id"] for cn in config.get("controlnets", [])],
"test_type": "controlnet_video_test",
- "description": "ControlNet video processing using tensor output for performance"
+ "description": "ControlNet video processing using tensor output for performance",
}
-
+
# Save metrics
metrics_path = output_dir / "performance_metrics.json"
- with open(metrics_path, 'w') as f:
+ with open(metrics_path, "w") as f:
json.dump(metrics, f, indent=2)
-
- print(f"process_video: Processing completed!")
+
+ print("process_video: Processing completed!")
print(f"process_video: Output video saved to: {output_video_path}")
print(f"process_video: Performance metrics saved to: {metrics_path}")
print(f"process_video: Average FPS: {avg_fps:.2f}")
print(f"process_video: Total time: {total_time:.2f} seconds")
-
+
return metrics
+
def main():
parser = argparse.ArgumentParser(description="ControlNet Video Test Demo")
-
+
# Get the script directory to make paths relative to it
script_dir = Path(__file__).parent
default_config = script_dir.parent.parent / "configs" / "controlnet_examples" / "multi_controlnet_example.yaml"
-
- parser.add_argument("--config", type=str, required=True,
- help="Path to ControlNet configuration file")
- parser.add_argument("--input-video", type=str, required=True,
- help="Path to input video file")
- parser.add_argument("--output-dir", type=str, default="output",
- help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.")
- parser.add_argument("--engine-only", action="store_true", help="Only build TensorRT engines and exit (no video processing)")
-
+
+ parser.add_argument("--config", type=str, required=True, help="Path to ControlNet configuration file")
+ parser.add_argument("--input-video", type=str, required=True, help="Path to input video file")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="output",
+ help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.",
+ )
+ parser.add_argument(
+ "--engine-only", action="store_true", help="Only build TensorRT engines and exit (no video processing)"
+ )
+
args = parser.parse_args()
-
+
# Create timestamped subdirectory within the specified parent directory
timestamp = time.strftime("%Y%m%d_%H%M%S")
input_name = Path(args.input_video).stem
config_name = Path(args.config).stem
subdir_name = f"controlnet_test_{config_name}_{input_name}_{timestamp}"
-
+
# Combine parent directory with generated subdirectory name
final_output_dir = Path(args.output_dir) / subdir_name
args.output_dir = str(final_output_dir)
print(f"main: Using output directory: {args.output_dir}")
-
+
# Validate input files
if not Path(args.config).exists():
print(f"main: Error - Config file not found: {args.config}")
return 1
-
+
if not Path(args.input_video).exists():
print(f"main: Error - Input video not found: {args.input_video}")
return 1
-
+
print("ControlNet Video Test Demo")
print(f"main: Config: {args.config}")
print(f"main: Input video: {args.input_video}")
print(f"main: Output directory: {args.output_dir}")
-
+
try:
metrics = process_video(args.config, args.input_video, args.output_dir, engine_only=args.engine_only)
if args.engine_only:
@@ -274,10 +278,11 @@ def main():
return 0
except Exception as e:
import traceback
+
print(f"main: Error during processing: {e}")
print(f"main: Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}")
return 1
if __name__ == "__main__":
- exit(main())
\ No newline at end of file
+ exit(main())
diff --git a/examples/img2img/multi.py b/examples/img2img/multi.py
index 8912d1439..af112585b 100644
--- a/examples/img2img/multi.py
+++ b/examples/img2img/multi.py
@@ -1,7 +1,7 @@
import glob
import os
import sys
-from typing import Literal, Dict, Optional
+from typing import Dict, Literal, Optional
import fire
@@ -10,6 +10,7 @@
from streamdiffusion import StreamDiffusionWrapper
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
diff --git a/examples/img2img/single.py b/examples/img2img/single.py
index 4be8a2185..b894bc031 100644
--- a/examples/img2img/single.py
+++ b/examples/img2img/single.py
@@ -1,6 +1,6 @@
import os
import sys
-from typing import Literal, Dict, Optional
+from typing import Dict, Literal, Optional
import fire
@@ -9,6 +9,7 @@
from streamdiffusion import StreamDiffusionWrapper
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py
index 791d88b19..e154dfcc1 100644
--- a/examples/optimal-performance/multi.py
+++ b/examples/optimal-performance/multi.py
@@ -3,7 +3,7 @@
import threading
import time
import tkinter as tk
-from multiprocessing import Process, Queue, get_context
+from multiprocessing import Queue, get_context
from typing import List, Literal
import fire
@@ -98,9 +98,7 @@ def image_generation_process(
return
-def _receive_images(
- queue: Queue, fps_queue: Queue, labels: List[tk.Label], fps_label: tk.Label
-) -> None:
+def _receive_images(queue: Queue, fps_queue: Queue, labels: List[tk.Label], fps_label: tk.Label) -> None:
"""
Continuously receive images from a queue and update the labels.
@@ -120,9 +118,7 @@ def _receive_images(
if not queue.empty():
[
labels[0].after(0, update_image, image_data, labels)
- for image_data in postprocess_image(
- queue.get(block=False), output_type="pil"
- )
+ for image_data in postprocess_image(queue.get(block=False), output_type="pil")
]
if not fps_queue.empty():
fps_label.config(text=f"FPS: {fps_queue.get(block=False):.2f}")
@@ -153,9 +149,7 @@ def receive_images(queue: Queue, fps_queue: Queue) -> None:
fps_label = tk.Label(root, text="FPS: 0")
fps_label.grid(rows=2, columnspan=2)
- thread = threading.Thread(
- target=_receive_images, args=(queue, fps_queue, labels, fps_label), daemon=True
- )
+ thread = threading.Thread(target=_receive_images, args=(queue, fps_queue, labels, fps_label), daemon=True)
thread.start()
try:
@@ -173,7 +167,7 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
- ctx = get_context('spawn')
+ ctx = get_context("spawn")
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
@@ -188,5 +182,6 @@ def main(
process1.join()
process2.join()
+
if __name__ == "__main__":
fire.Fire(main)
diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py
index a8020bb86..ec610de72 100644
--- a/examples/optimal-performance/single.py
+++ b/examples/optimal-performance/single.py
@@ -1,15 +1,17 @@
import os
import sys
import time
-from multiprocessing import Process, Queue, get_context
+from multiprocessing import Queue, get_context
from typing import Literal
import fire
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
-from utils.viewer import receive_images
from streamdiffusion import StreamDiffusionWrapper
+from utils.viewer import receive_images
+
def image_generation_process(
queue: Queue,
@@ -63,6 +65,7 @@ def image_generation_process(
print(f"fps: {fps}")
return
+
def main(
prompt: str = "cat with sunglasses and a hat, photoreal, 8K",
model_id_or_path: str = "stabilityai/sd-turbo",
@@ -71,7 +74,7 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
- ctx = get_context('spawn')
+ ctx = get_context("spawn")
queue = ctx.Queue()
fps_queue = ctx.Queue()
process1 = ctx.Process(
@@ -86,5 +89,6 @@ def main(
process1.join()
process2.join()
+
if __name__ == "__main__":
fire.Fire(main)
diff --git a/examples/screen/main.py b/examples/screen/main.py
index 16042ed62..405ff6eec 100644
--- a/examples/screen/main.py
+++ b/examples/screen/main.py
@@ -1,26 +1,31 @@
import os
import sys
-import time
import threading
-from multiprocessing import Process, Queue, get_context
+import time
+import tkinter as tk
+from multiprocessing import Queue, get_context
from multiprocessing.connection import Connection
-from typing import List, Literal, Dict, Optional
-import torch
+from typing import Dict, Literal, Optional
+
+import fire
+import mss
import PIL.Image
+import torch
+
from streamdiffusion.image_utils import pil2tensor
-import mss
-import fire
-import tkinter as tk
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
-from utils.viewer import receive_images
from streamdiffusion import StreamDiffusionWrapper
+from utils.viewer import receive_images
+
inputs = []
top = 0
left = 0
+
def screen(
event: threading.Event,
height: int = 512,
@@ -37,10 +42,12 @@ def screen(
img = PIL.Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX")
img.resize((height, width))
inputs.append(pil2tensor(img))
- print('exit : screen')
+ print("exit : screen")
+
+
def dummy_screen(
- width: int,
- height: int,
+ width: int,
+ height: int,
):
root = tk.Tk()
root.title("Press Enter to start")
@@ -48,17 +55,22 @@ def dummy_screen(
root.resizable(False, False)
root.attributes("-alpha", 0.8)
root.configure(bg="black")
+
def destroy(event):
root.destroy()
+
root.bind("", destroy)
+
def update_geometry(event):
global top, left
top = root.winfo_y()
left = root.winfo_x()
+
root.bind("", update_geometry)
root.mainloop()
return {"top": top, "left": left, "width": width, "height": height}
+
def monitor_setting_process(
width: int,
height: int,
@@ -67,6 +79,7 @@ def monitor_setting_process(
monitor = dummy_screen(width, height)
monitor_sender.send(monitor)
+
def image_generation_process(
queue: Queue,
fps_queue: Queue,
@@ -88,7 +101,7 @@ def image_generation_process(
enable_similar_image_filter: bool,
similar_image_filter_threshold: float,
similar_image_filter_max_skip_frame: float,
- monitor_receiver : Connection,
+ monitor_receiver: Connection,
) -> None:
"""
Process for generating images based on a prompt using a specified model.
@@ -179,7 +192,7 @@ def image_generation_process(
while True:
try:
- if not close_queue.empty(): # closing check
+ if not close_queue.empty(): # closing check
break
if len(inputs) < frame_buffer_size:
time.sleep(0.005)
@@ -191,9 +204,7 @@ def image_generation_process(
sampled_inputs.append(inputs[len(inputs) - index - 1])
input_batch = torch.cat(sampled_inputs)
inputs.clear()
- output_images = stream.stream(
- input_batch.to(device=stream.device, dtype=stream.dtype)
- ).cpu()
+ output_images = stream.stream(input_batch.to(device=stream.device, dtype=stream.dtype)).cpu()
if frame_buffer_size == 1:
output_images = [output_images]
for output_image in output_images:
@@ -205,10 +216,11 @@ def image_generation_process(
break
print("closing image_generation_process...")
- event.set() # stop capture thread
+ event.set() # stop capture thread
input_screen.join()
print(f"fps: {fps}")
+
def main(
model_id_or_path: str = "KBlueLeaf/kohaku-v2.1",
lora_dict: Optional[Dict[str, float]] = None,
@@ -231,7 +243,7 @@ def main(
"""
Main function to start the image generation and viewer processes.
"""
- ctx = get_context('spawn')
+ ctx = get_context("spawn")
queue = ctx.Queue()
fps_queue = ctx.Queue()
close_queue = Queue()
@@ -262,7 +274,7 @@ def main(
similar_image_filter_threshold,
similar_image_filter_max_skip_frame,
monitor_receiver,
- ),
+ ),
)
process1.start()
@@ -272,7 +284,7 @@ def main(
width,
height,
monitor_sender,
- ),
+ ),
)
monitor_process.start()
monitor_process.join()
@@ -285,10 +297,10 @@ def main(
print("process2 terminated.")
close_queue.put(True)
print("process1 terminating...")
- process1.join(5) # with timeout
+ process1.join(5) # with timeout
if process1.is_alive():
print("process1 still alive. force killing...")
- process1.terminate() # force kill...
+ process1.terminate() # force kill...
process1.join()
print("process1 terminated.")
diff --git a/examples/txt2img/multi.py b/examples/txt2img/multi.py
index 0e50e36f8..1d3301966 100644
--- a/examples/txt2img/multi.py
+++ b/examples/txt2img/multi.py
@@ -1,18 +1,26 @@
import os
import sys
-from typing import Literal, Dict, Optional
+from typing import Dict, Literal, Optional
import fire
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from streamdiffusion import StreamDiffusionWrapper
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
def main(
- output: str = os.path.join(CURRENT_DIR, "..", "..", "images", "outputs",),
+ output: str = os.path.join(
+ CURRENT_DIR,
+ "..",
+ "..",
+ "images",
+ "outputs",
+ ),
model_id_or_path: str = "KBlueLeaf/kohaku-v2.1",
lora_dict: Optional[Dict[str, float]] = None,
prompt: str = "1girl with brown dog hair, thick glasses, smiling",
@@ -22,7 +30,6 @@ def main(
acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
seed: int = 2,
):
-
"""
Process for generating images based on a prompt using a specified model.
diff --git a/examples/txt2img/single.py b/examples/txt2img/single.py
index 80f5ed238..f3c8763aa 100644
--- a/examples/txt2img/single.py
+++ b/examples/txt2img/single.py
@@ -1,6 +1,6 @@
import os
import sys
-from typing import Literal, Dict, Optional
+from typing import Dict, Literal, Optional
import fire
@@ -9,6 +9,7 @@
from streamdiffusion import StreamDiffusionWrapper
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -23,7 +24,6 @@ def main(
use_denoising_batch: bool = False,
seed: int = 2,
):
-
"""
Process for generating images based on a prompt using a specified model.
diff --git a/examples/vid2vid/main.py b/examples/vid2vid/main.py
index c4860d64a..a045b29ad 100644
--- a/examples/vid2vid/main.py
+++ b/examples/vid2vid/main.py
@@ -1,16 +1,18 @@
import os
import sys
-from typing import Literal, Dict, Optional
+from typing import Dict, Literal, Optional
import fire
import torch
from torchvision.io import read_video, write_video
from tqdm import tqdm
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from streamdiffusion import StreamDiffusionWrapper
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -26,7 +28,6 @@ def main(
enable_similar_image_filter: bool = True,
seed: int = 2,
):
-
"""
Process for generating images based on a prompt using a specified model.
diff --git a/requirements-dev.txt b/requirements-dev.txt
new file mode 100644
index 000000000..90f6c5a19
--- /dev/null
+++ b/requirements-dev.txt
@@ -0,0 +1,4 @@
+# Development + quality-harness dependencies (not needed for production inference)
+lpips>=0.1.4
+scikit-image>=0.21
+PyYAML>=6.0
diff --git a/src/streamdiffusion/__init__.py b/src/streamdiffusion/__init__.py
index 88bfdb895..17b226ce3 100644
--- a/src/streamdiffusion/__init__.py
+++ b/src/streamdiffusion/__init__.py
@@ -1,3 +1,4 @@
+from . import _compat # noqa: F401 β applies kvo_cache patch before any diffusers import
from .config import create_wrapper_from_config, load_config, save_config
from .pipeline import StreamDiffusion
from .preprocessing.processors import list_preprocessors
diff --git a/src/streamdiffusion/_compat/__init__.py b/src/streamdiffusion/_compat/__init__.py
new file mode 100644
index 000000000..23f44fa02
--- /dev/null
+++ b/src/streamdiffusion/_compat/__init__.py
@@ -0,0 +1,4 @@
+from .diffusers_kvo_patch import apply as _apply_kvo_patch
+
+
+_apply_kvo_patch()
diff --git a/src/streamdiffusion/_compat/cuda_ipc/VENDORED_VERSION.txt b/src/streamdiffusion/_compat/cuda_ipc/VENDORED_VERSION.txt
new file mode 100644
index 000000000..6a5bc0258
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/VENDORED_VERSION.txt
@@ -0,0 +1,17 @@
+version: 1.4.1
+source: F:\RD_PROJECTS\COMPONENTS\cuda-link
+head_commit: 92989fc
+vendored: 2026-05-17
+files_vendored:
+ __init__.py
+ _nvtx.py
+ activation_barrier.py
+ cuda_graphs.py
+ cuda_ipc_exporter.py
+ cuda_ipc_importer.py
+ cuda_ipc_wrapper.py
+ cuda_runtime_types.py
+ nvml_observer.py
+ shm_protocol.py
+
+To re-sync: copy src/cuda_link/ from the source repo over this directory, then update head_commit and vendored date above.
diff --git a/src/streamdiffusion/_compat/cuda_ipc/__init__.py b/src/streamdiffusion/_compat/cuda_ipc/__init__.py
new file mode 100644
index 000000000..b7f55e78b
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/__init__.py
@@ -0,0 +1,42 @@
+"""
+cuda-link - Zero-copy GPU texture sharing between processes via CUDA IPC.
+
+This package links TouchDesigner and Python processes using CUDA Inter-Process
+Communication for zero-copy GPU texture transfer. Supports PyTorch (GPU tensors),
+CuPy (GPU arrays), and NumPy (CPU arrays) output modes.
+"""
+
+from .cuda_ipc_exporter import CUDAIPCExporter
+from .cuda_ipc_importer import CUPY_AVAILABLE, NUMPY_AVAILABLE, TORCH_AVAILABLE, CUDAIPCImporter
+from .cuda_ipc_wrapper import CUDARuntimeAPI, get_cuda_runtime
+from .nvml_observer import NVML_AVAILABLE, NVMLObserver
+from .shm_protocol import (
+ AcquireResult,
+ DtypeCodec,
+ Metadata,
+ SHMLayout,
+ SlotState,
+ acquire_slot,
+ publish_frame,
+)
+
+
+__version__ = "1.4.1"
+__all__ = [
+ "CUDAIPCExporter",
+ "CUDAIPCImporter",
+ "CUDARuntimeAPI",
+ "get_cuda_runtime",
+ "CUPY_AVAILABLE",
+ "NUMPY_AVAILABLE",
+ "TORCH_AVAILABLE",
+ "NVML_AVAILABLE",
+ "NVMLObserver",
+ "AcquireResult",
+ "DtypeCodec",
+ "Metadata",
+ "SHMLayout",
+ "SlotState",
+ "acquire_slot",
+ "publish_frame",
+]
diff --git a/src/streamdiffusion/_compat/cuda_ipc/_nvtx.py b/src/streamdiffusion/_compat/cuda_ipc/_nvtx.py
new file mode 100644
index 000000000..30a3c3509
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/_nvtx.py
@@ -0,0 +1,88 @@
+"""NVTX annotation shim for cuda-link profiling.
+
+Enabled via environment variables (read once at import, zero-cost when off):
+ CUDALINK_NVTX=1 β top-level phase ranges on the GPU timeline
+ CUDALINK_NVTX_VERBOSE=1 β sub-operation ranges (implies CUDALINK_NVTX=1)
+
+Requires the `nvtx` PyPI package when enabled: pip install nvtx
+
+Usage:
+ from cuda_link import _nvtx
+
+ _nvtx.push_range("cudalink.exporter.export_frame", "green")
+ try:
+ ...gpu work...
+ finally:
+ _nvtx.pop_range()
+
+ # or as a context manager for sub-ranges:
+ with _nvtx.annotate("cudalink.exporter.memcpy", "green"):
+ cuda.memcpy_async(...)
+"""
+
+from __future__ import annotations
+
+import os
+
+
+_VERBOSE = os.environ.get("CUDALINK_NVTX_VERBOSE", "0") == "1"
+_ENABLED = _VERBOSE or os.environ.get("CUDALINK_NVTX", "0") == "1"
+
+if _ENABLED:
+ try:
+ import nvtx as _lib
+
+ _AVAILABLE = True
+ except ImportError:
+ _lib = None
+ _AVAILABLE = False
+else:
+ _lib = None
+ _AVAILABLE = False
+
+
+class _Noop:
+ __slots__ = ()
+
+ def __enter__(self) -> _Noop:
+ return self
+
+ def __exit__(self, *_: object) -> None:
+ pass
+
+
+_NOOP = _Noop()
+
+
+def annotate(message: str, color: str = "white") -> _Noop:
+ """Context manager for a named NVTX range. No-op if NVTX is disabled."""
+ if _AVAILABLE:
+ return _lib.annotate(message, color=color) # type: ignore[union-attr]
+ return _NOOP
+
+
+def verbose_range(message: str, color: str = "white") -> _Noop:
+ """Context manager for a sub-operation range. Only active when CUDALINK_NVTX_VERBOSE=1."""
+ if _AVAILABLE and _VERBOSE:
+ return _lib.annotate(message, color=color) # type: ignore[union-attr]
+ return _NOOP
+
+
+def push_range(message: str, color: str = "white") -> None:
+ """Push a named NVTX range onto the thread-local stack."""
+ if _AVAILABLE:
+ _lib.push_range(message, color=color) # type: ignore[union-attr]
+
+
+def pop_range() -> None:
+ """Pop the innermost NVTX range from the thread-local stack."""
+ if _AVAILABLE:
+ _lib.pop_range() # type: ignore[union-attr]
+
+
+def is_enabled() -> bool:
+ return _AVAILABLE
+
+
+def is_verbose() -> bool:
+ return _AVAILABLE and _VERBOSE
diff --git a/src/streamdiffusion/_compat/cuda_ipc/activation_barrier.py b/src/streamdiffusion/_compat/cuda_ipc/activation_barrier.py
new file mode 100644
index 000000000..bd1fea403
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/activation_barrier.py
@@ -0,0 +1,106 @@
+"""Cross-process SHM activation barrier for cuda-link.
+
+Coordinates Python producer <-> TD-side Sender activation windows.
+When a Sender is initializing, it increments active_count; the producer
+skips export_frame while non-zero (best-effort, no OS atomics needed β
+the 5 s stale-timeout recovers from any stuck state).
+
+Segment layout (64 bytes, little-endian):
+ Offset Size Field Description
+ ------ ---- ----- -----------
+ 0 4 magic 0xCDA1BAAA β guards against alien segments
+ 4 4 version 1 β bumped if layout changes
+ 8 4 active_count Number of Senders inside an activation window
+ 12 4 _pad Align last_change_ns to 8 bytes
+ 16 8 last_change_ns time.monotonic_ns() of most recent write
+ 24 4 barrier_skips Producer-incremented skip-frame counter
+ 28 4 last_writer_pid Diagnostic: PID of last active_count writer
+ 32 32 reserved Zero-filled; reserved for future fields
+"""
+
+from __future__ import annotations
+
+import struct
+import time
+from multiprocessing.shared_memory import SharedMemory
+
+
+SHM_NAME = "cudalink_activation_barrier"
+SHM_SIZE = 64
+MAGIC = 0xCDA1BAAA
+VERSION = 1
+
+# Struct: magic(u32) version(u32) active_count(u32) pad(u32) last_change_ns(u64)
+# barrier_skips(u32) last_writer_pid(u32) reserved(32s)
+_STRUCT = struct.Struct(" SharedMemory:
+ """Open the existing segment or create and initialise it.
+
+ Args:
+ create: When True, create the segment on FileNotFoundError and write
+ the magic/version header. When False, raise FileNotFoundError
+ if the segment does not yet exist.
+
+ Returns:
+ Open SharedMemory handle (caller must close when done).
+ """
+ try:
+ return SharedMemory(name=SHM_NAME)
+ except FileNotFoundError:
+ if not create:
+ raise
+ shm = SharedMemory(name=SHM_NAME, create=True, size=SHM_SIZE)
+ _STRUCT.pack_into(shm.buf, 0, MAGIC, VERSION, 0, 0, 0, 0, 0, b"\x00" * 32)
+ return shm
+
+
+def read_state(shm: SharedMemory) -> tuple[int, int, int]:
+ """Return (active_count, last_change_ns, barrier_skips).
+
+ Snapshot-reads the full 64-byte segment to avoid tearing.
+ """
+ fields = _STRUCT.unpack(bytes(shm.buf[:SHM_SIZE]))
+ # (magic, version, active_count, pad, last_change_ns, barrier_skips, pid, reserved)
+ return fields[2], fields[4], fields[5]
+
+
+def increment(shm: SharedMemory, pid: int) -> int:
+ """Increment active_count, refresh last_change_ns and last_writer_pid.
+
+ Best-effort: no OS-level atomic. Race window is microseconds; the
+ producer-side stale-timeout absorbs any stuck state.
+
+ Returns:
+ New active_count value.
+ """
+ fields = list(_STRUCT.unpack(bytes(shm.buf[:SHM_SIZE])))
+ fields[2] += 1 # active_count
+ fields[4] = time.monotonic_ns() # last_change_ns
+ fields[6] = pid # last_writer_pid
+ _STRUCT.pack_into(shm.buf, 0, *fields)
+ return fields[2]
+
+
+def decrement(shm: SharedMemory, pid: int) -> int:
+ """Decrement active_count (clamps at zero), refresh timestamps.
+
+ Returns:
+ New active_count value.
+ """
+ fields = list(_STRUCT.unpack(bytes(shm.buf[:SHM_SIZE])))
+ fields[2] = max(0, fields[2] - 1) # active_count, no underflow
+ fields[4] = time.monotonic_ns() # last_change_ns
+ fields[6] = pid # last_writer_pid
+ _STRUCT.pack_into(shm.buf, 0, *fields)
+ return fields[2]
+
+
+def bump_skip(shm: SharedMemory) -> None:
+ """Increment barrier_skips counter (producer-only diagnostic)."""
+ fields = list(_STRUCT.unpack(bytes(shm.buf[:SHM_SIZE])))
+ fields[5] += 1 # barrier_skips
+ _STRUCT.pack_into(shm.buf, 0, *fields)
diff --git a/src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py b/src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py
new file mode 100644
index 000000000..454419579
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/cuda_graphs.py
@@ -0,0 +1,314 @@
+"""
+CUDA Graphs Mixin β CUDA Graph capture, instantiation, launch, and node-update methods.
+
+Provides CUDAGraphsMixin, mixed into CUDARuntimeAPI to contribute the graph-lifecycle
+API. All methods rely on self.cudart (the cudart DLL handle) and self.check_error from
+the host class.
+
+Shared between the pip package (cuda_link) and TouchDesigner textDATs.
+Compatible with both Python package and TD COMP namespace imports.
+"""
+
+from __future__ import annotations
+
+import ctypes
+from ctypes import byref, c_int, c_size_t, c_void_p
+
+from .cuda_runtime_types import ( # noqa: E402
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaExtent,
+ cudaMemcpy3DParms,
+ cudaPitchedPtr,
+ cudaPos,
+)
+
+
+class CUDAGraphsMixin:
+ """Mixin contributing CUDA Graph lifecycle methods to CUDARuntimeAPI.
+
+ Requires self.cudart (cudart DLL handle) and self.check_error from the host class.
+ """
+
+ # --- Phase 2: CUDA Graph API wrappers ---
+
+ def stream_begin_capture(self, stream: CUDAStream_t, mode: int = 0) -> None:
+ """Begin capturing a stream into a CUDA graph.
+
+ After this call, operations enqueued on *stream* are recorded into a
+ graph rather than executed immediately. End with stream_end_capture().
+
+ Args:
+ stream: Stream to capture.
+ mode: cudaStreamCaptureMode β 0=global, 1=thread_local, 2=relaxed.
+ In multi-threaded/multi-engine processes prefer 1 (ThreadLocal);
+ 0 (Global) is only "safest" for single-threaded programs.
+
+ Raises:
+ RuntimeError: If capture start fails (e.g., stream already capturing).
+ """
+ result = self.cudart.cudaStreamBeginCapture(stream, c_int(mode))
+ self.check_error(result, "cudaStreamBeginCapture")
+
+ def stream_end_capture(self, stream: CUDAStream_t) -> CUDAGraph_t:
+ """End stream capture and return the captured graph.
+
+ After this call the stream resumes normal execution mode. The returned
+ graph must be instantiated with graph_instantiate() before use, and
+ destroyed with graph_destroy() when done.
+
+ Args:
+ stream: Stream that was passed to stream_begin_capture().
+
+ Returns:
+ CUDAGraph_t handle to the captured graph template.
+
+ Raises:
+ RuntimeError: If capture end fails.
+ """
+ graph = CUDAGraph_t()
+ result = self.cudart.cudaStreamEndCapture(stream, byref(graph))
+ self.check_error(result, "cudaStreamEndCapture")
+ return graph
+
+ def graph_instantiate(self, graph: CUDAGraph_t, flags: int = 0) -> CUDAGraphExec_t:
+ """Instantiate a graph template into an executable graph.
+
+ The executable graph (CUDAGraphExec_t) can be launched repeatedly via
+ graph_launch(). The template graph can be destroyed after instantiation.
+
+ Args:
+ graph: CUDAGraph_t template returned by stream_end_capture().
+ flags: cudaGraphInstantiateFlagDeviceLaunch (0x02) for device-side
+ launch; 0 for normal host-side launch.
+
+ Returns:
+ CUDAGraphExec_t executable graph handle.
+
+ Raises:
+ RuntimeError: If instantiation fails.
+ """
+ from ctypes import c_uint64
+
+ graph_exec = CUDAGraphExec_t()
+ result = self.cudart.cudaGraphInstantiateWithFlags(byref(graph_exec), graph, c_uint64(flags))
+ self.check_error(result, "cudaGraphInstantiateWithFlags")
+ return graph_exec
+
+ def graph_launch(self, graph_exec: CUDAGraphExec_t, stream: CUDAStream_t) -> None:
+ """Launch an executable graph on a stream (single WDDM submission).
+
+ This replaces N individual API calls (stream_wait_event, memcpy_async,
+ record_event) with one batched WDDM submission, reducing kernel-mode
+ transition overhead from NΓ~15Β΅s to ~15Β΅s on Windows WDDM.
+
+ Args:
+ graph_exec: Executable graph from graph_instantiate().
+ stream: Stream on which to launch the graph.
+
+ Raises:
+ RuntimeError: If launch fails.
+ """
+ result = self.cudart.cudaGraphLaunch(graph_exec, stream)
+ self.check_error(result, "cudaGraphLaunch")
+
+ def graph_get_nodes(self, graph: CUDAGraph_t) -> list[CUDAGraphNode_t]:
+ """Return all nodes in a graph in topological (capture) order.
+
+ Useful for discovering node handles after stream capture, before the
+ template graph is destroyed.
+
+ Args:
+ graph: CUDAGraph_t template (must NOT yet be destroyed).
+
+ Returns:
+ List of CUDAGraphNode_t handles in capture order:
+ [EventWaitNode (if present), MemcpyNode, EventRecordNode].
+
+ Raises:
+ RuntimeError: If query fails.
+ """
+ count = c_size_t(0)
+ result = self.cudart.cudaGraphGetNodes(graph, None, byref(count))
+ self.check_error(result, "cudaGraphGetNodes (count)")
+ node_array = (CUDAGraphNode_t * count.value)()
+ result = self.cudart.cudaGraphGetNodes(graph, node_array, byref(count))
+ self.check_error(result, "cudaGraphGetNodes (fill)")
+ return list(node_array)
+
+ def graph_destroy(self, graph: CUDAGraph_t) -> None:
+ """Destroy a graph template (not the executable β use graph_exec_destroy for that).
+
+ Args:
+ graph: Template graph to destroy.
+
+ Raises:
+ RuntimeError: If destruction fails.
+ """
+ result = self.cudart.cudaGraphDestroy(graph)
+ self.check_error(result, "cudaGraphDestroy")
+
+ def graph_exec_destroy(self, graph_exec: CUDAGraphExec_t) -> None:
+ """Destroy an executable graph and free its resources.
+
+ Args:
+ graph_exec: Executable graph to destroy.
+
+ Raises:
+ RuntimeError: If destruction fails.
+ """
+ result = self.cudart.cudaGraphExecDestroy(graph_exec)
+ self.check_error(result, "cudaGraphExecDestroy")
+
+ @staticmethod
+ def make_memcpy3d_params(dst: c_void_p, src: c_void_p, count: int, kind: int) -> cudaMemcpy3DParms:
+ """Build a cudaMemcpy3DParms struct for a flat 1D memory copy.
+
+ Represents the copy as a single-row 2D memcpy (height=1, depth=1) so
+ that 'count' bytes are transferred from src to dst. This is the required
+ form for cudaGraphExecMemcpyNodeSetParams even when the original copy was
+ issued as cudaMemcpyAsync (1D form).
+
+ Args:
+ dst: Destination pointer.
+ src: Source pointer.
+ count: Number of bytes to copy.
+ kind: cudaMemcpyKind (3 = DeviceToDevice).
+
+ Returns:
+ Populated cudaMemcpy3DParms instance.
+ """
+ params = cudaMemcpy3DParms()
+ params.srcArray = None
+ params.srcPos = cudaPos(0, 0, 0)
+ params.srcPtr = cudaPitchedPtr(
+ ptr=ctypes.cast(src, c_void_p),
+ pitch=count,
+ xsize=count,
+ ysize=1,
+ )
+ params.dstArray = None
+ params.dstPos = cudaPos(0, 0, 0)
+ params.dstPtr = cudaPitchedPtr(
+ ptr=ctypes.cast(dst, c_void_p),
+ pitch=count,
+ xsize=count,
+ ysize=1,
+ )
+ params.extent = cudaExtent(width=count, height=1, depth=1)
+ params.kind = kind
+ return params
+
+ def graph_exec_memcpy_node_set_params(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ dst: c_void_p,
+ src: c_void_p,
+ count: int,
+ kind: int,
+ ) -> None:
+ """Update src/dst/count/kind of a memcpy node in an executable graph.
+
+ This is a CPU-only operation (no WDDM submission). Changes take effect
+ on the next graph_launch() call. The extent (count) must match the
+ extent used when the graph was captured β only pointer reassignment
+ within the same buffer size is supported.
+
+ Args:
+ graph_exec: Executable graph containing the node.
+ node: MemcpyNode handle from graph_get_nodes().
+ dst: New destination pointer.
+ src: New source pointer.
+ count: Copy size in bytes (must match captured size).
+ kind: cudaMemcpyKind (must match captured kind).
+
+ Raises:
+ RuntimeError: If parameter update fails.
+ """
+ params = self.make_memcpy3d_params(dst, src, count, kind)
+ result = self.cudart.cudaGraphExecMemcpyNodeSetParams(graph_exec, node, byref(params))
+ self.check_error(result, "cudaGraphExecMemcpyNodeSetParams")
+
+ def graph_exec_memcpy_node_set_params_1d(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ dst: c_void_p,
+ src: c_void_p,
+ count: int,
+ kind: int,
+ ) -> None:
+ """Update src/dst/count/kind of a 1D memcpy node in an executable graph.
+
+ Use this for nodes captured from cudaMemcpyAsync (1D form). The 3D variant
+ (graph_exec_memcpy_node_set_params) returns INVALID_VALUE on 1D nodes.
+ Requires CUDA 11.3+.
+ """
+ dst_int = dst.value if isinstance(dst, c_void_p) else int(dst)
+ src_int = src.value if isinstance(src, c_void_p) else int(src)
+ result = self.cudart.cudaGraphExecMemcpyNodeSetParams1D(
+ graph_exec,
+ node,
+ c_void_p(dst_int),
+ c_void_p(src_int),
+ c_size_t(count),
+ c_int(kind),
+ )
+ self.check_error(result, "cudaGraphExecMemcpyNodeSetParams1D")
+
+ def graph_exec_event_record_node_set_event(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ event: CUDAEvent_t,
+ ) -> None:
+ """Update the event recorded by an event-record node. CUDA 11.4+.
+
+ CPU-only β takes effect on next graph_launch(). Use this to update the
+ per-ring-slot IPC event when the ring slot changes between launches.
+
+ Args:
+ graph_exec: Executable graph containing the node.
+ node: EventRecordNode handle from graph_get_nodes().
+ event: New CUDAEvent_t to record.
+
+ Raises:
+ RuntimeError: If update fails.
+ """
+ result = self.cudart.cudaGraphExecEventRecordNodeSetEvent(graph_exec, node, event)
+ self.check_error(result, "cudaGraphExecEventRecordNodeSetEvent")
+
+ def graph_exec_event_wait_node_set_event(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ event: CUDAEvent_t,
+ ) -> None:
+ """Update the event waited on by an event-wait node. CUDA 11.4+.
+
+ Args:
+ graph_exec: Executable graph containing the node.
+ node: EventWaitNode handle from graph_get_nodes().
+ event: New CUDAEvent_t to wait on.
+
+ Raises:
+ RuntimeError: If update fails.
+ """
+ result = self.cudart.cudaGraphExecEventWaitNodeSetEvent(graph_exec, node, event)
+ self.check_error(result, "cudaGraphExecEventWaitNodeSetEvent")
+
+ def get_runtime_version(self) -> int:
+ """Return the CUDA runtime version as an int.
+
+ Examples: 11030 = CUDA 11.3, 11040 = CUDA 11.4, 12080 = CUDA 12.8.
+ Used to gate optional API calls when the loaded cudart DLL may be from
+ an older patch level (e.g., TouchDesigner ships ``cudart64_110.dll``).
+ """
+ version = c_int(0)
+ result = self.cudart.cudaRuntimeGetVersion(byref(version))
+ self.check_error(result, "cudaRuntimeGetVersion")
+ return int(version.value)
diff --git a/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_exporter.py b/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_exporter.py
new file mode 100644
index 000000000..02a4fdb63
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_exporter.py
@@ -0,0 +1,1164 @@
+"""
+CUDA IPC Exporter for Python Process
+Exports GPU memory FROM Python TO TouchDesigner via CUDA IPC handles.
+
+Enables the reverse direction: Python AI pipeline β TouchDesigner display.
+The TD side receives frames using CUDAIPCExtension in "Receiver" mode.
+
+Usage:
+ from cuda_link import CUDAIPCExporter
+
+ # Export AI-generated frames to TouchDesigner
+ with CUDAIPCExporter(
+ shm_name="ai_output_ipc",
+ height=512, width=512, channels=4, dtype="uint8",
+ ) as exporter:
+ exporter.initialize()
+ while running:
+ # output_tensor: (H, W, 4) uint8 BGRA on GPU
+ exporter.export_frame(
+ gpu_ptr=output_tensor.data_ptr(),
+ size=output_tensor.nelement() * output_tensor.element_size(),
+ )
+
+Architecture:
+ Python GPU tensor --> cudaMemcpy D2D --> Persistent IPC Ring Buffer
+ |
+ IPC Handle in SharedMemory (v0.5.0 protocol)
+ |
+ TouchDesigner (CUDAIPCExtension Receiver) opens handle once
+ --> import_frame(script_top) --> copyCUDAMemory() per frame
+
+Performance:
+ - Initialization: ~1ms (buffer alloc + handle export)
+ - Per-frame: ~177Β΅s at 512x512 @ 60 FPS (includes producer-side stream_synchronize)
+ - Per-frame: ~219Β΅s at 1080p @ 60 FPS (async D2D + stream_synchronize + protocol writes)
+
+Compatibility:
+ - Protocol: v0.5.0 (byte-identical to CUDAIPCExtension/CUDAIPCImporter)
+ - TD side: CUDAIPCExtension in "Receiver" mode reads the SharedMemory
+ - Platform: Windows only (CUDA IPC limitation)
+"""
+
+from __future__ import annotations
+
+import contextlib
+import logging
+import os
+import struct
+import threading
+import time
+import traceback
+from ctypes import c_void_p
+from dataclasses import dataclass, field
+from multiprocessing.shared_memory import SharedMemory
+from typing import TYPE_CHECKING
+
+from . import _nvtx
+from .activation_barrier import bump_skip as _ab_bump
+from .activation_barrier import open_or_create as _ab_open
+from .activation_barrier import read_state as _ab_read
+from .cuda_ipc_wrapper import CUDARuntimeAPI, get_cuda_runtime
+from .cuda_runtime_types import (
+ CUDART_GRAPHS_MIN_VERSION,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+)
+
+
+if TYPE_CHECKING:
+ from .nvml_observer import NVMLObserver
+
+logger = logging.getLogger(__name__)
+
+from .shm_protocol import ( # noqa: E402
+ _ST_U32,
+ MAGIC_OFFSET,
+ METADATA_SIZE,
+ NUM_SLOTS_OFFSET,
+ PROTOCOL_MAGIC,
+ SHM_HEADER_SIZE,
+ SHUTDOWN_FLAG_SIZE,
+ SLOT_SIZE,
+ TIMESTAMP_SIZE,
+ WRITE_IDX_OFFSET,
+ DtypeCodec,
+ Metadata,
+ SHMLayout,
+ bump_version,
+ publish_frame,
+)
+
+
+_DTYPE_ITEMSIZE_MAP = {
+ "float32": 4,
+ "float16": 2,
+ "uint8": 1,
+ "uint16": 2,
+}
+
+# _DTYPE_TO_KIND_BITS is now in shm_protocol.DtypeCodec.encode(); kept as alias for call sites below
+_DTYPE_TO_KIND_BITS = {k: DtypeCodec.encode(k) for k in ("float32", "float16", "uint8", "uint16")}
+
+
+@dataclass
+class ProducerActivationBarrier:
+ """Producer-side activation-barrier state.
+
+ Replaces five scattered attributes on CUDAIPCExporter (_barrier_enabled,
+ _barrier_stale_ns, _barrier_shm, _barrier_skip_log_last_ns,
+ _barrier_stale_log_last_ns) with a single cohesive value object.
+ """
+
+ enabled: bool
+ stale_ns: int
+ shm: SharedMemory | None = None
+ _skip_log_last_ns: int = field(init=False, default=0, repr=False)
+ _stale_log_last_ns: int = field(init=False, default=0, repr=False)
+
+ @classmethod
+ def from_env(cls) -> ProducerActivationBarrier:
+ return cls(
+ enabled=os.getenv("CUDALINK_ACTIVATION_BARRIER", "1") != "0",
+ stale_ns=int(os.getenv("CUDALINK_BARRIER_STALE_NS", str(5 * 1_000_000_000))),
+ )
+
+ def should_skip_publish(self) -> bool:
+ """Hot path: True β caller skips this frame.
+
+ Lazily opens the SHM segment on first call. Applies a stale-timeout so a
+ Sender that crashes mid-init cannot block the producer indefinitely.
+ """
+ if self.shm is None:
+ try:
+ self.shm = _ab_open(create=False)
+ except FileNotFoundError:
+ return False
+ try:
+ active_count, last_change_ns, _ = _ab_read(self.shm)
+ except (OSError, RuntimeError, struct.error):
+ return False
+ if active_count <= 0:
+ return False
+ now_ns = time.monotonic_ns()
+ if now_ns - last_change_ns > self.stale_ns:
+ if now_ns - self._stale_log_last_ns > 1_000_000_000:
+ logger.warning(
+ "[ACTIVATION_BARRIER] stale barrier (count=%d, age=%.1fs) β ignoring",
+ active_count,
+ (now_ns - last_change_ns) / 1e9,
+ )
+ self._stale_log_last_ns = now_ns
+ return False
+ with contextlib.suppress(OSError, RuntimeError, struct.error):
+ _ab_bump(self.shm)
+ if now_ns - self._skip_log_last_ns > 1_000_000_000:
+ logger.info("[ACTIVATION_BARRIER] skipping publish (active_count=%d)", active_count)
+ self._skip_log_last_ns = now_ns
+ return True
+
+ def close(self) -> None:
+ """Idempotent: close SHM handle if held."""
+ if self.shm is not None:
+ with contextlib.suppress(OSError, RuntimeError):
+ self.shm.close()
+ self.shm = None
+
+
+def _read_hws_mode() -> str:
+ """Read WDDM Hardware-Accelerated GPU Scheduling registry key (Windows only).
+
+ Returns "0" (software scheduling), "2" (hardware scheduling enabled),
+ or "unknown" on any error (non-Windows, key absent, permission denied).
+ Emitted as cudalink.startup.hws_mode= NVTX range at exporter init
+ so every nsys capture is self-documenting about the WDDM scheduling mode.
+ """
+ try:
+ import winreg # noqa: PLC0415
+
+ key = winreg.OpenKey(
+ winreg.HKEY_LOCAL_MACHINE,
+ r"SYSTEM\CurrentControlSet\Control\GraphicsDrivers",
+ )
+ value, _ = winreg.QueryValueEx(key, "HwSchMode")
+ winreg.CloseKey(key)
+ return str(value)
+ except Exception: # noqa: BLE001
+ return "unknown"
+
+
+class CUDAIPCExporter:
+ """Python-side exporter for CUDA IPC GPU memory.
+
+ Sends GPU frames FROM Python TO TouchDesigner via CUDA IPC.
+ Pairs with CUDAIPCExtension in "Receiver" mode on the TD side.
+
+ Responsibilities:
+ - Allocate persistent GPU ring buffer (cudaMalloc, 2 MiB aligned)
+ - Export IPC handles + metadata via SharedMemory (v0.5.0 protocol, once at startup)
+ - Per-frame: accept raw GPU pointer, async D2D memcpy to ring slot, record IPC event
+ - 7-step cleanup: shutdown signal β events β stream β SHM close β grace β free β unlink
+
+ Performance:
+ - Initialization: ~1ms (buffer alloc + handle export)
+ - Per-frame overhead: ~177Β΅s at 512x512, ~219Β΅s at 1080p @ 60 FPS (async D2D + stream_synchronize)
+ - Zero CPU memory copies (GPU-direct)
+ """
+
+ def __init__(
+ self,
+ shm_name: str,
+ height: int,
+ width: int,
+ channels: int = 4,
+ dtype: str = "uint8",
+ num_slots: int = 2,
+ debug: bool = False,
+ device: int = 0,
+ ) -> None:
+ """Initialize CUDA IPC exporter.
+
+ Args:
+ shm_name: SharedMemory name. Must match the TD Receiver's Ipcmemname parameter.
+ height: Frame height in pixels.
+ width: Frame width in pixels.
+ channels: Number of channels (default: 4 for BGRA/RGBA).
+ dtype: Data type string: "float32", "float16", or "uint8" (default: "uint8").
+ num_slots: Ring buffer slot count (default: 2 for double-buffering). Range: 1-10.
+ debug: Enable verbose per-frame performance logging.
+ device: CUDA device index to use (default: 0). Sender and receiver must
+ use the same device; IPC handles are device-scoped.
+
+ Raises:
+ ValueError: If dtype is unsupported or num_slots is out of range.
+ """
+ if dtype not in _DTYPE_TO_KIND_BITS:
+ raise ValueError(f"Unsupported dtype: {dtype!r}. Must be one of {list(_DTYPE_TO_KIND_BITS)}")
+ if not (0 < num_slots <= 10):
+ raise ValueError(f"num_slots must be 1-10, got {num_slots}")
+
+ self.shm_name = shm_name
+ self.height = height
+ self.width = width
+ self.channels = channels
+ self.dtype = dtype
+ self.num_slots = num_slots
+ self.debug = debug
+ self.device = device
+
+ # CUDALINK_EXPORT_SYNC: block CPU on ipc_stream after each record_event().
+ # Default on (Phase 3.6 β load-bearing for concurrent topologies; prevents
+ # cycle-2 TDR cascade when a TD Sender shares the process with a TD Receiver).
+ # Set to "0" to opt out for low-latency single-producer scenarios (~100-300Β΅s/frame saved).
+ self._export_sync: bool = os.getenv("CUDALINK_EXPORT_SYNC", "1") != "0"
+
+ # CUDALINK_USE_GRAPHS=1: capture the memcpy_async into a 1-node CUDA Graph and
+ # replay via graph_launch. IPC events (cudaEventInterprocess) and external
+ # stream_wait_event deps cannot be captured, so the graph contains only the D2D
+ # memcpy. Per-frame cost when source_sync not used:
+ # graph_launch (1 WDDM) + record_event (1 WDDM) = 2 submissions vs 3 legacy
+ # When record_source_sync() has been called at least once, stream_wait_event is
+ # issued before graph_launch (3 WDDM β same as legacy).
+ # Requires CUDA 12.x runtime (Python side). On by default; set CUDALINK_USE_GRAPHS=0
+ # to revert to the legacy 3-submission stream path.
+ self._use_graphs: bool = os.getenv("CUDALINK_USE_GRAPHS", "1") == "1"
+ self._graphs_disabled: bool = False # set True if build/launch fails at runtime
+ self._source_sync_recorded: bool = False # set True on first record_source_sync()
+ # One CUDAGraphExec_t + template CUDAGraph_t per ring slot.
+ # Template is kept alive so node handles remain valid for SetParams calls.
+ self._graph_execs: list[CUDAGraphExec_t | None] = [None] * num_slots
+ self._graph_templates: list[CUDAGraph_t | None] = [None] * num_slots
+ self._graph_memcpy_nodes: list[CUDAGraphNode_t | None] = [None] * num_slots
+ # CUDALINK_EXPORT_PROFILE=1: enables fine-grained per-region sub-timers in export_frame.
+ # Mirrors td_exporter/CUDAIPCExtension.py's same knob. Forces debug=True.
+ self._export_profile: bool = os.getenv("CUDALINK_EXPORT_PROFILE", "0") == "1"
+ # CUDALINK_EXPORT_FLUSH_PROBE: calls cudaStreamQuery after check_sticky_error when
+ # _export_sync=False. Forces WDDM-deferred commands to submit without CPU blocking.
+ # Default ON per Phase 3 decision (2026-05-04): ~12 Β΅s/frame cost, collapses
+ # Windows Task Manager 3D-engine reading from ~65% to ~7% on rigs where WDDM
+ # defers submissions. NVML true compute load is unchanged. Set to "0" to disable.
+ self._export_flush_probe: bool = os.getenv("CUDALINK_EXPORT_FLUSH_PROBE", "1") == "1"
+ # CUDALINK_ACTIVATION_BARRIER: read cudalink_activation_barrier SHM on each
+ # export_frame and skip publishing while a TD-side Sender is in its activation window.
+ # Cross-process backpressure mechanism β no CUDA stream coupling.
+ # Default on (Phase 3.6 β no-op when no TD-side Sender exists since the SHM
+ # counter stays at 0; gracefully skipped if SHM is missing). Set to "0" to opt out.
+ self._barrier = ProducerActivationBarrier.from_env()
+ if self._export_profile:
+ self.debug = True # profile mode requires timing path (mirrors TD L248-249)
+
+ # Derived sizes
+ itemsize = _DTYPE_ITEMSIZE_MAP[dtype]
+ self.data_size = height * width * channels * itemsize # Actual data bytes
+
+ # CUDA state
+ self.cuda: CUDARuntimeAPI | None = None
+ self._initialized = False
+ self.ipc_stream = None # Dedicated non-blocking CUDA stream
+ self.source_sync_event = None # Cross-stream sync event (GPU-side, non-blocking CPU)
+
+ # Ring buffer state (arrays sized by num_slots)
+ self.dev_ptrs: list = [None] * num_slots # GPU buffer pointers
+ self.ipc_handles: list = [None] * num_slots # IPC memory handles
+ self.ipc_events: list = [None] * num_slots # IPC events for GPU sync
+ self.ipc_event_handles: list = [None] * num_slots # Exportable event handles
+ self.write_idx: int = 0 # Monotonic frame counter
+
+ # SharedMemory
+ self.shm_handle: SharedMemory | None = None
+ self.buffer_size: int = self.data_size # Will be 2MiB-aligned in initialize()
+
+ # Performance tracking
+ self.frame_count: int = 0
+ self.total_memcpy_us: float = 0.0
+ self.total_export_us: float = 0.0
+ self.total_stream_wait_us: float = 0.0
+ self.total_record_event_us: float = 0.0
+ self.total_shm_write_us: float = 0.0
+ self.total_sync_us: float = 0.0
+ self.total_sticky_check_us: float = 0.0
+ self.total_flush_probe_us: float = 0.0
+
+ # Cached layout + offsets (set by _write_handles_to_shm, constant thereafter)
+ self._layout: SHMLayout = SHMLayout(num_slots)
+ self._ts_offset: int = self._layout.timestamp_offset
+ self._shutdown_offset: int = self._layout.shutdown_offset
+
+ # C2: device-affinity validation
+ # CUDALINK_STRICT_DEVICE=1 raises ValueError on mismatch; default warns+continues.
+ self._strict_device: bool = os.getenv("CUDALINK_STRICT_DEVICE", "0") == "1"
+ self._source_sync_device_warned: bool = False # emit at most one log per instance
+ # Cache of ptr values already validated (capped at 8 β covers typical buffer-rotation)
+ self._ptr_device_cache: set[int] = set()
+
+ # ------------------------------------------------------------------
+ # Initialization
+ # ------------------------------------------------------------------
+
+ def initialize(self) -> bool:
+ """Allocate GPU ring buffer, create IPC handles, write to SharedMemory.
+
+ Must be called before export_frame(). Safe to call multiple times
+ (idempotent β returns True if already initialized).
+
+ Returns:
+ True if initialization succeeded, False on error.
+ """
+ if self._initialized:
+ logger.debug("Already initialized")
+ return True
+
+ try:
+ # Load CUDA runtime bound to the requested device
+ self.cuda = get_cuda_runtime(device=self.device)
+ actual_device = self.cuda.get_device()
+ if actual_device != self.device:
+ raise RuntimeError(
+ f"Device mismatch: requested device {self.device} but CUDA context "
+ f"is bound to device {actual_device}. Ensure no other code calls "
+ "cudaSetDevice() with a different index before initialize()."
+ )
+ logger.info("Loaded CUDA runtime on device %d", actual_device)
+
+ hws_mode = _read_hws_mode()
+ logger.info("WDDM HwSchMode: %s (0=software, 2=hardware/GPU-P, unknown=non-Windows)", hws_mode)
+ with _nvtx.annotate(f"cudalink.startup.hws_mode={hws_mode}", "cyan"):
+ pass
+
+ # Create or reuse dedicated non-blocking IPC stream.
+ # cudaStreamNonBlocking (0x01) prevents the default stream from
+ # implicitly synchronising with this stream. Default is high-priority
+ # so the D2D memcpy preempts lower-priority compute work in the TD context.
+ # CUDALINK_LIB_STREAM_PRIO=normal: drop to default-priority stream.
+ # Use when a TD-side Sender-B coexists with this Python producer in the
+ # same machine and the high-priority stream contends with TD init.
+ if self.ipc_stream is None:
+ lib_stream_high_prio = os.environ.get("CUDALINK_LIB_STREAM_PRIO", "high") != "normal"
+ if lib_stream_high_prio:
+ self.ipc_stream = self.cuda.create_stream_with_priority(flags=0x01)
+ logger.info("Created IPC stream (high-priority): 0x%016x", int(self.ipc_stream.value))
+ else:
+ self.ipc_stream = self.cuda.create_stream(flags=0x01)
+ logger.info("Created IPC stream (normal-priority): 0x%016x", int(self.ipc_stream.value))
+ else:
+ logger.debug("Reusing IPC stream: 0x%016x", int(self.ipc_stream.value))
+
+ # Create cross-stream sync event for GPU-side producer ordering.
+ # ipc_stream is non-blocking and does NOT implicitly sync with any other stream,
+ # so callers MUST either call record_source_sync() or torch.cuda.synchronize()
+ # before export_frame(). The event enables the GPU-side-only path.
+ if self.source_sync_event is None:
+ self.source_sync_event = self.cuda.create_sync_event()
+ logger.info("Created cross-stream source sync event")
+
+ # Apply 2 MiB alignment (NVIDIA requirement: prevents information disclosure)
+ alignment = 2 * 1024 * 1024
+ self.buffer_size = ((self.data_size + alignment - 1) // alignment) * alignment
+ logger.info(
+ "Buffer: %.1f KB data, %.1f KB aligned, %d slots",
+ self.data_size / 1024,
+ self.buffer_size / 1024,
+ self.num_slots,
+ )
+
+ # PHASE 1: Allocate GPU ring buffer + create IPC handles
+ for slot in range(self.num_slots):
+ self.dev_ptrs[slot] = self.cuda.malloc(self.buffer_size)
+ logger.info(
+ "Slot %d: allocated %.1f KB at 0x%016x",
+ slot,
+ self.buffer_size / 1024,
+ self.dev_ptrs[slot].value,
+ )
+
+ # Memory handle (once at startup β reused every frame)
+ self.ipc_handles[slot] = self.cuda.ipc_get_mem_handle(self.dev_ptrs[slot])
+ logger.debug("Slot %d: created IPC mem handle (64 bytes)", slot)
+
+ # Event handle for GPU-side synchronization
+ self.ipc_events[slot] = self.cuda.create_ipc_event()
+ self.ipc_event_handles[slot] = self.cuda.ipc_get_event_handle(self.ipc_events[slot])
+ logger.debug("Slot %d: created IPC event (64 bytes)", slot)
+
+ logger.info("Created %d IPC buffer slots with GPU-side sync", self.num_slots)
+
+ # PHASE 2: Create SharedMemory
+ shm_size = (
+ SHM_HEADER_SIZE + (self.num_slots * SLOT_SIZE) + SHUTDOWN_FLAG_SIZE + METADATA_SIZE + TIMESTAMP_SIZE
+ )
+ try:
+ self.shm_handle = SharedMemory(name=self.shm_name)
+ logger.info("Opened existing SharedMemory: %s", self.shm_name)
+ except FileNotFoundError:
+ self.shm_handle = SharedMemory(name=self.shm_name, create=True, size=shm_size)
+ logger.info("Created SharedMemory: %s (%d bytes)", self.shm_name, shm_size)
+
+ # PHASE 3: Write protocol header + IPC handles + metadata
+ self._write_handles_to_shm()
+ self._write_metadata_to_shm()
+
+ # Cache timestamp offset from the pre-computed layout (set by _write_handles_to_shm)
+ self._ts_offset = self._layout.timestamp_offset
+
+ self._initialized = True
+
+ # CUDA Graphs build (after IPC stream / events / ring buffer are ready).
+ # Gated on cudart >= 11.4 (cudaGraphInstantiateWithFlags + the
+ # EventRecordNodeSetEvent / EventWaitNodeSetEvent APIs all require 11.4+).
+ # cudaGraphInstantiateWithFlags was introduced specifically to avoid the
+ # 5-arg (10.0-11.8) vs 3-arg (12.0+) ABI split in cudaGraphInstantiate.
+ if self._use_graphs:
+ try:
+ rt_version = self.cuda.get_runtime_version()
+ except (RuntimeError, OSError) as exc:
+ rt_version = 0
+ logger.warning("cudaRuntimeGetVersion failed (%s) β disabling graphs", exc)
+ if rt_version >= CUDART_GRAPHS_MIN_VERSION:
+ self._build_export_graphs()
+ else:
+ logger.warning(
+ "CUDALINK_USE_GRAPHS=1 ignored: cudart %d < %d "
+ "(cudaGraphInstantiateWithFlags requires 11.4+). "
+ "Falling back to legacy stream path.",
+ rt_version,
+ CUDART_GRAPHS_MIN_VERSION,
+ )
+ self._graphs_disabled = True
+
+ logger.info("Initialization complete β ready for zero-copy GPU transfer")
+ return True
+
+ except (OSError, RuntimeError, ValueError) as e:
+ logger.error("Initialization failed: %s", e)
+ traceback.print_exc()
+ return False
+
+ # ------------------------------------------------------------------
+ # Protocol write helpers
+ # ------------------------------------------------------------------
+
+ def _write_handles_to_shm(self) -> None:
+ """Write v0.5.0 protocol header + IPC handles to SharedMemory.
+
+ Layout:
+ [0-3] magic (uint32 LE) = 0x43495044
+ [4-11] version (uint64 LE) β incremented on each init
+ [12-15] num_slots (uint32 LE)
+ [16-19] write_idx (uint32 LE) β initialized to 0
+
+ Per slot (128 bytes):
+ [base..+63] cudaIpcMemHandle_t (64 bytes)
+ [base+64..+64] cudaIpcEventHandle_t (64 bytes)
+
+ Footer:
+ shutdown_flag (1 byte) = 0
+ """
+ if self.shm_handle is None or not all(self.ipc_handles):
+ return
+
+ # Write protocol header: magic, bump version, reset num_slots and write_idx
+ _ST_U32.pack_into(self.shm_handle.buf, MAGIC_OFFSET, PROTOCOL_MAGIC)
+ new_version = bump_version(self.shm_handle.buf)
+ _ST_U32.pack_into(self.shm_handle.buf, NUM_SLOTS_OFFSET, self.num_slots)
+ _ST_U32.pack_into(self.shm_handle.buf, WRITE_IDX_OFFSET, 0) # write_idx = 0 initially
+
+ # Write per-slot handles
+ for slot in range(self.num_slots):
+ base_offset = SHM_HEADER_SIZE + (slot * SLOT_SIZE)
+
+ mem_handle_bytes = bytes(self.ipc_handles[slot].internal)
+ self.shm_handle.buf[base_offset : base_offset + 64] = mem_handle_bytes
+
+ if self.ipc_event_handles[slot]:
+ event_handle_bytes = bytes(self.ipc_event_handles[slot].reserved)
+ self.shm_handle.buf[base_offset + 64 : base_offset + 128] = event_handle_bytes
+
+ # Pre-compute layout once; cache derived offsets for export_frame() hot-path
+ self._layout = SHMLayout(self.num_slots)
+ self._shutdown_offset = self._layout.shutdown_offset
+ self.shm_handle.buf[self._shutdown_offset] = 0
+
+ logger.info("Wrote IPC handles v%d to SharedMemory", new_version)
+
+ def _write_metadata_to_shm(self) -> None:
+ """Write texture metadata to the extended protocol region.
+
+ Layout (20 bytes after shutdown flag):
+ +0 width (uint32)
+ +4 height (uint32)
+ +8 num_comps (uint32)
+ +12 format_kind (uint8) β cudaChannelFormatKind: 0=Signed,1=Unsigned,2=Float
+ +13 bits_per_comp (uint8) β 8/16/32/64
+ +14 flags (uint16) β bit0=bfloat16; rest reserved=0
+ +16 data_size (uint32) β actual bytes (before 2MiB alignment)
+ """
+ if self.shm_handle is None or self.data_size == 0:
+ return
+
+ kind, bits, flags = DtypeCodec.encode(self.dtype)
+ Metadata(
+ width=self.width,
+ height=self.height,
+ num_comps=self.channels,
+ format_kind=kind,
+ bits_per_comp=bits,
+ flags=flags,
+ data_size=self.data_size,
+ ).pack_into(self.shm_handle.buf, self._layout)
+
+ logger.debug(
+ "Wrote metadata: %dx%dx%d, dtype=%s (kind=%d bits=%d flags=0x%04x), data_size=%dB",
+ self.width,
+ self.height,
+ self.channels,
+ self.dtype,
+ kind,
+ bits,
+ flags,
+ self.data_size,
+ )
+
+ # ------------------------------------------------------------------
+ # CUDA Graph helpers (Phase 2, CUDALINK_USE_GRAPHS=1)
+ # ------------------------------------------------------------------
+
+ def _build_export_graphs(self) -> None:
+ """Capture the D2D memcpy into a 1-node CUDA Graph exec per ring slot.
+
+ Graph topology per slot (1 node):
+ MemcpyNode(D2D, ring_slot[slot] β placeholder_src)
+
+ stream_wait_event is NOT captured: external events (recorded outside this
+ capture) do not produce EventWait nodes in global-mode capture.
+ record_event on IPC events is NOT captured: cudaEventInterprocess events
+ raise cudaErrorStreamCaptureUnsupported (error 900) during capture.
+
+ Per-frame cost:
+ - source_sync not used (common): graph_launch + record_event = 2 WDDM
+ - source_sync used: stream_wait_event + graph_launch + record_event = 3 WDDM
+
+ On failure the stream is restored to normal mode before returning so that
+ the legacy fallback path can use it without error.
+ """
+ assert self.cuda is not None
+ assert self.ipc_stream is not None
+
+ placeholder_src = self.dev_ptrs[0]
+
+ for slot in range(self.num_slots):
+ capture_started = False
+ try:
+ self.cuda.stream_begin_capture(self.ipc_stream, mode=1) # ThreadLocal: safer in multi-engine processes
+ capture_started = True
+ self.cuda.memcpy_async(
+ dst=self.dev_ptrs[slot],
+ src=placeholder_src,
+ count=self.data_size,
+ kind=3, # D2D
+ stream=self.ipc_stream,
+ )
+ template_graph = self.cuda.stream_end_capture(self.ipc_stream)
+ capture_started = False
+
+ nodes = self.cuda.graph_get_nodes(template_graph)
+ if len(nodes) != 1:
+ self.cuda.graph_destroy(template_graph)
+ raise RuntimeError(f"Unexpected graph node count {len(nodes)} (expected 1: MemcpyNode).")
+ memcpy_node = nodes[0]
+
+ graph_exec = self.cuda.graph_instantiate(template_graph)
+ # Keep template alive: node handles from the template must remain
+ # valid for cudaGraphExecMemcpyNodeSetParams1D per-frame updates.
+ # Template is destroyed in _destroy_export_graphs().
+
+ self._graph_execs[slot] = graph_exec
+ self._graph_templates[slot] = template_graph
+ self._graph_memcpy_nodes[slot] = memcpy_node
+ logger.debug("Built export graph for slot %d (1-node: Memcpy)", slot)
+
+ except (RuntimeError, OSError) as exc:
+ if capture_started:
+ try:
+ abandoned_graph = self.cuda.stream_end_capture(self.ipc_stream)
+ self.cuda.graph_destroy(abandoned_graph)
+ except (RuntimeError, OSError):
+ pass
+ logger.warning(
+ "CUDA Graph build failed for slot %d (%s) β "
+ "disabling graphs for this exporter instance and falling back to "
+ "legacy stream path. Set CUDALINK_USE_GRAPHS=0 to suppress.",
+ slot,
+ exc,
+ )
+ self._graphs_disabled = True
+ self._destroy_export_graphs()
+ return
+
+ logger.info("CUDA export graphs built for %d slots (CUDALINK_USE_GRAPHS=1)", self.num_slots)
+
+ def _destroy_export_graphs(self) -> None:
+ """Destroy all CUDA Graph exec objects and their templates (called from cleanup())."""
+ if self.cuda is None:
+ return
+ for slot, graph_exec in enumerate(self._graph_execs):
+ if graph_exec is not None:
+ try:
+ self.cuda.graph_exec_destroy(graph_exec)
+ logger.debug("Destroyed export graph exec slot %d", slot)
+ except (RuntimeError, OSError) as e:
+ logger.error("Error destroying graph exec slot %d: %s", slot, e)
+ self._graph_execs[slot] = None
+ for slot, template in enumerate(getattr(self, "_graph_templates", [])):
+ if template is not None:
+ with contextlib.suppress(RuntimeError, OSError):
+ self.cuda.graph_destroy(template)
+ self._graph_templates[slot] = None
+ self._graph_memcpy_nodes = [None] * self.num_slots
+
+ # ------------------------------------------------------------------
+ # Hot path
+ # ------------------------------------------------------------------
+
+ def record_source_sync(self, producer_stream_handle: int) -> None:
+ """Record sync event on the producer's CUDA stream (GPU-side, non-blocking CPU).
+
+ Call this AFTER your GPU kernel writes to the source buffer, BEFORE export_frame().
+ export_frame() will make ipc_stream wait for this event GPU-side before the D2D
+ memcpy, ensuring source data is ready without blocking the CPU.
+
+ This replaces ``torch.cuda.synchronize()`` and saves ~0.2-0.5ms per frame.
+
+ ``ipc_stream`` is created with cudaStreamNonBlocking and does NOT implicitly
+ synchronize with any other stream (including the legacy default stream). Without
+ this call, export_frame() has no ordering guarantee with the caller's GPU work.
+
+ Args:
+ producer_stream_handle: Raw CUDA stream integer. Examples:
+ - PyTorch: ``torch.cuda.current_stream().cuda_stream``
+ - CuPy: ``cupy.cuda.get_current_stream().ptr``
+ - Raw CUDA: the ``cudaStream_t`` cast to int
+
+ If this method is never called, ``source_sync_event`` stays in its initial
+ unrecorded state and stream_wait_event() in export_frame() is a benign no-op
+ (backward compatible). The caller is then responsible for their own sync.
+ """
+ if self.source_sync_event is not None and self.cuda is not None:
+ # C2: opportunistic stream-device check. CUDA Runtime < 12.8 has no
+ # cudaStreamGetDevice; we validate via the current context device instead.
+ # Emitted at most once per exporter instance to avoid log spam.
+ if not self._source_sync_device_warned:
+ current_device = self.cuda.get_device()
+ if current_device != self.device:
+ msg = (
+ f"record_source_sync: current CUDA device ({current_device}) "
+ f"does not match exporter device ({self.device}). "
+ "Call cudaSetDevice(device) before creating your producer stream. "
+ "Set CUDALINK_STRICT_DEVICE=1 to raise instead of warn."
+ )
+ if self._strict_device:
+ raise ValueError(msg)
+ logger.error(msg)
+ self._source_sync_device_warned = True
+ self.cuda.record_event(self.source_sync_event, CUDAStream_t(producer_stream_handle))
+ self._source_sync_recorded = True
+
+ def export_frame(self, gpu_ptr: int, size: int) -> bool:
+ """Export one frame from GPU memory via IPC ring buffer.
+
+ Args:
+ gpu_ptr: Source GPU pointer (from tensor.data_ptr()).
+ size: Buffer size in bytes (from tensor.nelement() * tensor.element_size()).
+
+ Returns:
+ True if export succeeded, False on error.
+ """
+ if not self._initialized:
+ logger.warning("Not initialized β call initialize() first")
+ return False
+ if size != self.data_size:
+ logger.error("Size mismatch: expected %d, got %d", self.data_size, size)
+ return False
+ # Activation-barrier check: skip publish if a TD-side Sender is in its activation window.
+ if self._barrier.enabled and self._barrier.should_skip_publish():
+ # Reassert the per-frame heartbeat even on the skip path.
+ # The consumer reads shutdown_flag == 1 as "producer gone"; bypassing the
+ # success-path heartbeat write on skip frames would leave any stale 1-byte
+ # uncleared and trip a false "Sender shutdown detected" on the TD receiver.
+ if self.shm_handle is not None and self._shutdown_offset:
+ with contextlib.suppress(OSError, BufferError):
+ self.shm_handle.buf[self._shutdown_offset] = 0
+ return False
+ debug = self.debug
+ if debug:
+ frame_start = time.perf_counter()
+ _nvtx.push_range(f"cudalink.exporter.slot{self.write_idx % self.num_slots}", "green")
+ try:
+ slot = self.write_idx % self.num_slots
+
+ # C2: validate source pointer's device and memory type on first appearance.
+ # Cache keyed by pointer integer (cap at 8 to cover typical buffer-rotation).
+ gpu_ptr_int = gpu_ptr if isinstance(gpu_ptr, int) else int(gpu_ptr)
+ if gpu_ptr_int not in self._ptr_device_cache:
+ attrs = self.cuda.pointer_get_attributes(gpu_ptr_int)
+ if attrs.type not in (2, 3): # 2=device, 3=managed (both valid for D2D)
+ msg = (
+ f"export_frame: gpu_ptr 0x{gpu_ptr_int:016x} is not device/managed "
+ f"memory (type={attrs.type}). Pass a GPU-resident pointer. "
+ "Set CUDALINK_STRICT_DEVICE=1 to raise instead of warn."
+ )
+ if self._strict_device:
+ raise ValueError(msg)
+ logger.error(msg)
+ elif attrs.device != self.device:
+ msg = (
+ f"export_frame: gpu_ptr 0x{gpu_ptr_int:016x} belongs to device "
+ f"{attrs.device}, but exporter is bound to device {self.device}. "
+ "Set CUDALINK_STRICT_DEVICE=1 to raise instead of warn."
+ )
+ if self._strict_device:
+ raise ValueError(msg)
+ logger.error(msg)
+ if len(self._ptr_device_cache) < 8:
+ self._ptr_device_cache.add(gpu_ptr_int)
+
+ # --- GPU copy + sync: graph path or legacy path ---
+ #
+ # Graph path (CUDALINK_USE_GRAPHS=1): replays a 1-node graph (MemcpyNode).
+ # stream_wait_event is issued before graph_launch only when record_source_sync()
+ # has been called (tracked by _source_sync_recorded); otherwise it is skipped
+ # (the event is in its initial "complete" state β no ordering needed).
+ # record_event is issued after graph_launch (IPC events not capturable).
+ # - source_sync not used: graph_launch + record_event = 2 WDDM (vs 3 legacy)
+ # - source_sync used: stream_wait + graph_launch + record_event = 3 WDDM
+ if self._use_graphs and not self._graphs_disabled:
+ if debug:
+ _t = time.perf_counter()
+ try:
+ self.cuda.graph_exec_memcpy_node_set_params_1d(
+ self._graph_execs[slot],
+ self._graph_memcpy_nodes[slot],
+ dst=self.dev_ptrs[slot],
+ src=c_void_p(gpu_ptr),
+ count=self.data_size,
+ kind=3,
+ )
+ if self._source_sync_recorded and self.source_sync_event is not None:
+ self.cuda.stream_wait_event(self.ipc_stream, self.source_sync_event, 0)
+ self.cuda.graph_launch(self._graph_execs[slot], self.ipc_stream)
+ if self.ipc_events[slot]:
+ self.cuda.record_event(self.ipc_events[slot], stream=self.ipc_stream)
+ except (RuntimeError, OSError) as _graph_err:
+ logger.warning(
+ "Graph launch failed (%s) β disabling graphs, retrying via legacy path",
+ _graph_err,
+ )
+ self._graphs_disabled = True
+ goto_legacy = True
+ else:
+ goto_legacy = False
+ if debug:
+ self.total_memcpy_us += (time.perf_counter() - _t) * 1_000_000
+ else:
+ goto_legacy = True
+
+ if goto_legacy:
+ # Legacy path: 3 separate WDDM submissions per frame.
+ # GPU-side wait on source_sync_event is a no-op if record_source_sync()
+ # was never called (event stays in its initial "complete" state).
+ if debug:
+ _t = time.perf_counter()
+ if self.source_sync_event is not None:
+ self.cuda.stream_wait_event(self.ipc_stream, self.source_sync_event, 0)
+ if debug:
+ self.total_stream_wait_us += (time.perf_counter() - _t) * 1_000_000
+
+ if debug:
+ memcpy_start = time.perf_counter()
+ with _nvtx.verbose_range("cudalink.exporter.memcpy", "green"):
+ self.cuda.memcpy_async(
+ dst=self.dev_ptrs[slot],
+ src=c_void_p(gpu_ptr),
+ count=self.data_size,
+ kind=3, # cudaMemcpyDeviceToDevice
+ stream=self.ipc_stream,
+ )
+ if debug:
+ self.total_memcpy_us += (time.perf_counter() - memcpy_start) * 1_000_000
+
+ if debug:
+ _t = time.perf_counter()
+ with _nvtx.verbose_range("cudalink.exporter.record_event", "green"):
+ if self.ipc_events[slot]:
+ self.cuda.record_event(self.ipc_events[slot], stream=self.ipc_stream)
+ if debug:
+ self.total_record_event_us += (time.perf_counter() - _t) * 1_000_000
+
+ # Optional CPU-blocking sync after record_event. Disabled by default.
+ #
+ # When enabled (CUDALINK_EXPORT_SYNC=1): blocks the CPU until the GPU has
+ # executed the memcpy + record_event, guaranteeing query_event() returns True
+ # on the first poll. Cost: ~13Β΅s @ 512Β², ~100Β΅s @ 1080p per frame.
+ #
+ # When disabled (default): the GPU event is recorded asynchronously.
+ # Consumers using the stream-ordered path (get_frame(stream=...)) issue
+ # cudaStreamWaitEvent, which correctly waits for the event-record to execute
+ # regardless of whether the CPU has synced. The _wait_for_slot() timeout
+ # guards the polling path. Skipping this sync saves ~13-100Β΅s/frame.
+ if self._export_sync:
+ if debug and self._export_profile:
+ _t_sync = time.perf_counter()
+ self.cuda.stream_synchronize(self.ipc_stream)
+ if debug and self._export_profile:
+ self.total_sync_us += (time.perf_counter() - _t_sync) * 1_000_000
+
+ if debug and self._export_profile:
+ _t_sticky = time.perf_counter()
+ self.cuda.check_sticky_error("export_frame")
+ if debug and self._export_profile:
+ self.total_sticky_check_us += (time.perf_counter() - _t_sticky) * 1_000_000
+
+ # WDDM deferred-submission probe: forces pending GPU work to submit without
+ # blocking. Per CUDA Handbook p3/pg56, WDDM buffers commands until a flush;
+ # cudaStreamQuery triggers that flush. Only active when EXPORT_FLUSH_PROBE=1
+ # and EXPORT_SYNC=0 (if sync is on, the stream is already flushed above).
+ if self._export_flush_probe and not self._export_sync:
+ if debug and self._export_profile:
+ _t_fp = time.perf_counter()
+ with _nvtx.verbose_range("cudalink.exporter.flush_probe", "green"):
+ self.cuda.stream_query(self.ipc_stream)
+ if debug and self._export_profile:
+ self.total_flush_probe_us += (time.perf_counter() - _t_fp) * 1_000_000
+
+ # Write timestamp, clear shutdown_flag, then publish write_idx LAST.
+ # Ordering matters: the consumer reads shutdown_flag BEFORE write_idx, so
+ # clearing it before incrementing write_idx ensures the consumer always sees
+ # shutdown_flag=0 when it detects a new frame (atomicity improvement).
+ if debug:
+ _t = time.perf_counter()
+ with _nvtx.verbose_range("cudalink.exporter.shm_write", "green"):
+ self.write_idx += 1
+ publish_frame(self.shm_handle.buf, self._layout, self.write_idx, time.perf_counter())
+ if debug:
+ self.total_shm_write_us += (time.perf_counter() - _t) * 1_000_000
+
+ self.frame_count += 1
+
+ if debug:
+ frame_time = (time.perf_counter() - frame_start) * 1_000_000
+ self.total_export_us += frame_time
+
+ if self.frame_count % 97 == 0:
+ n = self.frame_count
+ logger.debug(
+ "Frame %d: slot=%d | stream_wait=%.1fus memcpy=%.1fus "
+ "record_event=%.1fus shm_write=%.1fus | total=%.1fus",
+ n,
+ slot,
+ self.total_stream_wait_us / n,
+ self.total_memcpy_us / n,
+ self.total_record_event_us / n,
+ self.total_shm_write_us / n,
+ self.total_export_us / n,
+ )
+ if self._export_profile:
+ avg_wait = self.total_stream_wait_us / n
+ avg_memcpy = self.total_memcpy_us / n
+ avg_record = self.total_record_event_us / n
+ avg_sync = self.total_sync_us / n
+ avg_sticky = self.total_sticky_check_us / n
+ avg_fp = self.total_flush_probe_us / n
+ avg_shm = self.total_shm_write_us / n
+ avg_total = self.total_export_us / n
+ avg_unacc = avg_total - (
+ avg_wait + avg_memcpy + avg_record + avg_sync + avg_sticky + avg_fp + avg_shm
+ )
+ logger.debug(
+ "Frame %d [PROFILE] pre=0.0us interop=0.0us post=0.0us"
+ " memcpy=%.1fus record=%.1fus sync=%.1fus"
+ " sticky=%.1fus flush_probe=%.1fus shm=%.1fus unacc=%.1fus",
+ n,
+ avg_memcpy,
+ avg_record,
+ avg_sync,
+ avg_sticky,
+ avg_fp,
+ avg_shm,
+ avg_unacc,
+ )
+
+ return True
+
+ except (OSError, RuntimeError) as e:
+ logger.error("Export failed: %s", e)
+ traceback.print_exc()
+ return False
+ finally:
+ _nvtx.pop_range()
+
+ # ------------------------------------------------------------------
+ # Cleanup
+ # ------------------------------------------------------------------
+
+ def _is_cuda_context_valid(self) -> bool:
+ """Check if CUDA context is still valid.
+
+ The CUDA context may be destroyed before cleanup() is called
+ (e.g., when the Python process is terminating). Checking this
+ prevents spurious CUDA errors during cleanup.
+ """
+ if self.cuda is None:
+ return False
+ try:
+ self.cuda.cudart.cudaGetLastError()
+ return True
+ except (OSError, RuntimeError):
+ return False
+
+ def cleanup(self) -> None:
+ """Cleanup all CUDA IPC resources.
+
+ 7-step shutdown sequence (order is critical):
+ 1. Signal shutdown to consumer via SharedMemory flag
+ 2. Destroy IPC events (sender-side resources, safe to destroy)
+ 3. Destroy IPC stream
+ 4. Close SharedMemory (don't unlink yet β consumer may still read flag)
+ 5. Grace period (100ms) for consumer to detect shutdown and close handles
+ 6. Free GPU buffers (cudaFree blocks until consumer closes IPC handles)
+ 7. Unlink SharedMemory (producer owns it and is responsible for cleanup)
+ """
+ # Double-cleanup guard: skip if already cleaned up
+ if not self._initialized and self.shm_handle is None:
+ return
+
+ cuda_valid = self._is_cuda_context_valid()
+ if not cuda_valid:
+ logger.warning("CUDA context already destroyed β skipping GPU cleanup")
+
+ # STEP 1: Signal shutdown to consumer
+ if self.shm_handle:
+ try:
+ shutdown_offset = SHM_HEADER_SIZE + (self.num_slots * SLOT_SIZE)
+ struct.pack_into(" None:
+ try:
+ self.cuda.free(ptr)
+ logger.debug("Freed GPU buffer slot %d", s)
+ except (RuntimeError, OSError) as e:
+ logger.error("Error freeing GPU buffer slot %d: %s", s, e)
+
+ t = threading.Thread(target=_free, args=(dev_ptr,), daemon=True)
+ t.start()
+ t.join(timeout=0.5)
+ if t.is_alive():
+ logger.warning(
+ "cudaFree slot %d timed out (0x%016x) β receiver may not have closed "
+ "the IPC handle. Leaking GPU memory; OS will reclaim on process exit.",
+ slot,
+ dev_ptr.value,
+ )
+
+ # STEP 7: Unlink SharedMemory (producer is owner and responsible for cleanup)
+ try:
+ shm_temp = SharedMemory(name=self.shm_name)
+ shm_temp.close()
+ shm_temp.unlink()
+ logger.info("Unlinked SharedMemory")
+ except FileNotFoundError:
+ pass # Already unlinked
+ except (OSError, RuntimeError) as e:
+ logger.warning("Could not unlink SharedMemory: %s", e)
+
+ self._barrier.close()
+
+ # Reset all state to prevent double-free on re-entry
+ self.dev_ptrs = [None] * self.num_slots
+ self.ipc_events = [None] * self.num_slots
+ self.ipc_handles = [None] * self.num_slots
+ self.ipc_event_handles = [None] * self.num_slots
+ self._initialized = False
+
+ logger.info("Cleanup complete")
+
+ # ------------------------------------------------------------------
+ # Context manager
+ # ------------------------------------------------------------------
+
+ def __enter__(self) -> CUDAIPCExporter:
+ """Enter context manager β returns self for use in 'with' statement."""
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: object,
+ ) -> None:
+ """Exit context manager β cleanup resources regardless of exception."""
+ self.cleanup()
+ return None # Don't suppress exceptions
+
+ def __del__(self) -> None:
+ """Destructor β cleanup on garbage collection if not already done."""
+ if getattr(self, "_initialized", False):
+ self.cleanup()
+
+ # ------------------------------------------------------------------
+ # Status
+ # ------------------------------------------------------------------
+
+ def is_ready(self) -> bool:
+ """Check if exporter is ready to export frames.
+
+ Returns:
+ True if initialized with all GPU buffers allocated.
+ """
+ return self._initialized and all(ptr is not None for ptr in self.dev_ptrs)
+
+ def attach_nvml_observer(self, observer: NVMLObserver) -> None:
+ """Attach an NVMLObserver for GPU telemetry in get_stats().
+
+ Args:
+ observer: NVMLObserver instance (must already be started).
+ """
+ self._nvml_observer = observer
+
+ def get_stats(self) -> dict:
+ """Get exporter statistics for monitoring.
+
+ Returns:
+ Dictionary with current exporter state and performance metrics.
+ Includes an 'nvml' sub-dict when an NVMLObserver is attached.
+ """
+ avg_memcpy = self.total_memcpy_us / self.frame_count if self.frame_count > 0 else 0.0
+ avg_total = self.total_export_us / self.frame_count if self.frame_count > 0 else 0.0
+ stats: dict = {
+ "initialized": self._initialized,
+ "shm_name": self.shm_name,
+ "resolution": f"{self.width}x{self.height}x{self.channels}",
+ "dtype": self.dtype,
+ "num_slots": self.num_slots,
+ "data_size_kb": self.data_size / 1024,
+ "buffer_size_mb": self.buffer_size / (1024 * 1024),
+ "frame_count": self.frame_count,
+ "write_idx": self.write_idx,
+ "avg_memcpy_us": avg_memcpy,
+ "avg_total_us": avg_total,
+ "dev_ptrs": [f"0x{ptr.value:016x}" if ptr else "NULL" for ptr in self.dev_ptrs],
+ }
+ observer = getattr(self, "_nvml_observer", None)
+ if observer is not None:
+ stats["nvml"] = observer.snapshot()
+ return stats
diff --git a/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py b/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py
new file mode 100644
index 000000000..b4b821df6
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_importer.py
@@ -0,0 +1,1310 @@
+"""
+CUDA IPC Importer for Python Process
+Imports GPU memory from TouchDesigner via CUDA IPC handles
+
+Usage:
+ # PyTorch tensor (GPU, zero-copy)
+ importer = CUDAIPCImporter(shm_name="cudalink_output_ipc", shape=(512, 512, 4))
+ tensor = importer.get_frame() # torch.Tensor on GPU
+
+ # Numpy array (CPU, D2H copy)
+ importer = CUDAIPCImporter(shm_name="cudalink_output_ipc", shape=(512, 512, 4))
+ array = importer.get_frame_numpy() # numpy array on CPU
+
+Architecture:
+ TouchDesigner Process β IPC Handle in SharedMemory
+ β
+ Python Process β Open Handle β torch.as_tensor() or numpy D2H copy
+ (once) (zero-copy) (GPUβCPU)
+
+Value objects:
+ IPCConnection β CUDA runtime, SHM handle, per-slot dev_ptrs/ipc_events, layout.
+ Format β Parsed metadata (shape, dtype, frame_nbytes, numpy_dtype).
+ TorchBuffers β Per-slot zero-copy tensor views (built eagerly).
+ CupyBuffers β Per-slot zero-copy CuPy array views (built eagerly).
+ NumpyBuffers β Pinned host buffer + D2H streams (built lazily on first get_frame_numpy).
+"""
+
+from __future__ import annotations
+
+import contextlib
+import ctypes
+import logging
+import os
+import struct
+import sys
+import time
+import traceback
+from dataclasses import dataclass
+from multiprocessing.shared_memory import SharedMemory
+from typing import TYPE_CHECKING
+
+from . import _nvtx
+
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from .nvml_observer import NVMLObserver
+
+# Windows timer-resolution helper β reduces time.sleep floor from ~15ms to ~1ms.
+# The winmm DLL handle is cached at module level so the load cost is paid once.
+if sys.platform == "win32":
+ try:
+ _winmm = ctypes.WinDLL("winmm")
+ except OSError:
+ _winmm = None
+else:
+ _winmm = None
+
+
+class _HighResTimer:
+ """Context manager that requests 1ms timer resolution on Windows.
+
+ On Windows, the default system timer tick is ~15.6ms, making
+ ``time.sleep(0.0001)`` wake up 15-150x later than intended. Calling
+ ``timeBeginPeriod(1)`` drops the floor to ~1ms for the duration of the
+ with-block, then restores the default on exit. No-op on non-Windows.
+ """
+
+ __slots__ = ("_active",)
+
+ def __enter__(self) -> _HighResTimer:
+ self._active = _winmm is not None
+ if self._active:
+ _winmm.timeBeginPeriod(1)
+ return self
+
+ def __exit__(self, *_: object) -> None:
+ if self._active:
+ _winmm.timeEndPeriod(1)
+
+
+# Optional dependencies with fallback
+try:
+ import torch
+
+ TORCH_AVAILABLE = True
+except ImportError:
+ torch = None
+ TORCH_AVAILABLE = False
+
+try:
+ import numpy as np
+
+ NUMPY_AVAILABLE = True
+except ImportError:
+ np = None
+ NUMPY_AVAILABLE = False
+
+try:
+ import cupy as cp
+
+ CUPY_AVAILABLE = True
+except ImportError:
+ cp = None
+ CUPY_AVAILABLE = False
+
+from .cuda_ipc_wrapper import CUDARuntimeAPI, get_cuda_runtime # noqa: E402
+from .cuda_runtime_types import cudaIpcEventHandle_t, cudaIpcMemHandle_t # noqa: E402
+
+
+# Byte size per dtype β module-level constant avoids dict construction on every _dtype_itemsize() call
+_DTYPE_SIZES: dict = {"float32": 4, "float16": 2, "bfloat16": 2, "uint8": 1, "uint16": 2, "int8": 1, "int16": 2}
+
+from .shm_protocol import ( # noqa: E402
+ _ST_BBH,
+ MAGIC_OFFSET,
+ MAGIC_SIZE,
+ NUM_SLOTS_OFFSET,
+ NUM_SLOTS_SIZE,
+ PROTOCOL_MAGIC,
+ SHM_HEADER_SIZE,
+ SLOT_SIZE,
+ VERSION_OFFSET,
+ VERSION_SIZE,
+ AcquireResult,
+ DtypeCodec,
+ SHMLayout,
+ SlotState,
+ acquire_slot,
+)
+
+
+def _decode_dtype_str(kind: int, bits: int, flags: int) -> str:
+ return DtypeCodec.decode(kind, bits, flags)
+
+
+# ============================================================
+# Value objects
+# ============================================================
+
+
+@dataclass(frozen=True)
+class Format:
+ """Parsed frame format β shape, dtype, and precomputed derivations.
+
+ Immutable after construction. Two constructors:
+ - from_shm(): parse the extended metadata block in SharedMemory.
+ - from_overrides(): build from caller-supplied shape/dtype (no SHM read).
+ """
+
+ width: int
+ height: int
+ num_comps: int
+ kind: int
+ bits: int
+ flags: int
+ dtype_str: str
+ shape: tuple
+ numpy_dtype: object # np.dtype or None when numpy not available
+ frame_nbytes: int
+
+ @classmethod
+ def from_shm(cls, shm_buf: object, num_slots: int) -> Format | None:
+ """Parse extended metadata block from shared memory.
+
+ Returns None when the block is absent or contains zeros (sender not yet
+ written metadata).
+ """
+ layout = SHMLayout(num_slots)
+ metadata_offset = layout.metadata_offset
+ try:
+ width = struct.unpack(" 0 and height > 0 and num_comps > 0:
+ dtype_str = _decode_dtype_str(kind, bits, flags)
+ shape = (height, width, num_comps)
+ itemsize = _DTYPE_SIZES.get(dtype_str, bits // 8 or 4)
+ frame_nbytes = height * width * num_comps * itemsize
+ numpy_dtype = np.dtype(dtype_str) if NUMPY_AVAILABLE else None
+ return cls(
+ width=width,
+ height=height,
+ num_comps=num_comps,
+ kind=kind,
+ bits=bits,
+ flags=flags,
+ dtype_str=dtype_str,
+ shape=shape,
+ numpy_dtype=numpy_dtype,
+ frame_nbytes=frame_nbytes,
+ )
+ except (struct.error, ValueError, IndexError):
+ pass
+ return None
+
+ @classmethod
+ def from_overrides(cls, shape: tuple, dtype_str: str) -> Format:
+ """Build from caller-supplied shape/dtype (no SHM read).
+
+ kind/bits/flags are left as 0 sentinels β they are diagnostic fields only
+ and are not used by frame consumers.
+ """
+ height, width, num_comps = shape
+ itemsize = _DTYPE_SIZES.get(dtype_str, 4)
+ frame_nbytes = height * width * num_comps * itemsize
+ numpy_dtype = np.dtype(dtype_str) if NUMPY_AVAILABLE else None
+ return cls(
+ width=width,
+ height=height,
+ num_comps=num_comps,
+ kind=0,
+ bits=0,
+ flags=0,
+ dtype_str=dtype_str,
+ shape=shape,
+ numpy_dtype=numpy_dtype,
+ frame_nbytes=frame_nbytes,
+ )
+
+
+@dataclass
+class IPCConnection:
+ """Live CUDA IPC connection β runtime, SHM handle, per-slot GPU resources, layout.
+
+ Mutable: dev_ptrs/ipc_events/ipc_handles are populated slot-by-slot during
+ _open_ipc_slots(), then nulled in-place by close_ipc_handles() / close().
+ """
+
+ cuda: object # CUDARuntimeAPI
+ shm_handle: object # SharedMemory or None after close()
+ ipc_version: int
+ num_slots: int
+ ipc_handles: list # [cudaIpcMemHandle_t | None]
+ dev_ptrs: list # [c_void_p | None]
+ ipc_events: list # [event_t | None]
+ layout: object # SHMLayout
+ shutdown_offset: int
+ timestamp_offset: int
+
+ def close_ipc_handles(self) -> None:
+ """Close IPC mem handles and events. SharedMemory stays open (used by _reinitialize)."""
+ for slot, dev_ptr in enumerate(self.dev_ptrs):
+ if dev_ptr is not None:
+ try:
+ self.cuda.ipc_close_mem_handle(dev_ptr)
+ logger.info("Closed IPC handle for slot %d", slot)
+ except (RuntimeError, OSError) as e:
+ logger.error("Error closing IPC handle for slot %d: %s", slot, e)
+ self.dev_ptrs[slot] = None
+
+ for slot, event in enumerate(self.ipc_events):
+ if event is not None:
+ try:
+ self.cuda.destroy_event(event)
+ logger.info("Destroyed IPC event for slot %d", slot)
+ except (RuntimeError, OSError) as e:
+ logger.error("Error destroying event for slot %d: %s", slot, e)
+ self.ipc_events[slot] = None
+
+ def close(self) -> None:
+ """Close IPC handles and SharedMemory. Idempotent."""
+ self.close_ipc_handles()
+ if self.shm_handle is not None:
+ try:
+ self.shm_handle.close()
+ logger.info("Closed SharedMemory")
+ except (OSError, BufferError) as e:
+ logger.error("Error closing SharedMemory: %s", e)
+ self.shm_handle = None
+
+
+@dataclass
+class TorchBuffers:
+ """Per-slot zero-copy torch.Tensor views of GPU memory (built eagerly at init)."""
+
+ tensors: list # [torch.Tensor]
+ wrappers: list # GC keep-alive refs for __cuda_array_interface__ wrappers
+
+ @classmethod
+ def build(cls, conn: IPCConnection, fmt: Format) -> TorchBuffers:
+ """Create one zero-copy tensor view per slot via __cuda_array_interface__."""
+ typestr_map = {"float32": " None:
+ self.__cuda_array_interface__ = interface
+
+ wrapper = CUDAArrayWrapper(cuda_array_interface)
+ tensor = torch.as_tensor(wrapper, device="cuda")
+ wrappers.append(wrapper)
+ tensors.append(tensor)
+
+ return cls(tensors=tensors, wrappers=wrappers)
+
+
+@dataclass
+class CupyBuffers:
+ """Per-slot zero-copy CuPy array views of GPU memory (built eagerly at init)."""
+
+ arrays: list # [cp.ndarray]
+
+ @classmethod
+ def build(cls, conn: IPCConnection, fmt: Format) -> CupyBuffers:
+ """Create one zero-copy CuPy array view per slot via UnownedMemory."""
+ dtype_map = {"float32": cp.float32, "float16": cp.float16, "uint8": cp.uint8, "uint16": cp.uint16}
+ cp_dtype = dtype_map.get(fmt.dtype_str)
+ if cp_dtype is None:
+ raise ValueError(f"Unsupported dtype for CuPy: {fmt.dtype_str}")
+
+ arrays = []
+ for slot in range(conn.num_slots):
+ if conn.dev_ptrs[slot] is None:
+ raise RuntimeError(f"Device pointer for slot {slot} not initialized")
+ ptr_value = int(conn.dev_ptrs[slot].value)
+ mem = cp.cuda.UnownedMemory(ptr_value, fmt.frame_nbytes, owner=conn)
+ memptr = cp.cuda.MemoryPointer(mem, 0)
+ arrays.append(cp.ndarray(fmt.shape, dtype=cp_dtype, memptr=memptr))
+
+ return cls(arrays=arrays)
+
+
+@dataclass
+class NumpyBuffers:
+ """Pinned host buffer + D2H streams for numpy frame consumption (built lazily).
+
+ NumpyBuffers owns the CUDA streams and pinned host allocation. close() tears
+ them down idempotently.
+ """
+
+ cuda: object # CUDARuntimeAPI (same instance as IPCConnection.cuda)
+ fmt: Format
+ buffer: object # np.ndarray β reusable D2H destination
+ pinned_ptr: object # cudaMallocHost result or None
+ host_registered_arr: object # cudaHostRegister fallback array or None
+ pinned_memory_available: bool
+ primary_stream: object # primary D2H CUDA stream (also d2h_streams[0])
+ d2h_streams: list # one per CUDALINK_D2H_STREAMS value; slot 0 == primary_stream
+ d2h_events: list # join-barrier sync events, one per stream
+ num_streams: int
+
+ @classmethod
+ def build(cls, conn: IPCConnection, fmt: Format, num_streams: int) -> NumpyBuffers:
+ """Allocate pinned host buffer + D2H streams.
+
+ Allocation ladder: cudaMallocHost (portable pinned) β cudaHostRegister
+ (page-locked) β pageable fallback. Matches current _setup_numpy_buffer logic.
+ """
+ cuda = conn.cuda
+ nbytes = fmt.frame_nbytes
+
+ # Create streams
+ primary_stream = cuda.create_stream(flags=0x01) # cudaStreamNonBlocking
+ logger.debug("Created numpy stream: 0x%016x", int(primary_stream.value))
+ d2h_streams = [primary_stream] + [cuda.create_stream(flags=0x01) for _ in range(num_streams - 1)]
+ d2h_events = [cuda.create_sync_event() for _ in range(num_streams)]
+ if num_streams > 1:
+ logger.info("Multi-stream D2H enabled: %d streams (CUDALINK_D2H_STREAMS=%d)", num_streams, num_streams)
+
+ pinned_ptr = None
+ host_registered_arr = None
+ buffer = None
+ pinned_memory_available = False
+
+ try:
+ # cudaHostAllocPortable (0x01) makes the allocation accessible from any
+ # CUDA context in the process β needed when PyTorch and CuPy coexist.
+ pinned_ptr = cuda.malloc_host_alloc(nbytes, flags=0x01)
+ buf = (ctypes.c_ubyte * nbytes).from_address(pinned_ptr.value)
+ buffer = np.frombuffer(buf, dtype=fmt.numpy_dtype).reshape(fmt.shape)
+ pinned_memory_available = True
+ logger.debug("Allocated portable pinned numpy buffer: %s, %s", fmt.shape, fmt.dtype_str)
+ except (RuntimeError, OSError) as e:
+ logger.warning(
+ "cudaMallocHost failed for %d bytes (%.1f MB) β trying cudaHostRegister: %s",
+ nbytes,
+ nbytes / 1_048_576,
+ e,
+ )
+ try:
+ fallback_arr = np.empty(fmt.shape, dtype=fmt.numpy_dtype)
+ cuda.host_register(fallback_arr.ctypes.data, fallback_arr.nbytes)
+ host_registered_arr = fallback_arr
+ buffer = fallback_arr
+ pinned_memory_available = True
+ logger.info("cudaHostRegister succeeded β using registered pinned memory")
+ except (RuntimeError, OSError) as e2:
+ logger.warning(
+ "cudaHostRegister also failed β falling back to pageable memory "
+ "(expect ~2x slower D2H bandwidth): %s",
+ e2,
+ )
+ buffer = np.empty(fmt.shape, dtype=fmt.numpy_dtype)
+ pinned_memory_available = False
+
+ return cls(
+ cuda=cuda,
+ fmt=fmt,
+ buffer=buffer,
+ pinned_ptr=pinned_ptr,
+ host_registered_arr=host_registered_arr,
+ pinned_memory_available=pinned_memory_available,
+ primary_stream=primary_stream,
+ d2h_streams=d2h_streams,
+ d2h_events=d2h_events,
+ num_streams=num_streams,
+ )
+
+ def needs_rebuild(self, fmt: Format) -> bool:
+ """True when the pre-allocated buffer doesn't match the new format."""
+ return self.buffer.shape != fmt.shape or self.buffer.dtype != fmt.numpy_dtype
+
+ def close(self) -> None:
+ """Idempotent teardown: free pinned allocation, destroy streams and events."""
+ if self.pinned_ptr is not None:
+ try:
+ self.cuda.free_host(self.pinned_ptr)
+ logger.debug("Freed pinned numpy buffer")
+ except (RuntimeError, OSError) as e:
+ logger.debug("free_host skipped (context gone): %s", e)
+ self.pinned_ptr = None
+
+ if self.host_registered_arr is not None:
+ try:
+ self.cuda.host_unregister(self.host_registered_arr.ctypes.data)
+ except (RuntimeError, OSError) as e:
+ logger.debug("host_unregister failed: %s", e)
+ self.host_registered_arr = None
+
+ for evt in self.d2h_events:
+ if evt is not None:
+ with contextlib.suppress(RuntimeError, OSError):
+ self.cuda.destroy_event(evt)
+ self.d2h_events.clear()
+
+ for i, stream in enumerate(self.d2h_streams):
+ if i == 0:
+ continue # primary_stream destroyed below
+ if stream is not None:
+ with contextlib.suppress(RuntimeError, OSError):
+ self.cuda.destroy_stream(stream)
+ self.d2h_streams.clear()
+
+ if self.primary_stream is not None:
+ try:
+ self.cuda.destroy_stream(self.primary_stream)
+ logger.debug("Destroyed numpy stream")
+ except (RuntimeError, OSError) as e:
+ logger.debug("numpy stream destroy skipped (context gone): %s", e)
+ self.primary_stream = None
+
+
+# ============================================================
+# Importer
+# ============================================================
+
+
+class CUDAIPCImporter:
+ """Python-side importer for CUDA IPC GPU memory.
+
+ Responsibilities:
+ - Read 64-byte IPC handle from SharedMemory (once at startup)
+ - Open handle using cudaIpcOpenMemHandle() (once)
+ - Create persistent torch.Tensor view (zero-copy) or numpy array (D2H copy)
+ - Return tensor/array for each frame
+
+ Performance:
+ - Initialization: ~10-100ΞΌs (one-time handle opening)
+ - Per-frame (torch): < 1ΞΌs (just return existing tensor)
+ - Per-frame (numpy): ~300ΞΌs-5ms depending on resolution and dtype (GPUβCPU D2H copy)
+ """
+
+ def __init__(
+ self,
+ shm_name: str = "cudalink_output_ipc",
+ shape: tuple[int, int, int] | None = None,
+ dtype: str | None = None,
+ debug: bool = False,
+ timeout_ms: float = 5000.0,
+ device: int = 0,
+ ) -> None:
+ """Initialize CUDA IPC importer.
+
+ Args:
+ shm_name: SharedMemory name where IPC handle is stored
+ shape: Expected tensor shape (height, width, channels). If None, auto-detect from metadata.
+ dtype: Data type as string: "float32", "float16", or "uint8". If None, auto-detect from metadata.
+ debug: Enable verbose debug logging (default: False)
+ timeout_ms: Timeout for waiting on producer events in milliseconds (default: 5000.0)
+ device: CUDA device index (default: 0). Must match the sender's device.
+ IPC handles are device-scoped; opening a handle on the wrong device
+ causes error 400 (cudaErrorInvalidValue).
+ """
+ # Construction config (kept in sync with _format after init)
+ self.shm_name = shm_name
+ self.shape = shape # May be None initially (will be auto-detected)
+ self.dtype = dtype # May be None initially (will be auto-detected)
+ self.debug = debug
+ self.timeout_ms = timeout_ms
+ self.device = device
+
+ # N1: spin-then-sleep configuration.
+ # Phase 1: tight cudaEventQuery spin for up to _spin_us microseconds (no sleep).
+ # Phase 2: existing time.sleep(0.0001) poll loop (unchanged).
+ # CUDALINK_WAIT_SPIN_US=0 disables Phase 1, restoring pre-batch-2 behaviour.
+ self._spin_us: int = int(os.getenv("CUDALINK_WAIT_SPIN_US", "200"))
+
+ # Multi-stream D2H config (NumpyBuffers reads this at build time)
+ self._d2h_num_streams: int = max(1, int(os.getenv("CUDALINK_D2H_STREAMS", "1")))
+
+ # Initialization gate
+ self._initialized = False
+
+ # Value-object references (all None until _initialize() succeeds)
+ self._conn: IPCConnection | None = None
+ self._format: Format | None = None
+ self._torch: TorchBuffers | None = None
+ self._cupy: CupyBuffers | None = None
+ self._numpy: NumpyBuffers | None = None
+
+ # Frame tracking
+ self.frame_count = 0
+ self._last_write_idx = 0
+
+ # Performance metrics
+ self.total_wait_event_time = 0.0
+ self.total_get_frame_time = 0.0
+ self.total_shm_read_us: float = 0.0
+ self.last_latency = 0.0
+ # N1: spin-phase vs sleep-phase breakdown counters
+ self.total_wait_spin_us: float = 0.0
+ self.total_wait_sleep_us: float = 0.0
+ self.wait_spin_hits: int = 0
+ self.wait_sleep_hits: int = 0
+
+ # _numpy_dtype() cache (for pre-init or post-cleanup calls)
+ self._cached_dtype_str: str = ""
+ self._cached_numpy_dtype: object = None
+
+ # Auto-initialize
+ self._initialize()
+
+ # ------------------------------------------------------------------
+ # Convenience dtype methods (read self.dtype; kept for backward compat)
+ # ------------------------------------------------------------------
+
+ def _dtype_itemsize(self) -> int:
+ """Get byte size per element for the configured dtype."""
+ return _DTYPE_SIZES[self.dtype]
+
+ def _numpy_dtype(self) -> np.dtype:
+ """Get numpy dtype from string dtype (cached)."""
+ if not NUMPY_AVAILABLE:
+ raise RuntimeError("numpy is required but not installed")
+ if self.dtype != self._cached_dtype_str:
+ self._cached_numpy_dtype = np.dtype(self.dtype)
+ self._cached_dtype_str = self.dtype
+ return self._cached_numpy_dtype
+
+ def _torch_dtype(self) -> torch.dtype:
+ """Get torch dtype from string dtype."""
+ if not TORCH_AVAILABLE:
+ raise RuntimeError("torch is required but not installed")
+ mapping = {"float32": torch.float32, "float16": torch.float16, "uint8": torch.uint8}
+ if hasattr(torch, "uint16"):
+ mapping["uint16"] = torch.uint16
+ dtype = mapping.get(self.dtype)
+ if dtype is None:
+ raise RuntimeError(
+ f"dtype '{self.dtype}' requires PyTorch >= 2.5 (torch.uint16 not available). "
+ "Use get_frame_numpy() instead, or upgrade PyTorch."
+ )
+ return dtype
+
+ def _resolve_stream(self, stream: object) -> int | None:
+ """Extract raw CUDA stream pointer from torch/cupy stream or int."""
+ if stream is None:
+ return None
+ if isinstance(stream, int):
+ return stream
+ if TORCH_AVAILABLE and hasattr(stream, "cuda_stream"):
+ return stream.cuda_stream
+ if hasattr(stream, "ptr"):
+ return stream.ptr
+ raise TypeError(
+ f"Unsupported stream type: {type(stream)}. Expected torch.cuda.Stream, cupy.cuda.Stream, or int."
+ )
+
+ # ------------------------------------------------------------------
+ # Phase methods (each returns its piece; orchestrator assembles them)
+ # ------------------------------------------------------------------
+
+ def _setup_runtime(self) -> CUDARuntimeAPI:
+ """Load CUDA runtime on self.device; raise on device mismatch."""
+ cuda = get_cuda_runtime(device=self.device)
+ actual_device = cuda.get_device()
+ if actual_device != self.device:
+ raise RuntimeError(
+ f"Device mismatch: requested device {self.device} but CUDA context "
+ f"is bound to device {actual_device}. Sender and receiver must use "
+ "the same device index."
+ )
+ logger.info("Loaded CUDA runtime on device %d", actual_device)
+ return cuda
+
+ def _open_and_validate_shm(self) -> tuple:
+ """Open SharedMemory and validate protocol magic, version, num_slots, shutdown flag.
+
+ Returns:
+ (shm, num_slots, ipc_version) on success. Raises on any failure.
+ """
+ try:
+ shm = SharedMemory(name=self.shm_name)
+ except FileNotFoundError:
+ logger.error("SharedMemory '%s' not found", self.shm_name)
+ logger.error("Make sure TouchDesigner CUDAIPCExporter is initialized first")
+ raise
+
+ logger.info("Opened SharedMemory: %s", self.shm_name)
+
+ # Validate protocol magic
+ try:
+ magic = struct.unpack(" 10:
+ logger.error(
+ "Invalid num_slots=%d read from SharedMemory. Protocol error or corrupted SHM (expected 1-10).",
+ num_slots,
+ )
+ shm.close()
+ raise ValueError(f"Invalid num_slots={num_slots}")
+
+ shutdown_offset = SHM_HEADER_SIZE + num_slots * SLOT_SIZE
+ try:
+ shutdown_flag = shm.buf[shutdown_offset]
+ except (OSError, BufferError, IndexError) as e:
+ logger.error("Could not read shutdown flag: %s", e)
+ shm.close()
+ raise
+
+ if shutdown_flag == 1:
+ logger.warning("Sender shutdown flag detected - SharedMemory is stale")
+ shm.close()
+ raise RuntimeError("Sender shutdown flag set β SharedMemory is stale")
+
+ logger.info("Ring buffer with %d slots (v%d)", num_slots, ipc_version)
+ return shm, num_slots, ipc_version
+
+ def _parse_format(self, shm: object, num_slots: int) -> Format:
+ """Read extended metadata block and return a Format.
+
+ Uses caller-supplied shape/dtype overrides when provided; falls back to
+ SHM metadata, then to (512,512,4)/'float32' on parse failure.
+ Updates self.shape and self.dtype to stay in sync with the returned Format.
+ """
+ if self.shape is None or self.dtype is None:
+ fmt_from_shm = Format.from_shm(shm.buf, num_slots)
+ if fmt_from_shm is not None:
+ shape = self.shape if self.shape is not None else fmt_from_shm.shape
+ dtype_str = self.dtype if self.dtype is not None else fmt_from_shm.dtype_str
+ if shape != fmt_from_shm.shape or dtype_str != fmt_from_shm.dtype_str:
+ # Override one dimension but parsed the other β rebuild from overrides
+ fmt = Format.from_overrides(shape, dtype_str)
+ else:
+ fmt = fmt_from_shm
+ if self.shape is None:
+ logger.info("Auto-detected shape: %s", fmt.shape)
+ if self.dtype is None:
+ logger.info("Auto-detected dtype: %s", fmt.dtype_str)
+ else:
+ logger.warning("Could not auto-detect metadata; using fallback: shape=(512,512,4), dtype='float32'")
+ shape = self.shape or (512, 512, 4)
+ dtype_str = self.dtype or "float32"
+ fmt = Format.from_overrides(shape, dtype_str)
+ else:
+ # Both provided by caller β no SHM metadata read needed
+ fmt = Format.from_overrides(self.shape, self.dtype)
+
+ # Keep construction hints in sync with the resolved format
+ self.shape = fmt.shape
+ self.dtype = fmt.dtype_str
+ return fmt
+
+ def _open_ipc_slots(
+ self,
+ cuda: CUDARuntimeAPI,
+ shm: object,
+ num_slots: int,
+ ipc_version: int,
+ fmt: Format,
+ ) -> IPCConnection:
+ """Open all IPC mem + event handles; return a live IPCConnection."""
+ ipc_handles: list = [None] * num_slots
+ dev_ptrs: list = [None] * num_slots
+ ipc_events: list = [None] * num_slots
+
+ for slot in range(num_slots):
+ base_offset = SHM_HEADER_SIZE + slot * SLOT_SIZE
+
+ # Read + open memory handle (64 bytes)
+ mem_handle_bytes = bytes(shm.buf[base_offset : base_offset + 64])
+ ipc_handles[slot] = cudaIpcMemHandle_t.from_buffer_copy(mem_handle_bytes)
+ # Flag 1 = cudaIpcMemLazyEnablePeerAccess
+ dev_ptrs[slot] = cuda.ipc_open_mem_handle(ipc_handles[slot], flags=1)
+
+ # Read + open event handle (64 bytes) if present
+ event_handle_bytes = bytes(shm.buf[base_offset + 64 : base_offset + 128])
+ if any(event_handle_bytes):
+ try:
+ ipc_event_handle = cudaIpcEventHandle_t.from_buffer_copy(event_handle_bytes)
+ ipc_events[slot] = cuda.ipc_open_event_handle(ipc_event_handle)
+ except (RuntimeError, OSError) as e:
+ logger.debug("Failed to open IPC event for slot %d: %s", slot, e)
+ ipc_events[slot] = None
+
+ tensor_info = f"tensor shape={fmt.shape}" if TORCH_AVAILABLE else "torch N/A"
+ logger.info(
+ "Slot %d: GPU at 0x%016x, %s, event=%s",
+ slot,
+ dev_ptrs[slot].value,
+ tensor_info,
+ "YES" if ipc_events[slot] else "NO",
+ )
+
+ logger.info("Opened %d IPC buffer slots with GPU-side sync", num_slots)
+
+ layout = SHMLayout(num_slots)
+ return IPCConnection(
+ cuda=cuda,
+ shm_handle=shm,
+ ipc_version=ipc_version,
+ num_slots=num_slots,
+ ipc_handles=ipc_handles,
+ dev_ptrs=dev_ptrs,
+ ipc_events=ipc_events,
+ layout=layout,
+ shutdown_offset=layout.shutdown_offset,
+ timestamp_offset=layout.timestamp_offset,
+ )
+
+ # ------------------------------------------------------------------
+ # Orchestrator
+ # ------------------------------------------------------------------
+
+ def _initialize(self) -> bool:
+ """Initialize CUDA IPC resources.
+
+ Returns True on success; False on any failure (already logged).
+ """
+ if self._initialized:
+ logger.debug("Already initialized")
+ return True
+
+ try:
+ cuda = self._setup_runtime()
+ shm, num_slots, ipc_version = self._open_and_validate_shm()
+ fmt = self._parse_format(shm, num_slots)
+ conn = self._open_ipc_slots(cuda, shm, num_slots, ipc_version, fmt)
+
+ self._conn = conn
+ self._format = fmt
+ self._torch = TorchBuffers.build(conn, fmt) if TORCH_AVAILABLE else None
+ self._cupy = CupyBuffers.build(conn, fmt) if CUPY_AVAILABLE else None
+ self._numpy = None # lazy β built on first get_frame_numpy()
+ self._last_write_idx = 0
+ self._initialized = True
+ logger.info("Initialization complete - ready for zero-copy GPU access")
+ return True
+
+ except (OSError, RuntimeError, ValueError, struct.error, IndexError) as e:
+ logger.error("Initialization failed: %s", e)
+ traceback.print_exc()
+ return False
+
+ # ------------------------------------------------------------------
+ # Slot acquisition + wait
+ # ------------------------------------------------------------------
+
+ def _try_acquire(self) -> AcquireResult | None:
+ """Acquire next frame via acquire_slot(); dispatch on state.
+
+ Returns:
+ AcquireResult on NEW_FRAME (slot/timestamp/write_idx populated), else None.
+ Side-effects:
+ cleanup() on SHUTDOWN; _reinitialize() on VERSION_CHANGED (single-tick stall).
+ """
+ try:
+ result = acquire_slot(
+ self._conn.shm_handle.buf,
+ self._conn.layout,
+ self._last_write_idx,
+ self._conn.ipc_version,
+ )
+ except (OSError, BufferError) as e:
+ logger.debug("SHM buffer inaccessible: %s", e)
+ return None
+ if result.state is SlotState.SHUTDOWN:
+ logger.info("Producer shutdown detected - cleaning up gracefully")
+ self.cleanup()
+ return None
+ if result.state is SlotState.VERSION_CHANGED:
+ logger.debug(
+ "TD re-initialized (v%d -> v%d), reopening IPC handle...",
+ self._conn.ipc_version,
+ result.new_version,
+ )
+ self._reinitialize()
+ return None # pick up frame next call
+ if result.state is SlotState.NO_FRAME:
+ return None
+ self._last_write_idx = result.write_idx
+ return result
+
+ def _wait_for_slot(self, slot: int) -> float:
+ """Wait for producer to finish writing slot, with timeout.
+
+ Returns:
+ Wait time in microseconds.
+
+ Raises:
+ TimeoutError: If wait exceeds timeout_ms.
+ """
+ conn = self._conn
+ wait_start = time.perf_counter()
+
+ if conn.ipc_events[slot]:
+ deadline = wait_start + self.timeout_ms / 1000
+
+ if self._spin_us > 0:
+ spin_deadline = wait_start + self._spin_us / 1_000_000
+ while time.perf_counter() < spin_deadline:
+ if conn.cuda.query_event(conn.ipc_events[slot]):
+ spin_us = (time.perf_counter() - wait_start) * 1_000_000
+ self.total_wait_spin_us += spin_us
+ self.wait_spin_hits += 1
+ return spin_us
+ if time.perf_counter() >= deadline:
+ raise TimeoutError(
+ f"IPC event wait timed out after {self.timeout_ms}ms (slot={slot}) β producer may have crashed"
+ )
+
+ phase2_start = time.perf_counter()
+ with _HighResTimer():
+ while True:
+ if conn.cuda.query_event(conn.ipc_events[slot]):
+ break
+ if time.perf_counter() >= deadline:
+ raise TimeoutError(
+ f"IPC event wait timed out after {self.timeout_ms}ms (slot={slot}) β producer may have crashed"
+ )
+ time.sleep(0.0001)
+ self.total_wait_sleep_us += (time.perf_counter() - phase2_start) * 1_000_000
+ self.wait_sleep_hits += 1
+ elif TORCH_AVAILABLE:
+ torch.cuda.synchronize()
+ else:
+ conn.cuda.synchronize()
+
+ return (time.perf_counter() - wait_start) * 1_000_000
+
+ # ------------------------------------------------------------------
+ # Frame consumers
+ # ------------------------------------------------------------------
+
+ def get_frame(self, stream: object | None = None) -> torch.Tensor | None:
+ """Get current frame as torch.Tensor (GPU, zero-copy).
+
+ Args:
+ stream: Optional CUDA stream (torch.cuda.Stream, cupy.cuda.Stream, int, or None).
+ If provided, issues cudaStreamWaitEvent on this stream
+ (non-blocking to CPU). If None, falls back to blocking
+ cudaEventSynchronize for backward compatibility.
+
+ Returns:
+ Zero-copy torch.Tensor on GPU, or None if not initialized or no new frame.
+
+ Raises:
+ RuntimeError: If torch is not available
+ """
+ if not TORCH_AVAILABLE:
+ raise RuntimeError("torch is required for get_frame(). Use get_frame_numpy() instead.")
+ debug = self.debug
+ if debug:
+ frame_start = time.perf_counter()
+ if not self._initialized:
+ logger.warning("Not initialized - call _initialize() first")
+ return None
+
+ if debug:
+ _shm_t = time.perf_counter()
+ result = self._try_acquire()
+ if result is None:
+ return None
+ read_slot = result.slot
+ producer_timestamp = result.timestamp
+ if debug:
+ self.total_shm_read_us += (time.perf_counter() - _shm_t) * 1_000_000
+
+ if producer_timestamp > 0:
+ self.last_latency = (time.perf_counter() - producer_timestamp) * 1000
+ else:
+ self.last_latency = 0.0
+
+ conn = self._conn
+ if debug:
+ wait_start = time.perf_counter()
+
+ _nvtx.push_range(f"cudalink.importer.get_frame.slot{read_slot}", "purple")
+ with _nvtx.verbose_range("cudalink.importer.event_wait", "purple"):
+ if stream is not None:
+ cuda_stream = self._resolve_stream(stream)
+ if conn.ipc_events[read_slot]:
+ conn.cuda.stream_wait_event(cuda_stream, conn.ipc_events[read_slot], 0)
+ else:
+ try:
+ self._wait_for_slot(read_slot)
+ except TimeoutError:
+ logger.error("Producer timeout β returning None")
+ _nvtx.pop_range()
+ return None
+
+ if debug:
+ self.total_wait_event_time += (time.perf_counter() - wait_start) * 1_000_000
+
+ self.frame_count += 1
+
+ if debug:
+ frame_time = (time.perf_counter() - frame_start) * 1_000_000
+ self.total_get_frame_time += frame_time
+
+ if self.frame_count % 97 == 0:
+ n = self.frame_count
+ sync_mode = "GPU-Events" if all(conn.ipc_events) else "CPU-Sync"
+ spin_hit_pct = 100.0 * self.wait_spin_hits / n if n > 0 else 0.0
+ logger.debug(
+ "Frame %d [%s]: shm_read=%.1fus stream_wait=%.1fus total=%.1fus "
+ "latency=%.2fms | spin_hit=%.0f%% avg_spin=%.1fus avg_sleep=%.1fus",
+ n,
+ sync_mode,
+ self.total_shm_read_us / n,
+ self.total_wait_event_time / n,
+ self.total_get_frame_time / n,
+ self.last_latency,
+ spin_hit_pct,
+ self.total_wait_spin_us / self.wait_spin_hits if self.wait_spin_hits > 0 else 0.0,
+ self.total_wait_sleep_us / self.wait_sleep_hits if self.wait_sleep_hits > 0 else 0.0,
+ )
+
+ _nvtx.pop_range()
+ return self._torch.tensors[read_slot]
+
+ def get_frame_numpy(self) -> np.ndarray | None:
+ """Get current frame as numpy array (CPU, involves D2H copy).
+
+ Returns:
+ Numpy array on CPU, or None if not initialized or no new frame.
+
+ Raises:
+ RuntimeError: If numpy is not available
+ """
+ if not NUMPY_AVAILABLE:
+ raise RuntimeError("numpy is required for get_frame_numpy()")
+ debug = self.debug
+ if debug:
+ frame_start = time.perf_counter()
+ if not self._initialized:
+ logger.warning("Not initialized - call _initialize() first")
+ return None
+
+ if debug:
+ _shm_t = time.perf_counter()
+ result = self._try_acquire()
+ if result is None:
+ return None
+ read_slot = result.slot
+ producer_timestamp = result.timestamp
+ if debug:
+ self.total_shm_read_us += (time.perf_counter() - _shm_t) * 1_000_000
+
+ if producer_timestamp > 0:
+ self.last_latency = (time.perf_counter() - producer_timestamp) * 1000
+ else:
+ self.last_latency = 0.0
+
+ conn = self._conn
+ fmt = self._format
+ nbytes = fmt.frame_nbytes
+
+ # Lazily build (or rebuild) NumpyBuffers when format changes
+ if self._numpy is None or self._numpy.needs_rebuild(fmt):
+ if self._numpy is not None:
+ self._numpy.close()
+ self._numpy = NumpyBuffers.build(conn, fmt, self._d2h_num_streams)
+
+ nb = self._numpy
+
+ # CPU-side event poll + async D2H + synchronize.
+ # Uses _wait_for_slot (query_event CPU poll) rather than stream_wait_event because
+ # cudaStreamWaitEvent on a cross-process IPC event has high kernel-mode latency on
+ # Windows (~100-300ms when followed by stream_synchronize). The producer records
+ # the IPC event BEFORE publishing write_idx (improvement #2), so the event is always
+ # pre-signaled when the consumer reads write_idx β query_event returns True on the
+ # first call with no polling delay.
+ _nvtx.push_range(f"cudalink.importer.get_frame_numpy.slot{read_slot}", "orange")
+ if debug:
+ _wait_t = time.perf_counter()
+ with _nvtx.verbose_range("cudalink.importer.event_wait", "orange"):
+ try:
+ self._wait_for_slot(read_slot)
+ except TimeoutError:
+ logger.error("Producer timeout β returning None")
+ _nvtx.pop_range()
+ return None
+ if debug:
+ self.total_wait_event_time += (time.perf_counter() - _wait_t) * 1_000_000
+
+ if debug:
+ _d2h_t = time.perf_counter()
+ with _nvtx.verbose_range("cudalink.importer.d2h_copy", "orange"):
+ n_streams = nb.num_streams
+ if n_streams <= 1:
+ conn.cuda.memcpy_async(
+ dst=ctypes.c_void_p(nb.buffer.ctypes.data),
+ src=conn.dev_ptrs[read_slot],
+ count=nbytes,
+ kind=2, # cudaMemcpyDeviceToHost
+ stream=nb.primary_stream,
+ )
+ conn.cuda.stream_synchronize(nb.primary_stream)
+ else:
+ # Chunk size: ceil-divided, rounded up to 16-byte alignment for DMA safety.
+ chunk = ((nbytes + n_streams - 1) // n_streams + 15) & ~15
+ dst_base = nb.buffer.ctypes.data
+ src_base = conn.dev_ptrs[read_slot].value
+ issued = 0
+ for i in range(n_streams):
+ offset = i * chunk
+ size = min(chunk, nbytes - offset)
+ if size <= 0:
+ break
+ conn.cuda.memcpy_async(
+ dst=ctypes.c_void_p(dst_base + offset),
+ src=ctypes.c_void_p(src_base + offset),
+ count=size,
+ kind=2,
+ stream=nb.d2h_streams[i],
+ )
+ conn.cuda.record_event(nb.d2h_events[i], stream=nb.d2h_streams[i])
+ issued = i + 1
+ for i in range(issued):
+ conn.cuda.wait_event(nb.d2h_events[i])
+ conn.cuda.check_sticky_error("get_frame_numpy")
+ if debug:
+ d2h_time = (time.perf_counter() - _d2h_t) * 1_000_000
+
+ self.frame_count += 1
+
+ if debug:
+ frame_time = (time.perf_counter() - frame_start) * 1_000_000
+ self.total_get_frame_time += frame_time
+
+ if self.frame_count % 97 == 0:
+ n = self.frame_count
+ logger.debug(
+ "Frame %d (numpy): shm_read=%.1fus wait=%.1fus d2h=%.1fus total=%.1fus latency=%.2fms",
+ n,
+ self.total_shm_read_us / n,
+ self.total_wait_event_time / n,
+ d2h_time,
+ self.total_get_frame_time / n,
+ self.last_latency,
+ )
+
+ _nvtx.pop_range()
+ return nb.buffer
+
+ def get_frame_cupy(self, stream: object | None = None) -> cp.ndarray | None:
+ """Get current frame as CuPy GPU array (zero-copy).
+
+ Args:
+ stream: Optional CuPy stream (cupy.cuda.Stream, torch.cuda.Stream, int, or None).
+ If provided, issues cudaStreamWaitEvent on this stream
+ (non-blocking to CPU). If None, uses CuPy's current stream.
+
+ Returns:
+ Zero-copy CuPy array on GPU, or None if not initialized
+
+ Raises:
+ RuntimeError: If CuPy is not available
+ """
+ if not CUPY_AVAILABLE:
+ raise RuntimeError("cupy is required for get_frame_cupy(). Install: pip install cupy-cuda12x")
+
+ if not self._initialized:
+ logger.warning("Not initialized - call _initialize() first")
+ return None
+
+ result = self._try_acquire()
+ if result is None:
+ return None
+ read_slot = result.slot
+ producer_timestamp = result.timestamp
+ if producer_timestamp > 0:
+ self.last_latency = (time.perf_counter() - producer_timestamp) * 1000
+ else:
+ self.last_latency = 0.0
+
+ conn = self._conn
+
+ if stream is None:
+ stream = cp.cuda.get_current_stream()
+ else:
+ if not isinstance(stream, cp.cuda.Stream):
+ cuda_stream_ptr = self._resolve_stream(stream)
+ stream = cp.cuda.ExternalStream(cuda_stream_ptr)
+
+ if conn.ipc_events[read_slot]:
+ cp.cuda.runtime.streamWaitEvent(stream.ptr, int(conn.ipc_events[read_slot]), 0)
+
+ return self._cupy.arrays[read_slot]
+
+ # ------------------------------------------------------------------
+ # Re-initialization (TD sender restarted with new IPC handles)
+ # ------------------------------------------------------------------
+
+ def _reinitialize(self) -> None:
+ """Re-open all IPC handles after TD re-initialization."""
+ old_conn = self._conn
+ shm = old_conn.shm_handle # keep SHM alive across handle close
+
+ # Close old IPC handles only (SHM stays open)
+ old_conn.close_ipc_handles()
+
+ # Re-read version and num_slots
+ new_ipc_version = struct.unpack(" 0 and height > 0 and num_comps > 0:
+ new_dtype_str = _decode_dtype_str(kind, bits, flags)
+ new_shape = (height, width, num_comps)
+ itemsize = _DTYPE_SIZES.get(new_dtype_str, bits // 8 or 4)
+ new_fmt = Format(
+ width=width,
+ height=height,
+ num_comps=num_comps,
+ kind=kind,
+ bits=bits,
+ flags=flags,
+ dtype_str=new_dtype_str,
+ shape=new_shape,
+ numpy_dtype=np.dtype(new_dtype_str) if NUMPY_AVAILABLE else None,
+ frame_nbytes=height * width * num_comps * itemsize,
+ )
+ except (struct.error, ValueError, IndexError) as e:
+ logger.debug("Could not re-read metadata during reinit: %s", e)
+
+ if new_fmt != self._format:
+ logger.info(
+ "Metadata changed on reinit: %s %s -> %s %s",
+ self._format.shape,
+ self._format.dtype_str,
+ new_fmt.shape,
+ new_fmt.dtype_str,
+ )
+ # Tear down numpy buffers β will be rebuilt lazily on next get_frame_numpy()
+ if self._numpy is not None:
+ self._numpy.close()
+ self._numpy = None
+ self.shape = new_fmt.shape
+ self.dtype = new_fmt.dtype_str
+
+ self._format = new_fmt
+
+ # Rebuild IPC connection (reusing the still-open SHM handle)
+ new_conn = self._open_ipc_slots(old_conn.cuda, shm, new_num_slots, new_ipc_version, new_fmt)
+ self._conn = new_conn
+
+ # Rebuild torch buffers (cupy not rebuilt β matches pre-refactor behavior)
+ if TORCH_AVAILABLE:
+ self._torch = TorchBuffers.build(new_conn, new_fmt)
+
+ logger.debug("Reopened %d IPC handles v%d", new_num_slots, new_ipc_version)
+ for slot in range(new_num_slots):
+ logger.debug("Slot %d: GPU at 0x%016x", slot, new_conn.dev_ptrs[slot].value)
+
+ # ------------------------------------------------------------------
+ # Cleanup
+ # ------------------------------------------------------------------
+
+ def cleanup(self) -> None:
+ """Cleanup CUDA IPC resources."""
+ if getattr(self, "_numpy", None) is not None:
+ self._numpy.close()
+ self._numpy = None
+ if getattr(self, "_conn", None) is not None:
+ self._conn.close()
+ self._conn = None
+ # TorchBuffers and CupyBuffers hold zero-copy views; GC reclaims on deref
+ self._torch = None
+ self._cupy = None
+ self._format = None
+ self._initialized = False
+ logger.info("Cleanup complete")
+
+ def __del__(self) -> None:
+ if getattr(self, "_initialized", False):
+ self.cleanup()
+
+ def __enter__(self) -> CUDAIPCImporter:
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: object,
+ ) -> None:
+ self.cleanup()
+ return None
+
+ # ------------------------------------------------------------------
+ # Status / stats
+ # ------------------------------------------------------------------
+
+ def is_ready(self) -> bool:
+ """Check if importer is ready for frame access."""
+ if not self._initialized or self._conn is None:
+ return False
+ return len(self._conn.dev_ptrs) > 0 and all(ptr is not None for ptr in self._conn.dev_ptrs)
+
+ def attach_nvml_observer(self, observer: NVMLObserver) -> None:
+ """Attach an NVMLObserver for GPU telemetry in get_stats()."""
+ self._nvml_observer = observer
+
+ def get_stats(self) -> dict[str, object]:
+ """Get importer statistics."""
+ conn = self._conn
+ dev_ptrs = conn.dev_ptrs if conn is not None else []
+ num_slots = conn.num_slots if conn is not None else 0
+ tensors = self._torch.tensors if self._torch is not None else []
+
+ stats: dict[str, object] = {
+ "initialized": self._initialized,
+ "shape": self.shape,
+ "dtype": self.dtype,
+ "frame_count": self.frame_count,
+ "shm_name": self.shm_name,
+ "num_slots": num_slots,
+ "torch_available": TORCH_AVAILABLE,
+ "numpy_available": NUMPY_AVAILABLE,
+ "dev_ptrs": [f"0x{ptr.value:016x}" if ptr else "NULL" for ptr in dev_ptrs],
+ "tensor_device": (
+ str(tensors[0].device) if TORCH_AVAILABLE and tensors and tensors[0] is not None else "N/A"
+ ),
+ "wait_spin_hits": self.wait_spin_hits,
+ "wait_sleep_hits": self.wait_sleep_hits,
+ "avg_spin_us": self.total_wait_spin_us / self.wait_spin_hits if self.wait_spin_hits > 0 else 0.0,
+ "avg_sleep_us": self.total_wait_sleep_us / self.wait_sleep_hits if self.wait_sleep_hits > 0 else 0.0,
+ }
+ observer = getattr(self, "_nvml_observer", None)
+ if observer is not None:
+ stats["nvml"] = observer.snapshot()
+ return stats
diff --git a/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_wrapper.py b/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_wrapper.py
new file mode 100644
index 000000000..7e0de46ac
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/cuda_ipc_wrapper.py
@@ -0,0 +1,1070 @@
+"""
+CUDA IPC Wrapper for Windows
+Based on vLLM cuda_wrapper.py pattern
+
+Provides ctypes interface to CUDA Runtime API for inter-process communication.
+Compatible with both TouchDesigner and Python processes.
+
+Requirements:
+- CUDA 12.x runtime (cudart64_12.dll)
+- Windows operating system
+- Same GPU visible to both processes
+"""
+
+from __future__ import annotations
+
+import ctypes
+import logging
+import os
+from ctypes import POINTER, byref, c_float, c_int, c_size_t, c_uint, c_uint64, c_void_p
+
+
+_logger = logging.getLogger(__name__)
+
+from .cuda_graphs import CUDAGraphsMixin # noqa: E402
+from .cuda_runtime_types import ( # noqa: E402
+ CUDAError,
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaIpcEventHandle_t,
+ cudaIpcMemHandle_t,
+ cudaMemcpy3DParms,
+ cudaPointerAttributes,
+)
+
+
+class CUDARuntimeAPI(CUDAGraphsMixin):
+ """CUDA Runtime API wrapper using ctypes.
+
+ Provides access to CUDA IPC functions for zero-copy GPU memory
+ sharing between processes.
+
+ Usage:
+ cuda = CUDARuntimeAPI()
+
+ # Allocate GPU memory
+ dev_ptr = cuda.malloc(buffer_size)
+
+ # Export IPC handle (sender process)
+ handle = cuda.ipc_get_mem_handle(dev_ptr)
+
+ # Import IPC handle (receiver process)
+ imported_ptr = cuda.ipc_open_mem_handle(handle)
+
+ # Use memory...
+
+ # Close handle (receiver)
+ cuda.ipc_close_mem_handle(imported_ptr)
+
+ # Free memory (sender)
+ cuda.free(dev_ptr)
+ """
+
+ def __init__(self, device: int = 0) -> None:
+ """Initialize CUDA runtime library.
+
+ Args:
+ device: CUDA device index to bind. Defaults to 0.
+ IPC handles are device-scoped; sender and receiver must
+ use the same device or peer-access must be enabled.
+ """
+ self.device = device
+ self.cudart = self._load_cuda_runtime()
+ self._setup_function_signatures()
+ # Establish CUDA primary context on the requested device.
+ # Prevents cudaIpcOpenMemHandle error 400 when a second cudart DLL is loaded
+ # alongside torch (which has its own bundled cudart). Each DLL instance needs
+ # its own context initialized before IPC handle operations can succeed.
+ self.cudart.cudaSetDevice(device)
+
+ if os.environ.get("CUDA_LAUNCH_BLOCKING") == "1":
+ _logger.warning(
+ "CUDA_LAUNCH_BLOCKING=1 is set β all CUDA operations are serialized. "
+ "This causes ~30x slower frame rates and should only be used for debugging."
+ )
+
+ # Default ON; set CUDALINK_STICKY_ERROR_CHECK=0 to skip the cudaPeekAtLastError call.
+ self._sticky_check_enabled: bool = os.environ.get("CUDALINK_STICKY_ERROR_CHECK", "1") != "0"
+
+ def _load_cuda_runtime(self) -> ctypes.CDLL:
+ """Load CUDA runtime DLL.
+
+ Returns:
+ ctypes.CDLL: Loaded CUDA runtime library
+
+ Raises:
+ RuntimeError: If CUDA runtime cannot be loaded
+ """
+ # Try by name FIRST: if cudart is already loaded in this process (e.g., by
+ # torch), Windows returns the cached handle β ensuring we share the same
+ # runtime instance and CUDA context. Loading by full path can create a second
+ # independent instance with its own state, breaking cross-process IPC.
+ # cudart64_110.dll is preferred for bisect testing (W1): reverts the 12.x
+ # preference introduced in 4695d8f to test whether cudart64_12 ABI is the
+ # driver-error amplifier on WDDM.
+ dll_names = ["cudart64_110.dll", "cudart64_12.dll", "cudart64_11.dll"]
+ for name in dll_names:
+ try:
+ dll = ctypes.CDLL(name)
+ self._log_dll_path(dll, name)
+ return dll
+ except OSError:
+ continue
+
+ # Fallback: try full toolkit paths when not already in PATH
+ dll_paths = [
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\cudart64_12.dll",
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cudart64_12.dll",
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cudart64_12.dll",
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.0\bin\cudart64_12.dll",
+ ]
+ for dll_path in dll_paths:
+ if os.path.exists(dll_path):
+ try:
+ dll = ctypes.CDLL(dll_path)
+ self._log_dll_path(dll, dll_path)
+ return dll
+ except OSError:
+ continue
+
+ raise RuntimeError(
+ "Could not load CUDA runtime. Please ensure CUDA 12.x is installed.\n"
+ f"Tried names: {dll_names}\n"
+ f"Tried paths: {dll_paths}"
+ )
+
+ @staticmethod
+ def _log_dll_path(dll: ctypes.CDLL, hint: str) -> None:
+ """Log the resolved filesystem path of a loaded DLL (Windows only)."""
+ try:
+ buf = ctypes.create_unicode_buffer(260)
+ # GetModuleFileNameW needs HMODULE as c_void_p to avoid 32-bit overflow
+ ctypes.windll.kernel32.GetModuleFileNameW(ctypes.c_void_p(dll._handle), buf, 260)
+ _logger.debug("Loaded CUDA runtime: %s", buf.value)
+ except (OSError, AttributeError) as e:
+ _logger.debug("Could not log DLL path: %s", e)
+
+ def _setup_function_signatures(self) -> None:
+ """Define function signatures for CUDA runtime functions."""
+ # cudaMalloc(void** devPtr, size_t size)
+ self.cudart.cudaMalloc.argtypes = [POINTER(c_void_p), c_size_t]
+ self.cudart.cudaMalloc.restype = c_int
+
+ # cudaFree(void* devPtr)
+ self.cudart.cudaFree.argtypes = [c_void_p]
+ self.cudart.cudaFree.restype = c_int
+
+ # cudaMallocHost(void** ptr, size_t size) β allocate pinned (page-locked) host memory
+ self.cudart.cudaMallocHost.argtypes = [POINTER(c_void_p), c_size_t]
+ self.cudart.cudaMallocHost.restype = c_int
+
+ # cudaFreeHost(void* ptr) β free pinned host memory
+ self.cudart.cudaFreeHost.argtypes = [c_void_p]
+ self.cudart.cudaFreeHost.restype = c_int
+
+ # cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind)
+ self.cudart.cudaMemcpy.argtypes = [c_void_p, c_void_p, c_size_t, c_int]
+ self.cudart.cudaMemcpy.restype = c_int
+
+ # cudaIpcGetMemHandle(cudaIpcMemHandle_t* handle, void* devPtr)
+ self.cudart.cudaIpcGetMemHandle.argtypes = [
+ POINTER(cudaIpcMemHandle_t),
+ c_void_p,
+ ]
+ self.cudart.cudaIpcGetMemHandle.restype = c_int
+
+ # cudaIpcOpenMemHandle(void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags)
+ self.cudart.cudaIpcOpenMemHandle.argtypes = [
+ POINTER(c_void_p),
+ cudaIpcMemHandle_t,
+ c_uint,
+ ]
+ self.cudart.cudaIpcOpenMemHandle.restype = c_int
+
+ # cudaIpcCloseMemHandle(void* devPtr)
+ self.cudart.cudaIpcCloseMemHandle.argtypes = [c_void_p]
+ self.cudart.cudaIpcCloseMemHandle.restype = c_int
+
+ # cudaIpcGetEventHandle(cudaIpcEventHandle_t* handle, cudaEvent_t event)
+ self.cudart.cudaIpcGetEventHandle.argtypes = [
+ POINTER(cudaIpcEventHandle_t),
+ CUDAEvent_t,
+ ]
+ self.cudart.cudaIpcGetEventHandle.restype = c_int
+
+ # cudaIpcOpenEventHandle(cudaEvent_t* event, cudaIpcEventHandle_t handle)
+ self.cudart.cudaIpcOpenEventHandle.argtypes = [
+ POINTER(CUDAEvent_t),
+ cudaIpcEventHandle_t,
+ ]
+ self.cudart.cudaIpcOpenEventHandle.restype = c_int
+
+ # cudaEventCreateWithFlags(cudaEvent_t* event, unsigned int flags)
+ self.cudart.cudaEventCreateWithFlags.argtypes = [POINTER(CUDAEvent_t), c_uint]
+ self.cudart.cudaEventCreateWithFlags.restype = c_int
+
+ # cudaEventRecord(cudaEvent_t event, cudaStream_t stream)
+ self.cudart.cudaEventRecord.argtypes = [CUDAEvent_t, CUDAStream_t]
+ self.cudart.cudaEventRecord.restype = c_int
+
+ # cudaEventQuery(cudaEvent_t event)
+ self.cudart.cudaEventQuery.argtypes = [CUDAEvent_t]
+ self.cudart.cudaEventQuery.restype = c_int
+
+ # cudaEventSynchronize(cudaEvent_t event)
+ self.cudart.cudaEventSynchronize.argtypes = [CUDAEvent_t]
+ self.cudart.cudaEventSynchronize.restype = c_int
+
+ # cudaEventDestroy(cudaEvent_t event)
+ self.cudart.cudaEventDestroy.argtypes = [CUDAEvent_t]
+ self.cudart.cudaEventDestroy.restype = c_int
+
+ # cudaEventElapsedTime(float* ms, cudaEvent_t start, cudaEvent_t end)
+ self.cudart.cudaEventElapsedTime.argtypes = [POINTER(c_float), CUDAEvent_t, CUDAEvent_t]
+ self.cudart.cudaEventElapsedTime.restype = c_int
+
+ # cudaDeviceSynchronize()
+ self.cudart.cudaDeviceSynchronize.argtypes = []
+ self.cudart.cudaDeviceSynchronize.restype = c_int
+
+ # cudaGetLastError()
+ self.cudart.cudaGetLastError.argtypes = []
+ self.cudart.cudaGetLastError.restype = c_int
+
+ # cudaPeekAtLastError() β non-destructive sticky-error read (does NOT clear the error)
+ self.cudart.cudaPeekAtLastError.argtypes = []
+ self.cudart.cudaPeekAtLastError.restype = c_int
+
+ # cudaHostRegister(void* ptr, size_t size, unsigned int flags) β page-lock existing host memory
+ self.cudart.cudaHostRegister.argtypes = [c_void_p, c_size_t, c_uint]
+ self.cudart.cudaHostRegister.restype = c_int
+
+ # cudaHostUnregister(void* ptr) β unregister page-locked host memory
+ self.cudart.cudaHostUnregister.argtypes = [c_void_p]
+ self.cudart.cudaHostUnregister.restype = c_int
+
+ # cudaGetErrorString(cudaError_t error)
+ self.cudart.cudaGetErrorString.argtypes = [c_int]
+ self.cudart.cudaGetErrorString.restype = ctypes.c_char_p
+
+ # cudaStreamCreateWithFlags(cudaStream_t* pStream, unsigned int flags)
+ self.cudart.cudaStreamCreateWithFlags.argtypes = [POINTER(CUDAStream_t), c_uint]
+ self.cudart.cudaStreamCreateWithFlags.restype = c_int
+
+ # cudaStreamDestroy(cudaStream_t stream)
+ self.cudart.cudaStreamDestroy.argtypes = [CUDAStream_t]
+ self.cudart.cudaStreamDestroy.restype = c_int
+
+ # cudaStreamWaitEvent(cudaStream_t stream, cudaEvent_t event, unsigned int flags)
+ self.cudart.cudaStreamWaitEvent.argtypes = [CUDAStream_t, CUDAEvent_t, c_uint]
+ self.cudart.cudaStreamWaitEvent.restype = c_int
+
+ # cudaStreamSynchronize(cudaStream_t stream)
+ self.cudart.cudaStreamSynchronize.argtypes = [CUDAStream_t]
+ self.cudart.cudaStreamSynchronize.restype = c_int
+
+ # cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream)
+ self.cudart.cudaMemcpyAsync.argtypes = [c_void_p, c_void_p, c_size_t, c_int, CUDAStream_t]
+ self.cudart.cudaMemcpyAsync.restype = c_int
+
+ # cudaMemGetInfo(size_t* free, size_t* total)
+ self.cudart.cudaMemGetInfo.argtypes = [POINTER(c_size_t), POINTER(c_size_t)]
+ self.cudart.cudaMemGetInfo.restype = c_int
+
+ # cudaSetDevice(int device)
+ self.cudart.cudaSetDevice.argtypes = [c_int]
+ self.cudart.cudaSetDevice.restype = c_int
+
+ # cudaGetDevice(int* device)
+ self.cudart.cudaGetDevice.argtypes = [POINTER(c_int)]
+ self.cudart.cudaGetDevice.restype = c_int
+
+ # cudaStreamQuery(cudaStream_t stream)
+ self.cudart.cudaStreamQuery.argtypes = [CUDAStream_t]
+ self.cudart.cudaStreamQuery.restype = c_int
+
+ # cudaDeviceCanAccessPeer(int* canAccessPeer, int device, int peerDevice)
+ self.cudart.cudaDeviceCanAccessPeer.argtypes = [POINTER(c_int), c_int, c_int]
+ self.cudart.cudaDeviceCanAccessPeer.restype = c_int
+
+ # cudaDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPriority)
+ self.cudart.cudaDeviceGetStreamPriorityRange.argtypes = [POINTER(c_int), POINTER(c_int)]
+ self.cudart.cudaDeviceGetStreamPriorityRange.restype = c_int
+
+ # cudaStreamCreateWithPriority(cudaStream_t* pStream, unsigned int flags, int priority)
+ self.cudart.cudaStreamCreateWithPriority.argtypes = [POINTER(CUDAStream_t), c_uint, c_int]
+ self.cudart.cudaStreamCreateWithPriority.restype = c_int
+
+ # cudaPointerGetAttributes(cudaPointerAttributes* attributes, const void* ptr)
+ self.cudart.cudaPointerGetAttributes.argtypes = [POINTER(cudaPointerAttributes), c_void_p]
+ self.cudart.cudaPointerGetAttributes.restype = c_int
+
+ # === G1: non-graph helpers (re-enabled Phase 1.1) ===
+ # cudaHostAlloc(void** ptr, size_t size, unsigned int flags)
+ # Replaces cudaMallocHost with explicit flag control.
+ # cudaHostAllocPortable = 0x01 β accessible from any CUDA context in process
+ # cudaHostAllocMapped = 0x02 β map into device address space
+ # cudaHostAllocWriteCombined = 0x04 β write-combined (fast CPU writes, slow CPU reads)
+ self.cudart.cudaHostAlloc.argtypes = [POINTER(c_void_p), c_size_t, c_uint]
+ self.cudart.cudaHostAlloc.restype = c_int
+
+ # cudaDeviceGetAttribute(int* value, cudaDeviceAttr attr, int device)
+ # Used to query cudaDevAttrAsyncEngineCount (attr=4) β how many DMA copy engines exist.
+ self.cudart.cudaDeviceGetAttribute.argtypes = [POINTER(c_int), c_int, c_int]
+ self.cudart.cudaDeviceGetAttribute.restype = c_int
+
+ # === G2: graph lifecycle (re-enabled Phase 1.2) ===
+ # CUDA 10.0+ graph capture/build/launch/teardown + runtime-version gate.
+
+ # cudaStreamBeginCapture(cudaStream_t stream, cudaStreamCaptureMode mode)
+ # mode: 0=global, 1=thread_local, 2=relaxed
+ self.cudart.cudaStreamBeginCapture.argtypes = [CUDAStream_t, c_int]
+ self.cudart.cudaStreamBeginCapture.restype = c_int
+
+ # cudaStreamEndCapture(cudaStream_t stream, cudaGraph_t* pGraph)
+ self.cudart.cudaStreamEndCapture.argtypes = [CUDAStream_t, POINTER(CUDAGraph_t)]
+ self.cudart.cudaStreamEndCapture.restype = c_int
+
+ # cudaGraphInstantiateWithFlags(cudaGraphExec_t* pGraphExec, cudaGraph_t graph,
+ # unsigned long long flags) [CUDA 11.4+ stable 3-arg form]
+ # Prefer this over cudaGraphInstantiate on any cudart 11.x: the latter changed
+ # from 5-arg (CUDA 10.0β11.8) to 3-arg (CUDA 12.0+) β calling the 12.0 3-arg
+ # binding against an 11.x DLL mismatches the ABI and crashes (WDDM access
+ # violation). cudaGraphInstantiateWithFlags has had a stable 3-arg signature
+ # since 11.4 and is available in all 12.x releases as well.
+ self.cudart.cudaGraphInstantiateWithFlags.argtypes = [POINTER(CUDAGraphExec_t), CUDAGraph_t, c_uint64]
+ self.cudart.cudaGraphInstantiateWithFlags.restype = c_int
+
+ # cudaGraphLaunch(cudaGraphExec_t graphExec, cudaStream_t stream)
+ self.cudart.cudaGraphLaunch.argtypes = [CUDAGraphExec_t, CUDAStream_t]
+ self.cudart.cudaGraphLaunch.restype = c_int
+
+ # cudaGraphDestroy(cudaGraph_t graph)
+ self.cudart.cudaGraphDestroy.argtypes = [CUDAGraph_t]
+ self.cudart.cudaGraphDestroy.restype = c_int
+
+ # cudaGraphExecDestroy(cudaGraphExec_t graphExec)
+ self.cudart.cudaGraphExecDestroy.argtypes = [CUDAGraphExec_t]
+ self.cudart.cudaGraphExecDestroy.restype = c_int
+
+ # cudaGraphGetNodes(cudaGraph_t graph, cudaGraphNode_t* nodes, size_t* numNodes)
+ # Pass nodes=NULL to query count; then call again with allocated array.
+ self.cudart.cudaGraphGetNodes.argtypes = [CUDAGraph_t, POINTER(CUDAGraphNode_t), POINTER(c_size_t)]
+ self.cudart.cudaGraphGetNodes.restype = c_int
+
+ # cudaRuntimeGetVersion(int* runtimeVersion)
+ # Returns the version as int (e.g., 11040 = CUDA 11.4, 12080 = CUDA 12.8).
+ # Used to gate optional API calls (e.g., cudaGraphExecMemcpyNodeSetParams1D
+ # requires 11.3+) when the loaded cudart DLL may be a 11.0.x patch.
+ self.cudart.cudaRuntimeGetVersion.argtypes = [POINTER(c_int)]
+ self.cudart.cudaRuntimeGetVersion.restype = c_int
+
+ # === G3: graph node setters (re-enabled Phase 1.3) ===
+ # Per-frame in-place node update for ring-slot remap. Most CUDA-12-flavoured
+ # of the 14 (NodeSetParams1D 11.3+; event-node setters 11.4+).
+
+ # cudaGraphExecMemcpyNodeSetParams(cudaGraphExec_t, cudaGraphNode_t,
+ # const cudaMemcpy3DParms*)
+ # Updates a 3D-captured memcpy node. For nodes captured from cudaMemcpyAsync
+ # (1D form) use cudaGraphExecMemcpyNodeSetParams1D instead.
+ self.cudart.cudaGraphExecMemcpyNodeSetParams.argtypes = [
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ POINTER(cudaMemcpy3DParms),
+ ]
+ self.cudart.cudaGraphExecMemcpyNodeSetParams.restype = c_int
+
+ # cudaGraphExecMemcpyNodeSetParams1D(cudaGraphExec_t, cudaGraphNode_t,
+ # void* dst, const void* src,
+ # size_t count, cudaMemcpyKind kind)
+ # Updates a 1D memcpy node (captured from cudaMemcpyAsync). CUDA 11.3+.
+ self.cudart.cudaGraphExecMemcpyNodeSetParams1D.argtypes = [
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ c_void_p,
+ c_void_p,
+ c_size_t,
+ c_int,
+ ]
+ self.cudart.cudaGraphExecMemcpyNodeSetParams1D.restype = c_int
+
+ # cudaGraphExecEventRecordNodeSetEvent(cudaGraphExec_t, cudaGraphNode_t,
+ # cudaEvent_t event)
+ # Updates the event recorded by an event-record node. CUDA 11.4+.
+ self.cudart.cudaGraphExecEventRecordNodeSetEvent.argtypes = [CUDAGraphExec_t, CUDAGraphNode_t, CUDAEvent_t]
+ self.cudart.cudaGraphExecEventRecordNodeSetEvent.restype = c_int
+
+ # cudaGraphExecEventWaitNodeSetEvent(cudaGraphExec_t, cudaGraphNode_t,
+ # cudaEvent_t event)
+ # Updates the event waited on by an event-wait node. CUDA 11.4+.
+ self.cudart.cudaGraphExecEventWaitNodeSetEvent.argtypes = [CUDAGraphExec_t, CUDAGraphNode_t, CUDAEvent_t]
+ self.cudart.cudaGraphExecEventWaitNodeSetEvent.restype = c_int
+
+ def check_error(self, result: int, operation: str) -> None:
+ """Check CUDA error code and raise exception if failed.
+
+ Args:
+ result: CUDA error code
+ operation: Description of the operation that failed
+
+ Raises:
+ RuntimeError: If result indicates an error
+ """
+ if result != CUDAError.SUCCESS:
+ error_str = self.cudart.cudaGetErrorString(result).decode("utf-8")
+ error_name = CUDAError.get_name(result)
+ raise RuntimeError(f"CUDA {operation} failed: {error_str} (error {result}: {error_name})")
+
+ def peek_at_last_error(self) -> int:
+ """Non-destructively read the thread-local sticky CUDA error.
+
+ Returns SUCCESS (0) normally. A non-zero value means a prior async
+ operation (memcpy, kernel) failed and the error was not yet consumed.
+ Unlike cudaGetLastError this does NOT clear the latched error state.
+ """
+ return int(self.cudart.cudaPeekAtLastError())
+
+ def check_sticky_error(self, context: str) -> None:
+ """Warn and raise if a sticky CUDA error is latched from a prior async op.
+
+ No-op when CUDALINK_STICKY_ERROR_CHECK=0. Enabled by default.
+ Use peek_at_last_error() directly for the raw value without raising.
+ """
+ if not self._sticky_check_enabled:
+ return
+ code = int(self.cudart.cudaPeekAtLastError())
+ if code != CUDAError.SUCCESS:
+ error_str = self.cudart.cudaGetErrorString(code).decode("utf-8")
+ _logger.warning(
+ "Sticky CUDA error detected after %s: %s (code %d). "
+ "The CUDA context is poisoned β restart the process. "
+ "Set CUDALINK_STICKY_ERROR_CHECK=0 to disable this check.",
+ context,
+ error_str,
+ code,
+ )
+ raise RuntimeError(
+ f"Sticky CUDA error after {context}: {error_str} (code {code}). "
+ "The CUDA context is poisoned. Restart the process or set "
+ "CUDALINK_STICKY_ERROR_CHECK=0 to disable this check."
+ )
+
+ def host_register(self, ptr: int, size: int, flags: int = 0) -> None:
+ """Page-lock an existing host allocation via cudaHostRegister.
+
+ Args:
+ ptr: Host pointer as integer (e.g., arr.ctypes.data)
+ size: Number of bytes to register
+ flags: Registration flags (0=default, 1=portable, 2=mapped, 4=write-combined)
+
+ Raises:
+ RuntimeError: If registration fails
+ """
+ result = self.cudart.cudaHostRegister(c_void_p(ptr), c_size_t(size), c_uint(flags))
+ self.check_error(result, "cudaHostRegister")
+
+ def host_unregister(self, ptr: int) -> None:
+ """Unregister a page-locked host allocation registered with host_register().
+
+ Args:
+ ptr: Host pointer as integer (same value passed to host_register())
+
+ Raises:
+ RuntimeError: If unregistration fails
+ """
+ result = self.cudart.cudaHostUnregister(c_void_p(ptr))
+ self.check_error(result, "cudaHostUnregister")
+
+ # High-level API
+
+ def malloc(self, size: int) -> c_void_p:
+ """Allocate GPU memory.
+
+ Args:
+ size: Number of bytes to allocate
+
+ Returns:
+ Device pointer to allocated memory
+
+ Raises:
+ RuntimeError: If allocation fails
+ """
+ dev_ptr = c_void_p()
+ result = self.cudart.cudaMalloc(byref(dev_ptr), size)
+ self.check_error(result, "cudaMalloc")
+ return dev_ptr
+
+ def free(self, dev_ptr: c_void_p) -> None:
+ """Free GPU memory.
+
+ Args:
+ dev_ptr: Device pointer to free
+
+ Raises:
+ RuntimeError: If free fails
+ """
+ result = self.cudart.cudaFree(dev_ptr)
+ self.check_error(result, "cudaFree")
+
+ def malloc_host(self, size: int) -> c_void_p:
+ """Allocate pinned (page-locked) host memory via cudaMallocHost.
+
+ Pinned memory enables direct DMA for D2H transfers, eliminating the
+ CUDA driver's internal staging copy that pageable memory requires.
+
+ Note: this project is single-GPU by construction (get_cuda_runtime rejects
+ a second device). Multi-GPU would require cudaHostAlloc with
+ cudaHostAllocPortable for cross-device visibility (Handbook Β§5.1).
+
+ Args:
+ size: Number of bytes to allocate
+
+ Returns:
+ Host pointer to pinned memory
+
+ Raises:
+ RuntimeError: If allocation fails
+ """
+ ptr = c_void_p()
+ result = self.cudart.cudaMallocHost(byref(ptr), size)
+ self.check_error(result, "cudaMallocHost")
+ return ptr
+
+ def free_host(self, ptr: c_void_p) -> None:
+ """Free pinned host memory allocated with malloc_host().
+
+ Args:
+ ptr: Host pointer to free
+
+ Raises:
+ RuntimeError: If free fails
+ """
+ result = self.cudart.cudaFreeHost(ptr)
+ self.check_error(result, "cudaFreeHost")
+
+ def memcpy(self, dst: c_void_p, src: c_void_p, count: int, kind: int) -> None:
+ """Copy memory (device-to-device, host-to-device, or device-to-host).
+
+ Args:
+ dst: Destination pointer
+ src: Source pointer
+ count: Number of bytes to copy
+ kind: cudaMemcpyKind (0=H2H, 1=H2D, 2=D2H, 3=D2D)
+
+ Raises:
+ RuntimeError: If copy fails
+ """
+ result = self.cudart.cudaMemcpy(dst, src, count, kind)
+ self.check_error(result, "cudaMemcpy")
+
+ def ipc_get_mem_handle(self, dev_ptr: c_void_p) -> cudaIpcMemHandle_t:
+ """Get IPC handle for GPU memory.
+
+ This handle can be transferred to another process via SharedMemory
+ or other IPC mechanism.
+
+ Args:
+ dev_ptr: Device pointer to export
+
+ Returns:
+ IPC handle (128 bytes)
+
+ Raises:
+ RuntimeError: If export fails
+ """
+ handle = cudaIpcMemHandle_t()
+ result = self.cudart.cudaIpcGetMemHandle(byref(handle), dev_ptr)
+ self.check_error(result, "cudaIpcGetMemHandle")
+ return handle
+
+ def ipc_open_mem_handle(self, handle: cudaIpcMemHandle_t, flags: int = 1) -> c_void_p:
+ """Open IPC handle to access GPU memory from another process.
+
+ Args:
+ handle: IPC handle received from another process
+ flags: IPC flags (1 = cudaIpcMemLazyEnablePeerAccess)
+
+ Returns:
+ Device pointer to shared memory
+
+ Raises:
+ RuntimeError: If opening fails
+ """
+ dev_ptr = c_void_p()
+ result = self.cudart.cudaIpcOpenMemHandle(byref(dev_ptr), handle, flags)
+ self.check_error(result, "cudaIpcOpenMemHandle")
+ return dev_ptr
+
+ def ipc_close_mem_handle(self, dev_ptr: c_void_p) -> None:
+ """Close IPC memory handle.
+
+ Args:
+ dev_ptr: Device pointer obtained from ipc_open_mem_handle()
+
+ Raises:
+ RuntimeError: If closing fails
+ """
+ result = self.cudart.cudaIpcCloseMemHandle(dev_ptr)
+ self.check_error(result, "cudaIpcCloseMemHandle")
+
+ def synchronize(self) -> None:
+ """Synchronize all CUDA operations on current device.
+
+ Raises:
+ RuntimeError: If synchronization fails
+ """
+ result = self.cudart.cudaDeviceSynchronize()
+ self.check_error(result, "cudaDeviceSynchronize")
+
+ # CUDA Event API (for async synchronization)
+
+ def create_ipc_event(self) -> CUDAEvent_t:
+ """Create CUDA event suitable for IPC (interprocess communication).
+
+ Returns:
+ Event handle for cross-process synchronization
+
+ Raises:
+ RuntimeError: If event creation fails
+ """
+ event = CUDAEvent_t()
+ # cudaEventInterprocess (4) | cudaEventDisableTiming (2) = 6
+ # NVIDIA requires cudaEventDisableTiming when using cudaEventInterprocess
+ result = self.cudart.cudaEventCreateWithFlags(byref(event), 6)
+ self.check_error(result, "cudaEventCreateWithFlags")
+ return event
+
+ def record_event(self, event: CUDAEvent_t, stream: CUDAStream_t | None = None) -> None:
+ """Record event on specified stream (or default stream).
+
+ Args:
+ event: Event handle to record
+ stream: CUDA stream (None = default stream)
+
+ Raises:
+ RuntimeError: If event recording fails
+ """
+ # Convert None to CUDA default stream (0) for ctypes compatibility
+ if stream is None:
+ stream = CUDAStream_t(0)
+ result = self.cudart.cudaEventRecord(event, stream)
+ self.check_error(result, "cudaEventRecord")
+
+ def query_event(self, event: c_void_p) -> bool:
+ """Query if event has completed (non-blocking).
+
+ Args:
+ event: Event handle to query
+
+ Returns:
+ True if event completed, False if still pending
+
+ Raises:
+ RuntimeError: If query fails with unexpected error
+ """
+ result = self.cudart.cudaEventQuery(event)
+ if result == CUDAError.SUCCESS:
+ return True
+ elif result == CUDAError.NOT_READY:
+ return False
+ self.check_error(result, "cudaEventQuery")
+ return False
+
+ def wait_event(self, event: CUDAEvent_t) -> None:
+ """Wait for event to complete (blocking).
+
+ Args:
+ event: Event handle to wait on
+
+ Raises:
+ RuntimeError: If wait fails
+ """
+ result = self.cudart.cudaEventSynchronize(event)
+ self.check_error(result, "cudaEventSynchronize")
+
+ def ipc_get_event_handle(self, event: CUDAEvent_t) -> cudaIpcEventHandle_t:
+ """Get IPC handle for event (for cross-process signaling).
+
+ Args:
+ event: Event created with create_ipc_event()
+
+ Returns:
+ IPC event handle (64 bytes)
+
+ Raises:
+ RuntimeError: If export fails
+ """
+ handle = cudaIpcEventHandle_t()
+ result = self.cudart.cudaIpcGetEventHandle(byref(handle), event)
+ self.check_error(result, "cudaIpcGetEventHandle")
+ return handle
+
+ def ipc_open_event_handle(self, handle: cudaIpcEventHandle_t) -> CUDAEvent_t:
+ """Open IPC event handle from another process.
+
+ Args:
+ handle: IPC event handle received from another process
+
+ Returns:
+ Event handle for this process
+
+ Raises:
+ RuntimeError: If opening fails
+ """
+ event = CUDAEvent_t()
+ result = self.cudart.cudaIpcOpenEventHandle(byref(event), handle)
+ self.check_error(result, "cudaIpcOpenEventHandle")
+ return event
+
+ def destroy_event(self, event: CUDAEvent_t) -> None:
+ """Destroy CUDA event.
+
+ Args:
+ event: Event handle to destroy
+
+ Raises:
+ RuntimeError: If destruction fails
+ """
+ result = self.cudart.cudaEventDestroy(event)
+ self.check_error(result, "cudaEventDestroy")
+
+ def create_timing_event(self) -> CUDAEvent_t:
+ """Create CUDA event suitable for GPU timing (NOT for IPC).
+
+ Returns:
+ Event handle for GPU-accurate timing measurements
+
+ Raises:
+ RuntimeError: If event creation fails
+
+ Note:
+ This creates an event with timing enabled (flags=0).
+ Use this for benchmarking, NOT for IPC synchronization.
+ IPC events require cudaEventDisableTiming flag.
+ """
+ event = CUDAEvent_t()
+ # flags=0 enables timing (no cudaEventDisableTiming, no cudaEventInterprocess)
+ result = self.cudart.cudaEventCreateWithFlags(byref(event), 0)
+ self.check_error(result, "cudaEventCreateWithFlags(timing)")
+ return event
+
+ def create_sync_event(self) -> CUDAEvent_t:
+ """Create CUDA event optimized for stream ordering (NOT timing, NOT IPC).
+
+ Returns:
+ Event handle for use with stream_wait_event() ordering
+
+ Raises:
+ RuntimeError: If event creation fails
+
+ Note:
+ Uses cudaEventDisableTiming (0x02). Per NVIDIA docs this provides
+ best performance when used with cudaStreamWaitEvent() and
+ cudaEventQuery() β removes per-record timing instrumentation overhead.
+ Do not use with event_elapsed_time(); use create_timing_event() for that.
+ """
+ event = CUDAEvent_t()
+ # cudaEventDisableTiming = 0x02 β optimal for ordering-only events
+ result = self.cudart.cudaEventCreateWithFlags(byref(event), 0x02)
+ self.check_error(result, "cudaEventCreateWithFlags(sync)")
+ return event
+
+ def event_elapsed_time(self, start: CUDAEvent_t, end: CUDAEvent_t) -> float:
+ """Get elapsed GPU time between two events.
+
+ Args:
+ start: Starting event (must be recorded before end event)
+ end: Ending event
+
+ Returns:
+ Elapsed time in milliseconds (GPU-measured)
+
+ Raises:
+ RuntimeError: If elapsed time query fails
+
+ Note:
+ Both events must have timing enabled (created with create_timing_event).
+ Events with cudaEventDisableTiming flag cannot be used for timing.
+ """
+ elapsed_ms = c_float()
+ result = self.cudart.cudaEventElapsedTime(byref(elapsed_ms), start, end)
+ self.check_error(result, "cudaEventElapsedTime")
+ return elapsed_ms.value
+
+ def get_device(self) -> int:
+ """Return the CUDA device index currently bound to this context.
+
+ Returns:
+ Integer device index (matches self.device if context is healthy)
+
+ Raises:
+ RuntimeError: If query fails
+ """
+ device = c_int()
+ result = self.cudart.cudaGetDevice(byref(device))
+ self.check_error(result, "cudaGetDevice")
+ return device.value
+
+ def create_stream(self, flags: int = 0x01) -> CUDAStream_t:
+ """Create CUDA stream with specified flags.
+
+ Args:
+ flags: Stream creation flags. Default 0x01 = cudaStreamNonBlocking
+
+ Returns:
+ CUDAStream_t: Opaque stream handle
+
+ Raises:
+ RuntimeError: If stream creation fails
+ """
+ stream = CUDAStream_t()
+ result = self.cudart.cudaStreamCreateWithFlags(byref(stream), flags)
+ self.check_error(result, "cudaStreamCreateWithFlags")
+ return stream
+
+ def create_stream_with_priority(self, flags: int = 0x01, priority: int | None = None) -> CUDAStream_t:
+ """Create CUDA stream at the specified (or highest available) priority.
+
+ On CUDA, stream priority is an integer where a smaller value means
+ higher priority. cudaDeviceGetStreamPriorityRange returns [least, greatest]
+ where greatest is the most-negative value β i.e., the highest priority.
+
+ Args:
+ flags: Stream flags. Default 0x01 = cudaStreamNonBlocking.
+ priority: Stream priority. None means use highest available (greatest).
+
+ Returns:
+ CUDAStream_t: Opaque stream handle
+
+ Raises:
+ RuntimeError: If stream creation fails
+ """
+ if priority is None:
+ least = c_int()
+ greatest = c_int()
+ result = self.cudart.cudaDeviceGetStreamPriorityRange(byref(least), byref(greatest))
+ self.check_error(result, "cudaDeviceGetStreamPriorityRange")
+ priority = greatest.value
+ stream = CUDAStream_t()
+ result = self.cudart.cudaStreamCreateWithPriority(byref(stream), flags, priority)
+ self.check_error(result, "cudaStreamCreateWithPriority")
+ return stream
+
+ def destroy_stream(self, stream: CUDAStream_t) -> None:
+ """Destroy CUDA stream.
+
+ Args:
+ stream: Stream handle to destroy
+
+ Raises:
+ RuntimeError: If destruction fails
+ """
+ result = self.cudart.cudaStreamDestroy(stream)
+ self.check_error(result, "cudaStreamDestroy")
+
+ def stream_wait_event(self, stream: CUDAStream_t, event: CUDAEvent_t, flags: int = 0) -> None:
+ """Make stream wait on event (GPU-side, non-blocking to CPU).
+
+ Args:
+ stream: Stream to wait
+ event: Event to wait for
+ flags: Wait flags (default 0)
+
+ Raises:
+ RuntimeError: If wait enqueue fails
+ """
+ result = self.cudart.cudaStreamWaitEvent(stream, event, flags)
+ self.check_error(result, "cudaStreamWaitEvent")
+
+ def stream_synchronize(self, stream: CUDAStream_t) -> None:
+ """Wait for all operations on stream to complete (CPU-blocking).
+
+ Args:
+ stream: Stream to synchronize
+
+ Raises:
+ RuntimeError: If synchronization fails
+ """
+ result = self.cudart.cudaStreamSynchronize(stream)
+ self.check_error(result, "cudaStreamSynchronize")
+
+ def memcpy_async(self, dst: c_void_p, src: c_void_p, count: int, kind: int, stream: CUDAStream_t) -> None:
+ """Asynchronous memory copy on a stream.
+
+ Args:
+ dst: Destination pointer
+ src: Source pointer
+ count: Number of bytes to copy
+ kind: cudaMemcpyKind (0=H2H, 1=H2D, 2=D2H, 3=D2D)
+ stream: CUDA stream for async operation
+
+ Raises:
+ RuntimeError: If async copy enqueue fails
+ """
+ result = self.cudart.cudaMemcpyAsync(dst, src, count, kind, stream)
+ self.check_error(result, "cudaMemcpyAsync")
+
+ def mem_get_info(self) -> tuple[int, int]:
+ """Get free and total device memory in bytes.
+
+ Returns:
+ Tuple of (free_bytes, total_bytes)
+
+ Raises:
+ RuntimeError: If query fails
+ """
+ free = c_size_t()
+ total = c_size_t()
+ result = self.cudart.cudaMemGetInfo(byref(free), byref(total))
+ self.check_error(result, "cudaMemGetInfo")
+ return free.value, total.value
+
+ def stream_query(self, stream: CUDAStream_t) -> bool:
+ """Non-blocking check if all operations on stream have completed.
+
+ Args:
+ stream: CUDA stream to query
+
+ Returns:
+ True if all stream operations have completed, False if still executing
+
+ Raises:
+ RuntimeError: If query fails with an error other than cudaErrorNotReady
+ """
+ result = self.cudart.cudaStreamQuery(stream)
+ if result == CUDAError.SUCCESS:
+ return True
+ if result == CUDAError.NOT_READY:
+ return False
+ self.check_error(result, "cudaStreamQuery")
+ return False # unreachable
+
+ def pointer_get_attributes(self, ptr: int) -> cudaPointerAttributes:
+ """Query memory type and owning device for a GPU pointer.
+
+ Args:
+ ptr: GPU pointer as integer (e.g., tensor.data_ptr())
+
+ Returns:
+ cudaPointerAttributes with .type (2=device, 3=managed) and .device (GPU index)
+
+ Raises:
+ RuntimeError: If query fails (e.g., unregistered host pointer passed)
+ """
+ attrs = cudaPointerAttributes()
+ result = self.cudart.cudaPointerGetAttributes(byref(attrs), c_void_p(ptr))
+ self.check_error(result, "cudaPointerGetAttributes")
+ return attrs
+
+ def device_can_access_peer(self, device: int, peer_device: int) -> bool:
+ """Check if device can directly access peer_device memory via IPC/NVLink.
+
+ Useful for validating multi-GPU setups before attempting IPC handle operations.
+ On single-GPU systems or systems without peer access, cudaIpcOpenMemHandle
+ may fall back to slower paths without warning.
+
+ Args:
+ device: Source device ID
+ peer_device: Target peer device ID
+
+ Returns:
+ True if direct peer access is available, False otherwise
+
+ Raises:
+ RuntimeError: If query fails
+ """
+ can_access = c_int(0)
+ result = self.cudart.cudaDeviceCanAccessPeer(byref(can_access), device, peer_device)
+ self.check_error(result, "cudaDeviceCanAccessPeer")
+ return bool(can_access.value)
+
+ # --- Phase 1: cudaHostAlloc (replaces cudaMallocHost with portable flag) ---
+
+ def malloc_host_alloc(self, size: int, flags: int = 0x01) -> c_void_p:
+ """Allocate pinned host memory via cudaHostAlloc with explicit flags.
+
+ Unlike malloc_host() which calls cudaMallocHost (no flags), this lets
+ callers pass cudaHostAllocPortable (0x01) to make the allocation visible
+ from any CUDA context in the process β useful when PyTorch and CuPy share
+ the same process.
+
+ Args:
+ size: Number of bytes to allocate.
+ flags: OR-combination of:
+ cudaHostAllocPortable = 0x01 (cross-context visibility)
+ cudaHostAllocMapped = 0x02 (map into device address space)
+ cudaHostAllocWriteCombined = 0x04 (WC; fast write, slow CPU read)
+
+ Returns:
+ Host pointer to allocated pinned memory.
+
+ Raises:
+ RuntimeError: If allocation fails.
+ """
+ ptr = c_void_p()
+ result = self.cudart.cudaHostAlloc(byref(ptr), c_size_t(size), c_uint(flags))
+ self.check_error(result, "cudaHostAlloc")
+ return ptr
+
+ # --- Phase 0: device attribute query ---
+
+ def get_device_attribute(self, attr: int, device: int | None = None) -> int:
+ """Query a cudaDeviceAttr value for a given device.
+
+ Common attrs:
+ cudaDevAttrAsyncEngineCount = 4 β number of DMA copy engines
+
+ Args:
+ attr: cudaDeviceAttr integer constant.
+ device: GPU device index. Defaults to self.device.
+
+ Returns:
+ Integer attribute value.
+
+ Raises:
+ RuntimeError: If query fails.
+ """
+ if device is None:
+ device = self.device
+ value = c_int()
+ result = self.cudart.cudaDeviceGetAttribute(byref(value), c_int(attr), c_int(device))
+ self.check_error(result, "cudaDeviceGetAttribute")
+ return value.value
+
+
+# Global singleton instance (lazy initialization)
+_cuda_runtime: CUDARuntimeAPI | None = None
+
+
+def get_cuda_runtime(device: int = 0) -> CUDARuntimeAPI:
+ """Get global CUDA runtime instance (singleton).
+
+ The singleton is created on first call. Subsequent calls with a *different*
+ device index will raise RuntimeError β a single process context can only
+ be bound to one device via this shared-cudart pattern.
+
+ Args:
+ device: CUDA device index (default 0). Must match across all callers
+ within the same process.
+
+ Returns:
+ CUDARuntimeAPI: Global CUDA runtime wrapper
+
+ Raises:
+ RuntimeError: If called with a device index that conflicts with the
+ already-initialized singleton.
+ """
+ global _cuda_runtime
+ if _cuda_runtime is None:
+ _cuda_runtime = CUDARuntimeAPI(device=device)
+ elif _cuda_runtime.device != device:
+ raise RuntimeError(
+ f"CUDA runtime singleton was initialized for device {_cuda_runtime.device}, "
+ f"but caller requested device {device}. A single process can only bind to "
+ "one device via the shared-cudart singleton. Create a separate "
+ "CUDARuntimeAPI(device=...) instance for multi-device use."
+ )
+ return _cuda_runtime
diff --git a/src/streamdiffusion/_compat/cuda_ipc/cuda_runtime_types.py b/src/streamdiffusion/_compat/cuda_ipc/cuda_runtime_types.py
new file mode 100644
index 000000000..a432c2630
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/cuda_runtime_types.py
@@ -0,0 +1,134 @@
+"""
+CUDA Runtime Types β ctypes structs, type aliases, and error codes for CUDA IPC.
+
+Shared between the pip package (cuda_link) and TouchDesigner textDATs.
+Compatible with both Python package and TD COMP namespace imports.
+"""
+
+from __future__ import annotations
+
+import ctypes
+from ctypes import c_int, c_size_t, c_uint64, c_void_p
+
+
+# CUDA handle types - use unsigned 64-bit to prevent overflow on Windows x64
+# See: https://github.com/pytorch/pytorch/pull/162920
+CUDAEvent_t = c_uint64 # cudaEvent_t opaque pointer
+CUDAStream_t = c_uint64 # cudaStream_t opaque pointer
+CUDAGraph_t = c_uint64 # cudaGraph_t opaque pointer (CUDA 10.0+)
+CUDAGraphExec_t = c_uint64 # cudaGraphExec_t opaque pointer (CUDA 10.0+)
+CUDAGraphNode_t = c_uint64 # cudaGraphNode_t opaque pointer (CUDA 10.0+)
+
+# Minimum cudart version required for all CUDA Graphs APIs used by this module.
+# cudaGraphInstantiateWithFlags, cudaGraphExecEventRecordNodeSetEvent, and
+# cudaGraphExecEventWaitNodeSetEvent are all CUDA 11.4+ (version integer 11040).
+CUDART_GRAPHS_MIN_VERSION = 11040
+
+# --- CUDA Graph parameter structs ---
+
+
+class cudaPos(ctypes.Structure):
+ """cudaPos: {x, y, z} offsets into an array or pitched memory."""
+
+ _fields_ = [("x", c_size_t), ("y", c_size_t), ("z", c_size_t)]
+
+
+class cudaPitchedPtr(ctypes.Structure):
+ """cudaPitchedPtr: pointer + pitch metadata for 2D/3D copies."""
+
+ _fields_ = [
+ ("ptr", c_void_p),
+ ("pitch", c_size_t),
+ ("xsize", c_size_t),
+ ("ysize", c_size_t),
+ ]
+
+
+class cudaExtent(ctypes.Structure):
+ """cudaExtent: width/height/depth dimensions in bytes for 3D copies."""
+
+ _fields_ = [("width", c_size_t), ("height", c_size_t), ("depth", c_size_t)]
+
+
+class cudaMemcpy3DParms(ctypes.Structure):
+ """cudaMemcpy3DParms: full parameter struct for cudaMemcpy3D and graph node updates."""
+
+ _fields_ = [
+ ("srcArray", c_void_p), # cudaArray_t β NULL for linear memory
+ ("srcPos", cudaPos),
+ ("srcPtr", cudaPitchedPtr),
+ ("dstArray", c_void_p), # cudaArray_t β NULL for linear memory
+ ("dstPos", cudaPos),
+ ("dstPtr", cudaPitchedPtr),
+ ("extent", cudaExtent),
+ ("kind", c_int), # cudaMemcpyKind
+ ]
+
+
+# CUDA IPC Handle structure (64 bytes, CUDA_IPC_HANDLE_SIZE per NVIDIA spec)
+class cudaIpcMemHandle_t(ctypes.Structure):
+ """CUDA IPC memory handle structure.
+
+ This opaque handle can be transferred between processes via
+ SharedMemory or other IPC mechanisms to enable GPU memory sharing.
+ """
+
+ _fields_ = [("internal", ctypes.c_byte * 64)]
+
+
+# CUDA IPC Event Handle structure (64 bytes per NVIDIA spec)
+class cudaIpcEventHandle_t(ctypes.Structure):
+ """CUDA IPC event handle structure.
+
+ Used for lightweight cross-process synchronization.
+ """
+
+ _fields_ = [("reserved", ctypes.c_byte * 64)]
+
+
+# CUDA pointer attributes β memory type and owning device for a GPU pointer
+class cudaPointerAttributes(ctypes.Structure):
+ """Result of cudaPointerGetAttributes.
+
+ Useful for validating that a caller-supplied GPU pointer belongs to the
+ expected device before issuing D2D operations (C2 affinity check).
+
+ .type values: 0=unregistered, 1=host, 2=device, 3=managed
+ .device: GPU index that owns the allocation
+ """
+
+ _fields_ = [
+ ("type", c_int), # cudaMemoryType enum (2 = cudaMemoryTypeDevice)
+ ("device", c_int), # GPU device index owning this allocation
+ ("devicePointer", c_void_p),
+ ("hostPointer", c_void_p),
+ ]
+
+
+# CUDA Error codes (subset)
+class CUDAError:
+ """CUDA runtime error codes."""
+
+ SUCCESS = 0
+ INVALID_VALUE = 1
+ MEMORY_ALLOCATION = 2
+ INVALID_DEVICE_POINTER = 17
+ INVALID_DEVICE = 101
+ INVALID_CONTEXT = 201 # Common in same-process IPC testing
+ NOT_READY = 600
+ PEER_ACCESS_ALREADY_ENABLED = 704
+
+ @staticmethod
+ def get_name(code: int) -> str:
+ """Get human-readable error name."""
+ names = {
+ 0: "SUCCESS",
+ 1: "INVALID_VALUE",
+ 2: "MEMORY_ALLOCATION",
+ 17: "INVALID_DEVICE_POINTER",
+ 101: "INVALID_DEVICE",
+ 201: "INVALID_CONTEXT",
+ 600: "NOT_READY",
+ 704: "PEER_ACCESS_ALREADY_ENABLED",
+ }
+ return names.get(code, f"UNKNOWN_ERROR_{code}")
diff --git a/src/streamdiffusion/_compat/cuda_ipc/debug_utils.py b/src/streamdiffusion/_compat/cuda_ipc/debug_utils.py
new file mode 100644
index 000000000..03405e33e
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/debug_utils.py
@@ -0,0 +1,204 @@
+"""
+Debug and profiling utilities for CUDA operations.
+
+Extracted subset from StreamDiffusion project.
+Requires PyTorch to be installed.
+"""
+
+from __future__ import annotations
+
+import logging
+from collections.abc import Callable
+from typing import Any
+
+
+logger = logging.getLogger(__name__)
+
+try:
+ import torch
+
+ TORCH_AVAILABLE = True
+except ImportError:
+ torch = None
+ TORCH_AVAILABLE = False
+
+
+def benchmark_with_events(
+ fn: Callable[..., Any], *args: Any, warmup: int = 3, iterations: int = 10, **kwargs: Any
+) -> float:
+ """
+ GPU-accurate timing using CUDA events.
+
+ Args:
+ fn: Function to benchmark
+ *args: Positional arguments for fn
+ warmup: Number of warmup iterations
+ iterations: Number of timed iterations
+ **kwargs: Keyword arguments for fn
+
+ Returns:
+ Average time per iteration in milliseconds
+
+ Raises:
+ ImportError: If PyTorch is not installed
+ """
+ if not TORCH_AVAILABLE:
+ raise ImportError(
+ "benchmark_with_events requires PyTorch. "
+ "Install with: pip install cuda-link[torch] or pip install torch>=2.0"
+ )
+
+ # Warmup
+ for _ in range(warmup):
+ fn(*args, **kwargs)
+
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+
+ torch.cuda.synchronize()
+ start.record()
+ for _ in range(iterations):
+ fn(*args, **kwargs)
+ end.record()
+ torch.cuda.synchronize()
+
+ return start.elapsed_time(end) / iterations
+
+
+class ProfileSection:
+ """Context manager for profiling code sections with CUDA events.
+
+ Usage:
+ with ProfileSection("UNet Forward"):
+ output = unet(input)
+ # Prints: [PROFILE] UNet Forward: 45.3ms
+ """
+
+ def __init__(self, name: str, enabled: bool = True) -> None:
+ """Initialize profiling section with a name and optional enable flag.
+
+ Args:
+ name: Label shown in the profiling log output.
+ enabled: If False, profiling is a no-op. Defaults to True.
+ """
+ if not TORCH_AVAILABLE:
+ raise ImportError(
+ "ProfileSection requires PyTorch. Install with: pip install cuda-link[torch] or pip install torch>=2.0"
+ )
+ self.name = name
+ self.enabled = enabled and torch.cuda.is_available()
+ self.start_event = None
+ self.end_event = None
+
+ def __enter__(self) -> ProfileSection:
+ """Enter profiling context and record CUDA start event."""
+ if self.enabled:
+ self.start_event = torch.cuda.Event(enable_timing=True)
+ self.end_event = torch.cuda.Event(enable_timing=True)
+ torch.cuda.synchronize()
+ self.start_event.record()
+ return self
+
+ def __exit__(self, *args: object) -> None:
+ """Exit profiling context, synchronize GPU, and log elapsed time."""
+ if self.enabled:
+ self.end_event.record()
+ torch.cuda.synchronize()
+ elapsed = self.start_event.elapsed_time(self.end_event)
+ logger.debug("[PROFILE] %s: %.1fms", self.name, elapsed)
+
+
+# ---------------------------------------------------------------------------
+# snoop helpers
+# ---------------------------------------------------------------------------
+
+
+def create_snoop_config(
+ out: str | None = None,
+ *,
+ enabled: bool = True,
+) -> Any:
+ """Create a snoop.Config with timestamp column output.
+
+ Note: call-depth tracing is configured per-decorator via ``cfg.snoop(depth=N)``,
+ not at Config creation time.
+
+ Args:
+ out: Output destination β file path string, or None for stderr.
+ enabled: Set to False to get a no-op config object.
+
+ Returns:
+ ``snoop.Config`` instance, or ``None`` if snoop is not installed.
+
+ Example::
+
+ cfg = create_snoop_config(out="debug.log")
+ if cfg:
+ @cfg.snoop(depth=2, watch=("self.write_idx",))
+ def _initialize(self): ...
+ """
+ try:
+ import snoop as _snoop
+
+ kwargs: dict[str, object] = {
+ "columns": "time",
+ "enabled": enabled,
+ }
+ if out is not None:
+ kwargs["out"] = out
+ return _snoop.Config(**kwargs)
+ except ImportError:
+ logger.debug("snoop not installed; create_snoop_config() is a no-op")
+ return None
+
+
+def snoop_decorator(
+ fn: Callable[..., Any] | None = None,
+ *,
+ depth: int = 1,
+ watch: tuple[str, ...] = (),
+ enabled: bool = True,
+) -> Callable[..., Any]:
+ """Return a @snoop decorator, or a transparent no-op if snoop is unavailable.
+
+ Designed so ``@snoop_decorator`` can be left on functions in development
+ branches without breaking production (no snoop installed = zero overhead).
+
+ Args:
+ fn: When used as ``@snoop_decorator`` (no args), receives the function
+ directly. When used as ``@snoop_decorator(depth=2)``, is None.
+ depth: Levels of called functions to trace.
+ watch: Extra expressions to evaluate and display (e.g.
+ ``("self.write_idx", "slot")``).
+ enabled: Pass ``False`` to get a no-op decorator regardless of
+ whether snoop is installed.
+
+ Returns:
+ Decorator or decorated function.
+
+ Example::
+
+ @snoop_decorator(depth=2, watch=("self.write_idx", "slot"))
+ def export_frame(self, gpu_ptr, size):
+ ...
+ """
+
+ def _noop(f: Callable[..., Any]) -> Callable[..., Any]:
+ return f
+
+ try:
+ import snoop as _snoop
+ except ImportError:
+ decorator: Callable[..., Any] = _noop
+ else:
+ if not enabled:
+ decorator = _noop
+ elif watch:
+ decorator = _snoop(depth=depth, watch=watch)
+ else:
+ decorator = _snoop(depth=depth)
+
+ if fn is not None:
+ # Called as @snoop_decorator with no arguments
+ return decorator(fn)
+ return decorator
diff --git a/src/streamdiffusion/_compat/cuda_ipc/nvml_observer.py b/src/streamdiffusion/_compat/cuda_ipc/nvml_observer.py
new file mode 100644
index 000000000..f0a80ae52
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/nvml_observer.py
@@ -0,0 +1,222 @@
+"""
+NVML Observability Hook for CUDA Link.
+
+Optional GPU telemetry via NVIDIA Management Library (pynvml).
+Follows the existing optional-dep pattern: import fails gracefully,
+NVML_AVAILABLE flag gates all usage.
+
+Usage:
+ from cuda_link import NVMLObserver, NVML_AVAILABLE
+
+ if NVML_AVAILABLE:
+ obs = NVMLObserver(device=0)
+ obs.start()
+ exporter.attach_nvml_observer(obs)
+ stats = exporter.get_stats() # includes stats["nvml"]
+ obs.stop()
+
+Install optional dep:
+ pip install "cuda-link[nvml]" # adds nvidia-ml-py>=12.535
+"""
+
+from __future__ import annotations
+
+import contextlib
+import logging
+import os
+import threading
+
+
+logger = logging.getLogger(__name__)
+
+try:
+ import pynvml
+
+ NVML_AVAILABLE = True
+except ImportError:
+ pynvml = None # type: ignore[assignment]
+ NVML_AVAILABLE = False
+
+
+class _NvmlRefCounter:
+ """Process-global ref-count for nvmlInit/nvmlShutdown.
+
+ Tolerates multiple NVMLObserver instances without double-init/shutdown errors.
+ """
+
+ def __init__(self) -> None:
+ self._count: int = 0
+ self._lock: threading.Lock = threading.Lock()
+
+ def acquire(self) -> None:
+ if not NVML_AVAILABLE:
+ return
+ with self._lock:
+ if self._count == 0:
+ pynvml.nvmlInit()
+ self._count += 1
+
+ def release(self) -> None:
+ if not NVML_AVAILABLE:
+ return
+ with self._lock:
+ self._count = max(0, self._count - 1)
+ if self._count == 0:
+ pynvml.nvmlShutdown()
+
+
+_NVML_REFS = _NvmlRefCounter()
+
+
+_THROTTLE_NAMES: dict[int, str] = {
+ 0x0000000000000001: "gpu_idle",
+ 0x0000000000000002: "applications_clocks_setting",
+ 0x0000000000000004: "sw_power_cap",
+ 0x0000000000000008: "hw_slowdown",
+ 0x0000000000000010: "sync_boost",
+ 0x0000000000000020: "sw_thermal_slowdown",
+ 0x0000000000000040: "hw_thermal_slowdown",
+ 0x0000000000000080: "hw_power_brake_slowdown",
+ 0x0000000000000100: "display_clocks_setting",
+}
+
+
+def _decode_throttle(bitmask: int) -> list[str]:
+ return [name for bit, name in _THROTTLE_NAMES.items() if bitmask & bit]
+
+
+class NVMLObserver:
+ """Pull-based GPU telemetry via pynvml.
+
+ Call snapshot() (or let get_stats() call it) to sample once.
+ No background thread β caller controls cadence.
+
+ Metrics returned by snapshot():
+ gpu_util_pct, mem_bw_util_pct (from nvmlDeviceGetUtilizationRates)
+ mem_used_mb, mem_total_mb (from nvmlDeviceGetMemoryInfo)
+ sm_clock_mhz, mem_clock_mhz (from nvmlDeviceGetClockInfo)
+ pcie_tx_kbps, pcie_rx_kbps (from nvmlDeviceGetPcieThroughput)
+ temp_c (from nvmlDeviceGetTemperature)
+ power_w, power_limit_w (from nvmlDeviceGetPowerUsage)
+ throttle_reasons (decoded bitmask list)
+ driver_model "WDDM" / "TCC" / "MCDM" (Windows only; absent on Linux)
+ """
+
+ def __init__(self, device: int = 0, enabled: bool | None = None) -> None:
+ """Initialize NVML observer.
+
+ Args:
+ device: CUDA device index (default 0).
+ enabled: If None, reads CUDALINK_NVML env var ("1" = enabled).
+ If False, snapshot() returns {"nvml_available": False} immediately.
+ """
+ self.device = device
+ if enabled is None:
+ self.enabled = os.getenv("CUDALINK_NVML", "0") == "1"
+ else:
+ self.enabled = enabled
+ self._handle = None
+ self._started = False
+ self._driver_model: str | None = None
+
+ def start(self) -> bool:
+ """Initialize NVML and open device handle.
+
+ Returns:
+ True if NVML is available and handle opened, False otherwise.
+ """
+ if not NVML_AVAILABLE or not self.enabled:
+ return False
+ if self._started:
+ return True
+ try:
+ _NVML_REFS.acquire()
+ self._handle = pynvml.nvmlDeviceGetHandleByIndex(self.device)
+ with contextlib.suppress(pynvml.NVMLError):
+ # Raises NVMLError_NotSupported on Linux (driver-model is Windows-only).
+ _model = pynvml.nvmlDeviceGetCurrentDriverModel(self._handle)
+ _names = {
+ pynvml.NVML_DRIVER_WDDM: "WDDM",
+ pynvml.NVML_DRIVER_WDM: "TCC",
+ }
+ if hasattr(pynvml, "NVML_DRIVER_MCDM"):
+ _names[pynvml.NVML_DRIVER_MCDM] = "MCDM"
+ self._driver_model = _names.get(_model, f"unknown({_model})")
+ self._started = True
+ return True
+ except (pynvml.NVMLError, RuntimeError, OSError) as e:
+ logger.warning("NVML start failed for device %d: %s", self.device, e)
+ return False
+
+ def stop(self) -> None:
+ """Release NVML handle and decrement global ref-count."""
+ if self._started:
+ _NVML_REFS.release()
+ self._handle = None
+ self._started = False
+
+ def __enter__(self) -> NVMLObserver:
+ self.start()
+ return self
+
+ def __exit__(self, *_: object) -> None:
+ self.stop()
+
+ def snapshot(self) -> dict:
+ """Sample all GPU metrics once (non-blocking, ~50-200Β΅s total).
+
+ Returns:
+ Dict of metric name β value. If NVML is unavailable or not started,
+ returns {"nvml_available": False}.
+ """
+ if not self._started or self._handle is None:
+ return {"nvml_available": False}
+
+ out: dict = {"nvml_available": True}
+ h = self._handle
+
+ try:
+ util = pynvml.nvmlDeviceGetUtilizationRates(h)
+ out["gpu_util_pct"] = util.gpu
+ out["mem_bw_util_pct"] = util.memory
+ except pynvml.NVMLError:
+ pass
+
+ try:
+ mem = pynvml.nvmlDeviceGetMemoryInfo(h)
+ out["mem_used_mb"] = mem.used / (1024 * 1024)
+ out["mem_total_mb"] = mem.total / (1024 * 1024)
+ except pynvml.NVMLError:
+ pass
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["sm_clock_mhz"] = pynvml.nvmlDeviceGetClockInfo(h, pynvml.NVML_CLOCK_SM)
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["mem_clock_mhz"] = pynvml.nvmlDeviceGetClockInfo(h, pynvml.NVML_CLOCK_MEM)
+
+ try:
+ out["pcie_tx_kbps"] = pynvml.nvmlDeviceGetPcieThroughput(h, pynvml.NVML_PCIE_UTIL_TX_BYTES)
+ out["pcie_rx_kbps"] = pynvml.nvmlDeviceGetPcieThroughput(h, pynvml.NVML_PCIE_UTIL_RX_BYTES)
+ except pynvml.NVMLError:
+ pass
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["temp_c"] = pynvml.nvmlDeviceGetTemperature(h, pynvml.NVML_TEMPERATURE_GPU)
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["power_w"] = pynvml.nvmlDeviceGetPowerUsage(h) / 1000.0
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["power_limit_w"] = pynvml.nvmlDeviceGetEnforcedPowerLimit(h) / 1000.0
+
+ try:
+ bitmask = pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(h)
+ out["throttle_reasons"] = _decode_throttle(bitmask)
+ except pynvml.NVMLError:
+ pass
+
+ if self._driver_model is not None:
+ out["driver_model"] = self._driver_model
+
+ return out
diff --git a/src/streamdiffusion/_compat/cuda_ipc/py.typed b/src/streamdiffusion/_compat/cuda_ipc/py.typed
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/streamdiffusion/_compat/cuda_ipc/shm_protocol.py b/src/streamdiffusion/_compat/cuda_ipc/shm_protocol.py
new file mode 100644
index 000000000..c1953f9dc
--- /dev/null
+++ b/src/streamdiffusion/_compat/cuda_ipc/shm_protocol.py
@@ -0,0 +1,310 @@
+"""
+SHM protocol v0.5.0 β canonical source of truth for the CUDA IPC shared-memory layout.
+
+All binary constants, struct codecs, dtype mappings, and publish/acquire ordering
+live here. Every module that reads or writes the SHM region must import from here;
+never define SHM_HEADER_SIZE or _ST_U32 locally.
+
+Binary layout (total = SHMLayout(num_slots).total_size):
+ [0-3] magic uint32 LE = PROTOCOL_MAGIC
+ [4-11] version uint64 LE (monotonic; incremented each sender init)
+ [12-15] num_slots uint32 LE
+ [16-19] write_idx uint32 LE (monotonic; 0 = no frames written yet)
+ [20 + slot*128 ...] IPC handles (64B mem + 64B event per slot, N slots)
+ [20 + N*128] shutdown_flag uint8
+ [21 + N*128 ...] metadata 20B (width/height/num_comps/kind/bits/flags/data_size)
+ [41 + N*128 ...] timestamp float64 LE (producer wall-clock time)
+"""
+
+from __future__ import annotations
+
+import struct
+import threading
+from dataclasses import dataclass
+from enum import Enum
+
+
+# ---------------------------------------------------------------------------
+# Protocol constants
+# ---------------------------------------------------------------------------
+
+PROTOCOL_MAGIC: int = 0x43495044 # "CIPD" β protocol validation magic number (v1.0.0)
+
+MAGIC_OFFSET: int = 0
+MAGIC_SIZE: int = 4
+VERSION_OFFSET: int = 4
+VERSION_SIZE: int = 8
+NUM_SLOTS_OFFSET: int = 12
+NUM_SLOTS_SIZE: int = 4
+WRITE_IDX_OFFSET: int = 16
+WRITE_IDX_SIZE: int = 4
+SHM_HEADER_SIZE: int = 20 # 4B magic + 8B version + 4B num_slots + 4B write_idx
+
+SLOT_SIZE: int = 128 # 64B cudaIpcMemHandle_t + 64B cudaIpcEventHandle_t
+
+SHUTDOWN_FLAG_SIZE: int = 1
+METADATA_SIZE: int = 20 # 4B width + 4B height + 4B num_comps + 1B kind + 1B bits + 2B flags + 4B data_size
+TIMESTAMP_SIZE: int = 8 # float64 LE producer wall-clock time
+
+# ---------------------------------------------------------------------------
+# DtypeCodec constants (cudaChannelFormatKind)
+# ---------------------------------------------------------------------------
+
+FORMAT_KIND_SIGNED: int = 0 # cudaChannelFormatKindSigned
+FORMAT_KIND_UNSIGNED: int = 1 # cudaChannelFormatKindUnsigned
+FORMAT_KIND_FLOAT: int = 2 # cudaChannelFormatKindFloat
+FLAGS_BFLOAT16: int = 0x0001 # bit0: bfloat16 (kind=Float, bits=16)
+
+# dtype string β (format_kind, bits_per_component, flags)
+_DTYPE_TO_KIND_BITS: dict[str, tuple[int, int, int]] = {
+ "float32": (FORMAT_KIND_FLOAT, 32, 0),
+ "float16": (FORMAT_KIND_FLOAT, 16, 0),
+ "uint8": (FORMAT_KIND_UNSIGNED, 8, 0),
+ "uint16": (FORMAT_KIND_UNSIGNED, 16, 0),
+}
+
+# ---------------------------------------------------------------------------
+# Pre-compiled struct codecs (hot-path, saves ~50-100ns per call)
+# ---------------------------------------------------------------------------
+
+_ST_U32 = struct.Struct(" None:
+ with _fence_lock:
+ pass
+
+
+# ---------------------------------------------------------------------------
+# SHMLayout β pre-computes all byte offsets for a given num_slots
+# ---------------------------------------------------------------------------
+
+
+@dataclass(frozen=True)
+class SHMLayout:
+ """Pre-computed byte offsets for a SHM region with num_slots IPC slots."""
+
+ num_slots: int
+
+ def slot_offset(self, i: int) -> int:
+ return SHM_HEADER_SIZE + i * SLOT_SIZE
+
+ @property
+ def shutdown_offset(self) -> int:
+ return SHM_HEADER_SIZE + self.num_slots * SLOT_SIZE
+
+ @property
+ def metadata_offset(self) -> int:
+ return self.shutdown_offset + SHUTDOWN_FLAG_SIZE
+
+ @property
+ def timestamp_offset(self) -> int:
+ return self.metadata_offset + METADATA_SIZE
+
+ @property
+ def total_size(self) -> int:
+ return self.timestamp_offset + TIMESTAMP_SIZE
+
+
+# ---------------------------------------------------------------------------
+# Metadata β typed representation of the 20-byte metadata region
+# ---------------------------------------------------------------------------
+
+
+@dataclass(frozen=True)
+class Metadata:
+ """Typed representation of the 20-byte metadata region."""
+
+ width: int
+ height: int
+ num_comps: int
+ format_kind: int # cudaChannelFormatKind
+ bits_per_comp: int
+ flags: int
+ data_size: int
+
+ def pack_into(self, buf: memoryview, layout: SHMLayout) -> None:
+ offset = layout.metadata_offset
+ _ST_U32.pack_into(buf, offset, self.width)
+ _ST_U32.pack_into(buf, offset + 4, self.height)
+ _ST_U32.pack_into(buf, offset + 8, self.num_comps)
+ _ST_BBH.pack_into(buf, offset + 12, self.format_kind, self.bits_per_comp, self.flags)
+ _ST_U32.pack_into(buf, offset + 16, self.data_size)
+
+ @classmethod
+ def read_from(cls, buf: memoryview, layout: SHMLayout) -> Metadata:
+ offset = layout.metadata_offset
+ width = _ST_U32.unpack_from(buf, offset)[0]
+ height = _ST_U32.unpack_from(buf, offset + 4)[0]
+ num_comps = _ST_U32.unpack_from(buf, offset + 8)[0]
+ kind, bits, flags = _ST_BBH.unpack_from(buf, offset + 12)
+ data_size = _ST_U32.unpack_from(buf, offset + 16)[0]
+ return cls(
+ width=width,
+ height=height,
+ num_comps=num_comps,
+ format_kind=kind,
+ bits_per_comp=bits,
+ flags=flags,
+ data_size=data_size,
+ )
+
+
+# ---------------------------------------------------------------------------
+# DtypeCodec β encode/decode dtype strings
+# ---------------------------------------------------------------------------
+
+
+class DtypeCodec:
+ """Encode/decode dtype strings to/from (format_kind, bits_per_comp, flags).
+
+ Folds _DTYPE_TO_KIND_BITS (exporter) and _decode_dtype_str (importer).
+ Adding a dtype is a single-file edit here.
+ """
+
+ @staticmethod
+ def encode(dtype: str) -> tuple[int, int, int]:
+ """dtype string β (format_kind, bits_per_comp, flags).
+
+ Raises:
+ KeyError: if dtype is not supported.
+ """
+ return _DTYPE_TO_KIND_BITS[dtype]
+
+ @staticmethod
+ def decode(kind: int, bits: int, flags: int) -> str:
+ """(format_kind, bits_per_comp, flags) β dtype string."""
+ if kind == FORMAT_KIND_FLOAT and bits == 16 and not (flags & FLAGS_BFLOAT16):
+ return "float16"
+ if kind == FORMAT_KIND_FLOAT:
+ return "float32"
+ if bits == 8:
+ return "uint8"
+ if bits == 16:
+ return "uint16"
+ return "float32" # safe fallback for future extensions
+
+
+# ---------------------------------------------------------------------------
+# Header helpers β read/write the 20-byte header region
+# ---------------------------------------------------------------------------
+
+
+def read_magic(buf: memoryview) -> int:
+ return _ST_U32.unpack_from(buf, MAGIC_OFFSET)[0]
+
+
+def read_version(buf: memoryview) -> int:
+ return _ST_U64.unpack_from(buf, VERSION_OFFSET)[0]
+
+
+def read_num_slots(buf: memoryview) -> int:
+ return _ST_U32.unpack_from(buf, NUM_SLOTS_OFFSET)[0]
+
+
+def read_write_idx(buf: memoryview) -> int:
+ return _ST_U32.unpack_from(buf, WRITE_IDX_OFFSET)[0]
+
+
+def bump_version(buf: memoryview) -> int:
+ """Increment the version counter in-place; return the new version."""
+ try:
+ current = read_version(buf)
+ except (struct.error, ValueError, IndexError):
+ current = 0
+ new_version = current + 1
+ _ST_U64.pack_into(buf, VERSION_OFFSET, new_version)
+ return new_version
+
+
+# ---------------------------------------------------------------------------
+# publish_frame β the only place that encodes the C3 ordering guarantee
+# ---------------------------------------------------------------------------
+
+
+def publish_frame(buf: memoryview, layout: SHMLayout, write_idx: int, timestamp: float) -> None:
+ """Write timestamp, clear shutdown_flag, fence, then publish write_idx LAST.
+
+ Ordering is critical: the consumer reads shutdown_flag BEFORE write_idx.
+ Clearing shutdown_flag before incrementing write_idx ensures the consumer
+ always sees shutdown_flag=0 when it first observes a new frame.
+
+ Callers must not replicate this sequence outside this function.
+ """
+ _ST_F64.pack_into(buf, layout.timestamp_offset, timestamp)
+ buf[layout.shutdown_offset] = 0
+ _release_fence() # C3 release barrier: shutdown_flag visible before write_idx
+ _ST_U32.pack_into(buf, WRITE_IDX_OFFSET, write_idx)
+
+
+# ---------------------------------------------------------------------------
+# acquire_slot β consumer-side frame acquisition
+# ---------------------------------------------------------------------------
+
+
+class SlotState(Enum):
+ NO_FRAME = "no_frame"
+ NEW_FRAME = "new_frame"
+ SHUTDOWN = "shutdown"
+ VERSION_CHANGED = "version_changed"
+
+
+@dataclass
+class AcquireResult:
+ """Result of acquire_slot()."""
+
+ state: SlotState
+ slot: int = -1
+ timestamp: float = 0.0
+ new_version: int = 0
+ write_idx: int = 0
+
+
+def acquire_slot(
+ buf: memoryview,
+ layout: SHMLayout,
+ last_write_idx: int,
+ last_version: int,
+) -> AcquireResult:
+ """Read SHM state and classify the result for the consumer.
+
+ Returns an AcquireResult with one of four states:
+ - NO_FRAME: write_idx unchanged; nothing to consume.
+ - NEW_FRAME: new frame at .slot; read and process it, update last_write_idx to .write_idx.
+ - SHUTDOWN: shutdown_flag=1; producer has exited, consumer should clean up.
+ - VERSION_CHANGED: SHM was re-initialised; consumer must reopen IPC handles.
+
+ Folds _get_read_slot() (importer) and the three identical preambles in
+ get_frame / get_frame_numpy / get_frame_cupy into one location.
+ """
+ if buf[layout.shutdown_offset] != 0:
+ return AcquireResult(state=SlotState.SHUTDOWN)
+
+ version = read_version(buf)
+ if version != last_version and last_version != 0:
+ return AcquireResult(state=SlotState.VERSION_CHANGED, new_version=version)
+
+ write_idx = read_write_idx(buf)
+ if write_idx == 0 or write_idx == last_write_idx:
+ return AcquireResult(state=SlotState.NO_FRAME)
+
+ slot = (write_idx - 1) % layout.num_slots
+ try:
+ timestamp = _ST_F64.unpack_from(buf, layout.timestamp_offset)[0]
+ except struct.error:
+ timestamp = 0.0
+
+ return AcquireResult(state=SlotState.NEW_FRAME, slot=slot, timestamp=timestamp, write_idx=write_idx)
diff --git a/src/streamdiffusion/_compat/diffusers_kvo_patch.py b/src/streamdiffusion/_compat/diffusers_kvo_patch.py
new file mode 100644
index 000000000..a45e82b26
--- /dev/null
+++ b/src/streamdiffusion/_compat/diffusers_kvo_patch.py
@@ -0,0 +1,792 @@
+"""Re-applies varshith15/diffusers@3e3b72f kvo_cache patch onto upstream diffusers.
+
+Source fork: https://github.com/varshith15/diffusers @ 3e3b72f557e91546894340edabc845e894f00922
+Target: diffusers >= 0.38.0 (upstream HuggingFace)
+
+The patch threads an optional KV-cache through the UNet2DConditionModel forward pass so
+that StreamDiffusion's TRT cached-attention pipeline (unet_unified_export.py) works with
+vanilla upstream diffusers. When kvo_cache=None (the default) behaviour is identical to
+upstream diffusers β no performance or correctness impact for non-cached paths.
+
+Called automatically at ``import streamdiffusion`` via _compat/__init__.py.
+"""
+
+from __future__ import annotations
+
+import inspect
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+_PATCHED = False
+
+# Sentinel used by _patch_down_block so we can distinguish "caller passed kvo_cache=None
+# explicitly" (patched UNet β must return 3-tuple) from "caller never passed kvo_cache"
+# (ControlNet / any other diffusers module that uses CrossAttnDownBlock2D β must return the
+# original 2-tuple to avoid "too many values to unpack (expected 2)").
+_KVO_NOT_GIVEN = object()
+
+
+# ---------------------------------------------------------------------------
+# Public entry point
+# ---------------------------------------------------------------------------
+
+
+def apply() -> None:
+ """Apply kvo_cache patch. Idempotent β safe to call multiple times."""
+ global _PATCHED
+ if _PATCHED:
+ return
+ if _is_patched():
+ _PATCHED = True
+ return
+
+ _patch_attn_processor()
+ _patch_attention_forward()
+ _patch_basic_transformer_block()
+ _patch_transformer2d()
+ _patch_mid_block()
+ _patch_down_block()
+ _patch_up_block()
+ _patch_unet2d()
+
+ _PATCHED = True
+ logger.debug("diffusers_kvo_patch: kvo_cache patch applied")
+
+
+def _is_patched() -> bool:
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+
+ return "kvo_cache" in inspect.signature(UNet2DConditionModel.forward).parameters
+
+
+# ---------------------------------------------------------------------------
+# Individual patches
+# ---------------------------------------------------------------------------
+
+
+def _patch_attn_processor() -> None:
+ """AttnProcessor2_0.__call__: accept kvo_cache kwarg, return (hidden_states, kvo_cache)."""
+ from diffusers.models.attention_processor import AttnProcessor2_0
+
+ _orig = AttnProcessor2_0.__call__
+
+ def _call(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ kvo_cache=None,
+ *args,
+ **kwargs,
+ ):
+ result = _orig(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ temb=temb,
+ *args,
+ **kwargs,
+ )
+ return result, kvo_cache
+
+ AttnProcessor2_0.__call__ = _call
+
+
+def _patch_attention_forward() -> None:
+ """Attention.forward: accept kvo_cache, route to processor only for self-attn."""
+ from diffusers.models.attention_processor import Attention
+ from diffusers.utils import logging as _dlog
+
+ _dlogger = _dlog.get_logger("diffusers.models.attention_processor")
+
+ def _forward(
+ self, hidden_states, encoder_hidden_states=None, attention_mask=None, kvo_cache=None, **cross_attention_kwargs
+ ):
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [
+ k for k in cross_attention_kwargs if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if unused_kwargs:
+ _dlogger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by "
+ f"{self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ if encoder_hidden_states is None:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ kvo_cache=kvo_cache,
+ **cross_attention_kwargs,
+ )
+ else:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ Attention.forward = _forward
+
+
+def _patch_basic_transformer_block() -> None:
+ """BasicTransformerBlock.forward: thread kvo_cache through attn1, handle attn2 tuple return."""
+ from diffusers.models.attention import BasicTransformerBlock
+
+ def _forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ timestep=None,
+ cross_attention_kwargs=None,
+ class_labels=None,
+ added_cond_kwargs=None,
+ kvo_cache=None,
+ ):
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ batch_size = hidden_states.shape[0]
+
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.norm_type == "ada_norm_zero":
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm1(hidden_states)
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif self.norm_type == "ada_norm_single":
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output, kvo_cache_out = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ kvo_cache=kvo_cache,
+ **cross_attention_kwargs,
+ )
+
+ if self.norm_type == "ada_norm_zero":
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.norm_type == "ada_norm_single":
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ if self.attn2 is not None:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
+ norm_hidden_states = self.norm2(hidden_states)
+ elif self.norm_type == "ada_norm_single":
+ norm_hidden_states = hidden_states
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ if isinstance(attn_output, tuple):
+ attn_output = attn_output[0]
+ hidden_states = attn_output + hidden_states
+
+ if self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif not self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ from diffusers.models.attention import _chunked_feed_forward
+
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.norm_type == "ada_norm_zero":
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.norm_type == "ada_norm_single":
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states, kvo_cache_out
+
+ BasicTransformerBlock.forward = _forward
+
+
+def _patch_transformer2d() -> None:
+ """Transformer2DModel.forward: thread kvo_cache through transformer_blocks."""
+ import torch
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
+ from diffusers.models.transformers.transformer_2d import Transformer2DModel
+
+ def _forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ added_cond_kwargs=None,
+ class_labels=None,
+ cross_attention_kwargs=None,
+ attention_mask=None,
+ encoder_attention_mask=None,
+ kvo_cache=None,
+ return_dict=True,
+ ):
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ if attention_mask is not None and attention_mask.ndim == 2:
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if self.is_input_continuous:
+ batch_size, _, height, width = hidden_states.shape
+ residual = hidden_states
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
+ )
+
+ kvo_cache_out = []
+ for idx, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ cross_attention_kwargs,
+ class_labels,
+ )
+ else:
+ block_cache_in = kvo_cache[idx] if kvo_cache else None
+ hidden_states, block_cache_out = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ kvo_cache=block_cache_in,
+ )
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+
+ if self.is_input_continuous:
+ output = self._get_output_for_continuous_inputs(
+ hidden_states=hidden_states,
+ residual=residual,
+ batch_size=batch_size,
+ height=height,
+ width=width,
+ inner_dim=inner_dim,
+ )
+ elif self.is_input_vectorized:
+ output = self._get_output_for_vectorized_inputs(hidden_states)
+ elif self.is_input_patches:
+ output = self._get_output_for_patched_inputs(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ class_labels=class_labels,
+ embedded_timestep=embedded_timestep,
+ height=height,
+ width=width,
+ )
+
+ if not return_dict:
+ return (output, kvo_cache_out)
+
+ return Transformer2DModelOutput(sample=output)
+
+ Transformer2DModel.forward = _forward
+
+
+def _patch_mid_block() -> None:
+ """UNetMidBlock2DCrossAttn.forward: thread kvo_cache through attention loop.
+
+ Same sentinel trick as _patch_down_block: callers that do not pass kvo_cache explicitly
+ (e.g. ControlNet) get the original single-return-value (hidden_states); the patched UNet
+ path β which always passes kvo_cache= explicitly β gets the 2-tuple (hidden, kvo_cache_out).
+ """
+ import torch
+ from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
+
+ def _forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ encoder_attention_mask=None,
+ kvo_cache=_KVO_NOT_GIVEN,
+ ):
+ _caller_passed_kvo = kvo_cache is not _KVO_NOT_GIVEN
+ if not _caller_passed_kvo:
+ kvo_cache = None
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ kvo_cache_out = []
+ for idx, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ else:
+ block_cache_in = kvo_cache[idx] if kvo_cache else None
+ hidden_states, block_cache_out = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ kvo_cache=block_cache_in,
+ return_dict=False,
+ )
+ hidden_states = resnet(hidden_states, temb)
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+
+ if _caller_passed_kvo:
+ return hidden_states, kvo_cache_out
+ return hidden_states
+
+ UNetMidBlock2DCrossAttn.forward = _forward
+
+
+def _patch_down_block() -> None:
+ """CrossAttnDownBlock2D.forward: thread kvo_cache, return (hidden, output_states, kvo_cache_out).
+
+ Uses _KVO_NOT_GIVEN sentinel so the patched forward stays backward-compatible with callers
+ (e.g. ControlNet) that never pass kvo_cache: those callers get the original 2-tuple return,
+ while _patch_unet2d (which always passes kvo_cache= explicitly) gets the 3-tuple.
+ """
+ import torch
+ from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D
+
+ def _forward(
+ self,
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ encoder_attention_mask=None,
+ additional_residuals=None,
+ kvo_cache=_KVO_NOT_GIVEN,
+ ):
+ _caller_passed_kvo = kvo_cache is not _KVO_NOT_GIVEN
+ if not _caller_passed_kvo:
+ kvo_cache = None
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+ blocks = list(zip(self.resnets, self.attentions))
+ kvo_cache_out = []
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ block_cache_in = kvo_cache[i] if kvo_cache else None
+ hidden_states, block_cache_out = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ kvo_cache=block_cache_in,
+ return_dict=False,
+ )
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+ output_states = output_states + (hidden_states,)
+
+ if _caller_passed_kvo:
+ return hidden_states, output_states, kvo_cache_out
+ return hidden_states, output_states
+
+ CrossAttnDownBlock2D.forward = _forward
+
+
+def _patch_up_block() -> None:
+ """CrossAttnUpBlock2D.forward: thread kvo_cache, return (hidden, kvo_cache_out)."""
+ from diffusers.models.unets.unet_2d_blocks import CrossAttnUpBlock2D
+
+ try:
+ from diffusers.models.unets.unet_2d_blocks import apply_freeu
+ except ImportError:
+ apply_freeu = None
+ import torch
+
+ def _forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ cross_attention_kwargs=None,
+ upsample_size=None,
+ attention_mask=None,
+ encoder_attention_mask=None,
+ kvo_cache=None,
+ ):
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ is_freeu_enabled = (
+ apply_freeu is not None
+ and getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ kvo_cache_out = []
+ for idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ block_cache_in = kvo_cache[idx] if kvo_cache else None
+ hidden_states, block_cache_out = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ kvo_cache=block_cache_in,
+ return_dict=False,
+ )
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states, kvo_cache_out
+
+ CrossAttnUpBlock2D.forward = _forward
+
+
+def _patch_unet2d() -> None:
+ """UNet2DConditionModel.forward: add kvo_cache param, wire through all blocks."""
+ import torch
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
+ from diffusers.utils import deprecate
+
+ def _forward(
+ self,
+ sample,
+ timestep,
+ encoder_hidden_states,
+ class_labels=None,
+ timestep_cond=None,
+ attention_mask=None,
+ cross_attention_kwargs=None,
+ added_cond_kwargs=None,
+ down_block_additional_residuals=None,
+ mid_block_additional_residual=None,
+ down_intrablock_additional_residuals=None,
+ encoder_attention_mask=None,
+ kvo_cache=None,
+ return_dict=True,
+ ):
+ default_overall_up_factor = 2**self.num_upsamplers
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ forward_upsample_size = True
+ break
+
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ sample = self.conv_in(sample)
+
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ is_adapter = down_intrablock_additional_residuals is not None
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
+ "and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used "
+ "for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ cache_idx = 0
+ kvo_cache_out = []
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ block_cache_in = kvo_cache[cache_idx] if kvo_cache else None
+ sample, res_samples, block_cache_out = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ kvo_cache=block_cache_in,
+ **additional_residuals,
+ )
+ cache_idx += 1
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+ down_block_res_samples = new_down_block_res_samples
+
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ block_cache_in = kvo_cache[cache_idx] if kvo_cache else None
+ sample, block_cache_out = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ kvo_cache=block_cache_in,
+ )
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+ cache_idx += 1
+ else:
+ sample = self.mid_block(sample, emb)
+
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ block_cache_in = kvo_cache[cache_idx] if kvo_cache else None
+ sample, block_cache_out = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ kvo_cache=block_cache_in,
+ )
+ cache_idx += 1
+ if block_cache_out is not None:
+ kvo_cache_out.append(block_cache_out)
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample, kvo_cache_out)
+
+ return UNet2DConditionOutput(sample=sample)
+
+ UNet2DConditionModel.forward = _forward
diff --git a/src/streamdiffusion/_compat/td_exporter/ActivationBarrier.py b/src/streamdiffusion/_compat/td_exporter/ActivationBarrier.py
new file mode 100644
index 000000000..bd1fea403
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/ActivationBarrier.py
@@ -0,0 +1,106 @@
+"""Cross-process SHM activation barrier for cuda-link.
+
+Coordinates Python producer <-> TD-side Sender activation windows.
+When a Sender is initializing, it increments active_count; the producer
+skips export_frame while non-zero (best-effort, no OS atomics needed β
+the 5 s stale-timeout recovers from any stuck state).
+
+Segment layout (64 bytes, little-endian):
+ Offset Size Field Description
+ ------ ---- ----- -----------
+ 0 4 magic 0xCDA1BAAA β guards against alien segments
+ 4 4 version 1 β bumped if layout changes
+ 8 4 active_count Number of Senders inside an activation window
+ 12 4 _pad Align last_change_ns to 8 bytes
+ 16 8 last_change_ns time.monotonic_ns() of most recent write
+ 24 4 barrier_skips Producer-incremented skip-frame counter
+ 28 4 last_writer_pid Diagnostic: PID of last active_count writer
+ 32 32 reserved Zero-filled; reserved for future fields
+"""
+
+from __future__ import annotations
+
+import struct
+import time
+from multiprocessing.shared_memory import SharedMemory
+
+
+SHM_NAME = "cudalink_activation_barrier"
+SHM_SIZE = 64
+MAGIC = 0xCDA1BAAA
+VERSION = 1
+
+# Struct: magic(u32) version(u32) active_count(u32) pad(u32) last_change_ns(u64)
+# barrier_skips(u32) last_writer_pid(u32) reserved(32s)
+_STRUCT = struct.Struct(" SharedMemory:
+ """Open the existing segment or create and initialise it.
+
+ Args:
+ create: When True, create the segment on FileNotFoundError and write
+ the magic/version header. When False, raise FileNotFoundError
+ if the segment does not yet exist.
+
+ Returns:
+ Open SharedMemory handle (caller must close when done).
+ """
+ try:
+ return SharedMemory(name=SHM_NAME)
+ except FileNotFoundError:
+ if not create:
+ raise
+ shm = SharedMemory(name=SHM_NAME, create=True, size=SHM_SIZE)
+ _STRUCT.pack_into(shm.buf, 0, MAGIC, VERSION, 0, 0, 0, 0, 0, b"\x00" * 32)
+ return shm
+
+
+def read_state(shm: SharedMemory) -> tuple[int, int, int]:
+ """Return (active_count, last_change_ns, barrier_skips).
+
+ Snapshot-reads the full 64-byte segment to avoid tearing.
+ """
+ fields = _STRUCT.unpack(bytes(shm.buf[:SHM_SIZE]))
+ # (magic, version, active_count, pad, last_change_ns, barrier_skips, pid, reserved)
+ return fields[2], fields[4], fields[5]
+
+
+def increment(shm: SharedMemory, pid: int) -> int:
+ """Increment active_count, refresh last_change_ns and last_writer_pid.
+
+ Best-effort: no OS-level atomic. Race window is microseconds; the
+ producer-side stale-timeout absorbs any stuck state.
+
+ Returns:
+ New active_count value.
+ """
+ fields = list(_STRUCT.unpack(bytes(shm.buf[:SHM_SIZE])))
+ fields[2] += 1 # active_count
+ fields[4] = time.monotonic_ns() # last_change_ns
+ fields[6] = pid # last_writer_pid
+ _STRUCT.pack_into(shm.buf, 0, *fields)
+ return fields[2]
+
+
+def decrement(shm: SharedMemory, pid: int) -> int:
+ """Decrement active_count (clamps at zero), refresh timestamps.
+
+ Returns:
+ New active_count value.
+ """
+ fields = list(_STRUCT.unpack(bytes(shm.buf[:SHM_SIZE])))
+ fields[2] = max(0, fields[2] - 1) # active_count, no underflow
+ fields[4] = time.monotonic_ns() # last_change_ns
+ fields[6] = pid # last_writer_pid
+ _STRUCT.pack_into(shm.buf, 0, *fields)
+ return fields[2]
+
+
+def bump_skip(shm: SharedMemory) -> None:
+ """Increment barrier_skips counter (producer-only diagnostic)."""
+ fields = list(_STRUCT.unpack(bytes(shm.buf[:SHM_SIZE])))
+ fields[5] += 1 # barrier_skips
+ _STRUCT.pack_into(shm.buf, 0, *fields)
diff --git a/src/streamdiffusion/_compat/td_exporter/CUDAGraphs.py b/src/streamdiffusion/_compat/td_exporter/CUDAGraphs.py
new file mode 100644
index 000000000..6b2e8da32
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/CUDAGraphs.py
@@ -0,0 +1,327 @@
+"""
+CUDA Graphs Mixin β CUDA Graph capture, instantiation, launch, and node-update methods.
+
+Provides CUDAGraphsMixin, mixed into CUDARuntimeAPI to contribute the graph-lifecycle
+API. All methods rely on self.cudart (the cudart DLL handle) and self.check_error from
+the host class.
+
+Shared between the pip package (cuda_link) and TouchDesigner textDATs.
+Compatible with both Python package and TD COMP namespace imports.
+"""
+
+from __future__ import annotations
+
+import ctypes
+from ctypes import byref, c_int, c_size_t, c_void_p
+
+
+try:
+ from cuda_link.cuda_runtime_types import ( # noqa: E402
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaExtent,
+ cudaMemcpy3DParms,
+ cudaPitchedPtr,
+ cudaPos,
+ )
+except ImportError:
+ from CUDARuntimeTypes import ( # type: ignore[no-redef] # noqa: E402
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaExtent,
+ cudaMemcpy3DParms,
+ cudaPitchedPtr,
+ cudaPos,
+ )
+
+
+class CUDAGraphsMixin:
+ """Mixin contributing CUDA Graph lifecycle methods to CUDARuntimeAPI.
+
+ Requires self.cudart (cudart DLL handle) and self.check_error from the host class.
+ """
+
+ # --- Phase 2: CUDA Graph API wrappers ---
+
+ def stream_begin_capture(self, stream: CUDAStream_t, mode: int = 0) -> None:
+ """Begin capturing a stream into a CUDA graph.
+
+ After this call, operations enqueued on *stream* are recorded into a
+ graph rather than executed immediately. End with stream_end_capture().
+
+ Args:
+ stream: Stream to capture.
+ mode: cudaStreamCaptureMode β 0=global (safest), 1=thread_local,
+ 2=relaxed. Use 0 unless you know what you're doing.
+
+ Raises:
+ RuntimeError: If capture start fails (e.g., stream already capturing).
+ """
+ result = self.cudart.cudaStreamBeginCapture(stream, c_int(mode))
+ self.check_error(result, "cudaStreamBeginCapture")
+
+ def stream_end_capture(self, stream: CUDAStream_t) -> CUDAGraph_t:
+ """End stream capture and return the captured graph.
+
+ After this call the stream resumes normal execution mode. The returned
+ graph must be instantiated with graph_instantiate() before use, and
+ destroyed with graph_destroy() when done.
+
+ Args:
+ stream: Stream that was passed to stream_begin_capture().
+
+ Returns:
+ CUDAGraph_t handle to the captured graph template.
+
+ Raises:
+ RuntimeError: If capture end fails.
+ """
+ graph = CUDAGraph_t()
+ result = self.cudart.cudaStreamEndCapture(stream, byref(graph))
+ self.check_error(result, "cudaStreamEndCapture")
+ return graph
+
+ def graph_instantiate(self, graph: CUDAGraph_t, flags: int = 0) -> CUDAGraphExec_t:
+ """Instantiate a graph template into an executable graph.
+
+ The executable graph (CUDAGraphExec_t) can be launched repeatedly via
+ graph_launch(). The template graph can be destroyed after instantiation.
+
+ Args:
+ graph: CUDAGraph_t template returned by stream_end_capture().
+ flags: cudaGraphInstantiateFlagDeviceLaunch (0x02) for device-side
+ launch; 0 for normal host-side launch.
+
+ Returns:
+ CUDAGraphExec_t executable graph handle.
+
+ Raises:
+ RuntimeError: If instantiation fails.
+ """
+ from ctypes import c_uint64
+
+ graph_exec = CUDAGraphExec_t()
+ result = self.cudart.cudaGraphInstantiateWithFlags(byref(graph_exec), graph, c_uint64(flags))
+ self.check_error(result, "cudaGraphInstantiateWithFlags")
+ return graph_exec
+
+ def graph_launch(self, graph_exec: CUDAGraphExec_t, stream: CUDAStream_t) -> None:
+ """Launch an executable graph on a stream (single WDDM submission).
+
+ This replaces N individual API calls (stream_wait_event, memcpy_async,
+ record_event) with one batched WDDM submission, reducing kernel-mode
+ transition overhead from NΓ~15Β΅s to ~15Β΅s on Windows WDDM.
+
+ Args:
+ graph_exec: Executable graph from graph_instantiate().
+ stream: Stream on which to launch the graph.
+
+ Raises:
+ RuntimeError: If launch fails.
+ """
+ result = self.cudart.cudaGraphLaunch(graph_exec, stream)
+ self.check_error(result, "cudaGraphLaunch")
+
+ def graph_get_nodes(self, graph: CUDAGraph_t) -> list[CUDAGraphNode_t]:
+ """Return all nodes in a graph in topological (capture) order.
+
+ Useful for discovering node handles after stream capture, before the
+ template graph is destroyed.
+
+ Args:
+ graph: CUDAGraph_t template (must NOT yet be destroyed).
+
+ Returns:
+ List of CUDAGraphNode_t handles in capture order:
+ [EventWaitNode (if present), MemcpyNode, EventRecordNode].
+
+ Raises:
+ RuntimeError: If query fails.
+ """
+ count = c_size_t(0)
+ result = self.cudart.cudaGraphGetNodes(graph, None, byref(count))
+ self.check_error(result, "cudaGraphGetNodes (count)")
+ node_array = (CUDAGraphNode_t * count.value)()
+ result = self.cudart.cudaGraphGetNodes(graph, node_array, byref(count))
+ self.check_error(result, "cudaGraphGetNodes (fill)")
+ return list(node_array)
+
+ def graph_destroy(self, graph: CUDAGraph_t) -> None:
+ """Destroy a graph template (not the executable β use graph_exec_destroy for that).
+
+ Args:
+ graph: Template graph to destroy.
+
+ Raises:
+ RuntimeError: If destruction fails.
+ """
+ result = self.cudart.cudaGraphDestroy(graph)
+ self.check_error(result, "cudaGraphDestroy")
+
+ def graph_exec_destroy(self, graph_exec: CUDAGraphExec_t) -> None:
+ """Destroy an executable graph and free its resources.
+
+ Args:
+ graph_exec: Executable graph to destroy.
+
+ Raises:
+ RuntimeError: If destruction fails.
+ """
+ result = self.cudart.cudaGraphExecDestroy(graph_exec)
+ self.check_error(result, "cudaGraphExecDestroy")
+
+ @staticmethod
+ def make_memcpy3d_params(dst: c_void_p, src: c_void_p, count: int, kind: int) -> cudaMemcpy3DParms:
+ """Build a cudaMemcpy3DParms struct for a flat 1D memory copy.
+
+ Represents the copy as a single-row 2D memcpy (height=1, depth=1) so
+ that 'count' bytes are transferred from src to dst. This is the required
+ form for cudaGraphExecMemcpyNodeSetParams even when the original copy was
+ issued as cudaMemcpyAsync (1D form).
+
+ Args:
+ dst: Destination pointer.
+ src: Source pointer.
+ count: Number of bytes to copy.
+ kind: cudaMemcpyKind (3 = DeviceToDevice).
+
+ Returns:
+ Populated cudaMemcpy3DParms instance.
+ """
+ params = cudaMemcpy3DParms()
+ params.srcArray = None
+ params.srcPos = cudaPos(0, 0, 0)
+ params.srcPtr = cudaPitchedPtr(
+ ptr=ctypes.cast(src, c_void_p),
+ pitch=count,
+ xsize=count,
+ ysize=1,
+ )
+ params.dstArray = None
+ params.dstPos = cudaPos(0, 0, 0)
+ params.dstPtr = cudaPitchedPtr(
+ ptr=ctypes.cast(dst, c_void_p),
+ pitch=count,
+ xsize=count,
+ ysize=1,
+ )
+ params.extent = cudaExtent(width=count, height=1, depth=1)
+ params.kind = kind
+ return params
+
+ def graph_exec_memcpy_node_set_params(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ dst: c_void_p,
+ src: c_void_p,
+ count: int,
+ kind: int,
+ ) -> None:
+ """Update src/dst/count/kind of a memcpy node in an executable graph.
+
+ This is a CPU-only operation (no WDDM submission). Changes take effect
+ on the next graph_launch() call. The extent (count) must match the
+ extent used when the graph was captured β only pointer reassignment
+ within the same buffer size is supported.
+
+ Args:
+ graph_exec: Executable graph containing the node.
+ node: MemcpyNode handle from graph_get_nodes().
+ dst: New destination pointer.
+ src: New source pointer.
+ count: Copy size in bytes (must match captured size).
+ kind: cudaMemcpyKind (must match captured kind).
+
+ Raises:
+ RuntimeError: If parameter update fails.
+ """
+ params = self.make_memcpy3d_params(dst, src, count, kind)
+ result = self.cudart.cudaGraphExecMemcpyNodeSetParams(graph_exec, node, byref(params))
+ self.check_error(result, "cudaGraphExecMemcpyNodeSetParams")
+
+ def graph_exec_memcpy_node_set_params_1d(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ dst: c_void_p,
+ src: c_void_p,
+ count: int,
+ kind: int,
+ ) -> None:
+ """Update src/dst/count/kind of a 1D memcpy node in an executable graph.
+
+ Use this for nodes captured from cudaMemcpyAsync (1D form). The 3D variant
+ (graph_exec_memcpy_node_set_params) returns INVALID_VALUE on 1D nodes.
+ Requires CUDA 11.3+.
+ """
+ dst_int = dst.value if isinstance(dst, c_void_p) else int(dst)
+ src_int = src.value if isinstance(src, c_void_p) else int(src)
+ result = self.cudart.cudaGraphExecMemcpyNodeSetParams1D(
+ graph_exec,
+ node,
+ c_void_p(dst_int),
+ c_void_p(src_int),
+ c_size_t(count),
+ c_int(kind),
+ )
+ self.check_error(result, "cudaGraphExecMemcpyNodeSetParams1D")
+
+ def graph_exec_event_record_node_set_event(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ event: CUDAEvent_t,
+ ) -> None:
+ """Update the event recorded by an event-record node. CUDA 11.4+.
+
+ CPU-only β takes effect on next graph_launch(). Use this to update the
+ per-ring-slot IPC event when the ring slot changes between launches.
+
+ Args:
+ graph_exec: Executable graph containing the node.
+ node: EventRecordNode handle from graph_get_nodes().
+ event: New CUDAEvent_t to record.
+
+ Raises:
+ RuntimeError: If update fails.
+ """
+ result = self.cudart.cudaGraphExecEventRecordNodeSetEvent(graph_exec, node, event)
+ self.check_error(result, "cudaGraphExecEventRecordNodeSetEvent")
+
+ def graph_exec_event_wait_node_set_event(
+ self,
+ graph_exec: CUDAGraphExec_t,
+ node: CUDAGraphNode_t,
+ event: CUDAEvent_t,
+ ) -> None:
+ """Update the event waited on by an event-wait node. CUDA 11.4+.
+
+ Args:
+ graph_exec: Executable graph containing the node.
+ node: EventWaitNode handle from graph_get_nodes().
+ event: New CUDAEvent_t to wait on.
+
+ Raises:
+ RuntimeError: If update fails.
+ """
+ result = self.cudart.cudaGraphExecEventWaitNodeSetEvent(graph_exec, node, event)
+ self.check_error(result, "cudaGraphExecEventWaitNodeSetEvent")
+
+ def get_runtime_version(self) -> int:
+ """Return the CUDA runtime version as an int.
+
+ Examples: 11030 = CUDA 11.3, 11040 = CUDA 11.4, 12080 = CUDA 12.8.
+ Used to gate optional API calls when the loaded cudart DLL may be from
+ an older patch level (e.g., TouchDesigner ships ``cudart64_110.dll``).
+ """
+ version = c_int(0)
+ result = self.cudart.cudaRuntimeGetVersion(byref(version))
+ self.check_error(result, "cudaRuntimeGetVersion")
+ return int(version.value)
diff --git a/src/streamdiffusion/_compat/td_exporter/CUDAIPCExtension.py b/src/streamdiffusion/_compat/td_exporter/CUDAIPCExtension.py
new file mode 100644
index 000000000..74e5e1d2b
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/CUDAIPCExtension.py
@@ -0,0 +1,267 @@
+"""
+CUDA IPC Extension for TouchDesigner - Dual-Mode Sender/Receiver
+Supports both exporting (Sender) and importing (Receiver) GPU textures via CUDA IPC
+
+Usage in TouchDesigner:
+ Sender: ext.CUDAIPCExtension.export_frame(top_op)
+ Receiver: ext.CUDAIPCExtension.import_frame(import_buffer)
+
+Architecture:
+ Sender: TD GPU -> cudaMemory() -> Persistent Buffer -> IPC Handle -> SharedMemory
+ Receiver: SharedMemory -> IPC Handle -> Opened GPU Buffer -> scriptTOP.copyCUDAMemory()
+
+Facade: delegates all Sender work to TDSenderEngine and all Receiver work to
+TDReceiverEngine. Mode switches create a fresh engine instance β zero state leak.
+"""
+
+from __future__ import annotations
+
+import contextlib
+
+
+try:
+ from td import COMP, TOP, CUDAMemoryShape
+except ImportError:
+ from typing import Any as COMP
+ from typing import Any as TOP
+
+ CUDAMemoryShape = None
+
+from SHMProtocol import ( # noqa: E402
+ FLAGS_BFLOAT16,
+ FORMAT_KIND_FLOAT,
+ FORMAT_KIND_SIGNED,
+ FORMAT_KIND_UNSIGNED,
+ PROTOCOL_MAGIC,
+ SHM_HEADER_SIZE,
+ SLOT_SIZE,
+)
+from TDConfig import TDSenderConfig # noqa: E402
+from TDHost import RealTDHost, RealTOPHandle, TDHost # noqa: E402
+from TDReceiver import TDReceiverEngine # noqa: E402
+from TDSender import TDSenderEngine # noqa: E402
+
+
+# Re-export protocol constants for backward compatibility (tests import these from here)
+__all__ = [
+ "CUDAIPCExtension",
+ "FORMAT_KIND_FLOAT",
+ "FORMAT_KIND_SIGNED",
+ "FORMAT_KIND_UNSIGNED",
+ "PROTOCOL_MAGIC",
+ "SLOT_SIZE",
+ "SHM_HEADER_SIZE",
+ "FLAGS_BFLOAT16",
+]
+
+# CuPy deferred import flag (tests may patch this)
+CUPY_AVAILABLE: bool = False
+cp = None
+
+
+class CUDAIPCExtension:
+ """TouchDesigner extension facade for dual-mode CUDA IPC texture sharing.
+
+ Delegates all Sender work to TDSenderEngine and all Receiver work to
+ TDReceiverEngine. Mode switches tear down the old engine and create a fresh
+ one β guaranteeing zero cross-mode state leak.
+
+ Public API is unchanged from v1.x so existing .tox callback templates continue
+ to work without modification.
+ """
+
+ def __init__(
+ self,
+ ownerComp: COMP,
+ host: TDHost | None = None,
+ config: TDSenderConfig | None = None,
+ ) -> None:
+ self.ownerComp = ownerComp
+ self._host: TDHost = host if host is not None else RealTDHost(ownerComp)
+ self._config: TDSenderConfig = config if config is not None else TDSenderConfig.from_env()
+
+ _mode_val = self._host.param_value("Mode")
+ self._mode: str = str(_mode_val) if _mode_val is not None else "Sender"
+
+ # Read construction params once (engine uses them at build time)
+ _slots_val = self._host.param_value("Numslots")
+ try:
+ self._num_slots: int = int(_slots_val) if _slots_val is not None else 3
+ except (ValueError, TypeError):
+ self._num_slots = 3
+
+ _dev_val = self._host.param_value("Cudadevice")
+ try:
+ self._device: int = int(_dev_val) if _dev_val is not None else 0
+ except (ValueError, TypeError):
+ self._device = 0
+
+ _shm_val = self._host.param_value("Ipcmemname")
+ self._shm_name: str = str(_shm_val) if _shm_val is not None else "cudalink_output_ipc"
+
+ _debug_val = self._host.param_value("Debug")
+ self._verbose: bool = bool(_debug_val) if _debug_val is not None else False
+ if self._config.export_profile:
+ self._verbose = True
+
+ _hide_val = self._host.param_value("Hidebuiltin")
+ if _hide_val is not None:
+ self._host.show_custom_only(bool(_hide_val))
+
+ self._engine: TDSenderEngine | TDReceiverEngine = self._make_engine()
+
+ self._log(f"Extension initialized on {ownerComp} [Mode: {self._mode}]", force=True)
+
+ if self._mode == "Receiver":
+ self._host.set_param_enabled("Numslots", False)
+
+ # ------------------------------------------------------------------
+ # Engine factory
+ # ------------------------------------------------------------------
+
+ def _make_engine(self) -> TDSenderEngine | TDReceiverEngine:
+ if self._mode == "Sender":
+ return TDSenderEngine(
+ host=self._host,
+ config=self._config,
+ cuda=None,
+ log_fn=self._log,
+ num_slots=self._num_slots,
+ device=self._device,
+ shm_name=self._shm_name,
+ verbose=self._verbose,
+ )
+ return TDReceiverEngine(
+ host=self._host,
+ config=self._config,
+ cuda=None,
+ log_fn=self._log,
+ num_slots=self._num_slots,
+ device=self._device,
+ shm_name=self._shm_name,
+ verbose=self._verbose,
+ )
+
+ # ------------------------------------------------------------------
+ # Logging (faΓ§ade owns this; engine holds a reference to it)
+ # ------------------------------------------------------------------
+
+ def _log(self, msg: str, force: bool = False) -> None:
+ prefix = f"[CUDAIPCExtension:{self._mode}]"
+ if force or self._verbose:
+ print(f"{prefix} {msg}")
+
+ # ------------------------------------------------------------------
+ # Public API β all delegate to engine
+ # ------------------------------------------------------------------
+
+ @property
+ def mode(self) -> str:
+ return self._mode
+
+ def initialize(self, width: int, height: int, channels: int = 4, buffer_size: int | None = None) -> bool:
+ """Delegate to sender engine's initialize() (kept for test injection)."""
+ return self._engine.initialize(width, height, channels, buffer_size)
+
+ def export_frame(self, top_op: TOP | None = None) -> bool:
+ if self._mode != "Sender":
+ return False
+ return self._engine.export_frame(top_op)
+
+ def import_frame(self, import_buffer: TOP) -> bool:
+ if self._mode != "Receiver":
+ return False
+ handle = RealTOPHandle(import_buffer) if import_buffer is not None else None
+ return self._engine.import_frame(handle)
+
+ def _check_deferred_cleanup(self) -> None:
+ if self._mode == "Sender":
+ self._engine._check_deferred_cleanup()
+
+ def update_receiver_resolution(self, import_buffer: TOP) -> None:
+ if self._mode == "Receiver":
+ handle = RealTOPHandle(import_buffer) if import_buffer is not None else None
+ self._engine.update_receiver_resolution(handle)
+
+ def is_active(self) -> bool:
+ """Delegate to host's active-parameter check (hot-path safe)."""
+ return self._host.is_active()
+
+ def initialize_receiver(self) -> bool:
+ """Delegate to receiver engine's initialize_receiver() (backward compat)."""
+ return self._engine.initialize_receiver()
+
+ def cleanup(self) -> None:
+ self._engine.cleanup()
+
+ def __delTD__(self) -> None:
+ self.cleanup()
+
+ def is_ready(self) -> bool:
+ return self._engine.is_ready()
+
+ def get_stats(self) -> dict:
+ return self._engine.get_stats()
+
+ def switch_mode(self, new_mode: str) -> None:
+ if new_mode == self._mode:
+ return
+ self._log(f"Switching mode: {self._mode} -> {new_mode}", force=True)
+ # Tear down old engine (guaranteed no state leak β new engine is a fresh instance)
+ self._engine.cleanup()
+ self._mode = new_mode
+ # When switching to Sender: re-read num_slots from UI (receiver may have updated it)
+ if new_mode == "Sender":
+ _ns = self._host.param_value("Numslots")
+ if _ns is not None:
+ with contextlib.suppress(ValueError, TypeError):
+ self._num_slots = int(_ns)
+ self._engine = self._make_engine()
+ self._host.set_param_enabled("Numslots", new_mode == "Sender")
+ self._log(f"Mode switched to {new_mode}. Will initialize on next frame.", force=True)
+
+ # ------------------------------------------------------------------
+ # Attribute bridges β callbacks in parexecute_callbacks.py write
+ # these directly; properties propagate to the current engine.
+ # ------------------------------------------------------------------
+
+ @property
+ def shm_name(self) -> str:
+ return self._engine.shm_name
+
+ @shm_name.setter
+ def shm_name(self, value: str) -> None:
+ self._shm_name = value
+ self._engine.shm_name = value
+
+ @property
+ def num_slots(self) -> int:
+ return self._engine.num_slots
+
+ @num_slots.setter
+ def num_slots(self, value: int) -> None:
+ self._num_slots = value
+ self._engine.num_slots = value
+
+ @property
+ def verbose_performance(self) -> bool:
+ return self._engine.verbose_performance
+
+ @verbose_performance.setter
+ def verbose_performance(self, value: bool) -> None:
+ self._verbose = value
+ self._engine.verbose_performance = value
+
+ def request_immediate_reconnect(self) -> None:
+ """Force next import_frame to attempt reconnection (called from parexecute callbacks)."""
+ if self._mode == "Receiver":
+ self._engine.request_immediate_reconnect()
+
+ def consume_pending_resolution(self) -> tuple | None:
+ """Return (width, height) if resolution update is pending, else None.
+
+ Called from script_top_callbacks.onCook to drive ImportBuffer Script TOP par updates.
+ """
+ if self._mode == "Receiver":
+ return self._engine.consume_pending_resolution()
+ return None
diff --git a/src/streamdiffusion/_compat/td_exporter/CUDAIPCWrapper.py b/src/streamdiffusion/_compat/td_exporter/CUDAIPCWrapper.py
new file mode 100644
index 000000000..6d4cfad5a
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/CUDAIPCWrapper.py
@@ -0,0 +1,1088 @@
+"""
+CUDA IPC Wrapper for Windows
+Based on vLLM cuda_wrapper.py pattern
+
+Provides ctypes interface to CUDA Runtime API for inter-process communication.
+Compatible with both TouchDesigner and Python processes.
+
+Requirements:
+- CUDA 12.x runtime (cudart64_12.dll)
+- Windows operating system
+- Same GPU visible to both processes
+"""
+
+from __future__ import annotations
+
+import ctypes
+import logging
+import os
+from ctypes import POINTER, byref, c_float, c_int, c_size_t, c_uint, c_uint64, c_void_p
+
+
+_logger = logging.getLogger(__name__)
+
+try:
+ from cuda_link.cuda_runtime_types import ( # noqa: E402
+ CUDAError,
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaIpcEventHandle_t,
+ cudaIpcMemHandle_t,
+ cudaMemcpy3DParms,
+ cudaPointerAttributes,
+ )
+except ImportError:
+ from CUDARuntimeTypes import ( # type: ignore[no-redef] # noqa: E402
+ CUDAError,
+ CUDAEvent_t,
+ CUDAGraph_t,
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ CUDAStream_t,
+ cudaIpcEventHandle_t,
+ cudaIpcMemHandle_t,
+ cudaMemcpy3DParms,
+ cudaPointerAttributes,
+ )
+
+try:
+ from cuda_link.cuda_graphs import CUDAGraphsMixin # noqa: E402
+except ImportError:
+ from CUDAGraphs import CUDAGraphsMixin # type: ignore[no-redef] # noqa: E402
+
+
+class CUDARuntimeAPI(CUDAGraphsMixin):
+ """CUDA Runtime API wrapper using ctypes.
+
+ Provides access to CUDA IPC functions for zero-copy GPU memory
+ sharing between processes.
+
+ Usage:
+ cuda = CUDARuntimeAPI()
+
+ # Allocate GPU memory
+ dev_ptr = cuda.malloc(buffer_size)
+
+ # Export IPC handle (sender process)
+ handle = cuda.ipc_get_mem_handle(dev_ptr)
+
+ # Import IPC handle (receiver process)
+ imported_ptr = cuda.ipc_open_mem_handle(handle)
+
+ # Use memory...
+
+ # Close handle (receiver)
+ cuda.ipc_close_mem_handle(imported_ptr)
+
+ # Free memory (sender)
+ cuda.free(dev_ptr)
+ """
+
+ def __init__(self, device: int = 0) -> None:
+ """Initialize CUDA runtime library.
+
+ Args:
+ device: CUDA device index to bind. Defaults to 0.
+ IPC handles are device-scoped; sender and receiver must
+ use the same device or peer-access must be enabled.
+ """
+ self.device = device
+ self.cudart = self._load_cuda_runtime()
+ self._setup_function_signatures()
+ # Establish CUDA primary context on the requested device.
+ # Prevents cudaIpcOpenMemHandle error 400 when a second cudart DLL is loaded
+ # alongside torch (which has its own bundled cudart). Each DLL instance needs
+ # its own context initialized before IPC handle operations can succeed.
+ self.cudart.cudaSetDevice(device)
+
+ if os.environ.get("CUDA_LAUNCH_BLOCKING") == "1":
+ _logger.warning(
+ "CUDA_LAUNCH_BLOCKING=1 is set β all CUDA operations are serialized. "
+ "This causes ~30x slower frame rates and should only be used for debugging."
+ )
+
+ # Default ON; set CUDALINK_STICKY_ERROR_CHECK=0 to skip the cudaPeekAtLastError call.
+ self._sticky_check_enabled: bool = os.environ.get("CUDALINK_STICKY_ERROR_CHECK", "1") != "0"
+
+ def _load_cuda_runtime(self) -> ctypes.CDLL:
+ """Load CUDA runtime DLL.
+
+ Returns:
+ ctypes.CDLL: Loaded CUDA runtime library
+
+ Raises:
+ RuntimeError: If CUDA runtime cannot be loaded
+ """
+ # Try by name FIRST: if cudart is already loaded in this process (e.g., by
+ # torch), Windows returns the cached handle β ensuring we share the same
+ # runtime instance and CUDA context. Loading by full path can create a second
+ # independent instance with its own state, breaking cross-process IPC.
+ # cudart64_110.dll is preferred for bisect testing (W1): reverts the 12.x
+ # preference introduced in 4695d8f to test whether cudart64_12 ABI is the
+ # driver-error amplifier on WDDM.
+ dll_names = ["cudart64_110.dll", "cudart64_12.dll", "cudart64_11.dll"]
+ for name in dll_names:
+ try:
+ dll = ctypes.CDLL(name)
+ self._log_dll_path(dll, name)
+ return dll
+ except OSError:
+ continue
+
+ # Fallback: try full toolkit paths when not already in PATH
+ dll_paths = [
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin\cudart64_12.dll",
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cudart64_12.dll",
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cudart64_12.dll",
+ r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.0\bin\cudart64_12.dll",
+ ]
+ for dll_path in dll_paths:
+ if os.path.exists(dll_path):
+ try:
+ dll = ctypes.CDLL(dll_path)
+ self._log_dll_path(dll, dll_path)
+ return dll
+ except OSError:
+ continue
+
+ raise RuntimeError(
+ "Could not load CUDA runtime. Please ensure CUDA 12.x is installed.\n"
+ f"Tried names: {dll_names}\n"
+ f"Tried paths: {dll_paths}"
+ )
+
+ @staticmethod
+ def _log_dll_path(dll: ctypes.CDLL, hint: str) -> None:
+ """Log the resolved filesystem path of a loaded DLL (Windows only)."""
+ try:
+ buf = ctypes.create_unicode_buffer(260)
+ # GetModuleFileNameW needs HMODULE as c_void_p to avoid 32-bit overflow
+ ctypes.windll.kernel32.GetModuleFileNameW(ctypes.c_void_p(dll._handle), buf, 260)
+ _logger.debug("Loaded CUDA runtime: %s", buf.value)
+ except (OSError, AttributeError) as e:
+ _logger.debug("Could not log DLL path: %s", e)
+
+ def _setup_function_signatures(self) -> None:
+ """Define function signatures for CUDA runtime functions."""
+ # cudaMalloc(void** devPtr, size_t size)
+ self.cudart.cudaMalloc.argtypes = [POINTER(c_void_p), c_size_t]
+ self.cudart.cudaMalloc.restype = c_int
+
+ # cudaFree(void* devPtr)
+ self.cudart.cudaFree.argtypes = [c_void_p]
+ self.cudart.cudaFree.restype = c_int
+
+ # cudaMallocHost(void** ptr, size_t size) β allocate pinned (page-locked) host memory
+ self.cudart.cudaMallocHost.argtypes = [POINTER(c_void_p), c_size_t]
+ self.cudart.cudaMallocHost.restype = c_int
+
+ # cudaFreeHost(void* ptr) β free pinned host memory
+ self.cudart.cudaFreeHost.argtypes = [c_void_p]
+ self.cudart.cudaFreeHost.restype = c_int
+
+ # cudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind)
+ self.cudart.cudaMemcpy.argtypes = [c_void_p, c_void_p, c_size_t, c_int]
+ self.cudart.cudaMemcpy.restype = c_int
+
+ # cudaIpcGetMemHandle(cudaIpcMemHandle_t* handle, void* devPtr)
+ self.cudart.cudaIpcGetMemHandle.argtypes = [
+ POINTER(cudaIpcMemHandle_t),
+ c_void_p,
+ ]
+ self.cudart.cudaIpcGetMemHandle.restype = c_int
+
+ # cudaIpcOpenMemHandle(void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags)
+ self.cudart.cudaIpcOpenMemHandle.argtypes = [
+ POINTER(c_void_p),
+ cudaIpcMemHandle_t,
+ c_uint,
+ ]
+ self.cudart.cudaIpcOpenMemHandle.restype = c_int
+
+ # cudaIpcCloseMemHandle(void* devPtr)
+ self.cudart.cudaIpcCloseMemHandle.argtypes = [c_void_p]
+ self.cudart.cudaIpcCloseMemHandle.restype = c_int
+
+ # cudaIpcGetEventHandle(cudaIpcEventHandle_t* handle, cudaEvent_t event)
+ self.cudart.cudaIpcGetEventHandle.argtypes = [
+ POINTER(cudaIpcEventHandle_t),
+ CUDAEvent_t,
+ ]
+ self.cudart.cudaIpcGetEventHandle.restype = c_int
+
+ # cudaIpcOpenEventHandle(cudaEvent_t* event, cudaIpcEventHandle_t handle)
+ self.cudart.cudaIpcOpenEventHandle.argtypes = [
+ POINTER(CUDAEvent_t),
+ cudaIpcEventHandle_t,
+ ]
+ self.cudart.cudaIpcOpenEventHandle.restype = c_int
+
+ # cudaEventCreateWithFlags(cudaEvent_t* event, unsigned int flags)
+ self.cudart.cudaEventCreateWithFlags.argtypes = [POINTER(CUDAEvent_t), c_uint]
+ self.cudart.cudaEventCreateWithFlags.restype = c_int
+
+ # cudaEventRecord(cudaEvent_t event, cudaStream_t stream)
+ self.cudart.cudaEventRecord.argtypes = [CUDAEvent_t, CUDAStream_t]
+ self.cudart.cudaEventRecord.restype = c_int
+
+ # cudaEventQuery(cudaEvent_t event)
+ self.cudart.cudaEventQuery.argtypes = [CUDAEvent_t]
+ self.cudart.cudaEventQuery.restype = c_int
+
+ # cudaEventSynchronize(cudaEvent_t event)
+ self.cudart.cudaEventSynchronize.argtypes = [CUDAEvent_t]
+ self.cudart.cudaEventSynchronize.restype = c_int
+
+ # cudaEventDestroy(cudaEvent_t event)
+ self.cudart.cudaEventDestroy.argtypes = [CUDAEvent_t]
+ self.cudart.cudaEventDestroy.restype = c_int
+
+ # cudaEventElapsedTime(float* ms, cudaEvent_t start, cudaEvent_t end)
+ self.cudart.cudaEventElapsedTime.argtypes = [POINTER(c_float), CUDAEvent_t, CUDAEvent_t]
+ self.cudart.cudaEventElapsedTime.restype = c_int
+
+ # cudaDeviceSynchronize()
+ self.cudart.cudaDeviceSynchronize.argtypes = []
+ self.cudart.cudaDeviceSynchronize.restype = c_int
+
+ # cudaGetLastError()
+ self.cudart.cudaGetLastError.argtypes = []
+ self.cudart.cudaGetLastError.restype = c_int
+
+ # cudaPeekAtLastError() β non-destructive sticky-error read (does NOT clear the error)
+ self.cudart.cudaPeekAtLastError.argtypes = []
+ self.cudart.cudaPeekAtLastError.restype = c_int
+
+ # cudaHostRegister(void* ptr, size_t size, unsigned int flags) β page-lock existing host memory
+ self.cudart.cudaHostRegister.argtypes = [c_void_p, c_size_t, c_uint]
+ self.cudart.cudaHostRegister.restype = c_int
+
+ # cudaHostUnregister(void* ptr) β unregister page-locked host memory
+ self.cudart.cudaHostUnregister.argtypes = [c_void_p]
+ self.cudart.cudaHostUnregister.restype = c_int
+
+ # cudaGetErrorString(cudaError_t error)
+ self.cudart.cudaGetErrorString.argtypes = [c_int]
+ self.cudart.cudaGetErrorString.restype = ctypes.c_char_p
+
+ # cudaStreamCreateWithFlags(cudaStream_t* pStream, unsigned int flags)
+ self.cudart.cudaStreamCreateWithFlags.argtypes = [POINTER(CUDAStream_t), c_uint]
+ self.cudart.cudaStreamCreateWithFlags.restype = c_int
+
+ # cudaStreamDestroy(cudaStream_t stream)
+ self.cudart.cudaStreamDestroy.argtypes = [CUDAStream_t]
+ self.cudart.cudaStreamDestroy.restype = c_int
+
+ # cudaStreamWaitEvent(cudaStream_t stream, cudaEvent_t event, unsigned int flags)
+ self.cudart.cudaStreamWaitEvent.argtypes = [CUDAStream_t, CUDAEvent_t, c_uint]
+ self.cudart.cudaStreamWaitEvent.restype = c_int
+
+ # cudaStreamSynchronize(cudaStream_t stream)
+ self.cudart.cudaStreamSynchronize.argtypes = [CUDAStream_t]
+ self.cudart.cudaStreamSynchronize.restype = c_int
+
+ # cudaMemcpyAsync(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream)
+ self.cudart.cudaMemcpyAsync.argtypes = [c_void_p, c_void_p, c_size_t, c_int, CUDAStream_t]
+ self.cudart.cudaMemcpyAsync.restype = c_int
+
+ # cudaMemGetInfo(size_t* free, size_t* total)
+ self.cudart.cudaMemGetInfo.argtypes = [POINTER(c_size_t), POINTER(c_size_t)]
+ self.cudart.cudaMemGetInfo.restype = c_int
+
+ # cudaSetDevice(int device)
+ self.cudart.cudaSetDevice.argtypes = [c_int]
+ self.cudart.cudaSetDevice.restype = c_int
+
+ # cudaGetDevice(int* device)
+ self.cudart.cudaGetDevice.argtypes = [POINTER(c_int)]
+ self.cudart.cudaGetDevice.restype = c_int
+
+ # cudaStreamQuery(cudaStream_t stream)
+ self.cudart.cudaStreamQuery.argtypes = [CUDAStream_t]
+ self.cudart.cudaStreamQuery.restype = c_int
+
+ # cudaDeviceCanAccessPeer(int* canAccessPeer, int device, int peerDevice)
+ self.cudart.cudaDeviceCanAccessPeer.argtypes = [POINTER(c_int), c_int, c_int]
+ self.cudart.cudaDeviceCanAccessPeer.restype = c_int
+
+ # cudaDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPriority)
+ self.cudart.cudaDeviceGetStreamPriorityRange.argtypes = [POINTER(c_int), POINTER(c_int)]
+ self.cudart.cudaDeviceGetStreamPriorityRange.restype = c_int
+
+ # cudaStreamCreateWithPriority(cudaStream_t* pStream, unsigned int flags, int priority)
+ self.cudart.cudaStreamCreateWithPriority.argtypes = [POINTER(CUDAStream_t), c_uint, c_int]
+ self.cudart.cudaStreamCreateWithPriority.restype = c_int
+
+ # cudaPointerGetAttributes(cudaPointerAttributes* attributes, const void* ptr)
+ self.cudart.cudaPointerGetAttributes.argtypes = [POINTER(cudaPointerAttributes), c_void_p]
+ self.cudart.cudaPointerGetAttributes.restype = c_int
+
+ # === G1: non-graph helpers (re-enabled Phase 1.1) ===
+ # cudaHostAlloc(void** ptr, size_t size, unsigned int flags)
+ # Replaces cudaMallocHost with explicit flag control.
+ # cudaHostAllocPortable = 0x01 β accessible from any CUDA context in process
+ # cudaHostAllocMapped = 0x02 β map into device address space
+ # cudaHostAllocWriteCombined = 0x04 β write-combined (fast CPU writes, slow CPU reads)
+ self.cudart.cudaHostAlloc.argtypes = [POINTER(c_void_p), c_size_t, c_uint]
+ self.cudart.cudaHostAlloc.restype = c_int
+
+ # cudaDeviceGetAttribute(int* value, cudaDeviceAttr attr, int device)
+ # Used to query cudaDevAttrAsyncEngineCount (attr=4) β how many DMA copy engines exist.
+ self.cudart.cudaDeviceGetAttribute.argtypes = [POINTER(c_int), c_int, c_int]
+ self.cudart.cudaDeviceGetAttribute.restype = c_int
+
+ # === G2: graph lifecycle (re-enabled Phase 1.2) ===
+ # CUDA 10.0+ graph capture/build/launch/teardown + runtime-version gate.
+
+ # cudaStreamBeginCapture(cudaStream_t stream, cudaStreamCaptureMode mode)
+ # mode: 0=global, 1=thread_local, 2=relaxed
+ self.cudart.cudaStreamBeginCapture.argtypes = [CUDAStream_t, c_int]
+ self.cudart.cudaStreamBeginCapture.restype = c_int
+
+ # cudaStreamEndCapture(cudaStream_t stream, cudaGraph_t* pGraph)
+ self.cudart.cudaStreamEndCapture.argtypes = [CUDAStream_t, POINTER(CUDAGraph_t)]
+ self.cudart.cudaStreamEndCapture.restype = c_int
+
+ # cudaGraphInstantiateWithFlags(cudaGraphExec_t* pGraphExec, cudaGraph_t graph,
+ # unsigned long long flags) [CUDA 11.4+ stable 3-arg form]
+ # Prefer this over cudaGraphInstantiate on any cudart 11.x: the latter changed
+ # from 5-arg (CUDA 10.0β11.8) to 3-arg (CUDA 12.0+) β calling the 12.0 3-arg
+ # binding against an 11.x DLL mismatches the ABI and crashes (WDDM access
+ # violation). cudaGraphInstantiateWithFlags has had a stable 3-arg signature
+ # since 11.4 and is available in all 12.x releases as well.
+ self.cudart.cudaGraphInstantiateWithFlags.argtypes = [POINTER(CUDAGraphExec_t), CUDAGraph_t, c_uint64]
+ self.cudart.cudaGraphInstantiateWithFlags.restype = c_int
+
+ # cudaGraphLaunch(cudaGraphExec_t graphExec, cudaStream_t stream)
+ self.cudart.cudaGraphLaunch.argtypes = [CUDAGraphExec_t, CUDAStream_t]
+ self.cudart.cudaGraphLaunch.restype = c_int
+
+ # cudaGraphDestroy(cudaGraph_t graph)
+ self.cudart.cudaGraphDestroy.argtypes = [CUDAGraph_t]
+ self.cudart.cudaGraphDestroy.restype = c_int
+
+ # cudaGraphExecDestroy(cudaGraphExec_t graphExec)
+ self.cudart.cudaGraphExecDestroy.argtypes = [CUDAGraphExec_t]
+ self.cudart.cudaGraphExecDestroy.restype = c_int
+
+ # cudaGraphGetNodes(cudaGraph_t graph, cudaGraphNode_t* nodes, size_t* numNodes)
+ # Pass nodes=NULL to query count; then call again with allocated array.
+ self.cudart.cudaGraphGetNodes.argtypes = [CUDAGraph_t, POINTER(CUDAGraphNode_t), POINTER(c_size_t)]
+ self.cudart.cudaGraphGetNodes.restype = c_int
+
+ # cudaRuntimeGetVersion(int* runtimeVersion)
+ # Returns the version as int (e.g., 11040 = CUDA 11.4, 12080 = CUDA 12.8).
+ # Used to gate optional API calls (e.g., cudaGraphExecMemcpyNodeSetParams1D
+ # requires 11.3+) when the loaded cudart DLL may be a 11.0.x patch.
+ self.cudart.cudaRuntimeGetVersion.argtypes = [POINTER(c_int)]
+ self.cudart.cudaRuntimeGetVersion.restype = c_int
+
+ # === G3: graph node setters (re-enabled Phase 1.3) ===
+ # Per-frame in-place node update for ring-slot remap. Most CUDA-12-flavoured
+ # of the 14 (NodeSetParams1D 11.3+; event-node setters 11.4+).
+
+ # cudaGraphExecMemcpyNodeSetParams(cudaGraphExec_t, cudaGraphNode_t,
+ # const cudaMemcpy3DParms*)
+ # Updates a 3D-captured memcpy node. For nodes captured from cudaMemcpyAsync
+ # (1D form) use cudaGraphExecMemcpyNodeSetParams1D instead.
+ self.cudart.cudaGraphExecMemcpyNodeSetParams.argtypes = [
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ POINTER(cudaMemcpy3DParms),
+ ]
+ self.cudart.cudaGraphExecMemcpyNodeSetParams.restype = c_int
+
+ # cudaGraphExecMemcpyNodeSetParams1D(cudaGraphExec_t, cudaGraphNode_t,
+ # void* dst, const void* src,
+ # size_t count, cudaMemcpyKind kind)
+ # Updates a 1D memcpy node (captured from cudaMemcpyAsync). CUDA 11.3+.
+ self.cudart.cudaGraphExecMemcpyNodeSetParams1D.argtypes = [
+ CUDAGraphExec_t,
+ CUDAGraphNode_t,
+ c_void_p,
+ c_void_p,
+ c_size_t,
+ c_int,
+ ]
+ self.cudart.cudaGraphExecMemcpyNodeSetParams1D.restype = c_int
+
+ # cudaGraphExecEventRecordNodeSetEvent(cudaGraphExec_t, cudaGraphNode_t,
+ # cudaEvent_t event)
+ # Updates the event recorded by an event-record node. CUDA 11.4+.
+ self.cudart.cudaGraphExecEventRecordNodeSetEvent.argtypes = [CUDAGraphExec_t, CUDAGraphNode_t, CUDAEvent_t]
+ self.cudart.cudaGraphExecEventRecordNodeSetEvent.restype = c_int
+
+ # cudaGraphExecEventWaitNodeSetEvent(cudaGraphExec_t, cudaGraphNode_t,
+ # cudaEvent_t event)
+ # Updates the event waited on by an event-wait node. CUDA 11.4+.
+ self.cudart.cudaGraphExecEventWaitNodeSetEvent.argtypes = [CUDAGraphExec_t, CUDAGraphNode_t, CUDAEvent_t]
+ self.cudart.cudaGraphExecEventWaitNodeSetEvent.restype = c_int
+
+ def check_error(self, result: int, operation: str) -> None:
+ """Check CUDA error code and raise exception if failed.
+
+ Args:
+ result: CUDA error code
+ operation: Description of the operation that failed
+
+ Raises:
+ RuntimeError: If result indicates an error
+ """
+ if result != CUDAError.SUCCESS:
+ error_str = self.cudart.cudaGetErrorString(result).decode("utf-8")
+ error_name = CUDAError.get_name(result)
+ raise RuntimeError(f"CUDA {operation} failed: {error_str} (error {result}: {error_name})")
+
+ def peek_at_last_error(self) -> int:
+ """Non-destructively read the thread-local sticky CUDA error.
+
+ Returns SUCCESS (0) normally. A non-zero value means a prior async
+ operation (memcpy, kernel) failed and the error was not yet consumed.
+ Unlike cudaGetLastError this does NOT clear the latched error state.
+ """
+ return int(self.cudart.cudaPeekAtLastError())
+
+ def check_sticky_error(self, context: str) -> None:
+ """Warn and raise if a sticky CUDA error is latched from a prior async op.
+
+ No-op when CUDALINK_STICKY_ERROR_CHECK=0. Enabled by default.
+ Use peek_at_last_error() directly for the raw value without raising.
+ """
+ if not self._sticky_check_enabled:
+ return
+ code = int(self.cudart.cudaPeekAtLastError())
+ if code != CUDAError.SUCCESS:
+ error_str = self.cudart.cudaGetErrorString(code).decode("utf-8")
+ _logger.warning(
+ "Sticky CUDA error detected after %s: %s (code %d). "
+ "The CUDA context is poisoned β restart the process. "
+ "Set CUDALINK_STICKY_ERROR_CHECK=0 to disable this check.",
+ context,
+ error_str,
+ code,
+ )
+ raise RuntimeError(
+ f"Sticky CUDA error after {context}: {error_str} (code {code}). "
+ "The CUDA context is poisoned. Restart the process or set "
+ "CUDALINK_STICKY_ERROR_CHECK=0 to disable this check."
+ )
+
+ def host_register(self, ptr: int, size: int, flags: int = 0) -> None:
+ """Page-lock an existing host allocation via cudaHostRegister.
+
+ Args:
+ ptr: Host pointer as integer (e.g., arr.ctypes.data)
+ size: Number of bytes to register
+ flags: Registration flags (0=default, 1=portable, 2=mapped, 4=write-combined)
+
+ Raises:
+ RuntimeError: If registration fails
+ """
+ result = self.cudart.cudaHostRegister(c_void_p(ptr), c_size_t(size), c_uint(flags))
+ self.check_error(result, "cudaHostRegister")
+
+ def host_unregister(self, ptr: int) -> None:
+ """Unregister a page-locked host allocation registered with host_register().
+
+ Args:
+ ptr: Host pointer as integer (same value passed to host_register())
+
+ Raises:
+ RuntimeError: If unregistration fails
+ """
+ result = self.cudart.cudaHostUnregister(c_void_p(ptr))
+ self.check_error(result, "cudaHostUnregister")
+
+ # High-level API
+
+ def malloc(self, size: int) -> c_void_p:
+ """Allocate GPU memory.
+
+ Args:
+ size: Number of bytes to allocate
+
+ Returns:
+ Device pointer to allocated memory
+
+ Raises:
+ RuntimeError: If allocation fails
+ """
+ dev_ptr = c_void_p()
+ result = self.cudart.cudaMalloc(byref(dev_ptr), size)
+ self.check_error(result, "cudaMalloc")
+ return dev_ptr
+
+ def free(self, dev_ptr: c_void_p) -> None:
+ """Free GPU memory.
+
+ Args:
+ dev_ptr: Device pointer to free
+
+ Raises:
+ RuntimeError: If free fails
+ """
+ result = self.cudart.cudaFree(dev_ptr)
+ self.check_error(result, "cudaFree")
+
+ def malloc_host(self, size: int) -> c_void_p:
+ """Allocate pinned (page-locked) host memory via cudaMallocHost.
+
+ Pinned memory enables direct DMA for D2H transfers, eliminating the
+ CUDA driver's internal staging copy that pageable memory requires.
+
+ Note: this project is single-GPU by construction (get_cuda_runtime rejects
+ a second device). Multi-GPU would require cudaHostAlloc with
+ cudaHostAllocPortable for cross-device visibility (Handbook Β§5.1).
+
+ Args:
+ size: Number of bytes to allocate
+
+ Returns:
+ Host pointer to pinned memory
+
+ Raises:
+ RuntimeError: If allocation fails
+ """
+ ptr = c_void_p()
+ result = self.cudart.cudaMallocHost(byref(ptr), size)
+ self.check_error(result, "cudaMallocHost")
+ return ptr
+
+ def free_host(self, ptr: c_void_p) -> None:
+ """Free pinned host memory allocated with malloc_host().
+
+ Args:
+ ptr: Host pointer to free
+
+ Raises:
+ RuntimeError: If free fails
+ """
+ result = self.cudart.cudaFreeHost(ptr)
+ self.check_error(result, "cudaFreeHost")
+
+ def memcpy(self, dst: c_void_p, src: c_void_p, count: int, kind: int) -> None:
+ """Copy memory (device-to-device, host-to-device, or device-to-host).
+
+ Args:
+ dst: Destination pointer
+ src: Source pointer
+ count: Number of bytes to copy
+ kind: cudaMemcpyKind (0=H2H, 1=H2D, 2=D2H, 3=D2D)
+
+ Raises:
+ RuntimeError: If copy fails
+ """
+ result = self.cudart.cudaMemcpy(dst, src, count, kind)
+ self.check_error(result, "cudaMemcpy")
+
+ def ipc_get_mem_handle(self, dev_ptr: c_void_p) -> cudaIpcMemHandle_t:
+ """Get IPC handle for GPU memory.
+
+ This handle can be transferred to another process via SharedMemory
+ or other IPC mechanism.
+
+ Args:
+ dev_ptr: Device pointer to export
+
+ Returns:
+ IPC handle (128 bytes)
+
+ Raises:
+ RuntimeError: If export fails
+ """
+ handle = cudaIpcMemHandle_t()
+ result = self.cudart.cudaIpcGetMemHandle(byref(handle), dev_ptr)
+ self.check_error(result, "cudaIpcGetMemHandle")
+ return handle
+
+ def ipc_open_mem_handle(self, handle: cudaIpcMemHandle_t, flags: int = 1) -> c_void_p:
+ """Open IPC handle to access GPU memory from another process.
+
+ Args:
+ handle: IPC handle received from another process
+ flags: IPC flags (1 = cudaIpcMemLazyEnablePeerAccess)
+
+ Returns:
+ Device pointer to shared memory
+
+ Raises:
+ RuntimeError: If opening fails
+ """
+ dev_ptr = c_void_p()
+ result = self.cudart.cudaIpcOpenMemHandle(byref(dev_ptr), handle, flags)
+ self.check_error(result, "cudaIpcOpenMemHandle")
+ return dev_ptr
+
+ def ipc_close_mem_handle(self, dev_ptr: c_void_p) -> None:
+ """Close IPC memory handle.
+
+ Args:
+ dev_ptr: Device pointer obtained from ipc_open_mem_handle()
+
+ Raises:
+ RuntimeError: If closing fails
+ """
+ result = self.cudart.cudaIpcCloseMemHandle(dev_ptr)
+ self.check_error(result, "cudaIpcCloseMemHandle")
+
+ def synchronize(self) -> None:
+ """Synchronize all CUDA operations on current device.
+
+ Raises:
+ RuntimeError: If synchronization fails
+ """
+ result = self.cudart.cudaDeviceSynchronize()
+ self.check_error(result, "cudaDeviceSynchronize")
+
+ # CUDA Event API (for async synchronization)
+
+ def create_ipc_event(self) -> CUDAEvent_t:
+ """Create CUDA event suitable for IPC (interprocess communication).
+
+ Returns:
+ Event handle for cross-process synchronization
+
+ Raises:
+ RuntimeError: If event creation fails
+ """
+ event = CUDAEvent_t()
+ # cudaEventInterprocess (4) | cudaEventDisableTiming (2) = 6
+ # NVIDIA requires cudaEventDisableTiming when using cudaEventInterprocess
+ result = self.cudart.cudaEventCreateWithFlags(byref(event), 6)
+ self.check_error(result, "cudaEventCreateWithFlags")
+ return event
+
+ def record_event(self, event: CUDAEvent_t, stream: CUDAStream_t | None = None) -> None:
+ """Record event on specified stream (or default stream).
+
+ Args:
+ event: Event handle to record
+ stream: CUDA stream (None = default stream)
+
+ Raises:
+ RuntimeError: If event recording fails
+ """
+ # Convert None to CUDA default stream (0) for ctypes compatibility
+ if stream is None:
+ stream = CUDAStream_t(0)
+ result = self.cudart.cudaEventRecord(event, stream)
+ self.check_error(result, "cudaEventRecord")
+
+ def query_event(self, event: c_void_p) -> bool:
+ """Query if event has completed (non-blocking).
+
+ Args:
+ event: Event handle to query
+
+ Returns:
+ True if event completed, False if still pending
+
+ Raises:
+ RuntimeError: If query fails with unexpected error
+ """
+ result = self.cudart.cudaEventQuery(event)
+ if result == CUDAError.SUCCESS:
+ return True
+ elif result == CUDAError.NOT_READY:
+ return False
+ self.check_error(result, "cudaEventQuery")
+ return False
+
+ def wait_event(self, event: CUDAEvent_t) -> None:
+ """Wait for event to complete (blocking).
+
+ Args:
+ event: Event handle to wait on
+
+ Raises:
+ RuntimeError: If wait fails
+ """
+ result = self.cudart.cudaEventSynchronize(event)
+ self.check_error(result, "cudaEventSynchronize")
+
+ def ipc_get_event_handle(self, event: CUDAEvent_t) -> cudaIpcEventHandle_t:
+ """Get IPC handle for event (for cross-process signaling).
+
+ Args:
+ event: Event created with create_ipc_event()
+
+ Returns:
+ IPC event handle (64 bytes)
+
+ Raises:
+ RuntimeError: If export fails
+ """
+ handle = cudaIpcEventHandle_t()
+ result = self.cudart.cudaIpcGetEventHandle(byref(handle), event)
+ self.check_error(result, "cudaIpcGetEventHandle")
+ return handle
+
+ def ipc_open_event_handle(self, handle: cudaIpcEventHandle_t) -> CUDAEvent_t:
+ """Open IPC event handle from another process.
+
+ Args:
+ handle: IPC event handle received from another process
+
+ Returns:
+ Event handle for this process
+
+ Raises:
+ RuntimeError: If opening fails
+ """
+ event = CUDAEvent_t()
+ result = self.cudart.cudaIpcOpenEventHandle(byref(event), handle)
+ self.check_error(result, "cudaIpcOpenEventHandle")
+ return event
+
+ def destroy_event(self, event: CUDAEvent_t) -> None:
+ """Destroy CUDA event.
+
+ Args:
+ event: Event handle to destroy
+
+ Raises:
+ RuntimeError: If destruction fails
+ """
+ result = self.cudart.cudaEventDestroy(event)
+ self.check_error(result, "cudaEventDestroy")
+
+ def create_timing_event(self) -> CUDAEvent_t:
+ """Create CUDA event suitable for GPU timing (NOT for IPC).
+
+ Returns:
+ Event handle for GPU-accurate timing measurements
+
+ Raises:
+ RuntimeError: If event creation fails
+
+ Note:
+ This creates an event with timing enabled (flags=0).
+ Use this for benchmarking, NOT for IPC synchronization.
+ IPC events require cudaEventDisableTiming flag.
+ """
+ event = CUDAEvent_t()
+ # flags=0 enables timing (no cudaEventDisableTiming, no cudaEventInterprocess)
+ result = self.cudart.cudaEventCreateWithFlags(byref(event), 0)
+ self.check_error(result, "cudaEventCreateWithFlags(timing)")
+ return event
+
+ def create_sync_event(self) -> CUDAEvent_t:
+ """Create CUDA event optimized for stream ordering (NOT timing, NOT IPC).
+
+ Returns:
+ Event handle for use with stream_wait_event() ordering
+
+ Raises:
+ RuntimeError: If event creation fails
+
+ Note:
+ Uses cudaEventDisableTiming (0x02). Per NVIDIA docs this provides
+ best performance when used with cudaStreamWaitEvent() and
+ cudaEventQuery() β removes per-record timing instrumentation overhead.
+ Do not use with event_elapsed_time(); use create_timing_event() for that.
+ """
+ event = CUDAEvent_t()
+ # cudaEventDisableTiming = 0x02 β optimal for ordering-only events
+ result = self.cudart.cudaEventCreateWithFlags(byref(event), 0x02)
+ self.check_error(result, "cudaEventCreateWithFlags(sync)")
+ return event
+
+ def event_elapsed_time(self, start: CUDAEvent_t, end: CUDAEvent_t) -> float:
+ """Get elapsed GPU time between two events.
+
+ Args:
+ start: Starting event (must be recorded before end event)
+ end: Ending event
+
+ Returns:
+ Elapsed time in milliseconds (GPU-measured)
+
+ Raises:
+ RuntimeError: If elapsed time query fails
+
+ Note:
+ Both events must have timing enabled (created with create_timing_event).
+ Events with cudaEventDisableTiming flag cannot be used for timing.
+ """
+ elapsed_ms = c_float()
+ result = self.cudart.cudaEventElapsedTime(byref(elapsed_ms), start, end)
+ self.check_error(result, "cudaEventElapsedTime")
+ return elapsed_ms.value
+
+ def get_device(self) -> int:
+ """Return the CUDA device index currently bound to this context.
+
+ Returns:
+ Integer device index (matches self.device if context is healthy)
+
+ Raises:
+ RuntimeError: If query fails
+ """
+ device = c_int()
+ result = self.cudart.cudaGetDevice(byref(device))
+ self.check_error(result, "cudaGetDevice")
+ return device.value
+
+ def create_stream(self, flags: int = 0x01) -> CUDAStream_t:
+ """Create CUDA stream with specified flags.
+
+ Args:
+ flags: Stream creation flags. Default 0x01 = cudaStreamNonBlocking
+
+ Returns:
+ CUDAStream_t: Opaque stream handle
+
+ Raises:
+ RuntimeError: If stream creation fails
+ """
+ stream = CUDAStream_t()
+ result = self.cudart.cudaStreamCreateWithFlags(byref(stream), flags)
+ self.check_error(result, "cudaStreamCreateWithFlags")
+ return stream
+
+ def create_stream_with_priority(self, flags: int = 0x01, priority: int | None = None) -> CUDAStream_t:
+ """Create CUDA stream at the specified (or highest available) priority.
+
+ On CUDA, stream priority is an integer where a smaller value means
+ higher priority. cudaDeviceGetStreamPriorityRange returns [least, greatest]
+ where greatest is the most-negative value β i.e., the highest priority.
+
+ Args:
+ flags: Stream flags. Default 0x01 = cudaStreamNonBlocking.
+ priority: Stream priority. None means use highest available (greatest).
+
+ Returns:
+ CUDAStream_t: Opaque stream handle
+
+ Raises:
+ RuntimeError: If stream creation fails
+ """
+ if priority is None:
+ least = c_int()
+ greatest = c_int()
+ result = self.cudart.cudaDeviceGetStreamPriorityRange(byref(least), byref(greatest))
+ self.check_error(result, "cudaDeviceGetStreamPriorityRange")
+ priority = greatest.value
+ stream = CUDAStream_t()
+ result = self.cudart.cudaStreamCreateWithPriority(byref(stream), flags, priority)
+ self.check_error(result, "cudaStreamCreateWithPriority")
+ return stream
+
+ def destroy_stream(self, stream: CUDAStream_t) -> None:
+ """Destroy CUDA stream.
+
+ Args:
+ stream: Stream handle to destroy
+
+ Raises:
+ RuntimeError: If destruction fails
+ """
+ result = self.cudart.cudaStreamDestroy(stream)
+ self.check_error(result, "cudaStreamDestroy")
+
+ def stream_wait_event(self, stream: CUDAStream_t, event: CUDAEvent_t, flags: int = 0) -> None:
+ """Make stream wait on event (GPU-side, non-blocking to CPU).
+
+ Args:
+ stream: Stream to wait
+ event: Event to wait for
+ flags: Wait flags (default 0)
+
+ Raises:
+ RuntimeError: If wait enqueue fails
+ """
+ result = self.cudart.cudaStreamWaitEvent(stream, event, flags)
+ self.check_error(result, "cudaStreamWaitEvent")
+
+ def stream_synchronize(self, stream: CUDAStream_t) -> None:
+ """Wait for all operations on stream to complete (CPU-blocking).
+
+ Args:
+ stream: Stream to synchronize
+
+ Raises:
+ RuntimeError: If synchronization fails
+ """
+ result = self.cudart.cudaStreamSynchronize(stream)
+ self.check_error(result, "cudaStreamSynchronize")
+
+ def memcpy_async(self, dst: c_void_p, src: c_void_p, count: int, kind: int, stream: CUDAStream_t) -> None:
+ """Asynchronous memory copy on a stream.
+
+ Args:
+ dst: Destination pointer
+ src: Source pointer
+ count: Number of bytes to copy
+ kind: cudaMemcpyKind (0=H2H, 1=H2D, 2=D2H, 3=D2D)
+ stream: CUDA stream for async operation
+
+ Raises:
+ RuntimeError: If async copy enqueue fails
+ """
+ result = self.cudart.cudaMemcpyAsync(dst, src, count, kind, stream)
+ self.check_error(result, "cudaMemcpyAsync")
+
+ def mem_get_info(self) -> tuple[int, int]:
+ """Get free and total device memory in bytes.
+
+ Returns:
+ Tuple of (free_bytes, total_bytes)
+
+ Raises:
+ RuntimeError: If query fails
+ """
+ free = c_size_t()
+ total = c_size_t()
+ result = self.cudart.cudaMemGetInfo(byref(free), byref(total))
+ self.check_error(result, "cudaMemGetInfo")
+ return free.value, total.value
+
+ def stream_query(self, stream: CUDAStream_t) -> bool:
+ """Non-blocking check if all operations on stream have completed.
+
+ Args:
+ stream: CUDA stream to query
+
+ Returns:
+ True if all stream operations have completed, False if still executing
+
+ Raises:
+ RuntimeError: If query fails with an error other than cudaErrorNotReady
+ """
+ result = self.cudart.cudaStreamQuery(stream)
+ if result == CUDAError.SUCCESS:
+ return True
+ if result == CUDAError.NOT_READY:
+ return False
+ self.check_error(result, "cudaStreamQuery")
+ return False # unreachable
+
+ def pointer_get_attributes(self, ptr: int) -> cudaPointerAttributes:
+ """Query memory type and owning device for a GPU pointer.
+
+ Args:
+ ptr: GPU pointer as integer (e.g., tensor.data_ptr())
+
+ Returns:
+ cudaPointerAttributes with .type (2=device, 3=managed) and .device (GPU index)
+
+ Raises:
+ RuntimeError: If query fails (e.g., unregistered host pointer passed)
+ """
+ attrs = cudaPointerAttributes()
+ result = self.cudart.cudaPointerGetAttributes(byref(attrs), c_void_p(ptr))
+ self.check_error(result, "cudaPointerGetAttributes")
+ return attrs
+
+ def device_can_access_peer(self, device: int, peer_device: int) -> bool:
+ """Check if device can directly access peer_device memory via IPC/NVLink.
+
+ Useful for validating multi-GPU setups before attempting IPC handle operations.
+ On single-GPU systems or systems without peer access, cudaIpcOpenMemHandle
+ may fall back to slower paths without warning.
+
+ Args:
+ device: Source device ID
+ peer_device: Target peer device ID
+
+ Returns:
+ True if direct peer access is available, False otherwise
+
+ Raises:
+ RuntimeError: If query fails
+ """
+ can_access = c_int(0)
+ result = self.cudart.cudaDeviceCanAccessPeer(byref(can_access), device, peer_device)
+ self.check_error(result, "cudaDeviceCanAccessPeer")
+ return bool(can_access.value)
+
+ # --- Phase 1: cudaHostAlloc (replaces cudaMallocHost with portable flag) ---
+
+ def malloc_host_alloc(self, size: int, flags: int = 0x01) -> c_void_p:
+ """Allocate pinned host memory via cudaHostAlloc with explicit flags.
+
+ Unlike malloc_host() which calls cudaMallocHost (no flags), this lets
+ callers pass cudaHostAllocPortable (0x01) to make the allocation visible
+ from any CUDA context in the process β useful when PyTorch and CuPy share
+ the same process.
+
+ Args:
+ size: Number of bytes to allocate.
+ flags: OR-combination of:
+ cudaHostAllocPortable = 0x01 (cross-context visibility)
+ cudaHostAllocMapped = 0x02 (map into device address space)
+ cudaHostAllocWriteCombined = 0x04 (WC; fast write, slow CPU read)
+
+ Returns:
+ Host pointer to allocated pinned memory.
+
+ Raises:
+ RuntimeError: If allocation fails.
+ """
+ ptr = c_void_p()
+ result = self.cudart.cudaHostAlloc(byref(ptr), c_size_t(size), c_uint(flags))
+ self.check_error(result, "cudaHostAlloc")
+ return ptr
+
+ # --- Phase 0: device attribute query ---
+
+ def get_device_attribute(self, attr: int, device: int | None = None) -> int:
+ """Query a cudaDeviceAttr value for a given device.
+
+ Common attrs:
+ cudaDevAttrAsyncEngineCount = 4 β number of DMA copy engines
+
+ Args:
+ attr: cudaDeviceAttr integer constant.
+ device: GPU device index. Defaults to self.device.
+
+ Returns:
+ Integer attribute value.
+
+ Raises:
+ RuntimeError: If query fails.
+ """
+ if device is None:
+ device = self.device
+ value = c_int()
+ result = self.cudart.cudaDeviceGetAttribute(byref(value), c_int(attr), c_int(device))
+ self.check_error(result, "cudaDeviceGetAttribute")
+ return value.value
+
+
+# Global singleton instance (lazy initialization)
+_cuda_runtime: CUDARuntimeAPI | None = None
+
+
+def get_cuda_runtime(device: int = 0) -> CUDARuntimeAPI:
+ """Get global CUDA runtime instance (singleton).
+
+ The singleton is created on first call. Subsequent calls with a *different*
+ device index will raise RuntimeError β a single process context can only
+ be bound to one device via this shared-cudart pattern.
+
+ Args:
+ device: CUDA device index (default 0). Must match across all callers
+ within the same process.
+
+ Returns:
+ CUDARuntimeAPI: Global CUDA runtime wrapper
+
+ Raises:
+ RuntimeError: If called with a device index that conflicts with the
+ already-initialized singleton.
+ """
+ global _cuda_runtime
+ if _cuda_runtime is None:
+ _cuda_runtime = CUDARuntimeAPI(device=device)
+ elif _cuda_runtime.device != device:
+ raise RuntimeError(
+ f"CUDA runtime singleton was initialized for device {_cuda_runtime.device}, "
+ f"but caller requested device {device}. A single process can only bind to "
+ "one device via the shared-cudart singleton. Create a separate "
+ "CUDARuntimeAPI(device=...) instance for multi-device use."
+ )
+ return _cuda_runtime
diff --git a/src/streamdiffusion/_compat/td_exporter/CUDARuntimeTypes.py b/src/streamdiffusion/_compat/td_exporter/CUDARuntimeTypes.py
new file mode 100644
index 000000000..a432c2630
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/CUDARuntimeTypes.py
@@ -0,0 +1,134 @@
+"""
+CUDA Runtime Types β ctypes structs, type aliases, and error codes for CUDA IPC.
+
+Shared between the pip package (cuda_link) and TouchDesigner textDATs.
+Compatible with both Python package and TD COMP namespace imports.
+"""
+
+from __future__ import annotations
+
+import ctypes
+from ctypes import c_int, c_size_t, c_uint64, c_void_p
+
+
+# CUDA handle types - use unsigned 64-bit to prevent overflow on Windows x64
+# See: https://github.com/pytorch/pytorch/pull/162920
+CUDAEvent_t = c_uint64 # cudaEvent_t opaque pointer
+CUDAStream_t = c_uint64 # cudaStream_t opaque pointer
+CUDAGraph_t = c_uint64 # cudaGraph_t opaque pointer (CUDA 10.0+)
+CUDAGraphExec_t = c_uint64 # cudaGraphExec_t opaque pointer (CUDA 10.0+)
+CUDAGraphNode_t = c_uint64 # cudaGraphNode_t opaque pointer (CUDA 10.0+)
+
+# Minimum cudart version required for all CUDA Graphs APIs used by this module.
+# cudaGraphInstantiateWithFlags, cudaGraphExecEventRecordNodeSetEvent, and
+# cudaGraphExecEventWaitNodeSetEvent are all CUDA 11.4+ (version integer 11040).
+CUDART_GRAPHS_MIN_VERSION = 11040
+
+# --- CUDA Graph parameter structs ---
+
+
+class cudaPos(ctypes.Structure):
+ """cudaPos: {x, y, z} offsets into an array or pitched memory."""
+
+ _fields_ = [("x", c_size_t), ("y", c_size_t), ("z", c_size_t)]
+
+
+class cudaPitchedPtr(ctypes.Structure):
+ """cudaPitchedPtr: pointer + pitch metadata for 2D/3D copies."""
+
+ _fields_ = [
+ ("ptr", c_void_p),
+ ("pitch", c_size_t),
+ ("xsize", c_size_t),
+ ("ysize", c_size_t),
+ ]
+
+
+class cudaExtent(ctypes.Structure):
+ """cudaExtent: width/height/depth dimensions in bytes for 3D copies."""
+
+ _fields_ = [("width", c_size_t), ("height", c_size_t), ("depth", c_size_t)]
+
+
+class cudaMemcpy3DParms(ctypes.Structure):
+ """cudaMemcpy3DParms: full parameter struct for cudaMemcpy3D and graph node updates."""
+
+ _fields_ = [
+ ("srcArray", c_void_p), # cudaArray_t β NULL for linear memory
+ ("srcPos", cudaPos),
+ ("srcPtr", cudaPitchedPtr),
+ ("dstArray", c_void_p), # cudaArray_t β NULL for linear memory
+ ("dstPos", cudaPos),
+ ("dstPtr", cudaPitchedPtr),
+ ("extent", cudaExtent),
+ ("kind", c_int), # cudaMemcpyKind
+ ]
+
+
+# CUDA IPC Handle structure (64 bytes, CUDA_IPC_HANDLE_SIZE per NVIDIA spec)
+class cudaIpcMemHandle_t(ctypes.Structure):
+ """CUDA IPC memory handle structure.
+
+ This opaque handle can be transferred between processes via
+ SharedMemory or other IPC mechanisms to enable GPU memory sharing.
+ """
+
+ _fields_ = [("internal", ctypes.c_byte * 64)]
+
+
+# CUDA IPC Event Handle structure (64 bytes per NVIDIA spec)
+class cudaIpcEventHandle_t(ctypes.Structure):
+ """CUDA IPC event handle structure.
+
+ Used for lightweight cross-process synchronization.
+ """
+
+ _fields_ = [("reserved", ctypes.c_byte * 64)]
+
+
+# CUDA pointer attributes β memory type and owning device for a GPU pointer
+class cudaPointerAttributes(ctypes.Structure):
+ """Result of cudaPointerGetAttributes.
+
+ Useful for validating that a caller-supplied GPU pointer belongs to the
+ expected device before issuing D2D operations (C2 affinity check).
+
+ .type values: 0=unregistered, 1=host, 2=device, 3=managed
+ .device: GPU index that owns the allocation
+ """
+
+ _fields_ = [
+ ("type", c_int), # cudaMemoryType enum (2 = cudaMemoryTypeDevice)
+ ("device", c_int), # GPU device index owning this allocation
+ ("devicePointer", c_void_p),
+ ("hostPointer", c_void_p),
+ ]
+
+
+# CUDA Error codes (subset)
+class CUDAError:
+ """CUDA runtime error codes."""
+
+ SUCCESS = 0
+ INVALID_VALUE = 1
+ MEMORY_ALLOCATION = 2
+ INVALID_DEVICE_POINTER = 17
+ INVALID_DEVICE = 101
+ INVALID_CONTEXT = 201 # Common in same-process IPC testing
+ NOT_READY = 600
+ PEER_ACCESS_ALREADY_ENABLED = 704
+
+ @staticmethod
+ def get_name(code: int) -> str:
+ """Get human-readable error name."""
+ names = {
+ 0: "SUCCESS",
+ 1: "INVALID_VALUE",
+ 2: "MEMORY_ALLOCATION",
+ 17: "INVALID_DEVICE_POINTER",
+ 101: "INVALID_DEVICE",
+ 201: "INVALID_CONTEXT",
+ 600: "NOT_READY",
+ 704: "PEER_ACCESS_ALREADY_ENABLED",
+ }
+ return names.get(code, f"UNKNOWN_ERROR_{code}")
diff --git a/src/streamdiffusion/_compat/td_exporter/HELP_DOC.md b/src/streamdiffusion/_compat/td_exporter/HELP_DOC.md
new file mode 100644
index 000000000..b3f046d14
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/HELP_DOC.md
@@ -0,0 +1,249 @@
+# CUDA-Link β Component Help
+
+> **Name:** CUDA-Link
+> **Description:** Zero-copy GPU texture sharing via CUDA IPC
+> **Author:** forkni (forkni@gmail.com)
+
+**Zero-copy GPU texture sharing between TouchDesigner and external Python processes using CUDA Inter-Process Communication (IPC).**
+
+---
+
+## Overview
+
+CUDAIPCLink transfers GPU textures between TouchDesigner and a Python process without copying data through CPU memory. Texture data stays on the GPU at all times β only a small control packet (~433 bytes) is exchanged through OS shared memory to coordinate access.
+
+The component operates in two modes: **Sender** (TouchDesigner exports textures to Python) and **Receiver** (Python sends frames back into TouchDesigner). Both directions use the same underlying protocol, so two TouchDesigner instances can also communicate directly with each other.
+
+Per-frame overhead is typically **0.5β2 Β΅s** β roughly 750Γ faster than copying textures through CPU shared memory (~1.5 ms at 1080p).
+
+---
+
+## How It Works
+
+### Sender Mode β TD β Python
+
+1. Each frame, the component calls `top_op.cudaMemory()` to get a raw GPU pointer to the upstream texture.
+2. That texture is copied into a pre-allocated ring buffer slot on the GPU using `cudaMemcpyAsync` (device-to-device, never touching CPU memory).
+3. A CUDA IPC event is recorded on that slot β a lightweight GPU-side signal (~1 Β΅s).
+4. A shared memory channel is updated with the current slot index and a producer timestamp.
+5. The Python process reads the slot index, waits on the GPU event (without blocking the CPU), and accesses the texture as a zero-copy `torch.Tensor` or `cupy.ndarray`.
+
+### Receiver Mode β Python β TD
+
+1. An external Python process allocates GPU buffers, writes IPC handles into shared memory, and signals via CUDA IPC events.
+2. On each frame start, the component reads the IPC handles from shared memory, waits on the GPU event, and copies the data into a Script TOP via `copyCUDAMemory()`.
+3. The result is a live TD texture that updates every frame with the Python process's output.
+
+### Ring Buffer Architecture
+
+The component maintains **N independent GPU buffer slots** (N = `Numslots`, default 3). The producer writes into the current slot while the consumer simultaneously reads from the previous slot. This pipeline prevents either side from ever waiting on the other:
+
+```
+Frame 0: Producer β Slot 0 Consumer idle
+Frame 1: Producer β Slot 1 Consumer β Slot 0
+Frame 2: Producer β Slot 2 Consumer β Slot 1
+Frame 3: Producer β Slot 0 Consumer β Slot 2 (wraps)
+```
+
+The consumer is always one frame behind the producer. At 60 FPS this is ~16 ms β negligible for real-time AI pipelines.
+
+### Shared Memory Protocol
+
+The shared memory channel carries only control data (no pixel data):
+
+| Field | Size | Purpose |
+|-------|------|---------|
+| Magic number | 4 B | Protocol validation (`CIPD`) |
+| Version counter | 8 B | Increments on sender re-init; receiver detects reconnection |
+| Slot count | 4 B | Number of ring buffer slots |
+| Write index | 4 B | Current producer slot (atomic counter) |
+| IPC mem handle Γ N | 128 B each | GPU memory handle per slot |
+| IPC event handle Γ N | 64 B each | GPU sync event handle per slot |
+| Shutdown flag | 1 B | Reasserted to 0 every frame; set to 1 on exit |
+| Texture metadata | 20 B | Width, height, components, dtype, buffer size |
+| Producer timestamp | 8 B | `perf_counter()` for latency measurement |
+
+Total for 3 slots: **433 bytes**.
+
+### Lazy Initialization
+
+GPU resources (buffer allocation, IPC handle creation, shared memory setup) are not allocated when `Active` is toggled on. Initialization happens on the **first frame** after activation. This avoids startup overhead and allows resolution to be detected automatically from the live texture.
+
+If the sender is not yet running, the receiver retries connection with **exponential backoff** (doubling the wait interval up to ~2 seconds between attempts), then keeps retrying silently.
+
+### Automatic Re-initialization
+
+If the upstream texture resolution or format changes, the component detects the mismatch on the next frame, tears down the existing buffers, and re-initializes with the new dimensions. This takes ~50β100 Β΅s (one-time) and is transparent to the connected Python process.
+
+---
+
+## Parameters
+
+### Active
+**Type:** Toggle | **Default:** On
+
+Master enable/disable switch for the CUDA IPC pipeline.
+
+- **On:** The component initializes GPU resources on the first frame and processes every frame thereafter.
+- **Off:** All GPU work stops immediately. `export_frame()` and `import_frame()` return without doing anything. Calling cleanup frees all GPU buffers, destroys IPC events, closes shared memory, and (in Sender mode) signals shutdown to connected consumers. The `Numslots` parameter is re-enabled for editing in Sender mode.
+- **Toggling On** does not re-initialize immediately β GPU resources are re-created lazily on the next frame callback.
+- Hot-swappable: can be toggled at any time without restarting TouchDesigner.
+
+---
+
+### Mode
+**Type:** Menu | **Default:** Sender | **Options:** Sender / Receiver
+
+Sets the direction of data flow.
+
+- **Sender:** This component is the producer. It captures the upstream texture each frame, copies it into the GPU ring buffer, and makes it available to an external Python process (or another TD instance in Receiver mode).
+- **Receiver:** This component is the consumer. It reads GPU frames produced by an external Python process (using `CUDAIPCExporter`) and imports them into a Script TOP for use in the TD network.
+
+Switching modes triggers a full cleanup of the current state and lazy re-initialization on the next frame. In Receiver mode, the `Numslots` parameter is locked and read-only β the slot count is determined by the sender's shared memory protocol and automatically reflected in the parameter display.
+
+---
+
+### Ipcmemname
+**Type:** String | **Default:** `cudalink_output_ipc`
+
+The name of the OS shared memory segment used to exchange GPU handles between the sender and receiver.
+
+Both sides **must use the exact same name**. On Windows, this maps to a named `CreateFileMapping` kernel object.
+
+Changing this parameter while active triggers a full cleanup and reconnection:
+- In Sender mode: re-initializes on the next frame export.
+- In Receiver mode: immediately resets the retry counter and attempts to connect on the next frame start (without waiting through the current backoff interval).
+
+Use different names to run multiple independent sender/receiver pairs simultaneously in the same TouchDesigner session.
+
+---
+
+### Numslots
+**Type:** Integer Menu | **Default:** 3 | **Options:** 2 / 3 / 4
+
+Number of ring buffer slots in the GPU pipeline.
+
+- **Higher values** (e.g., 4) reduce the chance of producer/consumer contention when frame processing takes variable time. Each additional slot uses one full texture worth of GPU memory (`ceil(W Γ H Γ C Γ sizeof(dtype) / 2 MiB) Γ 2 MiB`).
+- **Lower values** (e.g., 2) reduce GPU memory usage at the cost of slightly increased contention risk.
+- **3 slots (default)** is sufficient for the vast majority of use cases.
+
+**Lock behavior:**
+- Only editable when `Mode = Sender` and `Active = Off`.
+- Locked automatically when `Active` is turned On.
+- In Receiver mode: always locked. The actual slot count is read from the sender's shared memory and displayed here for reference.
+
+Changing this parameter while active is silently ignored. Changing it while inactive triggers a cleanup and lazy re-initialization on the next frame.
+
+---
+
+### Debug
+**Type:** Toggle | **Default:** Off
+
+Enables verbose performance logging to the TouchDesigner Textport.
+
+- **Off:** Only critical errors and state changes are logged.
+- **On:** every ~97 frames, prints an average timing breakdown:
+ - `cudaMemory` β OpenGLβCUDA interop time
+ - `memcpy` β D2D memcpy enqueue time
+ - `record` β IPC event record time
+ - `total` β full `export_frame()` wall-clock time
+ - `GPU memcpy` β actual GPU elapsed time measured via CUDA timing events (only available if Debug was On at initialization)
+ - `sync mode` β whether GPU-event synchronization or CPU-sync fallback is active
+
+The first frame after initialization always prints a detailed timing diagnostic regardless of this setting.
+
+Hot-swappable: can be toggled at runtime without affecting the pipeline. However, GPU timing events (`cudaEventElapsedTime`) are only created during initialization. If Debug is turned On after the component is already running, CPU-side timing is enabled immediately but the `GPU memcpy` metric will not appear until the next full cleanup/re-init cycle.
+
+---
+
+### Hide Built-In
+**Type:** Toggle | **Default:** Off
+
+Hides the built-in TouchDesigner parameter pages (Common, Extensions) from the parameter dialog, leaving only the CUDA IPC page visible.
+
+- **Off:** All parameter pages are shown β Common, Extensions, and CUDA IPC.
+- **On:** Only the CUDA IPC parameter page is shown. Built-in pages are not deleted; they are just hidden from the UI. Toggling Off restores them immediately.
+
+Hot-swappable: takes effect instantly without restarting or reinitializing the component. The setting is also applied automatically at component load time.
+
+Use this when distributing the component to end-users who should not need to interact with TD's built-in parameters.
+
+---
+
+## Quick Start
+
+### TD β Python (Sender mode)
+
+1. Drop `CUDAIPCLink_v1.4.1.tox` into your TD network.
+2. Wire your source TOP into the component's input.
+3. Set **Mode** = `Sender`.
+4. Set **Ipcmemname** to a unique name, e.g. `my_pipeline`.
+5. Toggle **Active** = On.
+6. In Python, install `cuda-link` and connect:
+ ```python
+ from cuda_link import CUDAIPCImporter
+ importer = CUDAIPCImporter(shm_name="my_pipeline")
+ frame = importer.get_frame() # torch.Tensor on GPU (zero-copy)
+ frame_np = importer.get_frame_numpy() # numpy array (CPU copy)
+ ```
+
+### Python β TD (Receiver mode)
+
+1. In Python, create an exporter:
+ ```python
+ from cuda_link import CUDAIPCExporter
+ exporter = CUDAIPCExporter(shm_name="ai_output", width=1920, height=1080)
+ exporter.export_frame(gpu_tensor)
+ ```
+2. Drop the component into TD and set **Mode** = `Receiver`.
+3. Set **Ipcmemname** to the same name (`ai_output`).
+4. Toggle **Active** = On. The receiver will connect automatically once the Python exporter is running.
+
+---
+
+## Performance Reference
+
+| Operation | Typical Time | Notes |
+|-----------|-------------|-------|
+| Per-frame IPC overhead | 0.5β2 Β΅s | GPU event record + `write_idx` update |
+| First-frame initialization | 50β100 Β΅s | One-time GPU buffer allocation + IPC handle creation |
+| D2D texture copy (1080p RGBA float32) | 60β80 Β΅s | Runs fully on GPU |
+| Receiver `copyCUDAMemory` into TD (1080p) | ~3 ms | Includes CUDAβOpenGL interop inside TD |
+| D2H numpy copy (1080p RGBA float32) | 400β600 Β΅s | Only when using `get_frame_numpy()` |
+
+**Baseline comparison:** CPU SharedMemory at 1080p RGBA float32 costs ~1.5 ms per frame β roughly **750Γ slower** than CUDA IPC.
+
+---
+
+## Troubleshooting
+
+**Receiver stays in "waiting for sender" state**
+- Confirm the sender is running and `Active` is On before starting the receiver.
+- Verify `Ipcmemname` is identical on both sides (case-sensitive).
+- Check the Textport for retry messages β the receiver uses exponential backoff up to ~2 seconds between attempts.
+
+**"Stale SharedMemory" or version mismatch logged**
+- The sender was restarted while the receiver is still holding old IPC handles. Toggle the receiver's `Active` Off β On to force reconnection.
+
+**"Protocol magic mismatch" error**
+- Another process is using the same `Ipcmemname` for a different purpose. Change `Ipcmemname` to a unique value.
+
+**GPU memory not freed after deactivation**
+- `cudaFree` of ring buffer slots is deferred briefly after cleanup (a 100 ms grace period) to allow the consumer to finish its current frame. This is normal behavior.
+
+**`Numslots` is greyed out**
+- In Sender mode: toggle `Active` Off first to edit slot count.
+- In Receiver mode: slot count is controlled by the sender and cannot be set locally.
+
+**Debug shows high `cudaMemory` time (>0.5 ms)**
+- This is the OpenGLβCUDA interop step inside TouchDesigner's `top_op.cudaMemory()` call and is not controllable by this component. It is normal for large textures or when the GPU is under heavy load.
+
+---
+
+## Requirements
+
+- **OS:** Windows 10 / 11 (CUDA IPC handle sharing is Windows-only)
+- **CUDA:** 12.x (tested with 12.4)
+- **GPU:** NVIDIA, CUDA compute capability 3.5 or higher
+- **TouchDesigner:** 2022.x or later
+- **Python (consumer side):** 3.9+, `cuda-link` package (`pip install cuda-link`)
diff --git a/src/streamdiffusion/_compat/td_exporter/NVMLObserver.py b/src/streamdiffusion/_compat/td_exporter/NVMLObserver.py
new file mode 100644
index 000000000..f0a80ae52
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/NVMLObserver.py
@@ -0,0 +1,222 @@
+"""
+NVML Observability Hook for CUDA Link.
+
+Optional GPU telemetry via NVIDIA Management Library (pynvml).
+Follows the existing optional-dep pattern: import fails gracefully,
+NVML_AVAILABLE flag gates all usage.
+
+Usage:
+ from cuda_link import NVMLObserver, NVML_AVAILABLE
+
+ if NVML_AVAILABLE:
+ obs = NVMLObserver(device=0)
+ obs.start()
+ exporter.attach_nvml_observer(obs)
+ stats = exporter.get_stats() # includes stats["nvml"]
+ obs.stop()
+
+Install optional dep:
+ pip install "cuda-link[nvml]" # adds nvidia-ml-py>=12.535
+"""
+
+from __future__ import annotations
+
+import contextlib
+import logging
+import os
+import threading
+
+
+logger = logging.getLogger(__name__)
+
+try:
+ import pynvml
+
+ NVML_AVAILABLE = True
+except ImportError:
+ pynvml = None # type: ignore[assignment]
+ NVML_AVAILABLE = False
+
+
+class _NvmlRefCounter:
+ """Process-global ref-count for nvmlInit/nvmlShutdown.
+
+ Tolerates multiple NVMLObserver instances without double-init/shutdown errors.
+ """
+
+ def __init__(self) -> None:
+ self._count: int = 0
+ self._lock: threading.Lock = threading.Lock()
+
+ def acquire(self) -> None:
+ if not NVML_AVAILABLE:
+ return
+ with self._lock:
+ if self._count == 0:
+ pynvml.nvmlInit()
+ self._count += 1
+
+ def release(self) -> None:
+ if not NVML_AVAILABLE:
+ return
+ with self._lock:
+ self._count = max(0, self._count - 1)
+ if self._count == 0:
+ pynvml.nvmlShutdown()
+
+
+_NVML_REFS = _NvmlRefCounter()
+
+
+_THROTTLE_NAMES: dict[int, str] = {
+ 0x0000000000000001: "gpu_idle",
+ 0x0000000000000002: "applications_clocks_setting",
+ 0x0000000000000004: "sw_power_cap",
+ 0x0000000000000008: "hw_slowdown",
+ 0x0000000000000010: "sync_boost",
+ 0x0000000000000020: "sw_thermal_slowdown",
+ 0x0000000000000040: "hw_thermal_slowdown",
+ 0x0000000000000080: "hw_power_brake_slowdown",
+ 0x0000000000000100: "display_clocks_setting",
+}
+
+
+def _decode_throttle(bitmask: int) -> list[str]:
+ return [name for bit, name in _THROTTLE_NAMES.items() if bitmask & bit]
+
+
+class NVMLObserver:
+ """Pull-based GPU telemetry via pynvml.
+
+ Call snapshot() (or let get_stats() call it) to sample once.
+ No background thread β caller controls cadence.
+
+ Metrics returned by snapshot():
+ gpu_util_pct, mem_bw_util_pct (from nvmlDeviceGetUtilizationRates)
+ mem_used_mb, mem_total_mb (from nvmlDeviceGetMemoryInfo)
+ sm_clock_mhz, mem_clock_mhz (from nvmlDeviceGetClockInfo)
+ pcie_tx_kbps, pcie_rx_kbps (from nvmlDeviceGetPcieThroughput)
+ temp_c (from nvmlDeviceGetTemperature)
+ power_w, power_limit_w (from nvmlDeviceGetPowerUsage)
+ throttle_reasons (decoded bitmask list)
+ driver_model "WDDM" / "TCC" / "MCDM" (Windows only; absent on Linux)
+ """
+
+ def __init__(self, device: int = 0, enabled: bool | None = None) -> None:
+ """Initialize NVML observer.
+
+ Args:
+ device: CUDA device index (default 0).
+ enabled: If None, reads CUDALINK_NVML env var ("1" = enabled).
+ If False, snapshot() returns {"nvml_available": False} immediately.
+ """
+ self.device = device
+ if enabled is None:
+ self.enabled = os.getenv("CUDALINK_NVML", "0") == "1"
+ else:
+ self.enabled = enabled
+ self._handle = None
+ self._started = False
+ self._driver_model: str | None = None
+
+ def start(self) -> bool:
+ """Initialize NVML and open device handle.
+
+ Returns:
+ True if NVML is available and handle opened, False otherwise.
+ """
+ if not NVML_AVAILABLE or not self.enabled:
+ return False
+ if self._started:
+ return True
+ try:
+ _NVML_REFS.acquire()
+ self._handle = pynvml.nvmlDeviceGetHandleByIndex(self.device)
+ with contextlib.suppress(pynvml.NVMLError):
+ # Raises NVMLError_NotSupported on Linux (driver-model is Windows-only).
+ _model = pynvml.nvmlDeviceGetCurrentDriverModel(self._handle)
+ _names = {
+ pynvml.NVML_DRIVER_WDDM: "WDDM",
+ pynvml.NVML_DRIVER_WDM: "TCC",
+ }
+ if hasattr(pynvml, "NVML_DRIVER_MCDM"):
+ _names[pynvml.NVML_DRIVER_MCDM] = "MCDM"
+ self._driver_model = _names.get(_model, f"unknown({_model})")
+ self._started = True
+ return True
+ except (pynvml.NVMLError, RuntimeError, OSError) as e:
+ logger.warning("NVML start failed for device %d: %s", self.device, e)
+ return False
+
+ def stop(self) -> None:
+ """Release NVML handle and decrement global ref-count."""
+ if self._started:
+ _NVML_REFS.release()
+ self._handle = None
+ self._started = False
+
+ def __enter__(self) -> NVMLObserver:
+ self.start()
+ return self
+
+ def __exit__(self, *_: object) -> None:
+ self.stop()
+
+ def snapshot(self) -> dict:
+ """Sample all GPU metrics once (non-blocking, ~50-200Β΅s total).
+
+ Returns:
+ Dict of metric name β value. If NVML is unavailable or not started,
+ returns {"nvml_available": False}.
+ """
+ if not self._started or self._handle is None:
+ return {"nvml_available": False}
+
+ out: dict = {"nvml_available": True}
+ h = self._handle
+
+ try:
+ util = pynvml.nvmlDeviceGetUtilizationRates(h)
+ out["gpu_util_pct"] = util.gpu
+ out["mem_bw_util_pct"] = util.memory
+ except pynvml.NVMLError:
+ pass
+
+ try:
+ mem = pynvml.nvmlDeviceGetMemoryInfo(h)
+ out["mem_used_mb"] = mem.used / (1024 * 1024)
+ out["mem_total_mb"] = mem.total / (1024 * 1024)
+ except pynvml.NVMLError:
+ pass
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["sm_clock_mhz"] = pynvml.nvmlDeviceGetClockInfo(h, pynvml.NVML_CLOCK_SM)
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["mem_clock_mhz"] = pynvml.nvmlDeviceGetClockInfo(h, pynvml.NVML_CLOCK_MEM)
+
+ try:
+ out["pcie_tx_kbps"] = pynvml.nvmlDeviceGetPcieThroughput(h, pynvml.NVML_PCIE_UTIL_TX_BYTES)
+ out["pcie_rx_kbps"] = pynvml.nvmlDeviceGetPcieThroughput(h, pynvml.NVML_PCIE_UTIL_RX_BYTES)
+ except pynvml.NVMLError:
+ pass
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["temp_c"] = pynvml.nvmlDeviceGetTemperature(h, pynvml.NVML_TEMPERATURE_GPU)
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["power_w"] = pynvml.nvmlDeviceGetPowerUsage(h) / 1000.0
+
+ with contextlib.suppress(pynvml.NVMLError):
+ out["power_limit_w"] = pynvml.nvmlDeviceGetEnforcedPowerLimit(h) / 1000.0
+
+ try:
+ bitmask = pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(h)
+ out["throttle_reasons"] = _decode_throttle(bitmask)
+ except pynvml.NVMLError:
+ pass
+
+ if self._driver_model is not None:
+ out["driver_model"] = self._driver_model
+
+ return out
diff --git a/src/streamdiffusion/_compat/td_exporter/NVTXShim.py b/src/streamdiffusion/_compat/td_exporter/NVTXShim.py
new file mode 100644
index 000000000..a631ccd7c
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/NVTXShim.py
@@ -0,0 +1,78 @@
+"""NVTX annotation shim for the td_exporter COMP namespace.
+
+Mirror of src/cuda_link/_nvtx.py for use by TDSender and TDReceiver.
+Identical semantics; different module name since td_exporter uses flat imports.
+
+Enabled via environment variables (read once at import, zero-cost when off):
+ CUDALINK_NVTX=1 β top-level phase ranges on the GPU timeline
+ CUDALINK_NVTX_VERBOSE=1 β sub-operation ranges (implies CUDALINK_NVTX=1)
+
+Requires the `nvtx` PyPI package when enabled: pip install nvtx
+"""
+
+from __future__ import annotations
+
+import os
+
+
+_VERBOSE = os.environ.get("CUDALINK_NVTX_VERBOSE", "0") == "1"
+_ENABLED = _VERBOSE or os.environ.get("CUDALINK_NVTX", "0") == "1"
+
+if _ENABLED:
+ try:
+ import nvtx as _lib
+
+ _AVAILABLE = True
+ except ImportError:
+ _lib = None
+ _AVAILABLE = False
+else:
+ _lib = None
+ _AVAILABLE = False
+
+
+class _Noop:
+ __slots__ = ()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *_):
+ pass
+
+
+_NOOP = _Noop()
+
+
+def annotate(message, color="white"):
+ """Context manager for a named NVTX range. No-op if NVTX is disabled."""
+ if _AVAILABLE:
+ return _lib.annotate(message, color=color)
+ return _NOOP
+
+
+def verbose_range(message, color="white"):
+ """Context manager for a sub-operation range. Only active when CUDALINK_NVTX_VERBOSE=1."""
+ if _AVAILABLE and _VERBOSE:
+ return _lib.annotate(message, color=color)
+ return _NOOP
+
+
+def push_range(message, color="white"):
+ """Push a named NVTX range onto the thread-local stack."""
+ if _AVAILABLE:
+ _lib.push_range(message, color=color)
+
+
+def pop_range():
+ """Pop the innermost NVTX range from the thread-local stack."""
+ if _AVAILABLE:
+ _lib.pop_range()
+
+
+def is_enabled():
+ return _AVAILABLE
+
+
+def is_verbose():
+ return _AVAILABLE and _VERBOSE
diff --git a/src/streamdiffusion/_compat/td_exporter/SHMProtocol.py b/src/streamdiffusion/_compat/td_exporter/SHMProtocol.py
new file mode 100644
index 000000000..c1953f9dc
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/SHMProtocol.py
@@ -0,0 +1,310 @@
+"""
+SHM protocol v0.5.0 β canonical source of truth for the CUDA IPC shared-memory layout.
+
+All binary constants, struct codecs, dtype mappings, and publish/acquire ordering
+live here. Every module that reads or writes the SHM region must import from here;
+never define SHM_HEADER_SIZE or _ST_U32 locally.
+
+Binary layout (total = SHMLayout(num_slots).total_size):
+ [0-3] magic uint32 LE = PROTOCOL_MAGIC
+ [4-11] version uint64 LE (monotonic; incremented each sender init)
+ [12-15] num_slots uint32 LE
+ [16-19] write_idx uint32 LE (monotonic; 0 = no frames written yet)
+ [20 + slot*128 ...] IPC handles (64B mem + 64B event per slot, N slots)
+ [20 + N*128] shutdown_flag uint8
+ [21 + N*128 ...] metadata 20B (width/height/num_comps/kind/bits/flags/data_size)
+ [41 + N*128 ...] timestamp float64 LE (producer wall-clock time)
+"""
+
+from __future__ import annotations
+
+import struct
+import threading
+from dataclasses import dataclass
+from enum import Enum
+
+
+# ---------------------------------------------------------------------------
+# Protocol constants
+# ---------------------------------------------------------------------------
+
+PROTOCOL_MAGIC: int = 0x43495044 # "CIPD" β protocol validation magic number (v1.0.0)
+
+MAGIC_OFFSET: int = 0
+MAGIC_SIZE: int = 4
+VERSION_OFFSET: int = 4
+VERSION_SIZE: int = 8
+NUM_SLOTS_OFFSET: int = 12
+NUM_SLOTS_SIZE: int = 4
+WRITE_IDX_OFFSET: int = 16
+WRITE_IDX_SIZE: int = 4
+SHM_HEADER_SIZE: int = 20 # 4B magic + 8B version + 4B num_slots + 4B write_idx
+
+SLOT_SIZE: int = 128 # 64B cudaIpcMemHandle_t + 64B cudaIpcEventHandle_t
+
+SHUTDOWN_FLAG_SIZE: int = 1
+METADATA_SIZE: int = 20 # 4B width + 4B height + 4B num_comps + 1B kind + 1B bits + 2B flags + 4B data_size
+TIMESTAMP_SIZE: int = 8 # float64 LE producer wall-clock time
+
+# ---------------------------------------------------------------------------
+# DtypeCodec constants (cudaChannelFormatKind)
+# ---------------------------------------------------------------------------
+
+FORMAT_KIND_SIGNED: int = 0 # cudaChannelFormatKindSigned
+FORMAT_KIND_UNSIGNED: int = 1 # cudaChannelFormatKindUnsigned
+FORMAT_KIND_FLOAT: int = 2 # cudaChannelFormatKindFloat
+FLAGS_BFLOAT16: int = 0x0001 # bit0: bfloat16 (kind=Float, bits=16)
+
+# dtype string β (format_kind, bits_per_component, flags)
+_DTYPE_TO_KIND_BITS: dict[str, tuple[int, int, int]] = {
+ "float32": (FORMAT_KIND_FLOAT, 32, 0),
+ "float16": (FORMAT_KIND_FLOAT, 16, 0),
+ "uint8": (FORMAT_KIND_UNSIGNED, 8, 0),
+ "uint16": (FORMAT_KIND_UNSIGNED, 16, 0),
+}
+
+# ---------------------------------------------------------------------------
+# Pre-compiled struct codecs (hot-path, saves ~50-100ns per call)
+# ---------------------------------------------------------------------------
+
+_ST_U32 = struct.Struct(" None:
+ with _fence_lock:
+ pass
+
+
+# ---------------------------------------------------------------------------
+# SHMLayout β pre-computes all byte offsets for a given num_slots
+# ---------------------------------------------------------------------------
+
+
+@dataclass(frozen=True)
+class SHMLayout:
+ """Pre-computed byte offsets for a SHM region with num_slots IPC slots."""
+
+ num_slots: int
+
+ def slot_offset(self, i: int) -> int:
+ return SHM_HEADER_SIZE + i * SLOT_SIZE
+
+ @property
+ def shutdown_offset(self) -> int:
+ return SHM_HEADER_SIZE + self.num_slots * SLOT_SIZE
+
+ @property
+ def metadata_offset(self) -> int:
+ return self.shutdown_offset + SHUTDOWN_FLAG_SIZE
+
+ @property
+ def timestamp_offset(self) -> int:
+ return self.metadata_offset + METADATA_SIZE
+
+ @property
+ def total_size(self) -> int:
+ return self.timestamp_offset + TIMESTAMP_SIZE
+
+
+# ---------------------------------------------------------------------------
+# Metadata β typed representation of the 20-byte metadata region
+# ---------------------------------------------------------------------------
+
+
+@dataclass(frozen=True)
+class Metadata:
+ """Typed representation of the 20-byte metadata region."""
+
+ width: int
+ height: int
+ num_comps: int
+ format_kind: int # cudaChannelFormatKind
+ bits_per_comp: int
+ flags: int
+ data_size: int
+
+ def pack_into(self, buf: memoryview, layout: SHMLayout) -> None:
+ offset = layout.metadata_offset
+ _ST_U32.pack_into(buf, offset, self.width)
+ _ST_U32.pack_into(buf, offset + 4, self.height)
+ _ST_U32.pack_into(buf, offset + 8, self.num_comps)
+ _ST_BBH.pack_into(buf, offset + 12, self.format_kind, self.bits_per_comp, self.flags)
+ _ST_U32.pack_into(buf, offset + 16, self.data_size)
+
+ @classmethod
+ def read_from(cls, buf: memoryview, layout: SHMLayout) -> Metadata:
+ offset = layout.metadata_offset
+ width = _ST_U32.unpack_from(buf, offset)[0]
+ height = _ST_U32.unpack_from(buf, offset + 4)[0]
+ num_comps = _ST_U32.unpack_from(buf, offset + 8)[0]
+ kind, bits, flags = _ST_BBH.unpack_from(buf, offset + 12)
+ data_size = _ST_U32.unpack_from(buf, offset + 16)[0]
+ return cls(
+ width=width,
+ height=height,
+ num_comps=num_comps,
+ format_kind=kind,
+ bits_per_comp=bits,
+ flags=flags,
+ data_size=data_size,
+ )
+
+
+# ---------------------------------------------------------------------------
+# DtypeCodec β encode/decode dtype strings
+# ---------------------------------------------------------------------------
+
+
+class DtypeCodec:
+ """Encode/decode dtype strings to/from (format_kind, bits_per_comp, flags).
+
+ Folds _DTYPE_TO_KIND_BITS (exporter) and _decode_dtype_str (importer).
+ Adding a dtype is a single-file edit here.
+ """
+
+ @staticmethod
+ def encode(dtype: str) -> tuple[int, int, int]:
+ """dtype string β (format_kind, bits_per_comp, flags).
+
+ Raises:
+ KeyError: if dtype is not supported.
+ """
+ return _DTYPE_TO_KIND_BITS[dtype]
+
+ @staticmethod
+ def decode(kind: int, bits: int, flags: int) -> str:
+ """(format_kind, bits_per_comp, flags) β dtype string."""
+ if kind == FORMAT_KIND_FLOAT and bits == 16 and not (flags & FLAGS_BFLOAT16):
+ return "float16"
+ if kind == FORMAT_KIND_FLOAT:
+ return "float32"
+ if bits == 8:
+ return "uint8"
+ if bits == 16:
+ return "uint16"
+ return "float32" # safe fallback for future extensions
+
+
+# ---------------------------------------------------------------------------
+# Header helpers β read/write the 20-byte header region
+# ---------------------------------------------------------------------------
+
+
+def read_magic(buf: memoryview) -> int:
+ return _ST_U32.unpack_from(buf, MAGIC_OFFSET)[0]
+
+
+def read_version(buf: memoryview) -> int:
+ return _ST_U64.unpack_from(buf, VERSION_OFFSET)[0]
+
+
+def read_num_slots(buf: memoryview) -> int:
+ return _ST_U32.unpack_from(buf, NUM_SLOTS_OFFSET)[0]
+
+
+def read_write_idx(buf: memoryview) -> int:
+ return _ST_U32.unpack_from(buf, WRITE_IDX_OFFSET)[0]
+
+
+def bump_version(buf: memoryview) -> int:
+ """Increment the version counter in-place; return the new version."""
+ try:
+ current = read_version(buf)
+ except (struct.error, ValueError, IndexError):
+ current = 0
+ new_version = current + 1
+ _ST_U64.pack_into(buf, VERSION_OFFSET, new_version)
+ return new_version
+
+
+# ---------------------------------------------------------------------------
+# publish_frame β the only place that encodes the C3 ordering guarantee
+# ---------------------------------------------------------------------------
+
+
+def publish_frame(buf: memoryview, layout: SHMLayout, write_idx: int, timestamp: float) -> None:
+ """Write timestamp, clear shutdown_flag, fence, then publish write_idx LAST.
+
+ Ordering is critical: the consumer reads shutdown_flag BEFORE write_idx.
+ Clearing shutdown_flag before incrementing write_idx ensures the consumer
+ always sees shutdown_flag=0 when it first observes a new frame.
+
+ Callers must not replicate this sequence outside this function.
+ """
+ _ST_F64.pack_into(buf, layout.timestamp_offset, timestamp)
+ buf[layout.shutdown_offset] = 0
+ _release_fence() # C3 release barrier: shutdown_flag visible before write_idx
+ _ST_U32.pack_into(buf, WRITE_IDX_OFFSET, write_idx)
+
+
+# ---------------------------------------------------------------------------
+# acquire_slot β consumer-side frame acquisition
+# ---------------------------------------------------------------------------
+
+
+class SlotState(Enum):
+ NO_FRAME = "no_frame"
+ NEW_FRAME = "new_frame"
+ SHUTDOWN = "shutdown"
+ VERSION_CHANGED = "version_changed"
+
+
+@dataclass
+class AcquireResult:
+ """Result of acquire_slot()."""
+
+ state: SlotState
+ slot: int = -1
+ timestamp: float = 0.0
+ new_version: int = 0
+ write_idx: int = 0
+
+
+def acquire_slot(
+ buf: memoryview,
+ layout: SHMLayout,
+ last_write_idx: int,
+ last_version: int,
+) -> AcquireResult:
+ """Read SHM state and classify the result for the consumer.
+
+ Returns an AcquireResult with one of four states:
+ - NO_FRAME: write_idx unchanged; nothing to consume.
+ - NEW_FRAME: new frame at .slot; read and process it, update last_write_idx to .write_idx.
+ - SHUTDOWN: shutdown_flag=1; producer has exited, consumer should clean up.
+ - VERSION_CHANGED: SHM was re-initialised; consumer must reopen IPC handles.
+
+ Folds _get_read_slot() (importer) and the three identical preambles in
+ get_frame / get_frame_numpy / get_frame_cupy into one location.
+ """
+ if buf[layout.shutdown_offset] != 0:
+ return AcquireResult(state=SlotState.SHUTDOWN)
+
+ version = read_version(buf)
+ if version != last_version and last_version != 0:
+ return AcquireResult(state=SlotState.VERSION_CHANGED, new_version=version)
+
+ write_idx = read_write_idx(buf)
+ if write_idx == 0 or write_idx == last_write_idx:
+ return AcquireResult(state=SlotState.NO_FRAME)
+
+ slot = (write_idx - 1) % layout.num_slots
+ try:
+ timestamp = _ST_F64.unpack_from(buf, layout.timestamp_offset)[0]
+ except struct.error:
+ timestamp = 0.0
+
+ return AcquireResult(state=SlotState.NEW_FRAME, slot=slot, timestamp=timestamp, write_idx=write_idx)
diff --git a/src/streamdiffusion/_compat/td_exporter/TDConfig.py b/src/streamdiffusion/_compat/td_exporter/TDConfig.py
new file mode 100644
index 000000000..72e452de3
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/TDConfig.py
@@ -0,0 +1,65 @@
+"""
+TDConfig β frozen configuration dataclasses for CUDAIPCExtension.
+
+Centralises all os.environ reads so the interaction matrix between toggles
+is visible in one place and the extension body only references self._config..
+
+textDAT name: TDConfig (must match the importable module name inside the COMP namespace)
+"""
+
+from __future__ import annotations
+
+import os
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class TDSenderConfig:
+ """Immutable sender configuration, resolved once at init time.
+
+ All booleans map 1-to-1 to CUDALINK_* env vars. Defaults match
+ the validated production stack described in docs/ARCHITECTURE.md.
+ """
+
+ export_sync: bool = True
+ export_profile: bool = False
+ export_flush_probe: bool = True
+ use_graphs: bool = False
+ graphs_deferred: bool = False
+ stream_high_prio: bool = False
+ init_pace: bool = False
+ persist_stream: bool = True
+ activation_barrier: bool = True
+ barrier_settle_frames: int = 30
+ nvml: bool = False
+
+ @classmethod
+ def from_env(cls) -> TDSenderConfig:
+ """Build a config from environment variables (production path)."""
+ return cls(
+ export_sync=os.environ.get("CUDALINK_EXPORT_SYNC", "1") != "0",
+ export_profile=os.environ.get("CUDALINK_EXPORT_PROFILE", "0") == "1",
+ export_flush_probe=os.environ.get("CUDALINK_EXPORT_FLUSH_PROBE", "1") == "1",
+ use_graphs=os.environ.get("CUDALINK_TD_USE_GRAPHS", "0") == "1",
+ graphs_deferred=os.environ.get("CUDALINK_TD_GRAPHS_DEFERRED", "0") == "1",
+ stream_high_prio=os.environ.get("CUDALINK_TD_STREAM_PRIO", "normal") == "high",
+ init_pace=os.environ.get("CUDALINK_TD_INIT_PACE", "0") == "1",
+ persist_stream=os.environ.get("CUDALINK_TD_PERSIST_STREAM", "1") != "0",
+ activation_barrier=os.environ.get("CUDALINK_TD_ACTIVATION_BARRIER", "1") != "0",
+ barrier_settle_frames=int(os.environ.get("CUDALINK_TD_BARRIER_SETTLE_FRAMES", "30")),
+ nvml=os.environ.get("CUDALINK_NVML", "0") == "1",
+ )
+
+ def __post_init__(self) -> None:
+ # export_flush_probe only takes effect when export_sync is False;
+ # no error, just a documented no-op.
+ if self.barrier_settle_frames < 0:
+ raise ValueError(f"barrier_settle_frames must be >= 0, got {self.barrier_settle_frames}")
+
+
+@dataclass(frozen=True)
+class TDReceiverConfig:
+ """Immutable receiver configuration.
+
+ No env vars are Receiver-only at present; placeholder for future additions.
+ """
diff --git a/src/streamdiffusion/_compat/td_exporter/TDHost.py b/src/streamdiffusion/_compat/td_exporter/TDHost.py
new file mode 100644
index 000000000..4038930b9
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/TDHost.py
@@ -0,0 +1,333 @@
+"""
+TDHost adapter β isolates all TouchDesigner runtime access behind a Protocol seam.
+
+Every call that touches ownerComp, a TOP, or a Script TOP goes through this module.
+Engine code imports nothing from the TD runtime; it calls TDHost / TOPHandle methods only.
+
+textDAT name: TDHost (must match the importable module name inside the COMP namespace)
+"""
+
+from __future__ import annotations
+
+import contextlib
+from dataclasses import dataclass, field
+from typing import Any
+
+
+# ---------------------------------------------------------------------------
+# CUDAMemoryRef β TD-agnostic result of top.cudaMemory()
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class CUDAMemoryRef:
+ """Wraps the raw CUDAMemory object returned by TOP.cudaMemory().
+
+ All fields are plain Python types β no TD types leak out.
+ """
+
+ ptr: int # GPU pointer as plain int
+ width: int
+ height: int
+ channels: int # shape.numComps
+ size: int
+ data_type: Any = field(default=None) # shape.dataType (TD-specific; forwarded opaquely)
+
+
+# ---------------------------------------------------------------------------
+# TOPHandle protocol
+# ---------------------------------------------------------------------------
+
+
+class TOPHandle:
+ """Protocol-compatible base for wrapping a single TouchDesigner TOP operator.
+
+ All concrete methods raise NotImplementedError; subclass RealTOPHandle provides
+ the TD-connected implementation and FakeTOPHandle provides the test double.
+ """
+
+ def cuda_memory(self, stream: Any = None) -> CUDAMemoryRef:
+ """Call top.cudaMemory(stream=stream) and return a CUDAMemoryRef."""
+ raise NotImplementedError
+
+ @property
+ def pixel_format(self) -> str:
+ """top.pixelFormat as a string."""
+ raise NotImplementedError
+
+ @property
+ def inputs(self) -> list[TOPHandle]:
+ """Wrapped TOPHandle for each upstream input operator."""
+ raise NotImplementedError
+
+ def set_format(self, fmt: str) -> None:
+ """Write top.par.format = fmt."""
+ raise NotImplementedError
+
+ def copy_cuda_memory(self, ptr: int, size: int, shape: Any, *, stream: int) -> None:
+ """Call script_top.copyCUDAMemory(ptr, size, shape, stream=stream)."""
+ raise NotImplementedError
+
+ def copy_numpy_array(self, arr: Any) -> None:
+ """Call script_top.copyNumpyArray(arr)."""
+ raise NotImplementedError
+
+ def set_resolution(self, width: int, height: int) -> None:
+ """Set Script TOP to custom resolution: outputresolution=9, resolutionw, resolutionh."""
+ raise NotImplementedError
+
+ def is_valid(self) -> bool:
+ """Return True if the underlying TD operator is still present in the network."""
+ raise NotImplementedError
+
+
+# ---------------------------------------------------------------------------
+# TDHost protocol
+# ---------------------------------------------------------------------------
+
+
+class TDHost:
+ """Protocol-compatible base for wrapping ownerComp.
+
+ All parameter reads/writes and operator lookups go through this class.
+ Subclass RealTDHost is the TD-connected implementation;
+ FakeTDHost (in tests) is the in-process test double.
+ """
+
+ def param_value(self, name: str) -> Any:
+ """Read ownerComp.par..eval()."""
+ raise NotImplementedError
+
+ def set_param_value(self, name: str, value: Any) -> None:
+ """Write ownerComp.par. = value."""
+ raise NotImplementedError
+
+ def set_param_enabled(self, name: str, enabled: bool) -> None:
+ """Write ownerComp.par..enable = enabled."""
+ raise NotImplementedError
+
+ def show_custom_only(self, value: bool) -> None:
+ """Write ownerComp.showCustomOnly = value."""
+ raise NotImplementedError
+
+ def is_active(self) -> bool:
+ """Read ownerComp.par.Active.eval() via cached reference (hot-path safe)."""
+ raise NotImplementedError
+
+ def find_top(self, name: str) -> TOPHandle | None:
+ """Return ownerComp.op(name) wrapped as a TOPHandle, or None."""
+ raise NotImplementedError
+
+ def set_warning_status(self, msg: str) -> None:
+ """Tint ownerComp yellow to signal a recoverable warning (e.g. bad pixel format)."""
+ raise NotImplementedError
+
+ def set_error_status(self, msg: str) -> None:
+ """Tint ownerComp red and emit a persistent script-error badge for fatal failures."""
+ raise NotImplementedError
+
+ def clear_status(self) -> None:
+ """Restore ownerComp to its original color and clear any script-error badges."""
+ raise NotImplementedError
+
+ def set_info_status(self, msg: str) -> None:
+ """Write an informational status message to the Status par (no tint/cook side effects)."""
+ raise NotImplementedError
+
+
+# ---------------------------------------------------------------------------
+# Production adapters
+# ---------------------------------------------------------------------------
+
+
+class RealTOPHandle(TOPHandle):
+ """Wraps a real TD TOP operator."""
+
+ def __init__(self, top: Any) -> None:
+ self._top = top
+
+ def cuda_memory(self, stream: Any = None) -> CUDAMemoryRef:
+ cm = self._top.cudaMemory(stream=stream) if stream is not None else self._top.cudaMemory()
+ shape = cm.shape
+ return CUDAMemoryRef(
+ ptr=int(cm.ptr),
+ width=int(shape.width),
+ height=int(shape.height),
+ channels=int(shape.numComps),
+ size=int(cm.size),
+ data_type=getattr(shape, "dataType", None),
+ )
+
+ @property
+ def pixel_format(self) -> str:
+ return str(getattr(self._top, "pixelFormat", ""))
+
+ @property
+ def inputs(self) -> list[TOPHandle]:
+ try:
+ return [RealTOPHandle(t) for t in self._top.inputs]
+ except (AttributeError, TypeError):
+ return []
+
+ def set_format(self, fmt: str) -> None:
+ with contextlib.suppress(AttributeError):
+ self._top.par.format = fmt
+
+ def copy_cuda_memory(self, ptr: int, size: int, shape: Any, *, stream: int) -> None:
+ self._top.copyCUDAMemory(ptr, size, shape, stream=stream)
+
+ def copy_numpy_array(self, arr: Any) -> None:
+ self._top.copyNumpyArray(arr)
+
+ def set_resolution(self, width: int, height: int) -> None:
+ with contextlib.suppress(AttributeError):
+ self._top.par.outputresolution = 9 # Custom Resolution mode
+ self._top.par.resolutionw = width
+ self._top.par.resolutionh = height
+
+ def is_valid(self) -> bool:
+ try:
+ return bool(getattr(self._top, "valid", True))
+ except (AttributeError, RuntimeError):
+ return False
+
+
+_WARNING_COLOR: tuple[float, float, float] = (0.9137, 1.0, 0.0)
+_ERROR_COLOR: tuple[float, float, float] = (0.7, 0.0, 0.0)
+_DEFAULT_NODE_COLOR: tuple[float, float, float] = (0.55, 0.55, 0.55)
+_MANAGED_COLORS = (_WARNING_COLOR, _ERROR_COLOR)
+
+
+class RealTDHost(TDHost):
+ """Wraps a real TD ownerComp.
+
+ Caches the Active parameter reference so is_active() avoids a 3-deep
+ attribute chain on every frame.
+ """
+
+ def __init__(self, owner_comp: Any) -> None:
+ self._comp = owner_comp
+ try:
+ self._active_par = owner_comp.par.Active
+ except AttributeError:
+ self._active_par = None
+ # _default_color is captured lazily on the first set_warning_status /
+ # set_error_status call so a tinted .tox save doesn't poison the cache.
+ # _reset_stale_tint() clears any visible managed-colour tint immediately
+ # so the COMP boots grey regardless of how it was saved.
+ self._default_color: tuple[float, float, float] | None = None
+ self._warning_emitter: Any = None # lazily resolved; False = looked up, not found
+ self._status_msg: str | None = None # current stored status; drives cook-on-transition
+ self._status_par_value: str | None = None # last value written to Status par; drives dedup
+ self._reset_stale_tint()
+
+ def param_value(self, name: str) -> Any:
+ try:
+ return getattr(self._comp.par, name).eval()
+ except AttributeError:
+ return None
+
+ def set_param_value(self, name: str, value: Any) -> None:
+ with contextlib.suppress(AttributeError):
+ setattr(self._comp.par, name, value)
+
+ def set_param_enabled(self, name: str, enabled: bool) -> None:
+ with contextlib.suppress(AttributeError):
+ getattr(self._comp.par, name).enable = enabled
+
+ def show_custom_only(self, value: bool) -> None:
+ with contextlib.suppress(AttributeError):
+ self._comp.showCustomOnly = value
+
+ def is_active(self) -> bool:
+ if self._active_par is None:
+ return True # no Active par β always active (backward compat)
+ try:
+ return bool(self._active_par.eval())
+ except AttributeError:
+ return True
+
+ def find_top(self, name: str) -> RealTOPHandle | None:
+ try:
+ top = self._comp.op(name)
+ return RealTOPHandle(top) if top is not None else None
+ except (AttributeError, RuntimeError):
+ return None
+
+ def _cook_warning_emitter(self) -> None:
+ if self._warning_emitter is None:
+ with contextlib.suppress(AttributeError, RuntimeError):
+ self._warning_emitter = self._comp.op("warning_emitter") or False
+ if self._warning_emitter:
+ with contextlib.suppress(AttributeError, RuntimeError):
+ self._warning_emitter.cook(force=True)
+
+ def _write_status_par(self, value: str) -> None:
+ if self._status_par_value == value:
+ return
+ self._status_par_value = value
+ self.set_param_value("Status", value)
+
+ def _reset_stale_tint(self) -> None:
+ with contextlib.suppress(AttributeError, RuntimeError):
+ c = self._comp.color
+ current = (float(c[0]), float(c[1]), float(c[2]))
+ if current in _MANAGED_COLORS:
+ self._comp.color = _DEFAULT_NODE_COLOR
+ self._comp.clearScriptErrors(error="*")
+ self._comp.unstore("cuda_link_status_msg")
+
+ def _capture_default_color(self) -> None:
+ if self._default_color is not None:
+ return
+ with contextlib.suppress(AttributeError, RuntimeError):
+ c = self._comp.color
+ current = (float(c[0]), float(c[1]), float(c[2]))
+ if current not in _MANAGED_COLORS:
+ self._default_color = current
+ return
+ # Fallback: current color is managed (stale tint from prior session) or
+ # unreadable β use TD's default node grey so clear_status always restores
+ # to a neutral colour rather than staying stuck at warning/error tint.
+ if self._default_color is None:
+ self._default_color = _DEFAULT_NODE_COLOR
+
+ def set_warning_status(self, msg: str) -> None:
+ self._capture_default_color()
+ full_msg = f"WARNING: {msg}"
+ needs_cook = self._status_msg != full_msg
+ self._status_msg = full_msg
+ with contextlib.suppress(AttributeError, RuntimeError):
+ self._comp.color = _WARNING_COLOR
+ self._comp.store("cuda_link_status_msg", full_msg)
+ self._write_status_par(full_msg)
+ if needs_cook:
+ self._cook_warning_emitter()
+
+ def set_error_status(self, msg: str) -> None:
+ self._capture_default_color()
+ full_msg = f"ERROR: {msg}"
+ needs_cook = self._status_msg != full_msg
+ self._status_msg = full_msg
+ with contextlib.suppress(AttributeError, RuntimeError):
+ self._comp.color = _ERROR_COLOR
+ self._comp.addScriptError(msg)
+ self._comp.store("cuda_link_status_msg", full_msg)
+ self._write_status_par(full_msg)
+ if needs_cook:
+ self._cook_warning_emitter()
+
+ def clear_status(self) -> None:
+ needs_cook = self._status_msg is not None
+ self._status_msg = None
+ with contextlib.suppress(AttributeError, RuntimeError):
+ if self._default_color is not None:
+ self._comp.color = self._default_color
+ self._comp.clearScriptErrors(error="*")
+ self._comp.unstore("cuda_link_status_msg")
+ self._write_status_par("Idle")
+ if needs_cook:
+ self._cook_warning_emitter()
+
+ def set_info_status(self, msg: str) -> None:
+ self._write_status_par(msg)
diff --git a/src/streamdiffusion/_compat/td_exporter/TDReceiver.py b/src/streamdiffusion/_compat/td_exporter/TDReceiver.py
new file mode 100644
index 000000000..be1cc3ac0
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/TDReceiver.py
@@ -0,0 +1,1063 @@
+"""
+TDReceiver - Receiver engine for CUDAIPCExtension.
+
+Owns all Receiver-mode CUDA IPC resources: SHM attachment, IPC handle opening,
+per-frame GPU event sync, and Script TOP copyCUDAMemory calls.
+
+textDAT name: TDReceiver (must match the importable module name inside the COMP namespace)
+"""
+
+from __future__ import annotations
+
+import contextlib
+import struct
+import time
+import traceback
+from ctypes import c_void_p
+from dataclasses import dataclass, field
+from multiprocessing.shared_memory import SharedMemory
+from typing import Any, Callable
+
+
+try:
+ import numpy
+except ImportError:
+ numpy = None # Will be imported at runtime in TD
+
+from CUDAIPCWrapper import get_cuda_runtime # noqa: E402
+from CUDARuntimeTypes import cudaIpcEventHandle_t, cudaIpcMemHandle_t # noqa: E402
+from NVTXShim import pop_range as _nvtx_pop # noqa: E402
+from NVTXShim import push_range as _nvtx_push
+from NVTXShim import verbose_range as _nvtx_verbose
+from SHMProtocol import ( # noqa: E402
+ _ST_BBH,
+ FLAGS_BFLOAT16,
+ FORMAT_KIND_FLOAT,
+ FORMAT_KIND_UNSIGNED,
+ MAGIC_OFFSET,
+ MAGIC_SIZE,
+ METADATA_SIZE,
+ NUM_SLOTS_OFFSET,
+ NUM_SLOTS_SIZE,
+ PROTOCOL_MAGIC,
+ SHM_HEADER_SIZE,
+ SHUTDOWN_FLAG_SIZE,
+ SLOT_SIZE,
+ VERSION_OFFSET,
+ VERSION_SIZE,
+ WRITE_IDX_OFFSET,
+ SHMLayout,
+ SlotState,
+ acquire_slot,
+)
+from TDConfig import TDSenderConfig # noqa: E402
+from TDHost import TDHost # noqa: E402
+
+
+# CuPy import deferred (heavy; only needed for float16 receiver path)
+CUPY_AVAILABLE: bool = False
+cp = None
+
+
+# ---------------------------------------------------------------------------
+# Value objects β extract the _rx_* bag into typed containers
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ReceiverConnection:
+ """Holds all CUDA IPC + SHM handles for one active receiver session.
+
+ Created by initialize_receiver(); torn down by close(). close() is idempotent.
+ """
+
+ shm_handle: object = None # SharedMemory | None
+ dev_ptrs: list = field(default_factory=list)
+ ipc_handles: list = field(default_factory=list)
+ ipc_events: list = field(default_factory=list)
+ stream: object = None
+ layout: object = None # SHMLayout | None
+ num_slots: int = 0
+ ipc_version: int = 0
+ shutdown_offset: int = 0
+ last_write_idx: int = 0 # per-frame protocol cursor; mutates inside import_frame
+
+ def is_open(self) -> bool:
+ return self.shm_handle is not None and bool(self.dev_ptrs)
+
+ def close(self, cuda: object, log_fn: Callable) -> None:
+ """Idempotent teardown β safe to call multiple times.
+
+ Consolidates cleanup() L745-789: mem handles β events β stream β SHM, in order.
+ """
+ _t0 = time.perf_counter()
+ _close_ms = _events_ms = _stream_ms = 0.0
+
+ if cuda and self.dev_ptrs:
+ _ct0 = time.perf_counter()
+ for slot, dev_ptr in enumerate(self.dev_ptrs):
+ if dev_ptr:
+ _st0 = time.perf_counter()
+ try:
+ cuda.ipc_close_mem_handle(dev_ptr)
+ log_fn(f"Closed IPC handle for slot {slot} ({(time.perf_counter() - _st0) * 1000:.1f} ms)")
+ except (RuntimeError, OSError) as e:
+ log_fn(f"Error closing IPC handle for slot {slot}: {e}", force=True)
+ _close_ms = (time.perf_counter() - _ct0) * 1000.0
+
+ if cuda and self.ipc_events:
+ _et0 = time.perf_counter()
+ for slot, event in enumerate(self.ipc_events):
+ if event:
+ _st0 = time.perf_counter()
+ try:
+ cuda.destroy_event(event)
+ log_fn(f"Destroyed IPC event for slot {slot} ({(time.perf_counter() - _st0) * 1000:.1f} ms)")
+ except (RuntimeError, OSError) as e:
+ log_fn(f"Error destroying event for slot {slot}: {e}", force=True)
+ _events_ms = (time.perf_counter() - _et0) * 1000.0
+
+ if cuda and self.stream:
+ _st0 = time.perf_counter()
+ try:
+ cuda.destroy_stream(self.stream)
+ _stream_ms = (time.perf_counter() - _st0) * 1000.0
+ log_fn(f"Destroyed receiver stream ({_stream_ms:.1f} ms)", force=True)
+ except (RuntimeError, OSError) as e:
+ log_fn(f"Error destroying receiver stream: {e}", force=True)
+
+ if self.shm_handle is not None:
+ try:
+ self.shm_handle.close()
+ except (OSError, BufferError) as e:
+ log_fn(f"Error closing SharedMemory: {e}", force=True)
+
+ self.dev_ptrs = []
+ self.ipc_handles = []
+ self.ipc_events = []
+ self.stream = None
+ self.shm_handle = None
+ self.num_slots = 0
+
+ _total_ms = (time.perf_counter() - _t0) * 1000.0
+ log_fn(
+ f"Receiver cleanup complete (total {_total_ms:.1f} ms, "
+ f"bypass 0.0 ms, ipc_close {_close_ms:.1f} ms, "
+ f"events {_events_ms:.1f} ms, stream {_stream_ms:.1f} ms)",
+ force=True,
+ )
+
+
+@dataclass
+class FormatDescriptor:
+ """Frame format negotiated from SHM metadata during initialize_receiver()."""
+
+ width: int = 0
+ height: int = 0
+ num_comps: int = 0
+ format_kind: int = FORMAT_KIND_FLOAT
+ bits_per_comp: int = 32
+ flags: int = 0
+ buffer_size: int = 0
+
+ @property
+ def is_bfloat16(self) -> bool:
+ return bool(self.flags & FLAGS_BFLOAT16)
+
+ @property
+ def is_float16(self) -> bool:
+ return self.format_kind == FORMAT_KIND_FLOAT and self.bits_per_comp == 16 and not self.is_bfloat16
+
+
+@dataclass
+class RetryState:
+ """Retry policy and transient counters for the connection-attempt loop."""
+
+ connect_attempts: int = 0
+ max_connect_attempts: int = 20
+ backoff_intervals: tuple = (1, 2, 4, 8, 16, 32, 64, 120)
+ retry_interval_frames: int = 1
+ frames_since_last_retry: int = 0
+ needs_resolution_update: bool = False
+
+ def request_immediate_reconnect(self) -> None:
+ """Force the next import_frame call to attempt reconnection."""
+ self.frames_since_last_retry = self.retry_interval_frames
+
+ def consume_resolution_update(self) -> bool:
+ """Return True and clear the flag if a resolution update is pending."""
+ if self.needs_resolution_update:
+ self.needs_resolution_update = False
+ return True
+ return False
+
+
+# ---------------------------------------------------------------------------
+# Engine
+# ---------------------------------------------------------------------------
+
+
+class TDReceiverEngine:
+ """Receiver-mode engine: owns all GPU/SHM resources for the Receiver path.
+
+ Constructed by the CUDAIPCExtension facade and replaced (not mutated) on
+ mode switches - guaranteeing zero state leak between Sender and Receiver.
+ """
+
+ def __init__(
+ self,
+ host: TDHost,
+ config: TDSenderConfig,
+ cuda: Any,
+ log_fn: Callable,
+ num_slots: int,
+ device: int,
+ shm_name: str,
+ verbose: bool,
+ ) -> None:
+ self._host = host
+ self._config = config
+ self.cuda = cuda
+ self._log = log_fn
+ self.num_slots = num_slots
+ self.device = device
+ self.shm_name = shm_name
+ self.verbose_performance = verbose
+
+ self._initialized = False
+
+ self._connection = ReceiverConnection()
+ self._format = FormatDescriptor()
+ self._retry = RetryState()
+
+ # Engine-private F16 conversion scratch (mutable per-format caches; not value objects)
+ self._f16_cpu_buf = None
+ self._f32_cpu_buf = None
+ self._f16_pinned_ptr = None
+ self._cupy_f32_buf = None
+ self._cupy_f16_views: list = []
+ self._cached_shape = None
+
+ self._diag_frames_since_reinit: int = 0
+
+ # frame_count mirrored from sender SHM - exposed for get_stats()
+ self.frame_count = 0
+
+ # --- Facade-compat property wrappers (keep CUDAIPCExtension getattr calls working) ---
+
+ @property
+ def shm_handle(self) -> object:
+ return self._connection.shm_handle
+
+ @shm_handle.setter
+ def shm_handle(self, value: object) -> None:
+ self._connection.shm_handle = value
+
+ @property
+ def dev_ptrs(self) -> list:
+ return self._connection.dev_ptrs
+
+ @property
+ def ipc_handles(self) -> list:
+ return self._connection.ipc_handles
+
+ @property
+ def write_idx(self) -> int:
+ return self._connection.last_write_idx
+
+ # --- Engine verbs (replace facade-poke patterns) ---
+
+ def request_immediate_reconnect(self) -> None:
+ """Force next import_frame to attempt reconnection.
+
+ Called from parexecute_callbacks after IPC name or slot-count changes.
+ """
+ self._retry.request_immediate_reconnect()
+
+ def consume_pending_resolution(self) -> tuple | None:
+ """Return (width, height) if resolution sync is pending, else None (and clear flag).
+
+ Called from script_top_callbacks.onCook to drive ImportBuffer Script TOP par updates.
+ """
+ if self._retry.consume_resolution_update():
+ return (self._format.width, self._format.height)
+ return None
+
+ # --- Core API ---
+
+ def is_ready(self) -> bool:
+ """True when initialized and all GPU buffer slots are open."""
+ return (
+ self._initialized
+ and bool(self._connection.dev_ptrs)
+ and all(ptr is not None for ptr in self._connection.dev_ptrs)
+ )
+
+ def get_stats(self) -> dict:
+ """Receiver statistics dict."""
+ return {
+ "mode": "Receiver",
+ "initialized": self._initialized,
+ "frame_count": self.frame_count,
+ "shm_name": self.shm_name,
+ "num_slots": self.num_slots,
+ "rx_resolution": (
+ f"{self._format.width}x{self._format.height}x{self._format.num_comps}"
+ if self._format.width > 0
+ else "N/A"
+ ),
+ "rx_buffer_size_mb": self._format.buffer_size / 1024 / 1024 if self._format.buffer_size > 0 else 0,
+ "rx_last_write_idx": self._connection.last_write_idx,
+ "rx_dev_ptrs": [f"0x{ptr.value:016x}" if ptr else "NULL" for ptr in self._connection.dev_ptrs],
+ }
+
+ def import_frame(self, handle: object) -> bool:
+ """Import frame from CUDA IPC into ImportBuffer (Script TOP).
+
+ Can be called from:
+ - Inside ImportBuffer's onCook callback (TD 2023+ compatibility)
+ - Execute DAT onFrameStart with modoutsidecook enabled (TD 2025+)
+
+ Args:
+ handle: TOPHandle wrapping the ImportBuffer Script TOP (wrapped by facade)
+
+ Returns:
+ True if import successful, False otherwise.
+ """
+ # Check Active parameter (hot path via TDHost.is_active())
+ if not self._host.is_active():
+ return False
+
+ # Lazy initialization with exponential backoff retry logic
+ if not self._initialized:
+ self._retry.frames_since_last_retry += 1
+ if self._retry.frames_since_last_retry < self._retry.retry_interval_frames:
+ return False # Wait before retrying
+
+ self._retry.frames_since_last_retry = 0
+ self._retry.connect_attempts += 1
+
+ if not self.initialize_receiver():
+ backoff_idx = min(self._retry.connect_attempts, len(self._retry.backoff_intervals) - 1)
+ self._retry.retry_interval_frames = self._retry.backoff_intervals[backoff_idx]
+ if self._retry.connect_attempts <= self._retry.max_connect_attempts:
+ self._log(
+ f"Waiting for sender... (attempt {self._retry.connect_attempts}, "
+ f"next retry in {self._retry.retry_interval_frames} frames)"
+ )
+ elif self._retry.connect_attempts == self._retry.max_connect_attempts + 1:
+ self._log("Sender not found. Will keep retrying silently.", force=True)
+ return False
+
+ _nvtx_push(
+ f"cudalink.receiver.import_frame.slot{(self._connection.last_write_idx) % max(self._connection.num_slots, 1)}",
+ "blue",
+ )
+ try:
+ result = acquire_slot(
+ self._connection.shm_handle.buf,
+ self._connection.layout,
+ self._connection.last_write_idx,
+ self._connection.ipc_version,
+ )
+ if result.state is SlotState.SHUTDOWN:
+ self._log("Sender shutdown detected. Cleaning up.", force=True)
+ self.cleanup()
+ return False
+ if result.state is SlotState.VERSION_CHANGED:
+ self._log(
+ f"Sender updated (v{self._connection.ipc_version} -> v{result.new_version}). Refreshing in-place...",
+ force=True,
+ )
+ if not self._refresh_on_version_change(result.new_version):
+ self._log("In-place refresh failed β falling back to full reinit.", force=True)
+ self.cleanup()
+ return False # No frame to consume this tick regardless of refresh outcome
+ if result.state is SlotState.NO_FRAME:
+ return False
+
+ self._connection.last_write_idx = result.write_idx
+ write_idx = result.write_idx
+ read_slot = result.slot
+
+ _diag = self._diag_frames_since_reinit < 5
+ _t_event = _t_copy = 0.0 # pre-init for static analyzers; only read when _diag is True
+ if _diag:
+ self._diag_frames_since_reinit += 1
+ _t_event = time.perf_counter()
+
+ # Wait on IPC event for this slot (stream-ordered, non-blocking to CPU)
+ with _nvtx_verbose("cudalink.receiver.event_wait", "blue"):
+ if self._connection.ipc_events[read_slot]:
+ self.cuda.stream_wait_event(
+ self._connection.stream,
+ self._connection.ipc_events[read_slot],
+ 0,
+ )
+ else:
+ # Fallback when no IPC event: drain the stream now.
+ # Note: float16 path will call stream_synchronize again below, but
+ # synchronizing an already-idle stream is a no-op in CUDA.
+ self.cuda.stream_synchronize(self._connection.stream)
+
+ if _diag:
+ _event_ms = (time.perf_counter() - _t_event) * 1000.0
+ _t_copy = time.perf_counter()
+
+ # Copy CUDA memory into ImportBuffer texture using cached shape
+ address = self._connection.dev_ptrs[read_slot].value
+
+ if self._format.is_float16:
+ if CUPY_AVAILABLE and self._cupy_f32_buf is not None:
+ # GPU-side float16βfloat32 conversion (Ch5: minimize PCIe traffic).
+ # stream_wait_event (enqueued above on _connection.stream) guarantees GPU data is ready.
+ # We create a zero-copy CuPy view of the IPC pointer, run an elementwise
+ # f16βf32 cast entirely on GPU via ExternalStream, then call copyCUDAMemory β
+ # eliminating two PCIe roundtrips and the CPU numpy.copyto call.
+ rx_stream_int = int(self._connection.stream.value)
+ f16_size = self._format.buffer_size # original float16 byte count
+ f32_size = f16_size * 2 # float32 = 2Γ bytes
+
+ cupy_f16 = self._cupy_f16_views[read_slot]
+ # Run conversion on _connection.stream so copyCUDAMemory (also on _connection.stream)
+ # automatically serializes after the elementwise cast kernel.
+ with cp.cuda.ExternalStream(rx_stream_int):
+ cp.copyto(self._cupy_f32_buf, cupy_f16, casting="same_kind")
+
+ handle.copy_cuda_memory(
+ self._cupy_f32_buf.data.ptr,
+ f32_size,
+ self._cached_shape, # dataType=float32 set during initialize_receiver()
+ stream=rx_stream_int,
+ )
+ else:
+ # CPU fallback: D2H + numpy convert + copyNumpyArray.
+ # Used when CuPy is not installed or GPU buffer allocation failed.
+ if self._f16_cpu_buf is None or self._f32_cpu_buf is None:
+ debug("[CUDAIPCLink] float16 CPU buffers not allocated β skipping frame")
+ return False
+
+ # D2H on _connection.stream: stream_wait_event (enqueued earlier) guarantees data is ready.
+ cpu_ptr = self._f16_cpu_buf.ctypes.data_as(c_void_p)
+ self.cuda.memcpy_async(
+ cpu_ptr, c_void_p(address), self._format.buffer_size, 2, self._connection.stream
+ )
+ self.cuda.stream_synchronize(self._connection.stream)
+ numpy.copyto(
+ self._f32_cpu_buf,
+ self._f16_cpu_buf.reshape(self._format.height, self._format.width, self._format.num_comps),
+ casting="same_kind",
+ )
+ handle.copy_numpy_array(self._f32_cpu_buf)
+ else:
+ handle.copy_cuda_memory(
+ address,
+ self._format.buffer_size,
+ self._cached_shape,
+ stream=int(self._connection.stream.value),
+ )
+
+ if _diag:
+ _copy_ms = (time.perf_counter() - _t_copy) * 1000.0
+ self._log(
+ f"[DIAG] import_frame #{self._diag_frames_since_reinit}: "
+ f"slot={read_slot} write_idx={write_idx} addr=0x{address:x} "
+ f"stream_wait={_event_ms:.2f}ms copyCUDAMemory={_copy_ms:.2f}ms",
+ force=True,
+ )
+
+ self.frame_count += 1
+ self._connection.last_write_idx = write_idx
+
+ _dt = self._cached_shape.dataType
+ _dtype_str = (
+ getattr(_dt, "name", None) or getattr(_dt, "__name__", str(_dt)) if _dt is not None else "unknown"
+ )
+ self._host.set_info_status(
+ f"{self._format.width}x{self._format.height} {_dtype_str} {self._format.num_comps}ch"
+ )
+
+ # Debug logging (97 = prime, avoids aliasing with slot counts 2,4,5)
+ if self.verbose_performance and self.frame_count % 97 == 0:
+ self._log(f"Frame {self.frame_count}: read_slot={read_slot}, write_idx={write_idx}")
+
+ return True
+
+ except (RuntimeError, OSError) as e:
+ self._log(f"Import failed: {e}", force=True)
+
+ traceback.print_exc()
+ return False
+ finally:
+ _nvtx_pop()
+
+ def update_receiver_resolution(self, handle: object) -> bool:
+ """Update ImportBuffer resolution from outside the cook cycle.
+
+ Safe to call from Execute DAT when modoutsidecook is enabled on the Script TOP (TD 2025+).
+ When modoutsidecook is NOT available, this is a no-op (resolution handled in onCook).
+
+ Args:
+ handle: TOPHandle wrapping the ImportBuffer Script TOP (wrapped by facade)
+
+ Returns:
+ True if resolution was updated, False if no update needed or not applicable
+ """
+ if not self._retry.needs_resolution_update:
+ return False
+
+ try:
+ handle.set_resolution(self._format.width, self._format.height)
+ self._retry.needs_resolution_update = False
+ self._log(
+ f"Set ImportBuffer resolution to {self._format.width}x{self._format.height} (from Execute DAT)",
+ force=True,
+ )
+ return True
+ except (AttributeError, RuntimeError) as e:
+ self._log(f"Could not set ImportBuffer resolution: {e}", force=True)
+ return False
+
+ def initialize_receiver(self) -> bool:
+ """Initialize receiver: open SharedMemory, read handles, open IPC handles.
+
+ Returns:
+ True if initialization successful, False otherwise.
+ """
+ if self._initialized:
+ return True
+
+ # Numslots is always disabled in Receiver mode (sender controls slot count)
+ self._host.set_param_enabled("Numslots", False)
+
+ _t0 = time.perf_counter()
+ try:
+ self.cuda = get_cuda_runtime(device=self.device)
+ self._log(f"Loaded CUDA runtime on device {self.cuda.get_device()}", force=True)
+
+ # Open SharedMemory (sender must have created it)
+ try:
+ shm_handle = SharedMemory(name=self.shm_name)
+ except FileNotFoundError:
+ self._log(f"SharedMemory '{self.shm_name}' not found. Sender not ready?")
+ return False
+
+ # Validate protocol magic number (new in this version)
+ try:
+ magic = struct.unpack(
+ " 10:
+ self._log(
+ f"Invalid num_slots: {num_slots}. Protocol error.",
+ force=True,
+ )
+ shm_handle.close()
+ return False
+
+ # Sync UI parameter to show sender's slot count (informational only).
+ # Do NOT set self.num_slots β that's the sender-specific working value.
+ # Receiver always uses connection.num_slots for its own arrays.
+ self._host.set_param_value("Numslots", num_slots)
+
+ # Cache receiver layout once β avoids per-frame arithmetic in import_frame()
+ layout = SHMLayout(num_slots)
+ shutdown_offset = layout.shutdown_offset
+ metadata_offset = shutdown_offset + SHUTDOWN_FLAG_SIZE
+
+ # Check if SharedMemory is large enough for metadata
+ if len(shm_handle.buf) >= metadata_offset + METADATA_SIZE:
+ width = struct.unpack(
+ " None:
+ """Cleanup partially-opened resources when initialization fails mid-slot.
+
+ Called when initialize_receiver() fails partway through slot iteration.
+ Closes IPC handles already opened for slots 0..failed_slot-1 to prevent
+ GPU resource leaks across backoff retries.
+
+ Args:
+ failed_slot: The slot index that failed (0-based). Cleans up slots 0..failed_slot-1.
+ dev_ptrs: In-progress dev_ptrs list from this init attempt.
+ ipc_events: In-progress ipc_events list from this init attempt.
+ stream: Stream created for this init attempt (only closed if freshly created).
+ shm_handle: SHM handle to close and clear.
+ """
+ for i in range(failed_slot):
+ if dev_ptrs[i] is not None:
+ try:
+ self.cuda.ipc_close_mem_handle(dev_ptrs[i])
+ self._log(f"Cleaned up partial slot {i} mem handle")
+ except (RuntimeError, OSError):
+ pass
+ dev_ptrs[i] = None
+ if ipc_events[i] is not None:
+ with contextlib.suppress(RuntimeError, OSError):
+ self.cuda.destroy_event(ipc_events[i])
+ ipc_events[i] = None
+
+ # Close SharedMemory so next retry re-opens fresh (avoids reading stale content)
+ if shm_handle is not None:
+ with contextlib.suppress(OSError, BufferError):
+ shm_handle.close()
+
+ def cleanup(self) -> None:
+ """Cleanup Receiver CUDA IPC resources."""
+ # Guard against double-cleanup (matches cleanup_sender() pattern)
+ if not self._initialized and self._connection.shm_handle is None:
+ return
+
+ self._connection.close(self.cuda, self._log)
+
+ # Free pinned float16 D2H buffer if allocated
+ if self._f16_pinned_ptr is not None:
+ try:
+ self.cuda.free_host(self._f16_pinned_ptr)
+ except (RuntimeError, OSError) as e:
+ self._log(f"free_host skipped (context gone): {e}")
+ self._f16_pinned_ptr = None
+ self._f16_cpu_buf = None
+ self._f32_cpu_buf = None
+ self._cupy_f32_buf = None # CuPy memory pool handles GPU free on GC
+ self._cupy_f16_views = []
+ self._cached_shape = None
+
+ self._initialized = False
+ self._retry.connect_attempts = 0
+ self._retry.frames_since_last_retry = 0
+
+ def _refresh_on_version_change(self, new_version: int) -> bool:
+ """Refresh format and IPC handles in-place after a sender version bump.
+
+ Keeps SHM, stream, and unchanged IPC handles open. Only re-reads the
+ 20-byte metadata block and rebuilds self._format and self._cached_shape.
+ For genuine sender re-inits (new IPC handles), also closes old handles
+ and opens the new ones β preserving the SHM connection throughout.
+
+ Mirrors src/cuda_link/cuda_ipc_importer.py:_reinitialize.
+
+ Returns:
+ True if refresh succeeded (caller skips cleanup and continues).
+ False on any error (caller falls back to cleanup + full reinit).
+ """
+ conn = self._connection
+ shm = conn.shm_handle
+ if shm is None or conn.layout is None:
+ return False
+
+ layout = conn.layout
+ metadata_offset = layout.metadata_offset
+ try:
+ if len(shm.buf) < metadata_offset + METADATA_SIZE:
+ return False
+ width = struct.unpack(" SenderActivationBarrier:
+ return cls(
+ enabled=config.activation_barrier,
+ settle_frames=config.barrier_settle_frames,
+ )
+
+ def acquire(self, pid: int, *, log_fn: Callable) -> None:
+ """Open-or-create the segment, increment, set held=True. Log+swallow failures."""
+ if not self.enabled:
+ return
+ try:
+ if self.shm is None:
+ self.shm = _ab_open_or_create(create=True)
+ count = _ab_increment(self.shm, pid)
+ self.held = True
+ log_fn(f"[ACTIVATION_BARRIER] held +1 (count={count}) for Sender init", force=True)
+ except (OSError, RuntimeError, struct.error) as _exc:
+ log_fn(f"[ACTIVATION_BARRIER] init increment failed (ignored): {_exc}", force=True)
+
+ def arm_settle_countdown(self) -> None:
+ """Called from initialize() tail when init succeeds β settle_remaining = settle_frames."""
+ if self.held:
+ self.settle_remaining = self.settle_frames
+
+ def tick_and_maybe_release(self, pid: int, *, log_fn: Callable) -> bool:
+ """Per-frame: decrement settle_remaining. When it hits 0 and held, release barrier.
+
+ Returns True iff the release fired this frame.
+ """
+ if self.settle_remaining <= 0:
+ return False
+ self.settle_remaining -= 1
+ if self.settle_remaining == 0 and self.held and self.shm is not None:
+ try:
+ count = _ab_decrement(self.shm, pid)
+ log_fn(
+ f"[ACTIVATION_BARRIER] released after {self.settle_frames}-frame settle (count now {count})",
+ force=True,
+ )
+ return True
+ except (OSError, RuntimeError, struct.error) as _exc:
+ log_fn(f"[ACTIVATION_BARRIER] settle decrement failed (ignored): {_exc}", force=True)
+ finally:
+ self.held = False
+ return False
+
+ def force_release(self, pid: int, *, log_fn: Callable) -> None:
+ """Cleanup-time: if still held, decrement and clear. Idempotent."""
+ if not (self.held and self.shm is not None):
+ return
+ try:
+ count = _ab_decrement(self.shm, pid)
+ log_fn(
+ f"[ACTIVATION_BARRIER] released on cleanup (mid-settle, count now {count})",
+ force=True,
+ )
+ except (OSError, RuntimeError, struct.error) as _exc:
+ log_fn(f"[ACTIVATION_BARRIER] cleanup decrement failed (ignored): {_exc}", force=True)
+ finally:
+ self.held = False
+
+ def close(self) -> None:
+ """Idempotent: close SHM handle if held."""
+ if self.shm is not None:
+ with contextlib.suppress(OSError, RuntimeError):
+ self.shm.close()
+ self.shm = None
+
+
+class TDSenderEngine:
+ """Sender-mode engine: owns all GPU/SHM resources for the Sender path.
+
+ Constructed by the CUDAIPCExtension facade and replaced (not mutated) on
+ mode switches - guaranteeing zero state leak between Sender and Receiver.
+ """
+
+ def __init__(
+ self,
+ host: TDHost,
+ config: TDSenderConfig,
+ cuda: Any,
+ log_fn: Callable,
+ num_slots: int,
+ device: int,
+ shm_name: str,
+ verbose: bool,
+ ) -> None:
+ self._host = host
+ self._config = config
+ self.cuda = cuda
+ self._log = log_fn
+ self.num_slots = num_slots
+ self.device = device
+ self.shm_name = shm_name
+ self.verbose_performance = verbose
+
+ self._initialized = False
+
+ self.dev_ptrs = [None] * self.num_slots
+ self.buffer_size = 0
+ self.data_size = 0
+ self.width = 0
+ self.height = 0
+ self.channels = 4
+
+ self.ipc_handles = [None] * self.num_slots
+ self.ipc_events = [None] * self.num_slots
+ self.ipc_event_handles = [None] * self.num_slots
+
+ self._pending_free_ptrs: list = []
+ self._pending_free_events: list = []
+ self._deferred_free_at_frame = 0
+
+ self.write_idx = 0
+ self.shm_handle = None
+ self._layout: SHMLayout | None = None
+ self._shutdown_offset = 0
+ self._ts_offset = 0
+ self.frame_count = 0
+ self.cuda_mem_ref = None
+ self.sync_interval = 10
+
+ self._export_sync: bool = self._config.export_sync
+ self._export_profile: bool = self._config.export_profile
+ self._export_flush_probe: bool = self._config.export_flush_probe
+ self._use_graphs: bool = self._config.use_graphs
+ self._graphs_disabled: bool = False
+ self._graph_execs: list = [None] * self.num_slots
+ self._graph_templates: list = [None] * self.num_slots
+ self._graph_memcpy_nodes: list = [None] * self.num_slots
+ self._stream_high_prio: bool = self._config.stream_high_prio
+ self._init_pace: bool = self._config.init_pace
+ self._graphs_pending: bool = False
+ self._graphs_deferred: bool = self._config.graphs_deferred
+ self._persist_stream: bool = self._config.persist_stream
+ self._barrier = SenderActivationBarrier.from_config(self._config)
+
+ self._nvml_observer: NVMLObserver | None = None
+
+ self.total_memcpy_time = 0.0
+ self.total_record_event_time = 0.0
+ self.total_export_time = 0.0
+ self.total_cuda_memory_time = 0.0
+ self.total_pre_interop_us: float = 0.0
+ self.total_post_interop_us: float = 0.0
+ self.total_sync_us: float = 0.0
+ self.total_sticky_check_us: float = 0.0
+ self.total_flush_probe_us: float = 0.0
+ self.total_shm_publish_us: float = 0.0
+ self.total_unaccounted_us: float = 0.0
+
+ self._warned_format = False
+ self._export_buffer: object = None
+ self._last_pixel_fmt: str = ""
+ self._last_fmt_needs_conv: bool = False
+
+ self.ipc_stream = None
+ self._last_cuda_mem_err = ""
+ self._detected_numpy_dtype: object = None
+ self._last_numpy_dtype: object = None
+
+ # Profiling events (created lazily in initialize when _export_profile=True)
+ self._timing_start = None
+ self._timing_end = None
+
+ def is_ready(self) -> bool:
+ """True when initialized and all GPU buffer slots are allocated."""
+ return self._initialized and all(ptr is not None for ptr in self.dev_ptrs)
+
+ def get_stats(self) -> dict:
+ """Sender statistics dict."""
+ return {
+ "mode": "Sender",
+ "initialized": self._initialized,
+ "frame_count": self.frame_count,
+ "shm_name": self.shm_name,
+ "num_slots": self.num_slots,
+ "buffer_size_mb": self.buffer_size / 1024 / 1024 if self.buffer_size > 0 else 0,
+ "resolution": f"{self.width}x{self.height}x{self.channels}" if self.width > 0 else "N/A",
+ "write_idx": self.write_idx,
+ "dev_ptrs": [f"0x{ptr.value:016x}" if ptr else "NULL" for ptr in self.dev_ptrs],
+ }
+
+ def _is_unsupported_format(self, top_op: object) -> bool:
+ """Return True if the TOP's pixel format is unsupported by cudaMemory() in TD 2025.
+
+ Empirical probe (verification/results/cuda_memory_probe_20260510_090919.json,
+ TD 2025.32820): cudaMemory() rejects all 4 float16 variants and 10:10:10:2 fixed
+ outright; 11:11:10 float "succeeds" but returns dataType=uint8/numComps=4 (raw
+ byte layout, NOT semantic) β silent corruption. On True: sender skips the frame
+ and emits a component warning (addScriptError); on False: warning is cleared.
+
+ top_op may be a RealTOPHandle, FakeTOPHandle, or raw TD TOP (backward compat).
+ """
+ if hasattr(top_op, "pixel_format"):
+ pixel_fmt = str(top_op.pixel_format)
+ else:
+ pixel_fmt = str(getattr(top_op, "pixelFormat", ""))
+ if pixel_fmt == self._last_pixel_fmt:
+ return self._last_fmt_needs_conv
+ self._last_pixel_fmt = pixel_fmt
+ pixel_lower = pixel_fmt.lower()
+ self._last_fmt_needs_conv = any(u in pixel_lower for u in _CUDA_UNSUPPORTED_PIXEL_FORMATS)
+ return self._last_fmt_needs_conv
+
+ def initialize(self, width: int, height: int, channels: int = 4, buffer_size: int | None = None) -> bool:
+ """Initialize CUDA IPC resources.
+
+ Args:
+ width: Texture width in pixels
+ height: Texture height in pixels
+ channels: Number of channels (default: 4 for RGBA)
+ buffer_size: Actual buffer size in bytes (optional, auto-calculated if None)
+
+ Returns:
+ True if initialization successful, False otherwise
+ """
+ if self._initialized:
+ self._log("Already initialized")
+ return True
+
+ # Lock Numslots while active β changing slot count at runtime causes array size mismatch
+ self._host.set_param_enabled("Numslots", False)
+
+ try:
+ # Activation-barrier hold: signal the Python producer to pause pushes
+ # during this Sender's WDDM-saturating init burst.
+ self._barrier.acquire(os.getpid(), log_fn=self._log)
+
+ # Load CUDA runtime bound to the configured device
+ self.cuda = get_cuda_runtime(device=self.device)
+ self._log(f"Loaded CUDA runtime on device {self.cuda.get_device()}", force=True)
+
+ # Create dedicated non-blocking stream for IPC operations.
+ # Reuse existing stream on re-init to avoid leaks.
+ if self.ipc_stream is None:
+ # Default normal-priority (Phase 4.1). Set CUDALINK_TD_STREAM_PRIO=high
+ # for explicit single-pair lowest-latency.
+ if self._stream_high_prio:
+ self.ipc_stream = self.cuda.create_stream_with_priority(flags=0x01)
+ self._log(
+ f"Created IPC stream (high-priority): 0x{int(self.ipc_stream.value):016x}",
+ force=True,
+ )
+ else:
+ self.ipc_stream = self.cuda.create_stream(flags=0x01)
+ self._log(
+ f"Created IPC stream (normal-priority): 0x{int(self.ipc_stream.value):016x}",
+ force=True,
+ )
+ else:
+ self._log(
+ f"Reusing IPC stream: 0x{int(self.ipc_stream.value):016x}",
+ force=True,
+ )
+
+ # Store dimensions
+ self.width = width
+ self.height = height
+ self.channels = channels
+ # Use provided buffer_size (from cuda_mem.size) or calculate
+ raw_size = buffer_size if buffer_size is not None else width * height * channels * 4
+ # Round up to 2MiB alignment (NVIDIA requirement: prevents unintended information disclosure)
+ alignment = 2 * 1024 * 1024 # 2 MiB
+ self.buffer_size = ((raw_size + alignment - 1) // alignment) * alignment
+ self.data_size = raw_size # Store actual data size for memcpy and comparisons
+
+ # Defensive array resize: num_slots may have changed between cleanup and init
+ # (e.g. handle_numslots_change() sets num_slots after cleanup resets arrays)
+ if len(self.dev_ptrs) != self.num_slots:
+ self.dev_ptrs = [None] * self.num_slots
+ self.ipc_handles = [None] * self.num_slots
+ self.ipc_events = [None] * self.num_slots
+ self.ipc_event_handles = [None] * self.num_slots
+
+ # Allocate ring buffer slots
+ for slot in range(self.num_slots):
+ # Allocate persistent GPU buffer for this slot
+ self.dev_ptrs[slot] = self.cuda.malloc(self.buffer_size)
+ self._log(
+ f"Allocated GPU buffer slot {slot}: "
+ f"{self.buffer_size / 1024 / 1024:.1f} MB at 0x{self.dev_ptrs[slot].value:016x}",
+ force=True,
+ )
+
+ # Create IPC handle for this buffer (ONCE - reuse for all frames)
+ self.ipc_handles[slot] = self.cuda.ipc_get_mem_handle(self.dev_ptrs[slot])
+ self._log(f"Created IPC handle for slot {slot} (64 bytes)")
+
+ # Create IPC event for GPU-side synchronization (per-slot)
+ self.ipc_events[slot] = self.cuda.create_ipc_event()
+ self.ipc_event_handles[slot] = self.cuda.ipc_get_event_handle(self.ipc_events[slot])
+ self._log(f"Created IPC event for slot {slot} (64 bytes)")
+
+ self._log(f"Created {self.num_slots} IPC buffer slots with events", force=True)
+
+ # INIT_PACE checkpoint 1/3 β CUDALINK_TD_INIT_PACE=1: flush WDDM queue after per-slot
+ # alloc burst (cudaMalloc + IpcGetMemHandle + EventCreate + IpcGetEventHandle Γ N).
+ if self._init_pace:
+ self.cuda.stream_synchronize(self.ipc_stream)
+ time.sleep(0.02)
+ self._log("[INIT_PACE] checkpoint 1/3 (post-slot-alloc)", force=True)
+
+ # Create SharedMemory for IPC handle transfer
+ # Size: header + slots + shutdown flag + metadata + timestamp (for extended protocol)
+ shm_size = (
+ SHM_HEADER_SIZE + (self.num_slots * SLOT_SIZE) + SHUTDOWN_FLAG_SIZE + METADATA_SIZE + TIMESTAMP_SIZE
+ )
+
+ try:
+ # Try to open existing SharedMemory first
+ self.shm_handle = SharedMemory(name=self.shm_name)
+ self._log(f"Opened existing SharedMemory: {self.shm_name}", force=True)
+ except FileNotFoundError:
+ # Create new SharedMemory if doesn't exist
+ self.shm_handle = SharedMemory(name=self.shm_name, create=True, size=shm_size)
+ self._log(
+ f"Created new SharedMemory: {self.shm_name} ({shm_size} bytes)",
+ force=True,
+ )
+
+ # Write IPC handle to SharedMemory (ONCE - Python process reads at startup)
+ self._write_handle_to_shm()
+
+ # Write texture metadata to extended protocol region
+ self._write_metadata_to_shm()
+
+ # INIT_PACE checkpoint 2/3 β after SHM segment creation + handle/metadata writes.
+ if self._init_pace:
+ self.cuda.stream_synchronize(self.ipc_stream)
+ time.sleep(0.02)
+ self._log("[INIT_PACE] checkpoint 2/3 (post-SHM-write)", force=True)
+
+ # Cache SHM offsets: avoid recomputing these on every export_frame() call
+ self._layout = SHMLayout(self.num_slots)
+ self._shutdown_offset = self._layout.shutdown_offset
+ self._ts_offset = self._layout.timestamp_offset
+
+ # Cache ExportBuffer as TOPHandle β eliminates per-frame ownerComp.op() lookup
+ self._export_buffer = self._host.find_top(_EXPORT_BUFFER_NAME)
+
+ # Create GPU timing events (only when Debug is ON for benchmarking)
+ if self.verbose_performance:
+ self._timing_start = self.cuda.create_timing_event()
+ self._timing_end = self.cuda.create_timing_event()
+ self._log("Created GPU timing events for benchmarking", force=False)
+ else:
+ self._timing_start = None
+ self._timing_end = None
+
+ self._initialized = True
+ self._barrier.arm_settle_countdown()
+ self._log("Initialization complete - ready for zero-copy GPU transfer", force=True)
+
+ # CUDA Graphs build (after IPC stream / events / ring buffer are ready).
+ # Gated on CUDALINK_TD_USE_GRAPHS=1 AND cudart >= 11.4
+ # (cudaGraphInstantiateWithFlags + EventRecordNodeSetEvent require 11.4+).
+ if self._use_graphs:
+ try:
+ rt_version = self.cuda.get_runtime_version()
+ except (RuntimeError, OSError) as exc:
+ rt_version = 0
+ self._log(f"cudaRuntimeGetVersion failed ({exc}) β disabling graphs", force=True)
+ if rt_version >= CUDART_GRAPHS_MIN_VERSION:
+ if self._graphs_deferred:
+ # CUDALINK_TD_GRAPHS_DEFERRED=1: defer graph capture to first
+ # warm frame (frame_count >= 30) so the init burst doesn't overlap
+ # with Receiver-A's 60 Hz stream. First 30 frames use legacy memcpy_async.
+ self._graphs_pending = True
+ self._log(
+ "CUDA export graph build deferred to first warm frame (CUDALINK_TD_GRAPHS_DEFERRED=1)",
+ force=True,
+ )
+ else:
+ self._build_export_graphs()
+ # INIT_PACE checkpoint 3/3 β after graph capture + instantiation.
+ if self._init_pace:
+ self.cuda.stream_synchronize(self.ipc_stream)
+ time.sleep(0.02)
+ self._log("[INIT_PACE] checkpoint 3/3 (post-graph-build)", force=True)
+ else:
+ self._log(
+ f"CUDALINK_TD_USE_GRAPHS=1 ignored: cudart {rt_version} < {CUDART_GRAPHS_MIN_VERSION} "
+ "(cudaGraphInstantiateWithFlags requires 11.4+).",
+ force=True,
+ )
+ self._graphs_disabled = True
+
+ if NVML_AVAILABLE and self._config.nvml:
+ obs = NVMLObserver(device=self.device, enabled=True)
+ if obs.start():
+ self._nvml_observer = obs
+ self._log(f"NVMLObserver attached on device {self.device}", force=True)
+
+ return True
+
+ except (OSError, RuntimeError, ValueError) as e:
+ self._log(f"Initialization failed: {e}", force=True)
+ self._host.set_error_status(f"Initialization failed: {e}")
+ traceback.print_exc()
+ return False
+
+ def _build_export_graphs(self) -> None:
+ """Capture the D2D memcpy into a 1-node CUDA Graph exec per ring slot.
+
+ Mirrors CUDAIPCExporter._build_export_graphs() on the Python side.
+ Captures only the memcpy_async (IPC events / external waits cannot be
+ captured in global mode). On failure the stream is restored to normal
+ mode and self._graphs_disabled is set so the legacy stream path is used.
+ """
+ if self.cuda is None or self.ipc_stream is None:
+ return
+
+ placeholder_src = self.dev_ptrs[0]
+
+ for slot in range(self.num_slots):
+ capture_started = False
+ try:
+ self.cuda.stream_begin_capture(self.ipc_stream, mode=0)
+ capture_started = True
+ self.cuda.memcpy_async(
+ dst=self.dev_ptrs[slot],
+ src=placeholder_src,
+ count=self.data_size,
+ kind=3, # cudaMemcpyDeviceToDevice
+ stream=self.ipc_stream,
+ )
+ template_graph = self.cuda.stream_end_capture(self.ipc_stream)
+ capture_started = False
+
+ nodes = self.cuda.graph_get_nodes(template_graph)
+ if len(nodes) != 1:
+ self.cuda.graph_destroy(template_graph)
+ raise RuntimeError(f"Unexpected graph node count {len(nodes)} (expected 1: MemcpyNode).")
+ memcpy_node = nodes[0]
+
+ graph_exec = self.cuda.graph_instantiate(template_graph)
+ # Keep template alive so the captured node handle stays valid for
+ # the per-frame cudaGraphExecMemcpyNodeSetParams1D updates.
+ self._graph_execs[slot] = graph_exec
+ self._graph_templates[slot] = template_graph
+ self._graph_memcpy_nodes[slot] = memcpy_node
+ self._log(f"Built export graph for slot {slot} (1-node: Memcpy)")
+
+ except (RuntimeError, OSError) as exc:
+ if capture_started:
+ try:
+ abandoned_graph = self.cuda.stream_end_capture(self.ipc_stream)
+ self.cuda.graph_destroy(abandoned_graph)
+ except (RuntimeError, OSError):
+ pass
+ self._log(
+ f"CUDA Graph build failed for slot {slot} ({exc}) β disabling graphs "
+ "for this session and falling back to legacy stream path. "
+ "Set CUDALINK_TD_USE_GRAPHS=0 to suppress.",
+ force=True,
+ )
+ self._graphs_disabled = True
+ self._destroy_export_graphs()
+ return
+
+ self._log(
+ f"CUDA export graphs built for {self.num_slots} slots (CUDALINK_TD_USE_GRAPHS=1)",
+ force=True,
+ )
+
+ def _destroy_export_graphs(self) -> None:
+ """Destroy all CUDA Graph exec objects and their templates."""
+ if self.cuda is None:
+ return
+ for slot, graph_exec in enumerate(getattr(self, "_graph_execs", [])):
+ if graph_exec is not None:
+ try:
+ self.cuda.graph_exec_destroy(graph_exec)
+ except (RuntimeError, OSError) as e:
+ self._log(f"Error destroying graph exec slot {slot}: {e}", force=True)
+ self._graph_execs[slot] = None
+ for slot, template in enumerate(getattr(self, "_graph_templates", [])):
+ if template is not None:
+ with contextlib.suppress(RuntimeError, OSError):
+ self.cuda.graph_destroy(template)
+ self._graph_templates[slot] = None
+ if hasattr(self, "_graph_memcpy_nodes"):
+ self._graph_memcpy_nodes = [None] * self.num_slots
+
+ def _write_handle_to_shm(self) -> None:
+ """Write magic + version + num_slots + write_idx + all IPC handles to SharedMemory.
+
+ Layout (20 + NUM_SLOTS*192 + 1 bytes):
+ [0-3] magic (4B) - protocol validation "CIPD"
+ [4-11] version (8B)
+ [12-15] num_slots (4B)
+ [16-19] write_idx (4B)
+
+ For each slot (128 bytes per slot):
+ [20+slot*128 : 84+slot*128] mem_handle (64B)
+ [84+slot*128 : 148+slot*128] event_handle (64B)
+
+ [20+NUM_SLOTS*128] shutdown flag (1B)
+ """
+ if self.shm_handle is None or not all(self.ipc_handles):
+ return
+
+ self._layout = SHMLayout(self.num_slots)
+
+ # Write protocol header: magic, bump version, num_slots, reset write_idx
+ _ST_U32.pack_into(self.shm_handle.buf, MAGIC_OFFSET, PROTOCOL_MAGIC)
+ new_version = bump_version(self.shm_handle.buf)
+ _ST_U32.pack_into(self.shm_handle.buf, NUM_SLOTS_OFFSET, self.num_slots)
+ _ST_U32.pack_into(self.shm_handle.buf, WRITE_IDX_OFFSET, 0) # write_idx=0 initially
+
+ # Write handles for each slot
+ for slot in range(self.num_slots):
+ base_offset = self._layout.slot_offset(slot)
+
+ # Write memory handle (64 bytes)
+ mem_handle_bytes = bytes(self.ipc_handles[slot].internal)
+ self.shm_handle.buf[base_offset : base_offset + 64] = mem_handle_bytes
+
+ # Write event handle (64 bytes) if available
+ if self.ipc_event_handles[slot]:
+ event_handle_bytes = bytes(self.ipc_event_handles[slot].reserved)
+ self.shm_handle.buf[base_offset + 64 : base_offset + 128] = event_handle_bytes
+ self._log(f"Wrote slot {slot} handles: mem={len(mem_handle_bytes)}B, event={len(event_handle_bytes)}B")
+ else:
+ self._log(f"Wrote slot {slot} mem handle: {len(mem_handle_bytes)}B")
+
+ # Clear shutdown flag β matches CUDAIPCExporter._write_handles_to_shm() on the Python side.
+ # Without this, a stale shutdown_flag=1 from a previous session (or a race where another
+ # sender initialised after this one) would block the receiver indefinitely.
+ self.shm_handle.buf[self._layout.shutdown_offset] = 0
+
+ self._log(
+ f"Wrote all IPC handles v{new_version} to SharedMemory ({SHM_HEADER_SIZE + self.num_slots * SLOT_SIZE + SHUTDOWN_FLAG_SIZE + METADATA_SIZE + TIMESTAMP_SIZE} bytes total)",
+ force=True,
+ )
+
+ def _write_metadata_to_shm(self) -> None:
+ """Write texture metadata to the extended protocol region after shutdown flag.
+
+ Extended protocol layout (20 bytes):
+ [+0 : 4B] width (uint32 LE)
+ [+4 : 4B] height (uint32 LE)
+ [+8 : 4B] num_comps (uint32 LE)
+ [+12 : 1B] format_kind (uint8) β cudaChannelFormatKind: 0=Signed,1=Unsigned,2=Float
+ [+13 : 1B] bits_per_comp (uint8) β 8/16/32/64
+ [+14 : 2B] flags (uint16 LE) β bit0=bfloat16; rest reserved=0
+ [+16 : 4B] data_size (uint32 LE) β actual bytes (before 2MiB alignment)
+ """
+ if self.shm_handle is None or self.data_size == 0:
+ return
+
+ # Encode format as (kind, bits, flags) β self-describing, receiver-compatible.
+ # Primary source: _detected_numpy_dtype from cuda_mem.data_type (authoritative).
+ # The GPU allocation size (self.data_size) may be padded or reflect the previous
+ # format when dtype changes with a constant allocation, so it must not drive
+ # kind/bits alone. Ratio-based fallback is used only when dtype is unavailable.
+ pixel_count = self.width * self.height * self.channels if (self.width and self.height and self.channels) else 0
+
+ flags = 0
+ # Fallback: derive bits/kind from GPU allocation ratio.
+ _ratio_bits = (
+ self.data_size // pixel_count * 8 if pixel_count > 0 and self.data_size % pixel_count == 0 else 32
+ )
+ bits = _ratio_bits
+ kind = FORMAT_KIND_UNSIGNED if bits == 8 else FORMAT_KIND_FLOAT
+
+ # Override with authoritative dtype hint when cuda_mem.data_type was reported.
+ _hint = self._detected_numpy_dtype
+ if _hint is not None:
+ try:
+ import numpy as _np
+
+ _hint = _np.dtype(_hint)
+ if _hint == _np.dtype("uint8"):
+ bits, kind = 8, FORMAT_KIND_UNSIGNED
+ elif _hint == _np.dtype("uint16"):
+ bits, kind = 16, FORMAT_KIND_UNSIGNED
+ elif _hint == _np.dtype("float16"):
+ bits, kind = 16, FORMAT_KIND_FLOAT
+ elif _hint == _np.dtype("float64"):
+ bits, kind = 64, FORMAT_KIND_FLOAT
+ else: # float32 and any future dtype
+ bits, kind = 32, FORMAT_KIND_FLOAT
+ except Exception: # noqa: BLE001
+ pass # keep ratio-derived fallback
+
+ # Use active-region size (W*H*C*(bits/8)) as the metadata data_size so
+ # the receiver invariant W*H*C*(bits/8)==data_size is always satisfied.
+ # self.data_size is the GPU allocation (may be padded/stale when dims change
+ # with a constant allocation), so it must not flow directly into the metadata
+ # field that the receiver validates. Python-side exporter does the same.
+ meta_data_size = pixel_count * (bits // 8)
+ Metadata(
+ width=self.width,
+ height=self.height,
+ num_comps=self.channels,
+ format_kind=kind,
+ bits_per_comp=bits,
+ flags=flags,
+ data_size=meta_data_size,
+ ).pack_into(self.shm_handle.buf, self._layout)
+
+ # Track last written dtype for change detection
+ self._last_numpy_dtype = self._detected_numpy_dtype
+
+ self._log(
+ f"Wrote metadata: {self.width}x{self.height}x{self.channels}, "
+ f"kind={kind} bits={bits} flags=0x{flags:04x}, size={meta_data_size}B"
+ )
+
+ def _has_dtype_changed(self) -> bool:
+ """Check if detected numpy dtype differs from last written metadata.
+
+ Both attributes are pre-initialized to None in __init__ and set as numpy.dtype
+ objects (from cuda_mem.shape.dataType / _write_metadata_to_shm), so direct
+ comparison is safe β no per-frame np.dtype() construction needed.
+ """
+ if self._detected_numpy_dtype is None or self._last_numpy_dtype is None:
+ return False # Not yet detected or not yet written
+ return self._detected_numpy_dtype != self._last_numpy_dtype
+
+ def _bump_version(self) -> None:
+ """Increment SharedMemory version counter to signal consumers to re-read metadata."""
+ if self.shm_handle is None:
+ return
+ new_version = bump_version(self.shm_handle.buf)
+ self._log(f"Version bumped to {new_version} (metadata-only change)")
+
+ def export_frame(self, top_op: TOP | None = None) -> bool:
+ """Export the ExportBuffer TOP texture via CUDA IPC.
+
+ Resolves ExportBuffer internally from ownerComp so the correct frame
+ is always exported regardless of what the caller previously passed.
+
+ Args:
+ top_op: Deprecated. Accepted for backwards compatibility but ignored.
+ ExportBuffer is always resolved from ownerComp internally.
+
+ Returns:
+ True if export successful, False otherwise
+ """
+ top_op = self._export_buffer
+ if top_op is None or not top_op.is_valid():
+ self._export_buffer = None # invalidate stale cache
+ # Lazy lookup: op may have been added after initialize() (e.g. dynamic network edits)
+ top_op = self._host.find_top(_EXPORT_BUFFER_NAME)
+ if top_op is None:
+ self._log(f"'{_EXPORT_BUFFER_NAME}' not found in component", force=True)
+ return False
+ self._export_buffer = top_op # cache for subsequent frames
+
+ # Check if Active parameter is enabled (hot path via TDHost.is_active())
+ if not self._host.is_active():
+ return False
+
+ # Start frame timer (only if verbose)
+ if self.verbose_performance:
+ frame_start = time.perf_counter()
+ if self._export_profile:
+ _t_pre = frame_start
+ # initialize per-frame profile locals so unaccounted calc is always defined
+ _this_pre = _this_post = _this_sync = _this_sticky = _this_fp = _this_shm = 0.0
+ # record_event_time is only set in the ipc_events path; init here for the fallback
+ record_event_time = 0.0
+
+ _nvtx_push(f"cudalink.sender.export_frame.slot{self.write_idx % self.num_slots}", "green")
+ try:
+ # Ensure CUDA runtime and stream exist BEFORE first cudaMemory() call.
+ # Always use a non-blocking stream (never None/default stream) for TD 2025 compat.
+ if self.cuda is None:
+ self.cuda = get_cuda_runtime(device=self.device)
+ if self.ipc_stream is None:
+ # Honour CUDALINK_TD_STREAM_PRIO in the pre-init lazy path too (mirror of init).
+ if self._stream_high_prio:
+ self.ipc_stream = self.cuda.create_stream_with_priority(flags=0x01)
+ self._log(
+ f"Created IPC stream (pre-init, high-priority): 0x{int(self.ipc_stream.value):016x}",
+ force=True,
+ )
+ else:
+ self.ipc_stream = self.cuda.create_stream(flags=0x01)
+ self._log(
+ f"Created IPC stream (pre-init, normal-priority): 0x{int(self.ipc_stream.value):016x}",
+ force=True,
+ )
+
+ # Block transfer when the source pixel format is unsupported by cudaMemory().
+ # Probe (verification/results/cuda_memory_probe_20260510_090919.json) confirmed
+ # 6 formats fail: all 4 float16 variants (hard exception), 10:10:10:2 (hard
+ # exception), 11:11:10 (succeeds but returns raw uint8/4ch β silent corruption).
+ # Tint the COMP yellow every bad frame (idempotent; keeps tint alive); log once.
+ if self._is_unsupported_format(top_op):
+ src_fmt = (
+ top_op.pixel_format if hasattr(top_op, "pixel_format") else getattr(top_op, "pixelFormat", "?")
+ )
+ warn_msg = f"unsupported pixel format {src_fmt!r}"
+ self._host.set_warning_status(warn_msg)
+ if not self._warned_format:
+ self._log(
+ f"Pixel format {src_fmt!r} unsupported by cudaMemory() β "
+ "transfer suspended; component tinted yellow",
+ force=True,
+ )
+ self._warned_format = True
+ return False
+ if self._warned_format:
+ self._host.clear_status()
+ self._log("Source pixel format now supported β transfer resumed", force=True)
+ self._warned_format = False
+
+ # Time cudaMemory() call (OpenGLβCUDA interop)
+ if self.verbose_performance:
+ if self._export_profile:
+ _this_pre = (time.perf_counter() - _t_pre) * 1_000_000
+ self.total_pre_interop_us += _this_pre
+ cuda_mem_start = time.perf_counter()
+
+ # Get TOP's CUDA memory β always pass a valid stream (never None)
+ try:
+ cuda_mem = top_op.cuda_memory(stream=int(self.ipc_stream.value))
+ except Exception as cuda_err:
+ pixel_fmt = (
+ top_op.pixel_format
+ if hasattr(top_op, "pixel_format")
+ else getattr(top_op, "pixelFormat", "unknown")
+ )
+ err_msg = f"cudaMemory() failed (pixelFormat={pixel_fmt}): {cuda_err}"
+ if err_msg != self._last_cuda_mem_err:
+ self._log(err_msg, force=True)
+ self._last_cuda_mem_err = err_msg
+ return False
+
+ if self.verbose_performance:
+ cuda_mem_time = (time.perf_counter() - cuda_mem_start) * 1_000_000 # microseconds
+ self.total_cuda_memory_time += cuda_mem_time
+ if self._export_profile:
+ _t_post = time.perf_counter()
+
+ # Reset error suppression on success
+ if self._last_cuda_mem_err:
+ self._log("cudaMemory() recovered.", force=True)
+ self._last_cuda_mem_err = ""
+
+ if cuda_mem is None:
+ self._log(f"Failed to get CUDA memory from {top_op}", force=True)
+ return False
+
+ # CRITICAL: Keep reference to prevent garbage collection
+ self.cuda_mem_ref = cuda_mem
+
+ # CUDAMemoryRef fields are plain Python ints β direct access, no shape indirection
+ actual_width = cuda_mem.width
+ actual_height = cuda_mem.height
+ actual_channels = cuda_mem.channels
+ actual_size = cuda_mem.size
+ self._detected_numpy_dtype = cuda_mem.data_type # numpy.dtype or None
+ _dt = self._detected_numpy_dtype
+ _dtype_str = (
+ getattr(_dt, "name", None) or getattr(_dt, "__name__", str(_dt)) if _dt is not None else "unknown"
+ )
+ self._host.set_info_status(f"{actual_width}x{actual_height} {_dtype_str} {actual_channels}ch")
+
+ # Check if we need to (re)initialize
+ if not self._initialized or actual_size != self.data_size:
+ if self._initialized:
+ self._log(
+ f"Resolution changed: {self.width}x{self.height}x{self.channels} -> {actual_width}x{actual_height}x{actual_channels}",
+ force=True,
+ )
+ # Queue old resources for deferred free (cudaFree blocks on IPC memory)
+ self._pending_free_ptrs.extend([p for p in self.dev_ptrs if p])
+ self._pending_free_events.extend([e for e in self.ipc_events if e])
+ self.dev_ptrs = [None] * self.num_slots
+ self.ipc_events = [None] * self.num_slots
+ self.ipc_handles = [None] * self.num_slots
+ self.ipc_event_handles = [None] * self.num_slots
+ self._initialized = False
+ # Schedule deferred free after 30 frames (receiver needs time to close handles)
+ self._deferred_free_at_frame = self.frame_count + 30
+
+ if not self.initialize(actual_width, actual_height, actual_channels, actual_size):
+ return False
+ # Metadata already written by initialize()
+
+ elif (
+ actual_width != self.width
+ or actual_height != self.height
+ or actual_channels != self.channels
+ or self._has_dtype_changed()
+ ):
+ # Metadata-only update: buffer size unchanged so GPU handles stay valid.
+ # Rewrite the 20-byte metadata region and bump version to signal consumers.
+ self.width = actual_width
+ self.height = actual_height
+ self.channels = actual_channels
+ self._write_metadata_to_shm()
+ self._bump_version()
+ self._log(
+ "Metadata changed (dtype/dimensions) without size change β updated in-place",
+ force=True,
+ )
+
+ # Calculate current slot for ring buffer rotation
+ slot = self.write_idx % self.num_slots
+
+ # Time cudaMemcpyAsync D2D (non-blocking) - only if verbose
+ if self.verbose_performance:
+ if self._export_profile:
+ _this_post = (time.perf_counter() - _t_post) * 1_000_000
+ self.total_post_interop_us += _this_post
+ memcpy_start = time.perf_counter()
+ # Record GPU timing start event (actual GPU time measurement)
+ if self._timing_start:
+ self.cuda.record_event(self._timing_start, stream=self.ipc_stream)
+
+ # Deferred graph build (CUDALINK_TD_GRAPHS_DEFERRED=1): fires once after 30
+ # steady-state frames so the capture burst doesn't overlap Sender-B's cold activation.
+ if self._graphs_pending and self.frame_count >= 30:
+ self._build_export_graphs()
+ self._graphs_pending = False
+
+ # Copy TOP texture to this slot's persistent buffer (async on IPC stream).
+ # When CUDALINK_TD_USE_GRAPHS=1 and the per-slot graph exec is built, replay
+ # a 1-node CUDA Graph (MemcpyNode) instead of the imperative memcpy_async β
+ # this collapses the kernel-mode submission into a single cudaGraphLaunch.
+ # Falls back automatically (and permanently for this instance) if launch fails.
+ if self._use_graphs and not self._graphs_disabled and self._graph_execs[slot] is not None:
+ try:
+ self.cuda.graph_exec_memcpy_node_set_params_1d(
+ self._graph_execs[slot],
+ self._graph_memcpy_nodes[slot],
+ dst=self.dev_ptrs[slot],
+ src=c_void_p(cuda_mem.ptr),
+ count=self.data_size,
+ kind=3, # cudaMemcpyDeviceToDevice
+ )
+ self.cuda.graph_launch(self._graph_execs[slot], self.ipc_stream)
+ except (RuntimeError, OSError) as _graph_err:
+ self._log(
+ f"Graph launch failed ({_graph_err}) β disabling graphs, "
+ "falling back to legacy memcpy_async this frame",
+ force=True,
+ )
+ self._graphs_disabled = True
+ self.cuda.memcpy_async(
+ dst=self.dev_ptrs[slot],
+ src=c_void_p(cuda_mem.ptr),
+ count=self.data_size,
+ kind=3,
+ stream=self.ipc_stream,
+ )
+ else:
+ with _nvtx_verbose("cudalink.sender.memcpy", "green"):
+ self.cuda.memcpy_async(
+ dst=self.dev_ptrs[slot],
+ src=c_void_p(cuda_mem.ptr),
+ count=self.data_size,
+ kind=3, # cudaMemcpyDeviceToDevice
+ stream=self.ipc_stream,
+ )
+
+ if self.verbose_performance:
+ # Record GPU timing end event (actual GPU time measurement)
+ if self._timing_end:
+ self.cuda.record_event(self._timing_end, stream=self.ipc_stream)
+ memcpy_time = (
+ time.perf_counter() - memcpy_start
+ ) * 1_000_000 # microseconds (enqueue time only, copy is async)
+ self.total_memcpy_time += memcpy_time
+
+ # GPU-side synchronization with CUDA IPC Events
+ if self.ipc_events[slot]:
+ if self.verbose_performance:
+ record_start = time.perf_counter()
+
+ # Record event for this slot after async memcpy (stream-ordered)
+ with _nvtx_verbose("cudalink.sender.record_event", "green"):
+ self.cuda.record_event(self.ipc_events[slot], stream=self.ipc_stream)
+
+ if self.verbose_performance:
+ record_event_time = (time.perf_counter() - record_start) * 1_000_000
+ self.total_record_event_time += record_event_time
+
+ # CUDALINK_EXPORT_SYNC=1: CPU-blocks on ipc_stream after record_event.
+ # Default is now "0" (receiver cudaStreamWaitEvent guarantees correctness).
+ # Enable for regression testing or if downstream consumers rely on CPU-timing.
+ if self._export_sync:
+ if self.verbose_performance and self._export_profile:
+ _t_sync = time.perf_counter()
+ self.cuda.stream_synchronize(self.ipc_stream)
+ if self.verbose_performance and self._export_profile:
+ _this_sync = (time.perf_counter() - _t_sync) * 1_000_000
+ self.total_sync_us += _this_sync
+
+ if self.verbose_performance and self._export_profile:
+ _t_sticky = time.perf_counter()
+ self.cuda.check_sticky_error("export_frame")
+ if self.verbose_performance and self._export_profile:
+ _this_sticky = (time.perf_counter() - _t_sticky) * 1_000_000
+ self.total_sticky_check_us += _this_sticky
+
+ # WDDM deferred-submission probe: forces pending GPU work to submit without
+ # blocking. Per CUDA Handbook p3/pg56, WDDM buffers commands until a flush;
+ # cudaStreamQuery triggers that flush. Only active when EXPORT_FLUSH_PROBE=1
+ # and EXPORT_SYNC=0 (if sync is on, the stream is already flushed above).
+ if self._export_flush_probe and not self._export_sync:
+ if self.verbose_performance and self._export_profile:
+ _t_fp = time.perf_counter()
+ self.cuda.stream_query(self.ipc_stream)
+ if self.verbose_performance and self._export_profile:
+ _this_fp = (time.perf_counter() - _t_fp) * 1_000_000
+ self.total_flush_probe_us += _this_fp
+
+ else:
+ # FALLBACK: Conditional CPU synchronization
+ if self.frame_count % self.sync_interval == 0:
+ self.cuda.synchronize()
+
+ # Publish: timestamp + clear shutdown_flag + fence + write_idx β in that order.
+ # publish_frame() encodes the C3 ordering guarantee; do not replicate inline.
+ if self.verbose_performance and self._export_profile:
+ _t_shm = time.perf_counter()
+ self.write_idx += 1
+ publish_frame(self.shm_handle.buf, self._layout, self.write_idx, time.perf_counter())
+ if self.verbose_performance and self._export_profile:
+ _this_shm = (time.perf_counter() - _t_shm) * 1_000_000
+ self.total_shm_publish_us += _this_shm
+
+ # Frame tracking
+ self.frame_count += 1
+
+ # Barrier settle countdown: release the cross-process activation barrier
+ # after settle_frames successful exports have elapsed post-init.
+ self._barrier.tick_and_maybe_release(os.getpid(), log_fn=self._log)
+
+ # Calculate total frame time (only if verbose)
+ if self.verbose_performance:
+ frame_time = (time.perf_counter() - frame_start) * 1_000_000
+ self.total_export_time += frame_time
+ if self._export_profile:
+ _this_accounted = (
+ _this_pre
+ + cuda_mem_time
+ + _this_post
+ + memcpy_time
+ + record_event_time
+ + _this_sync
+ + _this_sticky
+ + _this_fp
+ + _this_shm
+ )
+ self.total_unaccounted_us += frame_time - _this_accounted
+
+ # Detailed first-frame diagnostic (one-time, not affected by 100-frame interval)
+ if self.verbose_performance and self.frame_count == 1:
+ self._log(
+ f"FIRST FRAME: cudaMemory={cuda_mem_time:.1f}us, "
+ f"memcpy={memcpy_time:.1f}us, total={frame_time:.1f}us, "
+ f"res={actual_width}x{actual_height}, size={actual_size / (1024 * 1024):.1f}MB",
+ force=True,
+ )
+
+ # Log performance metrics every 97 frames (prime β avoids aliasing with slot counts 2,4,5)
+ if self.verbose_performance and self.frame_count % 97 == 0:
+ avg_memcpy = self.total_memcpy_time / self.frame_count
+ avg_record = self.total_record_event_time / self.frame_count if all(self.ipc_events) else 0
+ avg_total = self.total_export_time / self.frame_count
+ avg_cuda_mem = self.total_cuda_memory_time / self.frame_count
+ sync_mode = (
+ f"GPU-Events[{self.num_slots}]" if all(self.ipc_events) else f"CPU-Sync(1/{self.sync_interval})"
+ )
+
+ graphs_label = "ON" if self._use_graphs and not self._graphs_disabled else "OFF"
+ log_msg = (
+ f"Frame {self.frame_count}: slot {slot}, "
+ f"avg cudaMemory={avg_cuda_mem:.1f}us, "
+ f"avg memcpy={avg_memcpy:.1f}us, record={avg_record:.1f}us, "
+ f"total={avg_total:.1f}us, mode={sync_mode}, graphs={graphs_label}"
+ )
+
+ # Add GPU elapsed time if timing events available
+ if self._timing_start and self._timing_end:
+ try:
+ # Wait for timing events to complete before reading (prevents error 600)
+ self.cuda.wait_event(self._timing_end)
+ gpu_memcpy_ms = self.cuda.event_elapsed_time(self._timing_start, self._timing_end)
+ log_msg += f", GPU memcpy={gpu_memcpy_ms * 1000:.1f}us (actual GPU time)"
+ except RuntimeError as e:
+ # Rare: event wait/query failed
+ log_msg += f", GPU timing: {e}"
+
+ if self._nvml_observer is not None:
+ snap = self._nvml_observer.snapshot()
+ if snap.get("nvml_available"):
+ log_msg += (
+ f" | [NVML] gpu={snap.get('gpu_util_pct', '?')}%"
+ f" mem={snap.get('mem_bw_util_pct', '?')}%"
+ f" sm={snap.get('sm_clock_mhz', '?')}MHz"
+ f" pcie_tx={snap.get('pcie_tx_kbps', '?')}kbps"
+ f" pcie_rx={snap.get('pcie_rx_kbps', '?')}kbps"
+ f" temp={snap.get('temp_c', '?')}C"
+ f" power={snap.get('power_w', '?')}W"
+ )
+ reasons = snap.get("throttle_reasons") or []
+ if reasons:
+ log_msg += f" throttle={','.join(reasons)}"
+
+ if self._export_profile:
+ avg_pre = self.total_pre_interop_us / self.frame_count
+ avg_post = self.total_post_interop_us / self.frame_count
+ avg_sync = self.total_sync_us / self.frame_count
+ avg_sticky = self.total_sticky_check_us / self.frame_count
+ avg_fp = self.total_flush_probe_us / self.frame_count
+ avg_shm = self.total_shm_publish_us / self.frame_count
+ avg_unacc = self.total_unaccounted_us / self.frame_count
+ log_msg += (
+ f" | [PROFILE] pre={avg_pre:.1f}us"
+ f" interop={avg_cuda_mem:.1f}us"
+ f" post={avg_post:.1f}us"
+ f" memcpy={avg_memcpy:.1f}us"
+ f" record={avg_record:.1f}us"
+ f" sync={avg_sync:.1f}us"
+ f" sticky={avg_sticky:.1f}us"
+ f" flush_probe={avg_fp:.1f}us"
+ f" shm={avg_shm:.1f}us"
+ f" unacc={avg_unacc:.1f}us"
+ )
+
+ self._log(log_msg, force=False)
+
+ return True
+
+ except (OSError, RuntimeError, AttributeError) as e:
+ self._log(f"Export failed: {e}", force=True)
+
+ traceback.print_exc()
+ return False
+ finally:
+ _nvtx_pop()
+
+ def _check_deferred_cleanup(self) -> None:
+ """Execute deferred GPU cleanup if scheduled and enough frames have passed.
+
+ Lightweight check meant to be called from onFrameStart for minimal overhead.
+ """
+ if self._pending_free_ptrs and self.frame_count >= self._deferred_free_at_frame:
+ self._deferred_free()
+
+ def _deferred_free(self) -> None:
+ """Free GPU resources queued from export_frame() when deferred frame threshold is reached.
+
+ Called via _check_deferred_cleanup() after receiver has had time to close IPC handles.
+ """
+ if self.cuda is None:
+ return
+
+ freed_count = 0
+ for ptr in self._pending_free_ptrs:
+ try:
+ self.cuda.free(ptr)
+ freed_count += 1
+ except (RuntimeError, OSError) as e:
+ self._log(f"Deferred free failed: {e}")
+ self._pending_free_ptrs.clear()
+
+ for event in self._pending_free_events:
+ try:
+ self.cuda.destroy_event(event)
+ except (RuntimeError, OSError) as e:
+ self._log(f"Deferred event destroy failed: {e}")
+ self._pending_free_events.clear()
+
+ if freed_count > 0:
+ self._log(
+ f"Deferred cleanup complete: freed {freed_count} GPU buffers",
+ force=True,
+ )
+
+ def _is_cuda_context_valid(self) -> bool:
+ """Check if CUDA context is still valid (TD may destroy it before __delTD__)."""
+ if self.cuda is None:
+ return False
+ try:
+ self.cuda.cudart.cudaGetLastError()
+ return True
+ except (OSError, RuntimeError):
+ return False
+
+ def cleanup(self) -> None:
+ """Cleanup Sender CUDA IPC resources (all ring buffer slots).
+
+ CRITICAL ORDER: Signal shutdown FIRST, then free GPU resources.
+ cudaFree() blocks until all processes close IPC handles.
+ """
+ # Release activation barrier if still held (mid-settle cleanup path).
+ self._barrier.force_release(os.getpid(), log_fn=self._log)
+ self._barrier.close()
+
+ # Skip if already cleaned up (prevents double-cleanup from Active toggle + __delTD__)
+ if not self._initialized and self.shm_handle is None:
+ return
+
+ if self._nvml_observer is not None:
+ self._nvml_observer.stop()
+ self._nvml_observer = None
+
+ cuda_valid = self._is_cuda_context_valid()
+ if not cuda_valid:
+ self._log("CUDA context already destroyed β skipping GPU cleanup", force=True)
+
+ # Signal shutdown to consumer (before closing SharedMemory)
+ if self.shm_handle and self.shm_handle.buf is not None:
+ try:
+ shutdown_offset = SHM_HEADER_SIZE + (self.num_slots * SLOT_SIZE)
+ self.shm_handle.buf[shutdown_offset] = 1
+ self._log("Shutdown signal sent to consumer", force=True)
+ except (OSError, BufferError) as e:
+ self._log(f"Warning: Could not write shutdown signal: {e}", force=True)
+
+ # Zero out IPC handle bytes so any reader sees invalid handles.
+ # On Windows, unlink() is a no-op (SharedMemory uses CreateFileMapping kernel
+ # objects), so the SharedMemory may persist with stale non-zero handles that
+ # pass the all-zero validation check. Zeroing them prevents error 201 when a
+ # new Receiver reads before the SHM is destroyed or overwritten by a new producer.
+ if self.shm_handle and self.shm_handle.buf is not None:
+ try:
+ for slot in range(self.num_slots):
+ base_offset = SHM_HEADER_SIZE + (slot * SLOT_SIZE)
+ self.shm_handle.buf[base_offset : base_offset + SLOT_SIZE] = b"\x00" * SLOT_SIZE
+ self._log("Zeroed IPC handle bytes in SharedMemory", force=True)
+ except (OSError, BufferError) as e:
+ self._log(f"Warning: Could not zero IPC handles: {e}", force=True)
+
+ # Destroy CUDA Graph execs first β they hold references into the IPC stream
+ # and (transitively) the ring-buffer pointers, so they must be torn down before
+ # the events/stream/buffers below.
+ if cuda_valid and getattr(self, "_use_graphs", False):
+ self._destroy_export_graphs()
+
+ # Destroy IPC events (sender-side resources, safe to destroy)
+ if cuda_valid and hasattr(self, "ipc_events") and self.cuda:
+ for slot, event in enumerate(self.ipc_events):
+ if event:
+ try:
+ self.cuda.destroy_event(event)
+ self._log(f"Destroyed IPC event slot {slot}", force=True)
+ except (RuntimeError, OSError) as e:
+ self._log(f"Error destroying event slot {slot}: {e}", force=True)
+
+ # Destroy GPU timing events (benchmarking resources)
+ if cuda_valid and self.cuda:
+ if hasattr(self, "_timing_start") and self._timing_start:
+ try:
+ self.cuda.destroy_event(self._timing_start)
+ self._log("Destroyed GPU timing start event", force=False)
+ except (RuntimeError, OSError) as e:
+ self._log(f"Error destroying timing start event: {e}", force=True)
+ finally:
+ self._timing_start = None
+ if hasattr(self, "_timing_end") and self._timing_end:
+ try:
+ self.cuda.destroy_event(self._timing_end)
+ self._log("Destroyed GPU timing end event", force=False)
+ except (RuntimeError, OSError) as e:
+ self._log(f"Error destroying timing end event: {e}", force=True)
+ finally:
+ self._timing_end = None
+
+ # Destroy dedicated IPC stream (set to None to prevent double-free).
+ # CUDALINK_TD_PERSIST_STREAM=1: skip destroy so the stream survives
+ # deactivate/reactivate cycles; initialize() reuses it via the existing
+ # `if self.ipc_stream is None` guard.
+ if cuda_valid and hasattr(self, "ipc_stream") and self.ipc_stream and self.cuda:
+ if self._persist_stream:
+ self._log(
+ f"[PERSIST_STREAM] keeping ipc_stream=0x{int(self.ipc_stream.value):016x} across cleanup",
+ force=True,
+ )
+ else:
+ try:
+ self.cuda.destroy_stream(self.ipc_stream)
+ self._log("Destroyed IPC stream", force=True)
+ self.ipc_stream = None
+ except (RuntimeError, OSError) as e:
+ self._log(f"Error destroying IPC stream: {e}", force=True)
+
+ # Close SharedMemory (but don't unlink yet)
+ if self.shm_handle:
+ try:
+ self.shm_handle.close()
+ self._log("Closed SharedMemory", force=True)
+ except (OSError, BufferError) as e:
+ self._log(f"Error closing SharedMemory: {e}", force=True)
+
+ # Grace period for receiver to close IPC handles
+ if cuda_valid:
+ time.sleep(0.1) # 100ms for receiver to detect shutdown and close handles
+
+ # Free GPU buffers (now safe, receiver has closed IPC handles)
+ if cuda_valid and hasattr(self, "dev_ptrs") and self.cuda:
+ for slot, dev_ptr in enumerate(self.dev_ptrs):
+ if dev_ptr:
+ try:
+ self.cuda.free(dev_ptr)
+ self._log(f"Freed GPU buffer slot {slot}", force=True)
+ except (RuntimeError, OSError) as e:
+ self._log(f"Error freeing GPU buffer slot {slot}: {e}", force=True)
+
+ # Free any pending deferred resources
+ if cuda_valid and hasattr(self, "_pending_free_ptrs"):
+ self._deferred_free()
+
+ if self._warned_format:
+ self._host.clear_status()
+ self._warned_format = False
+
+ # Unlink SharedMemory (sender is owner and should clean up)
+ if hasattr(self, "shm_name"):
+ try:
+ try:
+ shm_temp = SharedMemory(name=self.shm_name)
+ shm_temp.close()
+ shm_temp.unlink()
+ self._log("Unlinked SharedMemory", force=True)
+ except FileNotFoundError:
+ pass # Already unlinked
+ except (OSError, RuntimeError, AttributeError) as e:
+ self._log(f"Warning: Could not unlink SharedMemory: {e}", force=True)
+
+ # Reset state to prevent double-free on re-entry.
+ # Use empty lists β initialize() will resize to current self.num_slots.
+ self.dev_ptrs = []
+ self.ipc_events = []
+ self.ipc_handles = []
+ self.ipc_event_handles = []
+ if not self._persist_stream:
+ self.ipc_stream = None
+ self.shm_handle = None
+ self._warned_format = False
+ self._export_buffer = None
+
+ # Reset per-session counters so averages are accurate after reinit
+ # and slot selection starts from 0 (matching SharedMemory write_idx=0 written on init).
+ self.write_idx = 0
+ self.frame_count = 0
+ self.total_memcpy_time = 0.0
+ self.total_record_event_time = 0.0
+ self.total_export_time = 0.0
+ self.total_cuda_memory_time = 0.0
+
+ self._initialized = False
+ self._log("Sender cleanup complete", force=True)
diff --git a/src/streamdiffusion/_compat/td_exporter/VENDORED_VERSION.txt b/src/streamdiffusion/_compat/td_exporter/VENDORED_VERSION.txt
new file mode 100644
index 000000000..51c527747
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/VENDORED_VERSION.txt
@@ -0,0 +1,12 @@
+version: 1.4.1
+source: F:\RD_PROJECTS\COMPONENTS\cuda-link\td_exporter\
+head_commit: 92989fc
+vendored: 2026-05-17
+
+NOTE: These files use flat-namespace imports (from SHMProtocol import ...) designed
+for TouchDesigner's Python environment where each Text DAT in a component is an
+importable module by its DAT name. They cannot be imported as a regular Python
+package β there is intentionally no __init__.py here.
+
+To deploy into a .tox component: copy each file into a Text DAT whose name matches
+the filename (without .py), e.g. CUDAIPCExtension.py β Text DAT named CUDAIPCExtension.
diff --git a/src/streamdiffusion/_compat/td_exporter/benchmark_timestamp.py b/src/streamdiffusion/_compat/td_exporter/benchmark_timestamp.py
new file mode 100644
index 000000000..f9fae45b4
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/benchmark_timestamp.py
@@ -0,0 +1,74 @@
+"""
+TouchDesigner Execute DAT script for shared timestamp channel.
+
+This script creates a Python SharedMemory segment that both CUDA IPC and
+Shared Mem Out TOP benchmarks can read to measure end-to-end latency.
+
+Setup in TouchDesigner:
+1. Create Execute DAT
+2. Paste this script into the DAT
+3. Set to: DAT Execute β Active = ON
+4. Callbacks: onFrameEnd = ON (all others OFF)
+
+The timestamp channel writes:
+- frame_counter (uint32) - increments each frame
+- timestamp (float64) - time.perf_counter() when frame ends
+
+Both benchmark scripts read this to compute end-to-end latency:
+ consumer_time - producer_time = latency
+"""
+
+import contextlib
+import struct
+import time
+from multiprocessing.shared_memory import SharedMemory
+
+
+# Global state (persists across frame callbacks)
+shm = None
+frame_counter = 0
+
+# SharedMemory name must match benchmark script --timestamp-shm argument
+TIMESTAMP_SHM_NAME = "cuda_ipc_benchmark_ts"
+
+
+def onFrameEnd(frame: int) -> None:
+ """Called after all operators finish cooking each frame.
+
+ This is the ideal timing point for producer timestamps since all
+ TOPs (including Shared Mem Out TOP and CUDA IPC sender) have already
+ written their data.
+ """
+ global shm, frame_counter
+
+ # Create SharedMemory on first frame
+ if shm is None:
+ try:
+ # Try opening existing segment first (consumer may have created it)
+ shm = SharedMemory(name=TIMESTAMP_SHM_NAME)
+ except FileNotFoundError:
+ # Create new segment if it doesn't exist
+ # Size: 4 bytes (uint32) + 8 bytes (float64) = 12 bytes (use 16 for alignment)
+ shm = SharedMemory(name=TIMESTAMP_SHM_NAME, create=True, size=16)
+
+ # Increment frame counter
+ frame_counter += 1
+
+ # Write timestamp: frame_counter (uint32) + timestamp (float64)
+ # Use perf_counter() for high-resolution timing
+ timestamp = time.perf_counter()
+ struct.pack_into(" None:
+ """Called when TD closes or DAT is deleted.
+
+ Cleanup SharedMemory to avoid stale segments.
+ """
+ global shm
+ if shm is not None:
+ # Note: We don't unlink() because consumer may still be reading
+ # The OS will clean up when both processes close the segment
+ with contextlib.suppress(OSError, BufferError):
+ shm.close()
+ shm = None
diff --git a/src/streamdiffusion/_compat/td_exporter/callbacks_template.py b/src/streamdiffusion/_compat/td_exporter/callbacks_template.py
new file mode 100644
index 000000000..b0b68bcf0
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/callbacks_template.py
@@ -0,0 +1,81 @@
+"""
+Execute DAT Callback for CUDAIPCExtension
+
+Copy this into an Execute DAT inside your .tox component.
+Enable "Frame Start", "Frame End", and "On Exit" toggles.
+
+Architecture:
+- Sender: onFrameStart=housekeeping, onFrameEnd=export (avoids 8.8ms GPU wait)
+- Receiver: onFrameStart=force-cook ImportBuffer (triggers Script TOP onCook)
+"""
+
+
+def onFrameStart(frame: int) -> None:
+ """Called at the start of every frame.
+
+ Sender: Lightweight housekeeping (deferred GPU cleanup).
+ Receiver: Force-cook ImportBuffer (triggers Script TOP onCook).
+
+ Args:
+ frame: Current frame number
+ """
+ ext = parent().ext.CUDAIPCExtension
+ if ext is None:
+ return
+
+ if ext.mode == "Sender":
+ # Check if deferred GPU cleanup is scheduled (lightweight, ~0ms normally)
+ ext._check_deferred_cleanup()
+
+ elif ext.mode == "Receiver":
+ import_buffer = op("ImportBuffer")
+ if import_buffer is None:
+ return
+
+ # TD 2025+: modoutsidecook enables copyCUDAMemory from Execute DAT
+ # This eliminates force-cook overhead and fixes resolution delay
+ if hasattr(import_buffer.par, "modoutsidecook") and import_buffer.par.modoutsidecook.eval():
+ # Import frame first: initialize_receiver() sets resolution flag
+ ext.import_frame(import_buffer)
+ # Resolution update after: catches flag set during initialization
+ ext.update_receiver_resolution(import_buffer)
+ else:
+ # TD 2023 fallback: force-cook triggers Script TOP onCook
+ # Resolution update happens inside onCook (1-frame delay for changes)
+ import_buffer.cook(force=True)
+
+
+def onFrameEnd(frame: int) -> None:
+ """Called at the end of every frame.
+
+ Sender: Export frame AFTER cook phase (texture already rendered on GPU).
+ cudaMemory() returns instantly instead of blocking 8.8ms waiting for GPU.
+ Receiver: Nothing (import already happened via Script TOP onCook).
+
+ Args:
+ frame: Current frame number
+ """
+ ext = parent().ext.CUDAIPCExtension
+ if ext is None:
+ return
+
+ if ext.mode == "Sender":
+ ext.export_frame()
+
+
+def onExit() -> None:
+ """Called when TouchDesigner exits or when this DAT is destroyed."""
+ ext = parent().ext.CUDAIPCExtension
+ if ext is not None:
+ ext.cleanup()
+
+
+# Other callback stubs (not used for CUDA IPC, but required by TD)
+def onStart() -> None:
+ """TD required callback - not used."""
+ return
+
+
+def onCreate() -> None:
+ """TD required callback - not used."""
+ return
diff --git a/src/streamdiffusion/_compat/td_exporter/example_sender_launcher.py b/src/streamdiffusion/_compat/td_exporter/example_sender_launcher.py
new file mode 100644
index 000000000..ff00db750
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/example_sender_launcher.py
@@ -0,0 +1,117 @@
+"""
+Execute DAT β CUDA-Link Python β TouchDesigner Launcher
+
+Paste this into an Execute DAT in your example project.
+Enable "Start", "Frame Start", and "On Exit" toggles.
+
+This DAT spawns example_sender_python.py as a separate OS process on project
+start and terminates it on exit. CUDA IPC requires separate processes β sender
+and receiver cannot share GPU handles within the same process.
+
+Pipeline:
+ onStart() β subprocess.Popen(example_sender_python.py)
+ β CUDA IPC (cudalink_output_ipc)
+ CUDAIPCLink_from_Python (Receiver mode, same project)
+ β
+ Script TOP output β cycling solid colors
+
+TD Setup:
+ 1. Add CUDAIPCLink_from_Python component to the network
+ 2. Set Mode β Receiver
+ 3. Set Ipcmemname β cudalink_output_ipc
+ 4. Set Active β ON
+ 5. Paste THIS script into an Execute DAT β enable Start, Frame Start, On Exit
+ 6. Press Play (or reopen the project) to trigger onStart()
+"""
+
+import os
+import signal
+import subprocess
+
+
+_process = None # Sender subprocess handle
+
+
+def onStart() -> None:
+ """Launch the Python sender as a separate subprocess."""
+ global _process
+
+ script = os.path.join(project.folder, "td_exporter", "example_sender_python.py")
+
+ if not os.path.isfile(script):
+ print("[CUDA-Link Launcher] ERROR: sender script not found:")
+ print(f" {script}")
+ return
+
+ _process = subprocess.Popen(
+ ["python", script],
+ # CREATE_NEW_CONSOLE: opens a visible console window for the sender.
+ # CREATE_NEW_PROCESS_GROUP: required to send CTRL_BREAK_EVENT on shutdown
+ # (CTRL_C_EVENT is blocked for new process groups on Windows).
+ creationflags=subprocess.CREATE_NEW_CONSOLE | subprocess.CREATE_NEW_PROCESS_GROUP,
+ )
+ print(f"[CUDA-Link Launcher] Sender subprocess started (PID {_process.pid})")
+ print(f" Script: {script}")
+
+
+def onCreate() -> None:
+ return
+
+
+def onExit() -> None:
+ """Terminate the sender subprocess when the project closes."""
+ global _process
+
+ if _process is None:
+ return
+
+ if _process.poll() is None:
+ pid = _process.pid
+ try:
+ # CTRL_BREAK_EVENT gives the Python sender a chance to run its IPC cleanup
+ # (7-step GPU teardown). CTRL_C_EVENT cannot cross CREATE_NEW_PROCESS_GROUP
+ # boundaries on Windows; CTRL_BREAK_EVENT can.
+ _process.send_signal(signal.CTRL_BREAK_EVENT)
+ _process.wait(timeout=3)
+ print(f"[CUDA-Link Launcher] Sender subprocess exited gracefully (PID {pid}).")
+ except subprocess.TimeoutExpired:
+ _process.terminate()
+ try:
+ _process.wait(timeout=2)
+ print(f"[CUDA-Link Launcher] Sender subprocess terminated (PID {pid}).")
+ except subprocess.TimeoutExpired:
+ _process.kill()
+ print(f"[CUDA-Link Launcher] Sender subprocess force-killed (PID {pid}).")
+ except OSError:
+ _process.kill()
+ print(f"[CUDA-Link Launcher] Sender subprocess force-killed (PID {pid}).")
+
+ _process = None
+
+
+def onFrameStart(frame: int) -> None:
+ """Check if the subprocess is still running; warn if it exited unexpectedly."""
+ if _process is not None and _process.poll() is not None:
+ code = _process.returncode
+ if code != 0:
+ print(f"[CUDA-Link Launcher] WARNING: sender subprocess exited unexpectedly (code={code}).")
+
+
+def onFrameEnd(frame: int) -> None:
+ return
+
+
+def onPlayStateChange(state: bool) -> None:
+ return
+
+
+def onDeviceChange() -> None:
+ return
+
+
+def onProjectPreSave() -> None:
+ return
+
+
+def onProjectPostSave() -> None:
+ return
diff --git a/src/streamdiffusion/_compat/td_exporter/example_sender_python.py b/src/streamdiffusion/_compat/td_exporter/example_sender_python.py
new file mode 100644
index 000000000..a43330b21
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/example_sender_python.py
@@ -0,0 +1,370 @@
+"""
+CUDA-Link Example β Python Sender (subprocess target)
+
+Sends animated solid RGBA color frames to TouchDesigner via CUDA IPC.
+Run as a subprocess launched by example_sender_launcher.py (Execute DAT),
+or directly from the command line:
+
+ python td_exporter/example_sender_python.py
+
+Pipeline: this script (separate OS process)
+ β CUDA IPC (cudalink_output_ipc)
+ CUDAIPCLink_from_Python (Receiver mode, in TouchDesigner)
+ β
+ Script TOP output β cycling solid colors
+
+TD Setup (handled by example_sender_launcher.py Execute DAT):
+ CUDAIPCLink_from_Python β Mode=Receiver, Ipcmemname=cudalink_output_ipc, Active=ON
+"""
+
+from __future__ import annotations
+
+import contextlib
+import ctypes
+import logging
+import os
+import struct
+import sys
+import threading
+import time
+
+
+# When CUDALINK_EXPORT_PROFILE=1 the lib promotes self.debug=True and emits
+# [PROFILE] lines via logger.debug(). Configure the root logger so those
+# messages reach stdout (standard Python logging convention requires the host
+# application to set up handlers; the lib itself cannot do it).
+if os.environ.get("CUDALINK_EXPORT_PROFILE", "0") == "1":
+ logging.basicConfig(level=logging.DEBUG, format="[lib] %(message)s", stream=sys.stdout)
+
+_probe_log_file = os.environ.get("CUDALINK_PROBE_LOG_FILE", "")
+if _probe_log_file:
+ _root_logger = logging.getLogger()
+ if not any(isinstance(h, logging.FileHandler) for h in _root_logger.handlers):
+ _fh = logging.FileHandler(_probe_log_file, mode="w", encoding="utf-8")
+ _fh.setFormatter(logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s"))
+ _root_logger.addHandler(_fh)
+ if _root_logger.level == logging.NOTSET:
+ _root_logger.setLevel(logging.INFO)
+
+# ---------------------------------------------------------------------------
+# Windows console control handler β ensures GPU IPC cleanup runs even when
+# the user closes the console window via the X button (CTRL_CLOSE_EVENT),
+# which does NOT raise KeyboardInterrupt in Python by default.
+# ---------------------------------------------------------------------------
+
+if sys.platform == "win32":
+ from ctypes import wintypes as _wintypes
+
+ CTRL_C_EVENT = 0
+ CTRL_BREAK_EVENT = 1
+ CTRL_CLOSE_EVENT = 2
+ CTRL_LOGOFF_EVENT = 5
+ CTRL_SHUTDOWN_EVENT = 6
+
+ _HandlerRoutine = ctypes.WINFUNCTYPE(_wintypes.BOOL, _wintypes.DWORD)
+
+# Module-level refs so the handler thread can access them regardless of stack.
+_cuda_ref = None
+_exporter_ref = None
+_staging_ptr_ref = None
+_cleaned_up = False
+# Track which event triggered shutdown β controls end-of-main "Press Enter" pause:
+# "ctrl_c" β user pressed Ctrl+C in console β pause (let user read messages).
+# "ctrl_break" β launcher sent CTRL_BREAK_EVENT (graceful .toe close) β no pause.
+# "ctrl_close" β console close / logoff / shutdown (incl. orchestrator-driven
+# taskkill, ncu/nsys captures) β no pause; OS is already exiting.
+# None β main loop exited some other way β pause as a safety net.
+_shutdown_via: str | None = None
+# Set by CTRL_CLOSE_EVENT handler to request the main loop to break and run
+# finally: cleanup from the main thread instead of the handler thread.
+# Avoids the race where the handler freed staging_ptr while main was mid-cudaMemcpy.
+_stop_requested: bool = False
+
+
+def _do_cleanup() -> None:
+ """Idempotent GPU IPC cleanup β safe to call from handler thread and from finally:."""
+ global _cleaned_up
+ if _cleaned_up:
+ return
+ _cleaned_up = True
+
+ # Under ncu kernel-replay the GPU command queue is paused inside ncu's replay
+ # state. cudaFree on the staging buffer implicitly synchronises the device and
+ # blocks until the queue drains β which never happens in that state, causing a
+ # 30+ s hang. Wrap in a daemon thread with a 0.5 s watchdog; same pattern as
+ # cuda_ipc_exporter.cleanup() Step 6. The 1 MB staging buffer is reclaimed by
+ # the OS on process exit, so leaking it here is harmless.
+ if _staging_ptr_ref is not None and _cuda_ref is not None:
+
+ def _free_staging() -> None:
+ try:
+ _cuda_ref.free(_staging_ptr_ref)
+ except Exception as exc:
+ print(f"[sender] cleanup: cuda.free(staging) error: {exc}", flush=True)
+
+ t = threading.Thread(target=_free_staging, daemon=True)
+ t.start()
+ t.join(timeout=0.5)
+ if t.is_alive():
+ print("[sender] cudaFree(staging) timed out β OS will reclaim on process exit", flush=True)
+
+ # Under ncu kernel-replay, Steps 1c/2/3 of exporter.cleanup()
+ # (graph_exec_destroy, destroy_event, destroy_stream) can block on a
+ # paused command queue. Bound total cleanup time so main returns and
+ # ncu finalizes. Same pattern as staging watchdog above and
+ # cuda_ipc_exporter.cleanup() Step 6.
+ if _exporter_ref is not None:
+
+ def _do_exporter_cleanup() -> None:
+ try:
+ _exporter_ref.cleanup()
+ except Exception as exc:
+ print(f"[sender] cleanup: exporter.cleanup error: {exc}", flush=True)
+
+ t = threading.Thread(target=_do_exporter_cleanup, daemon=True)
+ t.start()
+ t.join(timeout=3.0)
+ if t.is_alive():
+ print("[sender] exporter.cleanup() timed out β OS will reclaim resources on process exit", flush=True)
+
+
+if sys.platform == "win32":
+
+ def _ctrl_handler(ctrl_type: int) -> bool:
+ global _shutdown_via
+ if ctrl_type == CTRL_C_EVENT:
+ _shutdown_via = "ctrl_c"
+ print("\n[sender] Ctrl+C β stopping ...", flush=True)
+ return False # Chain to Python's default β raises KeyboardInterrupt in main.
+ if ctrl_type == CTRL_BREAK_EVENT:
+ _shutdown_via = "ctrl_break"
+ print("\n[sender] Ctrl+Break / launcher shutdown β stopping ...", flush=True)
+ return False # Chain to Python's default β raises KeyboardInterrupt in main.
+ if ctrl_type in (CTRL_CLOSE_EVENT, CTRL_LOGOFF_EVENT, CTRL_SHUTDOWN_EVENT):
+ # Console X-button, user logoff, or system shutdown.
+ # OS allows ~5 s (CLOSE) or ~20 s (LOGOFF/SHUTDOWN) before forced termination.
+ # Signal main to break out of the loop; main's finally: runs _do_cleanup()
+ # from the main thread. Calling _do_cleanup() here (handler thread) races
+ # with an in-flight cudaMemcpy in _fill_ctypes β INVALID_VALUE crash.
+ global _stop_requested
+ _shutdown_via = "ctrl_close"
+ _stop_requested = True
+ print(
+ f"\n[sender] Console control event {ctrl_type} (close/logoff/shutdown) β signaling main loop to stop ...",
+ flush=True,
+ )
+ return True # Handled β OS grace period covers main's exit + cleanup.
+ return False
+
+ # The launcher uses CREATE_NEW_PROCESS_GROUP, which DISABLES Ctrl+C delivery to the
+ # child process by default. SetConsoleCtrlHandler(NULL, FALSE) re-enables it before
+ # we install our own handler.
+ ctypes.windll.kernel32.SetConsoleCtrlHandler(None, False)
+
+ # MUST be module-level; a local variable would be GC'd and Windows would call freed memory.
+ _ctrl_handler_ref = _HandlerRoutine(_ctrl_handler)
+ if not ctypes.windll.kernel32.SetConsoleCtrlHandler(_ctrl_handler_ref, True):
+ print("[sender] WARNING: SetConsoleCtrlHandler failed β console-close cleanup unavailable")
+
+# ---------------------------------------------------------------------------
+# Configuration
+# ---------------------------------------------------------------------------
+
+SHM_NAME = "cudalink_output_ipc"
+WIDTH = 512
+HEIGHT = 512
+DTYPE = "uint8" # "uint8" or "float32"
+NUM_SLOTS = 3
+TARGET_FPS = 60.0
+FRAMES_PER_COLOR = 30 # Hold each solid color this many frames
+REPORT_EVERY = 150 # Print status every N frames
+
+
+# ---------------------------------------------------------------------------
+# Color cycle (RGBA uint8)
+# ---------------------------------------------------------------------------
+
+_COLORS = [
+ (255, 0, 0, 255), # Red
+ (0, 255, 0, 255), # Green
+ (0, 0, 255, 255), # Blue
+ (255, 255, 0, 255), # Yellow
+ (0, 255, 255, 255), # Cyan
+ (255, 0, 255, 255), # Magenta
+ (255, 255, 255, 255), # White
+ (64, 64, 64, 255), # Grey
+]
+_COLOR_NAMES = ["Red", "Green", "Blue", "Yellow", "Cyan", "Magenta", "White", "Grey"]
+
+
+# ---------------------------------------------------------------------------
+# GPU fill helpers
+# ---------------------------------------------------------------------------
+
+
+def _fill_ctypes(cuda: object, ptr: object, data_size: int, color: tuple) -> None:
+ """Write a solid RGBA color into a GPU buffer via H2D ctypes copy."""
+ r, g, b, a = color
+ if DTYPE == "uint8":
+ pixel = bytes([int(r), int(g), int(b), int(a)])
+ data = pixel * (data_size // 4)
+ buf = (ctypes.c_uint8 * data_size).from_buffer_copy(data)
+ else: # float32
+ pixel = struct.pack("<4f", r / 255.0, g / 255.0, b / 255.0, a / 255.0)
+ data = pixel * (data_size // 16)
+ buf = (ctypes.c_uint8 * data_size).from_buffer_copy(data)
+
+ cuda.memcpy(
+ dst=ptr,
+ src=ctypes.c_void_p(ctypes.addressof(buf)),
+ count=data_size,
+ kind=1, # cudaMemcpyHostToDevice
+ )
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+
+def main() -> None:
+ global _cuda_ref, _exporter_ref, _staging_ptr_ref
+ # Ensure cuda_link is importable β try src/ relative to this script
+ try:
+ from cuda_link import CUDAIPCExporter
+ from cuda_link.cuda_ipc_wrapper import get_cuda_runtime
+ except ImportError:
+ src_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "src")
+ src_dir = os.path.normpath(src_dir)
+ if src_dir not in sys.path:
+ sys.path.insert(0, src_dir)
+ try:
+ from cuda_link import CUDAIPCExporter
+ from cuda_link.cuda_ipc_wrapper import get_cuda_runtime
+ except ImportError:
+ print(f"[sender] ERROR: cuda_link not found. Searched: {src_dir}")
+ print("[sender] Run: pip install cuda-link (from the project root)")
+ sys.exit(1)
+
+ cuda = get_cuda_runtime()
+ _cuda_ref = cuda
+
+ print("=" * 58)
+ print(" CUDA-Link Example -- Python -> TouchDesigner Sender")
+ print("=" * 58)
+ print(f" channel : {SHM_NAME}")
+ print(f" resolution: {WIDTH}x{HEIGHT} RGBA {DTYPE}")
+ print(f" fps target: {TARGET_FPS}")
+ print()
+ print(" TD: CUDAIPCLink_from_Python Mode=Receiver Active=ON")
+ print()
+
+ exporter = CUDAIPCExporter(
+ shm_name=SHM_NAME,
+ height=HEIGHT,
+ width=WIDTH,
+ channels=4,
+ dtype=DTYPE,
+ num_slots=NUM_SLOTS,
+ debug=False,
+ )
+ _exporter_ref = exporter
+
+ if not exporter.initialize():
+ print("[sender] ERROR: exporter.initialize() failed.")
+ sys.exit(1)
+
+ graphs_active = bool(getattr(exporter, "_use_graphs", False) and not getattr(exporter, "_graphs_disabled", False))
+ graphs_label = "ON" if graphs_active else "OFF"
+ profile_on = os.environ.get("CUDALINK_EXPORT_PROFILE", "0") == "1"
+ env_setting = os.environ.get("CUDALINK_USE_GRAPHS", "(default=1)")
+ try:
+ rt_version = cuda.get_runtime_version()
+ rt_label = f"{rt_version // 1000}.{(rt_version % 1000) // 10}"
+ except Exception:
+ rt_version = 0
+ rt_label = "unknown"
+ print(f"[sender] cudart runtime: {rt_label} ({rt_version})")
+ print(f"[sender] CUDA Graphs path: {graphs_label} (CUDALINK_USE_GRAPHS={env_setting})")
+ if not graphs_active and env_setting in ("1", "(default=1)"):
+ print("[sender] (graphs requested but disabled β see exporter logs for reason)")
+ print("[sender] Initialized β waiting for TD receiver to connect ...\n")
+
+ staging_ptr = cuda.malloc(exporter.data_size)
+ _staging_ptr_ref = staging_ptr
+ frame_interval = 1.0 / TARGET_FPS
+ frame_count = 0
+ start_time = time.perf_counter()
+ last_report = start_time
+
+ try:
+ while not _stop_requested:
+ t0 = time.perf_counter()
+ color_idx = (frame_count // FRAMES_PER_COLOR) % len(_COLORS)
+ color = _COLORS[color_idx]
+
+ _fill_ctypes(cuda, staging_ptr, exporter.data_size, color)
+ exporter.export_frame(
+ gpu_ptr=int(staging_ptr.value),
+ size=exporter.data_size,
+ )
+ frame_count += 1
+
+ now = time.perf_counter()
+ if frame_count % REPORT_EVERY == 0 or (now - last_report) >= 5.0:
+ elapsed = now - start_time
+ fps = frame_count / elapsed if elapsed > 0 else 0.0
+ export_us = (now - t0) * 1e6
+ if profile_on:
+ stats = exporter.get_stats()
+ profile_suffix = (
+ f" | avg_total={stats.get('avg_total_us', 0.0):.1f} Β΅s"
+ f" | avg_memcpy={stats.get('avg_memcpy_us', 0.0):.1f} Β΅s"
+ )
+ else:
+ profile_suffix = ""
+ print(
+ f" Frame {frame_count:5d} | {fps:5.1f} FPS | "
+ f"color={_COLOR_NAMES[color_idx]:<8s} | "
+ f"export={export_us:.0f} Β΅s"
+ f"{profile_suffix} | "
+ f"graphs={graphs_label}"
+ )
+ last_report = now
+
+ remaining = frame_interval - (time.perf_counter() - t0)
+ if remaining > 0:
+ time.sleep(remaining)
+
+ except KeyboardInterrupt:
+ print(f"\n[sender] Stopped after {frame_count} frames.")
+
+ finally:
+ try:
+ final_stats = exporter.get_stats() if profile_on else {}
+ except Exception:
+ final_stats = {}
+ _do_cleanup()
+ total = time.perf_counter() - start_time
+ avg_fps = frame_count / total if total > 0 else 0.0
+ print(f"[sender] Done β {frame_count} frames in {total:.1f}s ({avg_fps:.1f} FPS avg)", flush=True)
+ if final_stats:
+ print(
+ f"[sender] Final stats: graphs={graphs_label} "
+ f"avg_total={final_stats.get('avg_total_us', 0.0):.1f} Β΅s "
+ f"avg_memcpy={final_stats.get('avg_memcpy_us', 0.0):.1f} Β΅s "
+ f"frames={final_stats.get('frame_count', 0)}",
+ flush=True,
+ )
+ print("[sender] TD Receiver will detect shutdown on next cook.", flush=True)
+
+ # Hold the console window open so the user can read the cleanup output β
+ # but ONLY for user-initiated shutdowns. CTRL_BREAK_EVENT is also how the
+ # launcher signals graceful .toe-close, so we skip the pause in that case.
+ if _shutdown_via not in ("ctrl_break", "ctrl_close"):
+ with contextlib.suppress(EOFError, KeyboardInterrupt):
+ input("\n[sender] Press Enter to close this window ...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/streamdiffusion/_compat/td_exporter/parexecute_callbacks.py b/src/streamdiffusion/_compat/td_exporter/parexecute_callbacks.py
new file mode 100644
index 000000000..21e8f09c7
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/parexecute_callbacks.py
@@ -0,0 +1,264 @@
+"""
+Parameter Execute DAT Callback for CUDAIPCExtension
+
+Copy this into a Parameter Execute DAT inside your .tox component.
+Enable the parameters you want to monitor (Active, Ipcmemname, Numslots, Debug, Hidebuiltin, Mode).
+
+Handles parameter changes with debug logging and triggers appropriate re-initialization.
+"""
+
+import contextlib
+
+
+def onValueChange(par: object, prev: object) -> None:
+ """Called when any monitored parameter changes.
+
+ Args:
+ par: The parameter that changed
+ prev: The previous value of the parameter
+ """
+ ext = parent().ext.CUDAIPCExtension
+
+ if ext is None:
+ return
+
+ param_name = par.name
+ new_value = par.eval()
+
+ # Log the parameter change if debug is enabled
+ ext._log(f"Parameter '{param_name}' changed: {prev} -> {new_value}", force=False)
+
+ # Handle each parameter
+ if param_name == "Active":
+ handle_active_change(ext, new_value, prev)
+
+ elif param_name == "Ipcmemname":
+ handle_ipcmemname_change(ext, new_value, prev)
+
+ elif param_name == "Numslots":
+ handle_numslots_change(ext, new_value, prev)
+
+ elif param_name == "Debug":
+ handle_debug_change(ext, new_value, prev)
+
+ elif param_name == "Hidebuiltin":
+ handle_hidebuiltin_change(ext, new_value, prev)
+
+ elif param_name == "Mode":
+ handle_mode_change(ext, new_value, prev)
+
+
+def handle_active_change(ext: object, new_value: object, prev: object) -> None:
+ """Handle Active parameter toggle.
+
+ Args:
+ ext: CUDAIPCExtension instance
+ new_value: New Active state (bool or int)
+ prev: Previous Active state
+ """
+ # Convert to bool
+ new_value = bool(new_value)
+
+ if new_value:
+ ext._log("Component activated", force=True)
+ # Re-initialize based on current mode
+ if ext.mode == "Sender":
+ # Sender initialization happens on first export_frame() call
+ ext._log("Sender mode ready - will initialize on first frame export", force=False)
+ elif ext.mode == "Receiver":
+ # Trigger receiver initialization attempt
+ ext._log("Receiver mode activated - will attempt connection", force=False)
+ else:
+ ext._log("Component deactivated - cleaning up", force=True)
+ # Clean up current mode resources
+ ext.cleanup()
+ ext._host.clear_status()
+
+ # Disable Numslots while active to prevent runtime array size mismatch.
+ # Receiver mode always keeps Numslots disabled (sender controls slot count).
+ # Sender mode: editable only when inactive.
+ with contextlib.suppress(AttributeError):
+ parent().par.Numslots.enable = not new_value and ext.mode == "Sender"
+
+
+def handle_ipcmemname_change(ext: object, new_value: object, prev: object) -> None:
+ """Handle Ipcmemname parameter change.
+
+ Args:
+ ext: CUDAIPCExtension instance
+ new_value: New IPC memory name (str)
+ prev: Previous IPC memory name
+ """
+ # Convert to string
+ new_value = str(new_value)
+ prev = str(prev) if prev is not None else ""
+
+ if new_value == prev:
+ return
+
+ ext._log("IPC memory name changed - reinitializing", force=True)
+
+ # Clean up existing connection
+ ext.cleanup()
+
+ # Update internal state
+ ext.shm_name = new_value
+
+ # Re-initialize based on mode
+ if ext.mode == "Sender":
+ ext._log("Sender will reinitialize on next frame export", force=False)
+ elif ext.mode == "Receiver":
+ # Force immediate reconnection on next frame
+ ext.request_immediate_reconnect()
+ ext._log("Receiver will attempt reconnection on next frame", force=False)
+
+
+def handle_numslots_change(ext: object, new_value: object, prev: object) -> None:
+ """Handle Numslots parameter change.
+
+ Args:
+ ext: CUDAIPCExtension instance
+ new_value: New number of ring buffer slots (int or str)
+ prev: Previous number of slots
+ """
+ # Convert to int if string
+ new_value = int(new_value)
+ prev = int(prev) if prev is not None else 0
+
+ if new_value == prev:
+ return
+
+ # Receiver ignores manual Numslots changes β slot count comes from sender via SharedMemory.
+ # The parameter is disabled in the UI when in Receiver mode, but this guard handles
+ # any edge case where the callback fires anyway.
+ if ext.mode == "Receiver":
+ ext._log("Numslots change ignored in Receiver mode (controlled by sender)", force=True)
+ return
+
+ # Skip if component is active β Numslots should be disabled in UI, but guard
+ # against script-based changes which bypass the UI parameter enable state.
+ if ext.is_active():
+ ext._log("Numslots change ignored while Active (deactivate first)", force=True)
+ return
+
+ # Validate slot count (2-5 slots supported)
+ if new_value < 2 or new_value > 5:
+ ext._log(f"WARNING: Numslots={new_value} outside recommended range (2-5)", force=True)
+
+ ext._log("Ring buffer slot count changed - reinitializing", force=True)
+
+ # Clean up existing buffers
+ ext.cleanup()
+
+ # Update internal state
+ ext.num_slots = new_value
+
+ # Re-initialize based on mode
+ if ext.mode == "Sender":
+ ext._log("Sender will recreate ring buffer on next frame export", force=False)
+ elif ext.mode == "Receiver":
+ # Force immediate reconnection on next frame
+ ext.request_immediate_reconnect()
+ ext._log("Receiver will reconnect with new slot count on next frame", force=False)
+
+
+def handle_debug_change(ext: object, new_value: object, prev: object) -> None:
+ """Handle Debug parameter toggle.
+
+ Args:
+ ext: CUDAIPCExtension instance
+ new_value: New debug state (bool or int)
+ prev: Previous debug state
+ """
+ # Convert to bool
+ new_value = bool(new_value)
+
+ ext.verbose_performance = new_value
+
+ if new_value:
+ ext._log("Debug logging ENABLED", force=True)
+ else:
+ ext._log("Debug logging DISABLED", force=True)
+
+
+def handle_hidebuiltin_change(ext: object, new_value: object, prev: object) -> None:
+ """Handle Hidebuiltin parameter toggle.
+
+ Args:
+ ext: CUDAIPCExtension instance
+ new_value: New hide state (bool or int)
+ prev: Previous hide state
+ """
+ new_value = bool(new_value)
+ parent().showCustomOnly = new_value
+ ext._log(f"Built-in parameters {'hidden' if new_value else 'visible'}", force=True)
+
+
+def handle_mode_change(ext: object, new_value: object, prev: object) -> None:
+ """Handle Mode parameter change ('Sender' <-> 'Receiver').
+
+ Args:
+ ext: CUDAIPCExtension instance
+ new_value: New mode ('Sender' or 'Receiver')
+ prev: Previous mode
+ """
+ # Convert to string
+ new_value = str(new_value)
+ prev = str(prev) if prev is not None else ""
+
+ if new_value == prev:
+ return
+
+ ext._log(f"Mode switching: {prev} -> {new_value}", force=True)
+
+ # Use extension's built-in switch_mode method
+ try:
+ ext.switch_mode(new_value)
+ ext._log(f"Mode switch complete: now in {new_value} mode", force=True)
+
+ # Update 'bg' selectTOP to display correct buffer
+ try:
+ bg_select = parent().op("bg")
+ if bg_select:
+ if new_value == "Sender":
+ bg_select.par.top = "ExportBuffer"
+ ext._log("Updated bg selectTOP -> ExportBuffer", force=False)
+ elif new_value == "Receiver":
+ bg_select.par.top = "ImportBuffer"
+ ext._log("Updated bg selectTOP -> ImportBuffer", force=False)
+ except (AttributeError, RuntimeError) as e:
+ ext._log(f"Could not update bg selectTOP: {e}", force=False)
+
+ except (AttributeError, RuntimeError) as e:
+ ext._log(f"ERROR switching mode: {e}", force=True)
+
+
+# Other callback stubs (not used for parameter monitoring)
+def onPulse(par: object) -> None:
+ """Called when a pulse parameter is triggered."""
+ pass
+
+
+def onExpressionChange(par: object, val: object, prev: object) -> None:
+ """Called when an expression parameter changes."""
+ pass
+
+
+def onExportChange(par: object, val: object, prev: object) -> None:
+ """Called when an export parameter changes."""
+ pass
+
+
+def onEnableChange(par: object, val: object, prev: object) -> None:
+ """Called when a parameter's enable state changes."""
+ pass
+
+
+def onModeChange(par: object, val: object, prev: object) -> None:
+ """Called when a parameter's mode changes."""
+ pass
+
+
+def onNameChange(par: object, val: object, prev: object) -> None:
+ """Called when a parameter's name changes."""
+ pass
diff --git a/src/streamdiffusion/_compat/td_exporter/script_top_callbacks.py b/src/streamdiffusion/_compat/td_exporter/script_top_callbacks.py
new file mode 100644
index 000000000..800cb7127
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/script_top_callbacks.py
@@ -0,0 +1,65 @@
+"""
+ImportBuffer Callback for CUDAIPCExtension Receiver Mode
+
+Copy this into the ImportBuffer's Callbacks DAT inside the .tox component.
+This is ONLY used when Mode = 'Receiver'.
+
+The ImportBuffer's callbacks parameter should point to this DAT:
+ ImportBuffer.par.callbacks = op('script_top_callbacks')
+
+This Script TOP is force-cooked from Execute DAT onFrameStart (nothing pulls it downstream).
+The onCook callback handles resolution update (one-time) and frame import.
+"""
+
+
+def onCook(scriptTop: object) -> None:
+ """Called every time the Script TOP needs to cook.
+
+ TD 2023 path: Handles resolution update (one-time) and imports frame from CUDA IPC.
+ TD 2025+ with modoutsidecook: This callback may still fire but import_frame()
+ is driven from Execute DAT. The resolution update here serves as a safety net.
+
+ Args:
+ scriptTop: The Script TOP operator instance (same as 'me')
+ """
+ ext = parent().ext.CUDAIPCExtension
+ if ext is None:
+ return
+
+ # Handle resolution update (one-time, after initialize_receiver)
+ # With modoutsidecook, this may already be handled by Execute DAT
+ pending = ext.consume_pending_resolution()
+ if pending is not None:
+ width, height = pending
+ try:
+ scriptTop.par.outputresolution = 9 # Custom Resolution
+ scriptTop.par.resolutionw = width
+ scriptTop.par.resolutionh = height
+ ext._log(
+ f"Set ImportBuffer resolution to {width}x{height}",
+ force=True,
+ )
+ except (AttributeError, RuntimeError) as e:
+ ext._log(f"Could not set ImportBuffer resolution: {e}", force=True)
+
+ # TD 2023 path: Import frame from CUDA IPC into this Script TOP
+ # With modoutsidecook (TD 2025+), import_frame() is called from Execute DAT instead
+ # Check if modoutsidecook is active; if so, skip to avoid double-import
+ try:
+ if hasattr(scriptTop.par, "modoutsidecook") and scriptTop.par.modoutsidecook.eval():
+ return # Import handled by Execute DAT
+ except (AttributeError, RuntimeError):
+ pass # Parameter doesn't exist or can't be read, proceed with import
+
+ ext.import_frame(scriptTop)
+
+
+def onSetupParameters(scriptTop: object, page: object) -> None:
+ """Called when Setup Parameters is pressed.
+
+ Args:
+ scriptTop: The Script TOP
+ page: The custom parameter page
+ """
+ # No custom parameters needed for receiver
+ pass
diff --git a/src/streamdiffusion/_compat/td_exporter/warning_emitter_callbacks.py b/src/streamdiffusion/_compat/td_exporter/warning_emitter_callbacks.py
new file mode 100644
index 000000000..254685122
--- /dev/null
+++ b/src/streamdiffusion/_compat/td_exporter/warning_emitter_callbacks.py
@@ -0,0 +1,25 @@
+"""Script TOP callbacks for the warning_emitter operator inside CUDAIPCLink.
+
+Reads the status message written by RealTDHost (via ownerComp storage key
+'cuda_link_status_msg') and re-emits it as a local addWarning badge on this
+Script TOP. Produces a visible warning indicator inside the COMP alongside
+the COMP-body tint set by RealTDHost.set_warning_status / set_error_status.
+
+RealTDHost force-cooks this TOP on every status transition so the badge stays
+in sync without relying on continuous cooking. Cook Type should be set to
+'Off' (Pulse to Cook) in the TD parameter dialog.
+"""
+
+
+def onCook(scriptOp):
+ msg = scriptOp.parent().fetch("cuda_link_status_msg", None)
+ if msg:
+ scriptOp.addWarning(str(msg))
+
+
+def onSetupParameters(scriptOp):
+ return
+
+
+def onPulse(par):
+ return
diff --git a/src/streamdiffusion/_hf_tracing_patches.py b/src/streamdiffusion/_hf_tracing_patches.py
index d2d611f30..422b6bf65 100644
--- a/src/streamdiffusion/_hf_tracing_patches.py
+++ b/src/streamdiffusion/_hf_tracing_patches.py
@@ -5,7 +5,10 @@
import torch
+
_ALREADY = False # idempotence guard
+
+
# --------------------------------------------------------------------------- #
# 1. UNet2DConditionModel: guard in_channels % up_factor
# --------------------------------------------------------------------------- #
@@ -16,11 +19,11 @@ def _patch_unet():
def patched(self, sample, *args, **kwargs):
if torch.jit.is_tracing():
- dim = torch.as_tensor(getattr(self.config, "in_channels", self.in_channels))
+ dim = torch.as_tensor(getattr(self.config, "in_channels", self.in_channels))
up_factor = torch.as_tensor(getattr(self.config, "default_overall_up_factor", 1))
torch._assert(
torch.remainder(dim, up_factor) == 0,
- f"in_channels={dim} not divisible by default_overall_up_factor={up_factor}"
+ f"in_channels={dim} not divisible by default_overall_up_factor={up_factor}",
)
return orig_fwd(self, sample, *args, **kwargs)
@@ -32,12 +35,13 @@ def patched(self, sample, *args, **kwargs):
# --------------------------------------------------------------------------- #
def _patch_downsample():
import diffusers.models.downsampling as d
+
orig_fwd = d.Downsample2D.forward
def patched(self, hidden_states, *args, **kwargs):
torch._assert(
hidden_states.shape[1] == self.channels,
- f"[Downsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}"
+ f"[Downsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}",
)
return orig_fwd(self, hidden_states, *args, **kwargs)
@@ -49,12 +53,13 @@ def patched(self, hidden_states, *args, **kwargs):
# --------------------------------------------------------------------------- #
def _patch_upsample():
import diffusers.models.upsampling as u
+
orig_fwd = u.Upsample2D.forward
def patched(self, hidden_states, *args, **kwargs):
torch._assert(
hidden_states.shape[1] == self.channels,
- f"[Upsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}"
+ f"[Upsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}",
)
return orig_fwd(self, hidden_states, *args, **kwargs)
diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py
index dd3a97e0b..f79fff72e 100644
--- a/src/streamdiffusion/acceleration/tensorrt/builder.py
+++ b/src/streamdiffusion/acceleration/tensorrt/builder.py
@@ -286,6 +286,9 @@ def _quant_fn():
# --- FP8 Q/DQ layer count (sanity gate: < 100 means quantization is inactive) ---
if fp8 and os.path.exists(engine_path):
try:
+ import json as _json
+ import re as _re
+
import tensorrt as trt
_rt = trt.Runtime(trt.Logger(trt.Logger.WARNING))
@@ -300,6 +303,28 @@ def _quant_fn():
_build_logger.warning(
f"[BUILD] Low Q/DQ count ({_qdq} < 500) β FP8 quantization likely inactive or incomplete"
)
+
+ # Fused-MHA check: count attention layers TRT fused into a single kernel.
+ # Pattern is empirical β FLUX uses "_gemm_mha_v2"; SDXL on Ada may differ.
+ # First build logs sample names so the regex can be confirmed or tightened.
+ _MHA_RE = _re.compile(r"mha|fmha|MultiHead|FlashAttn", _re.IGNORECASE)
+ try:
+ _layers = _json.loads(_info).get("Layers", [])
+ except Exception:
+ _layers = []
+ _total = len(_layers)
+ _mha_names = [_l.get("Name", "") for _l in _layers if _MHA_RE.search(_l.get("Name", ""))]
+ _mha_count = len(_mha_names)
+ stats["mha_fused_kernels"] = _mha_count
+ stats["total_engine_layers"] = _total
+ _build_logger.info(f"[BUILD] FP8 engine fused MHA layers: {_mha_count} / {_total} total")
+ if _mha_count == 0 and _total > 0:
+ _build_logger.warning(
+ "[BUILD] No fused MHA layers detected β attention may be running decomposed "
+ "(slower). Sample layer names (first 5): " + str([_l.get("Name", "") for _l in _layers[:5]])
+ )
+ else:
+ _build_logger.info(f"[BUILD] Sample fused-MHA layer names: {_mha_names[:3]}")
except Exception as _e:
_build_logger.warning(f"[BUILD] FP8 inspector check skipped: {_e}")
diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py
index 13532a0c4..be5408075 100644
--- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py
+++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py
@@ -1,15 +1,16 @@
from .controlnet_export import SDXLControlNetExportWrapper
from .unet_controlnet_export import ControlNetUNetExportWrapper, MultiControlNetUNetExportWrapper
from .unet_ipadapter_export import IPAdapterUNetExportWrapper
-from .unet_sdxl_export import SDXLExportWrapper, SDXLConditioningHandler
+from .unet_sdxl_export import SDXLConditioningHandler, SDXLExportWrapper
from .unet_unified_export import UnifiedExportWrapper
+
__all__ = [
"SDXLControlNetExportWrapper",
"ControlNetUNetExportWrapper",
- "MultiControlNetUNetExportWrapper",
+ "MultiControlNetUNetExportWrapper",
"IPAdapterUNetExportWrapper",
"SDXLExportWrapper",
"SDXLConditioningHandler",
"UnifiedExportWrapper",
-]
\ No newline at end of file
+]
diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py
index 946917b1c..43809a1a7 100644
--- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py
+++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py
@@ -1,23 +1,24 @@
import torch
+
class SDXLControlNetExportWrapper(torch.nn.Module):
"""Wrapper for SDXL ControlNet models to handle added_cond_kwargs properly during ONNX export"""
-
+
def __init__(self, controlnet_model):
super().__init__()
self.controlnet = controlnet_model
-
+
# Get device and dtype from model
- if hasattr(controlnet_model, 'device'):
+ if hasattr(controlnet_model, "device"):
self.device = controlnet_model.device
else:
# Try to infer from first parameter
try:
self.device = next(controlnet_model.parameters()).device
except:
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- if hasattr(controlnet_model, 'dtype'):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ if hasattr(controlnet_model, "dtype"):
self.dtype = controlnet_model.dtype
else:
# Try to infer from first parameter
@@ -25,15 +26,14 @@ def __init__(self, controlnet_model):
self.dtype = next(controlnet_model.parameters()).dtype
except:
self.dtype = torch.float16
-
- def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, text_embeds, time_ids):
+
+ def forward(
+ self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, text_embeds, time_ids
+ ):
"""Forward pass that handles SDXL ControlNet requirements and produces 9 down blocks"""
# Use the provided SDXL conditioning
- added_cond_kwargs = {
- 'text_embeds': text_embeds,
- 'time_ids': time_ids
- }
-
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
+
# Call the ControlNet with proper arguments including conditioning_scale
result = self.controlnet(
sample=sample,
@@ -42,40 +42,49 @@ def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, cond
controlnet_cond=controlnet_cond,
conditioning_scale=conditioning_scale,
added_cond_kwargs=added_cond_kwargs,
- return_dict=False
+ return_dict=False,
)
-
+
# Extract down blocks and mid block from result
if isinstance(result, tuple) and len(result) >= 2:
down_block_res_samples, mid_block_res_sample = result[0], result[1]
- elif hasattr(result, 'down_block_res_samples') and hasattr(result, 'mid_block_res_sample'):
+ elif hasattr(result, "down_block_res_samples") and hasattr(result, "mid_block_res_sample"):
down_block_res_samples = result.down_block_res_samples
mid_block_res_sample = result.mid_block_res_sample
else:
raise ValueError(f"Unexpected ControlNet output format: {type(result)}")
-
+
# SDXL ControlNet should have exactly 9 down blocks
if len(down_block_res_samples) != 9:
raise ValueError(f"SDXL ControlNet expected 9 down blocks, got {len(down_block_res_samples)}")
-
+
# Return 9 down blocks + 1 mid block with explicit names matching UNet pattern
# Following the pattern from controlnet_wrapper.py and models.py:
# down_block_00: Initial sample (320 channels)
- # down_block_01-03: Block 0 residuals (320 channels)
+ # down_block_01-03: Block 0 residuals (320 channels)
# down_block_04-06: Block 1 residuals (640 channels)
# down_block_07-08: Block 2 residuals (1280 channels)
down_block_00 = down_block_res_samples[0] # Initial: 320 channels, 88x88
down_block_01 = down_block_res_samples[1] # Block0: 320 channels, 88x88
- down_block_02 = down_block_res_samples[2] # Block0: 320 channels, 88x88
+ down_block_02 = down_block_res_samples[2] # Block0: 320 channels, 88x88
down_block_03 = down_block_res_samples[3] # Block0: 320 channels, 44x44
down_block_04 = down_block_res_samples[4] # Block1: 640 channels, 44x44
down_block_05 = down_block_res_samples[5] # Block1: 640 channels, 44x44
down_block_06 = down_block_res_samples[6] # Block1: 640 channels, 22x22
down_block_07 = down_block_res_samples[7] # Block2: 1280 channels, 22x22
down_block_08 = down_block_res_samples[8] # Block2: 1280 channels, 22x22
- mid_block = mid_block_res_sample # Mid: 1280 channels, 22x22
-
+ mid_block = mid_block_res_sample # Mid: 1280 channels, 22x22
+
# Return as individual tensors to preserve names in ONNX
- return (down_block_00, down_block_01, down_block_02, down_block_03,
- down_block_04, down_block_05, down_block_06, down_block_07,
- down_block_08, mid_block)
\ No newline at end of file
+ return (
+ down_block_00,
+ down_block_01,
+ down_block_02,
+ down_block_03,
+ down_block_04,
+ down_block_05,
+ down_block_06,
+ down_block_07,
+ down_block_08,
+ mid_block,
+ )
diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py
index e7a834bc1..fb6b5733c 100644
--- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py
+++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py
@@ -1,30 +1,32 @@
"""ControlNet-aware UNet wrapper for ONNX export"""
+from typing import Dict, List, Optional
+
import torch
-from typing import List, Optional, Dict, Any
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+
from ..models.utils import convert_list_to_structure
class ControlNetUNetExportWrapper(torch.nn.Module):
"""Wrapper that combines UNet with ControlNet inputs for ONNX export"""
-
+
def __init__(self, unet: UNet2DConditionModel, control_input_names: List[str], kvo_cache_structure: List[int]):
super().__init__()
self.unet = unet
self.control_input_names = control_input_names
self.kvo_cache_structure = kvo_cache_structure
-
+
self.control_names = []
for name in control_input_names:
if "input_control" in name or "output_control" in name or "middle_control" in name:
self.control_names.append(name)
-
+
self.num_controlnet_args = len(self.control_names)
-
+
# Detect if this is SDXL based on UNet config
self.is_sdxl = self._detect_sdxl_architecture(unet)
-
+
# SDXL ControlNet has different structure than SD1.5
if self.is_sdxl:
# SDXL has 1 initial + 3 down blocks producing 9 control tensors total
@@ -32,30 +34,30 @@ def __init__(self, unet: UNet2DConditionModel, control_input_names: List[str], k
else:
# SD1.5 has 12 down blocks
self.expected_down_blocks = 12
-
+
def _detect_sdxl_architecture(self, unet):
"""Detect if UNet is SDXL based on architecture"""
- if hasattr(unet, 'config'):
+ if hasattr(unet, "config"):
config = unet.config
# SDXL has 3 down blocks vs SD1.5's 4 down blocks
- block_out_channels = getattr(config, 'block_out_channels', None)
+ block_out_channels = getattr(config, "block_out_channels", None)
if block_out_channels and len(block_out_channels) == 3:
return True
return False
-
+
def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
"""Forward pass that organizes control inputs and calls UNet"""
-
- control_args = args[:self.num_controlnet_args]
- kvo_cache = args[self.num_controlnet_args:]
-
+
+ control_args = args[: self.num_controlnet_args]
+ kvo_cache = args[self.num_controlnet_args :]
+
down_block_controls = []
mid_block_control = None
-
+
if control_args:
all_control_tensors = []
middle_tensor = None
-
+
for tensor, name in zip(control_args, self.control_names):
if "input_control" in name:
if "middle" in name:
@@ -64,7 +66,7 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
all_control_tensors.append(tensor)
elif "middle_control" in name:
middle_tensor = tensor
-
+
if len(all_control_tensors) == self.expected_down_blocks:
down_block_controls = all_control_tensors
mid_block_control = middle_tensor
@@ -73,7 +75,7 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
if len(all_control_tensors) > 0:
if len(all_control_tensors) > self.expected_down_blocks:
# Too many tensors - take the first expected_down_blocks
- down_block_controls = all_control_tensors[:self.expected_down_blocks]
+ down_block_controls = all_control_tensors[: self.expected_down_blocks]
else:
# Too few tensors - use what we have
down_block_controls = all_control_tensors
@@ -82,30 +84,29 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
# No control tensors available - skip ControlNet
down_block_controls = None
mid_block_control = None
-
+
formatted_kvo_cache = []
if len(kvo_cache) > 0:
formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure)
unet_kwargs = {
- 'sample': sample,
- 'timestep': timestep,
- 'encoder_hidden_states': encoder_hidden_states,
- 'kvo_cache': formatted_kvo_cache,
- 'return_dict': False,
+ "sample": sample,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "kvo_cache": formatted_kvo_cache,
+ "return_dict": False,
}
-
+
# Pass through all additional kwargs (for SDXL models)
unet_kwargs.update(kwargs)
# Auto-generate SDXL conditioning if missing and UNet requires it
- if 'added_cond_kwargs' not in unet_kwargs or unet_kwargs.get('added_cond_kwargs') is None:
- if (hasattr(self.unet, 'config') and
- getattr(self.unet.config, 'addition_embed_type', None) == 'text_time'):
+ if "added_cond_kwargs" not in unet_kwargs or unet_kwargs.get("added_cond_kwargs") is None:
+ if hasattr(self.unet, "config") and getattr(self.unet.config, "addition_embed_type", None) == "text_time":
batch_size = sample.shape[0]
- unet_kwargs['added_cond_kwargs'] = {
- 'text_embeds': torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype),
- 'time_ids': torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype),
+ unet_kwargs["added_cond_kwargs"] = {
+ "text_embeds": torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype),
+ "time_ids": torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype),
}
if down_block_controls:
@@ -115,30 +116,30 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
# Control tensors are now generated in the correct order to match UNet's down_block_res_samples
# For SDXL: [88x88, 88x88, 88x88, 44x44, 44x44, 44x44, 22x22, 22x22, 22x22]
# This directly aligns with UNet's: [initial_sample] + [block0_residuals] + [block1_residuals] + [block2_residuals]
- unet_kwargs['down_block_additional_residuals'] = adapted_controls
-
+ unet_kwargs["down_block_additional_residuals"] = adapted_controls
+
if mid_block_control is not None:
# Adapt middle control tensor shape if needed
adapted_mid_control = self._adapt_middle_control_tensor(mid_block_control, sample)
- unet_kwargs['mid_block_additional_residual'] = adapted_mid_control
-
+ unet_kwargs["mid_block_additional_residual"] = adapted_mid_control
+
try:
res = self.unet(**unet_kwargs)
if len(kvo_cache) > 0:
return res
else:
return res[0]
- except Exception as e:
+ except Exception:
raise
-
+
def _adapt_control_tensors(self, control_tensors, sample):
"""Adapt control tensor shapes to match UNet expectations"""
if not control_tensors:
return control_tensors
-
+
adapted_tensors = []
sample_height, sample_width = sample.shape[-2:]
-
+
# Updated factors to match the corrected control tensor generation
# SDXL: 9 tensors [88x88, 88x88, 88x88, 44x44, 44x44, 44x44, 22x22, 22x22, 22x22]
# Factors: [1, 1, 1, 2, 2, 2, 4, 4, 4] to match UNet down_block_res_samples structure
@@ -146,30 +147,31 @@ def _adapt_control_tensors(self, control_tensors, sample):
expected_downsample_factors = [1, 1, 1, 2, 2, 2, 4, 4, 4] # 9 tensors for SDXL
else:
expected_downsample_factors = [1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8] # 12 tensors for SD1.5
-
+
for i, control_tensor in enumerate(control_tensors):
if control_tensor is None:
adapted_tensors.append(control_tensor)
continue
-
+
# Check if tensor needs spatial adaptation
if len(control_tensor.shape) >= 4:
control_height, control_width = control_tensor.shape[-2:]
-
+
# Use the correct downsampling factor for this tensor index
if i < len(expected_downsample_factors):
downsample_factor = expected_downsample_factors[i]
expected_height = sample_height // downsample_factor
expected_width = sample_width // downsample_factor
-
+
if control_height != expected_height or control_width != expected_width:
# Use interpolation to adapt size
import torch.nn.functional as F
+
adapted_tensor = F.interpolate(
- control_tensor,
+ control_tensor,
size=(expected_height, expected_width),
- mode='bilinear',
- align_corners=False
+ mode="bilinear",
+ align_corners=False,
)
adapted_tensors.append(adapted_tensor)
else:
@@ -179,94 +181,94 @@ def _adapt_control_tensors(self, control_tensors, sample):
adapted_tensors.append(control_tensor)
else:
adapted_tensors.append(control_tensor)
-
+
return adapted_tensors
-
+
def _adapt_middle_control_tensor(self, mid_control, sample):
"""Adapt middle control tensor shape to match UNet expectations"""
if mid_control is None:
return mid_control
-
+
# Middle control is typically at the bottleneck, so heavily downsampled
if len(mid_control.shape) >= 4 and len(sample.shape) >= 4:
sample_height, sample_width = sample.shape[-2:]
control_height, control_width = mid_control.shape[-2:]
-
+
# For SDXL: middle block is at 4x downsampling (22x22 from 88x88)
# For SD1.5: middle block is at 8x downsampling
expected_factor = 4 if self.is_sdxl else 8
expected_height = sample_height // expected_factor
expected_width = sample_width // expected_factor
-
+
if control_height != expected_height or control_width != expected_width:
import torch.nn.functional as F
+
adapted_tensor = F.interpolate(
- mid_control,
- size=(expected_height, expected_width),
- mode='bilinear',
- align_corners=False
+ mid_control, size=(expected_height, expected_width), mode="bilinear", align_corners=False
)
return adapted_tensor
-
+
return mid_control
class MultiControlNetUNetExportWrapper(torch.nn.Module):
"""Advanced wrapper for multiple ControlNets with different scales"""
-
- def __init__(self,
- unet: UNet2DConditionModel,
- control_input_names: List[str],
- kvo_cache_structure: List[int],
- num_controlnets: int = 1,
- conditioning_scales: Optional[List[float]] = None):
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ control_input_names: List[str],
+ kvo_cache_structure: List[int],
+ num_controlnets: int = 1,
+ conditioning_scales: Optional[List[float]] = None,
+ ):
super().__init__()
self.unet = unet
self.control_input_names = control_input_names
self.num_controlnets = num_controlnets
self.conditioning_scales = conditioning_scales or [1.0] * num_controlnets
self.kvo_cache_structure = kvo_cache_structure
-
+
self.control_names = []
for name in control_input_names:
if "input_control" in name or "output_control" in name or "middle_control" in name:
self.control_names.append(name)
-
+
self.num_controlnet_args = len(self.control_names)
self.controlnet_indices = []
controls_per_net = self.num_controlnet_args // num_controlnets
-
+
for cn_idx in range(num_controlnets):
start_idx = cn_idx * controls_per_net
end_idx = start_idx + controls_per_net
self.controlnet_indices.append(list(range(start_idx, end_idx)))
-
+
def forward(self, sample, timestep, encoder_hidden_states, *args):
"""Forward pass for multiple ControlNets"""
- control_args = args[:self.num_controlnet_args]
- kvo_cache = args[self.num_controlnet_args:]
+ control_args = args[: self.num_controlnet_args]
+ kvo_cache = args[self.num_controlnet_args :]
combined_down_controls = None
combined_mid_control = None
-
+
for cn_idx, indices in enumerate(self.controlnet_indices):
scale = self.conditioning_scales[cn_idx]
if scale == 0:
continue
-
+
cn_controls = [control_args[i] for i in indices if i < len(control_args)]
-
+
if not cn_controls:
continue
-
+
num_down = len(cn_controls) - 1
down_controls = cn_controls[:num_down]
mid_control = cn_controls[num_down] if num_down < len(cn_controls) else None
-
+
scaled_down = [ctrl * scale for ctrl in down_controls]
scaled_mid = mid_control * scale if mid_control is not None else None
-
+
if combined_down_controls is None:
combined_down_controls = scaled_down
combined_mid_control = scaled_mid
@@ -275,24 +277,24 @@ def forward(self, sample, timestep, encoder_hidden_states, *args):
combined_down_controls[i] += scaled_down[i]
if scaled_mid is not None and combined_mid_control is not None:
combined_mid_control += scaled_mid
-
+
formatted_kvo_cache = []
if len(kvo_cache) > 0:
formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure)
unet_kwargs = {
- 'sample': sample,
- 'timestep': timestep,
- 'encoder_hidden_states': encoder_hidden_states,
- 'kvo_cache': formatted_kvo_cache,
- 'return_dict': False,
+ "sample": sample,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "kvo_cache": formatted_kvo_cache,
+ "return_dict": False,
}
-
+
if combined_down_controls:
- unet_kwargs['down_block_additional_residuals'] = list(reversed(combined_down_controls))
+ unet_kwargs["down_block_additional_residuals"] = list(reversed(combined_down_controls))
if combined_mid_control is not None:
- unet_kwargs['mid_block_additional_residual'] = combined_mid_control
-
+ unet_kwargs["mid_block_additional_residual"] = combined_mid_control
+
res = self.unet(**unet_kwargs)
if len(kvo_cache) > 0:
return res
@@ -301,11 +303,13 @@ def forward(self, sample, timestep, encoder_hidden_states, *args):
return res
-def create_controlnet_wrapper(unet: UNet2DConditionModel,
- control_input_names: List[str],
- kvo_cache_structure: List[int],
- num_controlnets: int = 1,
- conditioning_scales: Optional[List[float]] = None) -> torch.nn.Module:
+def create_controlnet_wrapper(
+ unet: UNet2DConditionModel,
+ control_input_names: List[str],
+ kvo_cache_structure: List[int],
+ num_controlnets: int = 1,
+ conditioning_scales: Optional[List[float]] = None,
+) -> torch.nn.Module:
"""Factory function to create appropriate ControlNet wrapper"""
if num_controlnets == 1:
return ControlNetUNetExportWrapper(unet, control_input_names, kvo_cache_structure)
@@ -315,17 +319,18 @@ def create_controlnet_wrapper(unet: UNet2DConditionModel,
)
-def organize_control_tensors(control_tensors: List[torch.Tensor],
- control_input_names: List[str]) -> Dict[str, List[torch.Tensor]]:
+def organize_control_tensors(
+ control_tensors: List[torch.Tensor], control_input_names: List[str]
+) -> Dict[str, List[torch.Tensor]]:
"""Organize control tensors by type (input, output, middle)"""
- organized = {'input': [], 'output': [], 'middle': []}
-
+ organized = {"input": [], "output": [], "middle": []}
+
for tensor, name in zip(control_tensors, control_input_names):
if "input_control" in name:
- organized['input'].append(tensor)
+ organized["input"].append(tensor)
elif "output_control" in name:
- organized['output'].append(tensor)
+ organized["output"].append(tensor)
elif "middle_control" in name:
- organized['middle'].append(tensor)
-
- return organized
\ No newline at end of file
+ organized["middle"].append(tensor)
+
+ return organized
diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py
index f7eb18615..93310ad87 100644
--- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py
+++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py
@@ -1,30 +1,37 @@
+from typing import List
+
import torch
from diffusers import UNet2DConditionModel
-from typing import Optional, Dict, Any, List
-
-from ....model_detection import detect_model, detect_model_from_diffusers_unet
from diffusers_ipadapter.ip_adapter.attention_processor import TRTIPAttnProcessor, TRTIPAttnProcessor2_0
+from ....model_detection import detect_model_from_diffusers_unet
+
class IPAdapterUNetExportWrapper(torch.nn.Module):
"""
Wrapper that bakes IPAdapter attention processors into the UNet for ONNX export.
-
+
This approach installs IPAdapter attention processors before ONNX export,
allowing the specialized attention logic to be compiled into TensorRT.
The UNet expects concatenated embeddings (text + image) as encoder_hidden_states.
"""
-
- def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tokens: int = 4, install_processors: bool = True):
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ cross_attention_dim: int,
+ num_tokens: int = 4,
+ install_processors: bool = True,
+ ):
super().__init__()
self.unet = unet
self.num_image_tokens = num_tokens # 4 for standard, 16 for plus
self.cross_attention_dim = cross_attention_dim # 768 for SD1.5, 2048 for SDXL
self.install_processors = install_processors
-
+
# Convert to float32 BEFORE installing processors (to avoid resetting them)
self.unet = self.unet.to(dtype=torch.float32)
-
+
# Track installed TRT processors
self._ip_trt_processors: List[torch.nn.Module] = []
self.num_ip_layers: int = 0
@@ -36,8 +43,10 @@ def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tok
# Install IPAdapter processors AFTER dtype conversion
self._install_ipadapter_processors()
else:
- print("IPAdapterUNetExportWrapper: WARNING - UNet will not have IPAdapter functionality without processors!")
-
+ print(
+ "IPAdapterUNetExportWrapper: WARNING - UNet will not have IPAdapter functionality without processors!"
+ )
+
def _has_ipadapter_processors(self) -> bool:
"""Check if the UNet already has IPAdapter processors installed"""
try:
@@ -45,44 +54,48 @@ def _has_ipadapter_processors(self) -> bool:
for name, processor in processors.items():
# Check for IPAdapter processor class names
processor_class = processor.__class__.__name__
- if 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class:
+ if "IPAttn" in processor_class or "IPAttnProcessor" in processor_class:
return True
return False
except Exception as e:
print(f"IPAdapterUNetExportWrapper: Error checking existing processors: {e}")
return False
-
+
def _ensure_processor_dtype_consistency(self):
"""Ensure existing IPAdapter processors have correct dtype for ONNX export"""
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor
+
IPProcClass = TRTIPAttnProcessor2_0
else:
from diffusers.models.attention_processor import AttnProcessor
+
IPProcClass = TRTIPAttnProcessor
try:
processors = self.unet.attn_processors
updated_processors = {}
self._ip_trt_processors = []
ip_layer_index = 0
-
+
for name, processor in processors.items():
processor_class = processor.__class__.__name__
- if 'TRTIPAttn' in processor_class:
+ if "TRTIPAttn" in processor_class:
# Already TRT processors: ensure dtype and record
proc = processor.to(dtype=torch.float32)
proc._scale_index = ip_layer_index
self._ip_trt_processors.append(proc)
ip_layer_index += 1
updated_processors[name] = proc
- elif 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class:
+ elif "IPAttn" in processor_class or "IPAttnProcessor" in processor_class:
# Replace standard processors with TRT variants, preserving weights where applicable
- hidden_size = getattr(processor, 'hidden_size', None)
- cross_attention_dim = getattr(processor, 'cross_attention_dim', None)
- num_tokens = getattr(processor, 'num_tokens', self.num_image_tokens)
- proc = IPProcClass(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens)
+ hidden_size = getattr(processor, "hidden_size", None)
+ cross_attention_dim = getattr(processor, "cross_attention_dim", None)
+ num_tokens = getattr(processor, "num_tokens", self.num_image_tokens)
+ proc = IPProcClass(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens
+ )
# Copy IP projection weights if present
- if hasattr(processor, 'to_k_ip') and hasattr(processor, 'to_v_ip') and hasattr(proc, 'to_k_ip'):
+ if hasattr(processor, "to_k_ip") and hasattr(processor, "to_v_ip") and hasattr(proc, "to_k_ip"):
with torch.no_grad():
proc.to_k_ip.weight.copy_(processor.to_k_ip.weight.to(dtype=torch.float32))
proc.to_v_ip.weight.copy_(processor.to_v_ip.weight.to(dtype=torch.float32))
@@ -93,16 +106,17 @@ def _ensure_processor_dtype_consistency(self):
updated_processors[name] = proc
else:
updated_processors[name] = AttnProcessor()
-
+
# Update all processors to ensure consistency
self.unet.set_attn_processor(updated_processors)
self.num_ip_layers = len(self._ip_trt_processors)
-
+
except Exception as e:
print(f"IPAdapterUNetExportWrapper: Error updating processor dtypes: {e}")
import traceback
+
traceback.print_exc()
-
+
def _install_ipadapter_processors(self):
"""
Install IPAdapter attention processors that will be baked into ONNX.
@@ -112,19 +126,23 @@ def _install_ipadapter_processors(self):
try:
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor
+
IPProcClass = TRTIPAttnProcessor2_0
else:
from diffusers.models.attention_processor import AttnProcessor
+
IPProcClass = TRTIPAttnProcessor
-
+
# Install attention processors with proper configuration
processor_names = list(self.unet.attn_processors.keys())
-
+
attn_procs = {}
ip_layer_index = 0
for name in processor_names:
- cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
-
+ cross_attention_dim = (
+ None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
+ )
+
# Determine hidden_size based on processor location
hidden_size = None
if name.startswith("mid_block"):
@@ -138,7 +156,7 @@ def _install_ipadapter_processors(self):
else:
# Fallback for any unexpected processor names
hidden_size = self.unet.config.block_out_channels[0] # Use first block size as fallback
-
+
if cross_attention_dim is None:
# Self-attention layers use standard processors
attn_procs[name] = AttnProcessor()
@@ -154,38 +172,46 @@ def _install_ipadapter_processors(self):
self._ip_trt_processors.append(proc)
ip_layer_index += 1
attn_procs[name] = proc
-
+
self.unet.set_attn_processor(attn_procs)
self.num_ip_layers = len(self._ip_trt_processors)
-
-
except Exception as e:
print(f"IPAdapterUNetExportWrapper: ERROR - Could not install IPAdapter processors: {e}")
print(f"IPAdapterUNetExportWrapper: Exception type: {type(e).__name__}")
print("IPAdapterUNetExportWrapper: IPAdapter functionality will not work without processors!")
import traceback
+
traceback.print_exc()
raise e
-
+
def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None:
"""Assign per-layer scale tensor to installed TRTIPAttn processors."""
if not isinstance(ipadapter_scale, torch.Tensor):
import logging
- logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}")
+
+ logging.getLogger(__name__).error(
+ f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}"
+ )
raise TypeError("ipadapter_scale must be a torch.Tensor")
if self.num_ip_layers <= 0 or not self._ip_trt_processors:
raise RuntimeError("No TRTIPAttn processors installed")
if ipadapter_scale.ndim != 1 or ipadapter_scale.shape[0] != self.num_ip_layers:
import logging
- logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)")
+
+ logging.getLogger(__name__).error(
+ f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)"
+ )
raise ValueError(f"ipadapter_scale must have shape [{self.num_ip_layers}]")
# Ensure float32 for ONNX export stability
scale_vec = ipadapter_scale.to(dtype=torch.float32)
try:
import logging
- logging.getLogger(__name__).debug(f"IPAdapterUNetExportWrapper: scale_vec min={scale_vec.min()}, max={scale_vec.max()}")
+
+ logging.getLogger(__name__).debug(
+ f"IPAdapterUNetExportWrapper: scale_vec min={scale_vec.min()}, max={scale_vec.max()}"
+ )
except Exception:
pass
for proc in self._ip_trt_processors:
@@ -194,27 +220,27 @@ def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None:
def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torch.Tensor = None):
"""
Forward pass with concatenated embeddings (text + image).
-
+
The IPAdapter processors installed in the UNet will automatically:
1. Split the concatenated embeddings into text and image parts
2. Process image tokens with separate attention computation
3. Apply scaling and blending between text and image attention
-
+
Args:
sample: Latent input tensor
- timestep: Timestep tensor
+ timestep: Timestep tensor
encoder_hidden_states: Concatenated embeddings [text_tokens + image_tokens, cross_attention_dim]
-
+
Returns:
UNet output (noise prediction)
"""
# Validate input shapes
batch_size, seq_len, embed_dim = encoder_hidden_states.shape
-
+
# Check that we have the expected number of image tokens
if embed_dim != self.cross_attention_dim:
raise ValueError(f"Embedding dimension {embed_dim} doesn't match expected {self.cross_attention_dim}")
-
+
# Ensure dtype consistency for ONNX export
if encoder_hidden_states.dtype != torch.float32:
encoder_hidden_states = encoder_hidden_states.to(torch.float32)
@@ -223,29 +249,28 @@ def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torc
if ipadapter_scale is None:
raise RuntimeError("IPAdapterUNetExportWrapper.forward requires ipadapter_scale tensor")
self.set_ipadapter_scale(ipadapter_scale)
-
+
# Pass concatenated embeddings to UNet with baked-in IPAdapter processors
return self.unet(
- sample=sample,
- timestep=timestep,
- encoder_hidden_states=encoder_hidden_states,
- return_dict=False
+ sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, return_dict=False
)
-def create_ipadapter_wrapper(unet: UNet2DConditionModel, num_tokens: int = 4, install_processors: bool = True) -> IPAdapterUNetExportWrapper:
+def create_ipadapter_wrapper(
+ unet: UNet2DConditionModel, num_tokens: int = 4, install_processors: bool = True
+) -> IPAdapterUNetExportWrapper:
"""
Create an IPAdapter wrapper with automatic architecture detection and baked-in processors.
-
+
Handles both cases:
1. UNet with pre-loaded IPAdapter processors (preserves existing weights)
2. UNet without IPAdapter processors (installs new ones if install_processors=True)
-
+
Args:
unet: UNet2DConditionModel to wrap
num_tokens: Number of image tokens (4 for standard, 16 for plus)
install_processors: Whether to install IPAdapter processors if none exist
-
+
Returns:
IPAdapterUNetExportWrapper with baked-in IPAdapter attention processors
"""
@@ -253,23 +278,21 @@ def create_ipadapter_wrapper(unet: UNet2DConditionModel, num_tokens: int = 4, in
try:
model_type = detect_model_from_diffusers_unet(unet)
cross_attention_dim = unet.config.cross_attention_dim
-
+
# Check if UNet already has IPAdapter processors installed
existing_processors = unet.attn_processors
- has_ipadapter = any('IPAttn' in proc.__class__.__name__ or 'IPAttnProcessor' in proc.__class__.__name__
- for proc in existing_processors.values())
-
+ has_ipadapter = any(
+ "IPAttn" in proc.__class__.__name__ or "IPAttnProcessor" in proc.__class__.__name__
+ for proc in existing_processors.values()
+ )
+
# Validate expected dimensions
- expected_dims = {
- "SD15": 768,
- "SDXL": 2048,
- "SD21": 1024
- }
-
+ expected_dims = {"SD15": 768, "SDXL": 2048, "SD21": 1024}
+
expected_dim = expected_dims.get(model_type)
-
+
return IPAdapterUNetExportWrapper(unet, cross_attention_dim, num_tokens, install_processors)
-
+
except Exception as e:
print(f"create_ipadapter_wrapper: Error during model detection: {e}")
- return IPAdapterUNetExportWrapper(unet, 768, num_tokens, install_processors)
\ No newline at end of file
+ return IPAdapterUNetExportWrapper(unet, 768, num_tokens, install_processors)
diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py
index fa1f0f890..078b1f913 100644
--- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py
+++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py
@@ -4,13 +4,17 @@
conditioning parameters, and Turbo variants
"""
+import logging
+from typing import Any, Dict
+
import torch
-from typing import Dict, List, Optional, Tuple, Any, Union
from diffusers import UNet2DConditionModel
+
from ....model_detection import (
detect_model,
)
-import logging
+
+
logger = logging.getLogger(__name__)
# Handle different diffusers versions for CLIPTextModel import
@@ -18,7 +22,7 @@
from diffusers.models.transformers.clip_text_model import CLIPTextModel
except ImportError:
try:
- from diffusers.models.clip_text_model import CLIPTextModel
+ from diffusers.models.clip_text_model import CLIPTextModel
except ImportError:
try:
from transformers import CLIPTextModel
@@ -29,79 +33,81 @@
class SDXLExportWrapper(torch.nn.Module):
"""Wrapper for SDXL UNet to handle optional conditioning in legacy TensorRT"""
-
+
def __init__(self, unet):
super().__init__()
self.unet = unet
self.base_unet = self._get_base_unet(unet)
self.supports_added_cond = self._test_added_cond_support()
-
+
def _get_base_unet(self, unet):
"""Extract the base UNet from wrappers"""
# Handle ControlNet wrapper
- if hasattr(unet, 'unet_model') and hasattr(unet.unet_model, 'config'):
+ if hasattr(unet, "unet_model") and hasattr(unet.unet_model, "config"):
return unet.unet_model
- elif hasattr(unet, 'unet') and hasattr(unet.unet, 'config'):
+ elif hasattr(unet, "unet") and hasattr(unet.unet, "config"):
return unet.unet
- elif hasattr(unet, 'config'):
+ elif hasattr(unet, "config"):
return unet
else:
# Fallback: try to find any attribute that has config
for attr_name in dir(unet):
- if not attr_name.startswith('_'):
+ if not attr_name.startswith("_"):
attr = getattr(unet, attr_name, None)
- if hasattr(attr, 'config') and hasattr(attr.config, 'addition_embed_type'):
+ if hasattr(attr, "config") and hasattr(attr.config, "addition_embed_type"):
return attr
return unet
-
+
def _test_added_cond_support(self):
"""Test if this SDXL model supports added_cond_kwargs"""
try:
# Create minimal test inputs
- sample = torch.randn(1, 4, 8, 8, device='cuda', dtype=torch.float16)
- timestep = torch.tensor([0.5], device='cuda', dtype=torch.float32)
- encoder_hidden_states = torch.randn(1, 77, 2048, device='cuda', dtype=torch.float16)
-
+ sample = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16)
+ timestep = torch.tensor([0.5], device="cuda", dtype=torch.float32)
+ encoder_hidden_states = torch.randn(1, 77, 2048, device="cuda", dtype=torch.float16)
+
# Test with added_cond_kwargs
test_added_cond = {
- 'text_embeds': torch.randn(1, 1280, device='cuda', dtype=torch.float16),
- 'time_ids': torch.randn(1, 6, device='cuda', dtype=torch.float16)
+ "text_embeds": torch.randn(1, 1280, device="cuda", dtype=torch.float16),
+ "time_ids": torch.randn(1, 6, device="cuda", dtype=torch.float16),
}
-
+
with torch.no_grad():
_ = self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=test_added_cond)
-
+
logger.info("SDXL model supports added_cond_kwargs")
return True
-
+
except Exception as e:
logger.error(f"SDXL model does not support added_cond_kwargs: {e}")
return False
-
+
def forward(self, *args, **kwargs):
"""Forward pass that handles SDXL conditioning gracefully"""
try:
# Ensure added_cond_kwargs is never None to prevent TypeError
- if 'added_cond_kwargs' in kwargs and kwargs['added_cond_kwargs'] is None:
- kwargs['added_cond_kwargs'] = {}
-
+ if "added_cond_kwargs" in kwargs and kwargs["added_cond_kwargs"] is None:
+ kwargs["added_cond_kwargs"] = {}
+
# Auto-generate SDXL conditioning if missing and model needs it
- if (len(args) >= 3 and 'added_cond_kwargs' not in kwargs and
- hasattr(self.base_unet.config, 'addition_embed_type') and
- self.base_unet.config.addition_embed_type == 'text_time'):
-
+ if (
+ len(args) >= 3
+ and "added_cond_kwargs" not in kwargs
+ and hasattr(self.base_unet.config, "addition_embed_type")
+ and self.base_unet.config.addition_embed_type == "text_time"
+ ):
sample = args[0]
device = sample.device
batch_size = sample.shape[0]
-
+
logger.info("Auto-generating required SDXL conditioning...")
- kwargs['added_cond_kwargs'] = {
- 'text_embeds': torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype),
- 'time_ids': torch.zeros(batch_size, 6, device=device, dtype=sample.dtype)
+ kwargs["added_cond_kwargs"] = {
+ "text_embeds": torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype),
+ "time_ids": torch.zeros(batch_size, 6, device=device, dtype=sample.dtype),
}
-
+
# If model supports added conditioning and we have the kwargs, use them
- if self.supports_added_cond and 'added_cond_kwargs' in kwargs:
+ if self.supports_added_cond and "added_cond_kwargs" in kwargs:
result = self.unet(*args, **kwargs)
return result
elif len(args) >= 3:
@@ -110,7 +116,7 @@ def forward(self, *args, **kwargs):
else:
# Fallback
return self.unet(*args, **kwargs)
-
+
except (TypeError, AttributeError) as e:
logger.error(f"[SDXL_WRAPPER] forward: Exception caught: {e}")
if "NoneType" in str(e) or "iterable" in str(e) or "text_embeds" in str(e):
@@ -120,15 +126,17 @@ def forward(self, *args, **kwargs):
sample, timestep, encoder_hidden_states = args[0], args[1], args[2]
device = sample.device
batch_size = sample.shape[0]
-
+
# Create minimal valid SDXL conditioning
minimal_conditioning = {
- 'text_embeds': torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype),
- 'time_ids': torch.zeros(batch_size, 6, device=device, dtype=sample.dtype)
+ "text_embeds": torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype),
+ "time_ids": torch.zeros(batch_size, 6, device=device, dtype=sample.dtype),
}
-
+
try:
- return self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=minimal_conditioning)
+ return self.unet(
+ sample, timestep, encoder_hidden_states, added_cond_kwargs=minimal_conditioning
+ )
except Exception as final_e:
logger.info(f"Final fallback to basic call: {final_e}")
return self.unet(sample, timestep, encoder_hidden_states)
@@ -136,181 +144,180 @@ def forward(self, *args, **kwargs):
return self.unet(*args)
else:
raise e
-
+
+
class SDXLConditioningHandler:
"""Handles SDXL conditioning parameters and dual text encoders"""
-
+
def __init__(self, unet_info: Dict[str, Any]):
self.unet_info = unet_info
- self.is_sdxl = unet_info['is_sdxl']
- self.has_time_cond = unet_info['has_time_cond']
- self.has_addition_embed = unet_info['has_addition_embed']
-
+ self.is_sdxl = unet_info["is_sdxl"]
+ self.has_time_cond = unet_info["has_time_cond"]
+ self.has_addition_embed = unet_info["has_addition_embed"]
+
def get_conditioning_spec(self) -> Dict[str, Any]:
"""Get conditioning specification for ONNX export and TensorRT"""
spec = {
- 'text_encoder_dim': 768, # CLIP ViT-L
- 'context_dim': 768, # Default SD1.5
- 'pooled_embeds': False,
- 'time_ids': False,
- 'dual_encoders': False
+ "text_encoder_dim": 768, # CLIP ViT-L
+ "context_dim": 768, # Default SD1.5
+ "pooled_embeds": False,
+ "time_ids": False,
+ "dual_encoders": False,
}
-
+
if self.is_sdxl:
- spec.update({
- 'text_encoder_dim': 768, # CLIP ViT-L
- 'text_encoder_2_dim': 1280, # OpenCLIP ViT-bigG
- 'context_dim': 2048, # Concatenated 768 + 1280
- 'pooled_embeds': True, # Pooled text embeddings
- 'time_ids': self.has_time_cond, # Size/crop conditioning
- 'dual_encoders': True
- })
-
+ spec.update(
+ {
+ "text_encoder_dim": 768, # CLIP ViT-L
+ "text_encoder_2_dim": 1280, # OpenCLIP ViT-bigG
+ "context_dim": 2048, # Concatenated 768 + 1280
+ "pooled_embeds": True, # Pooled text embeddings
+ "time_ids": self.has_time_cond, # Size/crop conditioning
+ "dual_encoders": True,
+ }
+ )
+
return spec
-
- def create_sample_conditioning(self, batch_size: int = 1, device: str = 'cuda') -> Dict[str, torch.Tensor]:
+
+ def create_sample_conditioning(self, batch_size: int = 1, device: str = "cuda") -> Dict[str, torch.Tensor]:
"""Create sample conditioning tensors for testing/export"""
spec = self.get_conditioning_spec()
dtype = torch.float16
-
+
conditioning = {
- 'encoder_hidden_states': torch.randn(
- batch_size, 77, spec['context_dim'],
- device=device, dtype=dtype
- )
+ "encoder_hidden_states": torch.randn(batch_size, 77, spec["context_dim"], device=device, dtype=dtype)
}
-
- if spec['pooled_embeds']:
- conditioning['text_embeds'] = torch.randn(
- batch_size, spec['text_encoder_2_dim'],
- device=device, dtype=dtype
+
+ if spec["pooled_embeds"]:
+ conditioning["text_embeds"] = torch.randn(
+ batch_size, spec["text_encoder_2_dim"], device=device, dtype=dtype
)
-
- if spec['time_ids']:
- conditioning['time_ids'] = torch.randn(
- batch_size, 6, # [height, width, crop_h, crop_w, target_height, target_width]
- device=device, dtype=dtype
+
+ if spec["time_ids"]:
+ conditioning["time_ids"] = torch.randn(
+ batch_size,
+ 6, # [height, width, crop_h, crop_w, target_height, target_width]
+ device=device,
+ dtype=dtype,
)
-
+
return conditioning
-
+
def test_unet_conditioning(self, unet: UNet2DConditionModel) -> Dict[str, bool]:
"""Test what conditioning the UNet actually supports"""
- results = {
- 'basic': False,
- 'added_cond_kwargs': False,
- 'separate_args': False
- }
-
+ results = {"basic": False, "added_cond_kwargs": False, "separate_args": False}
+
try:
# Ensure model is on CUDA and in eval mode for testing
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ device = "cuda" if torch.cuda.is_available() else "cpu"
unet_test = unet.to(device).eval()
-
+
# Create test inputs on the same device
sample = torch.randn(1, 4, 8, 8, device=device, dtype=torch.float16)
timestep = torch.tensor([0.5], device=device, dtype=torch.float32)
conditioning = self.create_sample_conditioning(1, device=device)
-
+
# Test basic call
try:
with torch.no_grad():
- _ = unet_test(sample, timestep, conditioning['encoder_hidden_states'])
- results['basic'] = True
+ _ = unet_test(sample, timestep, conditioning["encoder_hidden_states"])
+ results["basic"] = True
except Exception:
pass
-
+
# Test added_cond_kwargs (standard SDXL)
if self.is_sdxl:
try:
added_cond = {}
- if 'text_embeds' in conditioning:
- added_cond['text_embeds'] = conditioning['text_embeds']
- if 'time_ids' in conditioning:
- added_cond['time_ids'] = conditioning['time_ids']
-
+ if "text_embeds" in conditioning:
+ added_cond["text_embeds"] = conditioning["text_embeds"]
+ if "time_ids" in conditioning:
+ added_cond["time_ids"] = conditioning["time_ids"]
+
with torch.no_grad():
- _ = unet_test(sample, timestep, conditioning['encoder_hidden_states'],
- added_cond_kwargs=added_cond)
- results['added_cond_kwargs'] = True
+ _ = unet_test(
+ sample, timestep, conditioning["encoder_hidden_states"], added_cond_kwargs=added_cond
+ )
+ results["added_cond_kwargs"] = True
except Exception:
pass
-
+
# Test separate arguments (some implementations)
try:
- args = [sample, timestep, conditioning['encoder_hidden_states']]
- if 'text_embeds' in conditioning:
- args.append(conditioning['text_embeds'])
- if 'time_ids' in conditioning:
- args.append(conditioning['time_ids'])
-
+ args = [sample, timestep, conditioning["encoder_hidden_states"]]
+ if "text_embeds" in conditioning:
+ args.append(conditioning["text_embeds"])
+ if "time_ids" in conditioning:
+ args.append(conditioning["time_ids"])
+
with torch.no_grad():
_ = unet_test(*args)
- results['separate_args'] = True
+ results["separate_args"] = True
except Exception:
pass
-
+
except Exception as e:
# If testing fails completely, provide safe defaults
print(f"β οΈ UNet conditioning test setup failed: {e}")
results = {
- 'basic': True, # Assume basic call works
- 'added_cond_kwargs': self.is_sdxl, # Assume SDXL models support this
- 'separate_args': False
+ "basic": True, # Assume basic call works
+ "added_cond_kwargs": self.is_sdxl, # Assume SDXL models support this
+ "separate_args": False,
}
-
+
return results
def get_onnx_export_spec(self) -> Dict[str, Any]:
"""Get specification for ONNX export"""
spec = self.conditioning_handler.get_conditioning_spec()
-
+
# Add export-specific details
- spec.update({
- 'input_names': ['sample', 'timestep', 'encoder_hidden_states'],
- 'output_names': ['noise_pred'],
- 'dynamic_axes': {
- 'sample': {0: 'batch_size'},
- 'timestep': {0: 'batch_size'},
- 'encoder_hidden_states': {0: 'batch_size'},
- 'noise_pred': {0: 'batch_size'}
+ spec.update(
+ {
+ "input_names": ["sample", "timestep", "encoder_hidden_states"],
+ "output_names": ["noise_pred"],
+ "dynamic_axes": {
+ "sample": {0: "batch_size"},
+ "timestep": {0: "batch_size"},
+ "encoder_hidden_states": {0: "batch_size"},
+ "noise_pred": {0: "batch_size"},
+ },
}
- })
-
+ )
+
# Add SDXL-specific inputs if supported
- if self.is_sdxl and self.supported_calls['added_cond_kwargs']:
- if spec['pooled_embeds']:
- spec['input_names'].append('text_embeds')
- spec['dynamic_axes']['text_embeds'] = {0: 'batch_size'}
-
- if spec['time_ids']:
- spec['input_names'].append('time_ids')
- spec['dynamic_axes']['time_ids'] = {0: 'batch_size'}
-
- return spec
+ if self.is_sdxl and self.supported_calls["added_cond_kwargs"]:
+ if spec["pooled_embeds"]:
+ spec["input_names"].append("text_embeds")
+ spec["dynamic_axes"]["text_embeds"] = {0: "batch_size"}
+
+ if spec["time_ids"]:
+ spec["input_names"].append("time_ids")
+ spec["dynamic_axes"]["time_ids"] = {0: "batch_size"}
+ return spec
def get_sdxl_tensorrt_config(model_path: str, unet: UNet2DConditionModel) -> Dict[str, Any]:
"""Get complete TensorRT configuration for SDXL model"""
# Use the new detection function
detection_result = detect_model(unet)
-
+
# Create a config dict compatible with SDXLConditioningHandler
config = {
- 'is_sdxl': detection_result['is_sdxl'],
- 'has_time_cond': detection_result['architecture_details']['has_time_conditioning'],
- 'has_addition_embed': detection_result['architecture_details']['has_addition_embeds'],
- 'model_type': detection_result['model_type'],
- 'is_turbo': detection_result['is_turbo'],
- 'is_sd3': detection_result['is_sd3'],
- 'confidence': detection_result['confidence'],
- 'architecture_details': detection_result['architecture_details'],
- 'compatibility_info': detection_result['compatibility_info']
+ "is_sdxl": detection_result["is_sdxl"],
+ "has_time_cond": detection_result["architecture_details"]["has_time_conditioning"],
+ "has_addition_embed": detection_result["architecture_details"]["has_addition_embeds"],
+ "model_type": detection_result["model_type"],
+ "is_turbo": detection_result["is_turbo"],
+ "is_sd3": detection_result["is_sd3"],
+ "confidence": detection_result["confidence"],
+ "architecture_details": detection_result["architecture_details"],
+ "compatibility_info": detection_result["compatibility_info"],
}
-
+
# Add conditioning specification
conditioning_handler = SDXLConditioningHandler(config)
- config['conditioning_spec'] = conditioning_handler.get_conditioning_spec()
-
- return config
\ No newline at end of file
+ config["conditioning_spec"] = conditioning_handler.get_conditioning_spec()
+
+ return config
diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py
index 1c87efbf9..cb4765b82 100644
--- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py
+++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py
@@ -1,23 +1,33 @@
+from typing import List, Optional
+
import torch
from diffusers import UNet2DConditionModel
-from typing import Optional, List
+
+from streamdiffusion._compat.diffusers_kvo_patch import apply as _apply_kvo_patch
+
+
+_apply_kvo_patch() # ensure kvo_cache patch is present even if diffusers was imported first
+
+from ..models.utils import convert_list_to_structure
from .unet_controlnet_export import create_controlnet_wrapper
from .unet_ipadapter_export import create_ipadapter_wrapper
-from ..models.utils import convert_list_to_structure
+
class UnifiedExportWrapper(torch.nn.Module):
"""
- Unified wrapper that composes wrappers for conditioning modules.
+ Unified wrapper that composes wrappers for conditioning modules.
"""
-
- def __init__(self,
- unet: UNet2DConditionModel,
- use_controlnet: bool = False,
- use_ipadapter: bool = False,
- control_input_names: Optional[List[str]] = None,
- num_tokens: int = 4,
- kvo_cache_structure: List[int] = [],
- **kwargs):
+
+ def __init__(
+ self,
+ unet: UNet2DConditionModel,
+ use_controlnet: bool = False,
+ use_ipadapter: bool = False,
+ control_input_names: Optional[List[str]] = None,
+ num_tokens: int = 4,
+ kvo_cache_structure: List[int] = [],
+ **kwargs,
+ ):
super().__init__()
self.use_controlnet = use_controlnet
self.use_ipadapter = use_ipadapter
@@ -25,23 +35,24 @@ def __init__(self,
self.ipadapter_wrapper = None
self.unet = unet
self.kvo_cache_structure = kvo_cache_structure
-
+
# Apply IPAdapter first (installs processors into UNet)
if use_ipadapter:
- ipadapter_kwargs = {k: v for k, v in kwargs.items() if k in ['install_processors']}
- if 'install_processors' not in ipadapter_kwargs:
- ipadapter_kwargs['install_processors'] = True
-
+ ipadapter_kwargs = {k: v for k, v in kwargs.items() if k in ["install_processors"]}
+ if "install_processors" not in ipadapter_kwargs:
+ ipadapter_kwargs["install_processors"] = True
self.ipadapter_wrapper = create_ipadapter_wrapper(unet, num_tokens=num_tokens, **ipadapter_kwargs)
self.unet = self.ipadapter_wrapper.unet
-
+
# Apply ControlNet second (wraps whatever UNet we have)
if use_controlnet and control_input_names:
- controlnet_kwargs = {k: v for k, v in kwargs.items() if k in ['num_controlnets', 'conditioning_scales']}
+ controlnet_kwargs = {k: v for k, v in kwargs.items() if k in ["num_controlnets", "conditioning_scales"]}
+
+ self.controlnet_wrapper = create_controlnet_wrapper(
+ self.unet, control_input_names, kvo_cache_structure, **controlnet_kwargs
+ )
- self.controlnet_wrapper = create_controlnet_wrapper(self.unet, control_input_names, kvo_cache_structure, **controlnet_kwargs)
-
def _basic_unet_forward(self, sample, timestep, encoder_hidden_states, *kvo_cache, **kwargs):
"""Basic UNet forward that passes through all parameters to handle any model type"""
formatted_kvo_cache = []
@@ -49,52 +60,57 @@ def _basic_unet_forward(self, sample, timestep, encoder_hidden_states, *kvo_cach
formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure)
# Auto-generate SDXL conditioning if missing and UNet requires it
- if 'added_cond_kwargs' not in kwargs or kwargs.get('added_cond_kwargs') is None:
+ if "added_cond_kwargs" not in kwargs or kwargs.get("added_cond_kwargs") is None:
base_unet = self.unet
- if (hasattr(base_unet, 'config') and
- getattr(base_unet.config, 'addition_embed_type', None) == 'text_time'):
+ if hasattr(base_unet, "config") and getattr(base_unet.config, "addition_embed_type", None) == "text_time":
batch_size = sample.shape[0]
- kwargs['added_cond_kwargs'] = {
- 'text_embeds': torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype),
- 'time_ids': torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype),
+ kwargs["added_cond_kwargs"] = {
+ "text_embeds": torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype),
+ "time_ids": torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype),
}
unet_kwargs = {
- 'sample': sample,
- 'timestep': timestep,
- 'encoder_hidden_states': encoder_hidden_states,
- 'return_dict': False,
- 'kvo_cache': formatted_kvo_cache,
- **kwargs # Pass through all additional parameters (SDXL, future model types, etc.)
+ "sample": sample,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "return_dict": False,
+ "kvo_cache": formatted_kvo_cache,
+ **kwargs, # Pass through all additional parameters (SDXL, future model types, etc.)
}
res = self.unet(**unet_kwargs)
if len(kvo_cache) > 0:
return res
else:
return res[0]
-
- def forward(self,
- sample: torch.Tensor,
- timestep: torch.Tensor,
- encoder_hidden_states: torch.Tensor,
- *args,
- **kwargs) -> torch.Tensor:
+
+ def forward(
+ self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, **kwargs
+ ) -> torch.Tensor:
"""Forward pass that handles any UNet parameters via **kwargs passthrough"""
# Handle IP-Adapter runtime scale vector as a positional argument placed before control tensors
if self.use_ipadapter and self.ipadapter_wrapper is not None:
# ipadapter_scale is appended as the first extra positional input after the 3 base inputs
if len(args) == 0:
import logging
- logging.getLogger(__name__).error("UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True")
+
+ logging.getLogger(__name__).error(
+ "UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True"
+ )
raise RuntimeError("UnifiedExportWrapper: ipadapter_scale tensor is required when use_ipadapter=True")
ipadapter_scale = args[0]
if not isinstance(ipadapter_scale, torch.Tensor):
import logging
- logging.getLogger(__name__).error(f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}")
+
+ logging.getLogger(__name__).error(
+ f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}"
+ )
raise TypeError("ipadapter_scale must be a torch.Tensor")
try:
import logging
- logging.getLogger(__name__).debug(f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}")
+
+ logging.getLogger(__name__).debug(
+ f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}"
+ )
except Exception:
pass
# assign per-layer scale tensors into processors
@@ -107,4 +123,4 @@ def forward(self,
return self.controlnet_wrapper(sample, timestep, encoder_hidden_states, *args, **kwargs)
else:
# Basic UNet call with all parameters passed through
- return self._basic_unet_forward(sample, timestep, encoder_hidden_states, *args, **kwargs)
\ No newline at end of file
+ return self._basic_unet_forward(sample, timestep, encoder_hidden_states, *args, **kwargs)
diff --git a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py
index 762c638ac..7cd517f4f 100644
--- a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py
+++ b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py
@@ -302,6 +302,59 @@ def _read_onnx_input_specs(onnx_path: str) -> Dict[str, tuple]:
return result
+def _assert_finite_qdq_scales(onnx_path: str) -> None:
+ """Raise RuntimeError if any FP8 Q/DQ scale in onnx_path is non-finite.
+
+ Root cause: modelopt/onnx/quantization/fp8.py computes
+ np_fp8_scale = (np_scale * 448.0) / 127.0
+ in the source dtype (FP16). An INT8 amax > ~18,500 overflows to +inf;
+ the resulting Q/DQ scale produces zero output at inference.
+
+ Checks both initializer scales AND Constant-node scales (the latter appear
+ on some residual-add Q nodes injected by modelopt on SDXL UNet).
+ On failure: raises RuntimeError listing the first 5 offending node names
+ and directs the user to extend _DEFAULT_EXCLUDE_PATTERNS and delete the
+ cached .fp8.onnx so modelopt reruns without those layers.
+ """
+ import onnx as _onnx
+ from onnx import numpy_helper as _numpy_helper
+
+ model = _onnx.load(onnx_path, load_external_data=True)
+ graph = model.graph
+ init_map = {init.name: init for init in graph.initializer}
+ const_map: dict = {}
+ for node in graph.node:
+ if node.op_type == "Constant" and node.output:
+ attr = {a.name: a for a in node.attribute}
+ if "value" in attr:
+ const_map[node.output[0]] = _numpy_helper.to_array(attr["value"].t)
+
+ bad: list = []
+ for node in graph.node:
+ if node.op_type not in ("QuantizeLinear", "DequantizeLinear"):
+ continue
+ if len(node.input) < 2:
+ continue
+ scale_name = node.input[1]
+ if scale_name in init_map:
+ arr = _numpy_helper.to_array(init_map[scale_name]).flatten().astype(np.float64)
+ elif scale_name in const_map:
+ arr = const_map[scale_name].flatten().astype(np.float64)
+ else:
+ continue
+ if not np.isfinite(arr).all():
+ bad.append(node.name or scale_name)
+
+ if bad:
+ names = ", ".join(bad[:5]) + ("..." if len(bad) > 5 else "")
+ raise RuntimeError(
+ f"[FP8] Non-finite Q/DQ scale in {len(bad)} node(s): {names}. "
+ f"Add the offending layer substring(s) to _DEFAULT_EXCLUDE_PATTERNS "
+ f"in fp8_quantize.py, delete the cached .fp8.onnx, and rebuild. "
+ f"Diagnostic: modelopt/onnx/quantization/fp8.py overflow when INT8 amax > ~18500."
+ )
+
+
def quantize_onnx_fp8(
onnx_path: str,
output_path: str,
@@ -376,7 +429,8 @@ def quantize_onnx_fp8(
# e.g. SDXL exports `sample` as FP32 even though the unet runs FP16.
# modelopt's CalibrationDataProvider asserts strict count match and ORT's
# inference probe rejects dtype mismatches, so filter+cast accordingly.
- _onnx_inputs = {k: v[0] for k, v in _read_onnx_input_specs(onnx_path).items()}
+ _specs = _read_onnx_input_specs(onnx_path) # {name: (dtype, dims)}
+ _onnx_inputs = {k: v[0] for k, v in _specs.items()}
_dropped = set(calibration_data.keys()) - set(_onnx_inputs)
if _dropped:
logger.info(f"[FP8] Dropping calibration keys not exposed by ONNX: {sorted(_dropped)}")
@@ -389,6 +443,27 @@ def quantize_onnx_fp8(
logger.info(f"[FP8] Casting calibration '{_k}': {calibration_data[_k].dtype} β {_expected}")
calibration_data[_k] = calibration_data[_k].astype(_expected)
+ # Per-input tile: target rows = n_itr Γ resolved_dim0(name) so every input
+ # splits into exactly n_itr chunks of shape (resolved_dim0, ...).
+ # Mirrors modelopt CalibrationDataProvider: symbolic dims β 1, static dims kept.
+ # NaΓ―ve _max_rows tile breaks kvo_cache_in_* (ONNX dim0=2 static) by pumping
+ # sample to 2Γ_n_itr rows, causing modelopt to split kvo into (1,...) chunks.
+ import math as _math
+
+ _resolved_dim0 = {name: max(1, (_specs[name][1][0] or 1)) for name in calibration_data}
+ _n_itr = max(arr.shape[0] // _resolved_dim0[name] for name, arr in calibration_data.items())
+ _n_itr = max(1, _n_itr)
+ for _k in list(calibration_data.keys()):
+ _arr = calibration_data[_k]
+ _target_rows = _n_itr * _resolved_dim0[_k]
+ if _arr.shape[0] != _target_rows:
+ _repeats = _math.ceil(_target_rows / max(1, _arr.shape[0]))
+ calibration_data[_k] = np.tile(_arr, (_repeats,) + (1,) * (_arr.ndim - 1))[:_target_rows]
+ logger.info(
+ f"[FP8] Tiled '{_k}' {_arr.shape[0]} β {_target_rows} rows "
+ f"(n_itr={_n_itr} Γ resolved_dim0={_resolved_dim0[_k]})"
+ )
+
import inspect as _inspect
_params = set(_inspect.signature(modelopt_quantize).parameters.keys())
@@ -421,6 +496,8 @@ def quantize_onnx_fp8(
if not os.path.exists(output_path):
raise RuntimeError(f"[FP8] modelopt_quantize completed but output not found: {output_path}")
+ _assert_finite_qdq_scales(output_path)
+
size_mb = os.path.getsize(output_path) / (1024**2)
logger.info(f"[FP8] FP8 ONNX written: {output_path} ({size_mb:.1f} MB)")
if size_mb > 5000:
diff --git a/src/streamdiffusion/acceleration/tensorrt/models/__init__.py b/src/streamdiffusion/acceleration/tensorrt/models/__init__.py
index f0f2c4b93..b6a1bd62d 100644
--- a/src/streamdiffusion/acceleration/tensorrt/models/__init__.py
+++ b/src/streamdiffusion/acceleration/tensorrt/models/__init__.py
@@ -1,13 +1,14 @@
-from .models import Optimizer, BaseModel, CLIP, UNet, VAE, VAEEncoder
-from .controlnet_models import ControlNetTRT, ControlNetSDXLTRT
+from .controlnet_models import ControlNetSDXLTRT, ControlNetTRT
+from .models import CLIP, VAE, BaseModel, Optimizer, UNet, VAEEncoder
+
__all__ = [
"Optimizer",
- "BaseModel",
+ "BaseModel",
"CLIP",
"UNet",
"VAE",
"VAEEncoder",
"ControlNetTRT",
"ControlNetSDXLTRT",
-]
\ No newline at end of file
+]
diff --git a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py
index 6179ffc9a..be41d502f 100644
--- a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py
+++ b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py
@@ -2,10 +2,10 @@
import torch
import torch.nn.functional as F
-
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
+
class CachedSTAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -28,9 +28,9 @@ def __init__(self):
# clone/contiguous path. Set to True by wrapper.py after engine build.
self._curr_key_buf: Optional[torch.Tensor] = None
self._curr_value_buf: Optional[torch.Tensor] = None
- self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key
+ self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key
self._cached_value_tr_buf: Optional[torch.Tensor] = None # transposed cache value
- self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H)
+ self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H)
self._use_prealloc: bool = False
def _ensure_buffers(
diff --git a/src/streamdiffusion/acceleration/tensorrt/models/models.py b/src/streamdiffusion/acceleration/tensorrt/models/models.py
index a85b2415d..d59e5a416 100644
--- a/src/streamdiffusion/acceleration/tensorrt/models/models.py
+++ b/src/streamdiffusion/acceleration/tensorrt/models/models.py
@@ -61,9 +61,9 @@ def infer_shapes(self, return_onnx=False):
onnx_graph = gs.export_onnx(self.graph)
if onnx_graph.ByteSize() > 2147483648:
print(
- f"β οΈ Model size ({onnx_graph.ByteSize() / (1024**3):.2f} GB) exceeds 2GB - this is normal for SDXL models"
+ f"[WARN] Model size ({onnx_graph.ByteSize() / (1024**3):.2f} GB) exceeds 2GB - this is normal for SDXL models"
)
- print("π§ ONNX shape inference will be skipped for large models to avoid memory issues")
+ print("[INFO] ONNX shape inference will be skipped for large models to avoid memory issues")
# For large models like SDXL, skip shape inference to avoid memory/size issues
# The model will still work with TensorRT's own shape inference during engine building
else:
diff --git a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py
index 165c261e4..7fa98f855 100644
--- a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py
+++ b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py
@@ -1,12 +1,13 @@
"""Runtime TensorRT engine wrappers."""
-from .unet_engine import UNet2DConditionModelEngine, AutoencoderKLEngine
-from .controlnet_engine import ControlNetModelEngine
from ..engine_manager import EngineManager
+from .controlnet_engine import ControlNetModelEngine
+from .unet_engine import AutoencoderKLEngine, UNet2DConditionModelEngine
+
__all__ = [
"UNet2DConditionModelEngine",
- "AutoencoderKLEngine",
+ "AutoencoderKLEngine",
"ControlNetModelEngine",
"EngineManager",
-]
\ No newline at end of file
+]
diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py
index b6b110946..0bbd1251e 100644
--- a/src/streamdiffusion/acceleration/tensorrt/utilities.py
+++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py
@@ -1015,6 +1015,11 @@ def infer(self, feed_dict, stream, use_cuda_graph=False):
if not noerror:
raise ValueError("ERROR: inference failed.")
stream.synchronize()
+ # Drain the legacy/NULL stream before capture. The polygraphy Stream
+ # is created via cudaStreamCreate (blocking), which implicitly syncs
+ # with legacy. Any pending GPU work on legacy at capture time triggers
+ # cudaErrorStreamCaptureInvalidated (901). One-time cost per engine.
+ torch.cuda.current_stream().synchronize()
# ThreadLocal mode: only captures ops on this thread's stream.
# Global mode would also capture any GPU work submitted from other
# threads (e.g. the TouchDesigner render thread), producing a
diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py
index 001fda3db..352d33d3b 100644
--- a/src/streamdiffusion/config.py
+++ b/src/streamdiffusion/config.py
@@ -1,91 +1,96 @@
-import os
-import sys
-import yaml
import json
-from typing import Dict, List, Optional, Union, Any, Tuple
from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import yaml
+
def load_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load StreamDiffusion configuration from YAML or JSON file"""
config_path = Path(config_path)
-
+
if not config_path.exists():
raise FileNotFoundError(f"load_config: Configuration file not found: {config_path}")
- with open(config_path, 'r', encoding='utf-8') as f:
- if config_path.suffix.lower() in ['.yaml', '.yml']:
+ with open(config_path, "r", encoding="utf-8") as f:
+ if config_path.suffix.lower() in [".yaml", ".yml"]:
config_data = yaml.safe_load(f)
- elif config_path.suffix.lower() == '.json':
+ elif config_path.suffix.lower() == ".json":
config_data = json.load(f)
else:
raise ValueError(f"load_config: Unsupported configuration file format: {config_path.suffix}")
-
+
_validate_config(config_data)
-
+
return config_data
def save_config(config: Dict[str, Any], config_path: Union[str, Path]) -> None:
"""Save StreamDiffusion configuration to YAML or JSON file"""
config_path = Path(config_path)
-
+
_validate_config(config)
config_path.parent.mkdir(parents=True, exist_ok=True)
- with open(config_path, 'w', encoding='utf-8') as f:
- if config_path.suffix.lower() in ['.yaml', '.yml']:
+ with open(config_path, "w", encoding="utf-8") as f:
+ if config_path.suffix.lower() in [".yaml", ".yml"]:
yaml.dump(config, f, default_flow_style=False, indent=2)
- elif config_path.suffix.lower() == '.json':
+ elif config_path.suffix.lower() == ".json":
json.dump(config, f, indent=2)
else:
raise ValueError(f"save_config: Unsupported configuration file format: {config_path.suffix}")
+
def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any:
"""Create StreamDiffusionWrapper from configuration dictionary
-
+
Prompt Interface:
- Legacy: Use 'prompt' field for single prompt
- New: Use 'prompt_blending' with 'prompt_list' for multiple weighted prompts
- If both are provided, 'prompt_blending' takes precedence and 'prompt' is ignored
- negative_prompt: Currently a single string (not list) for all prompt types
"""
+
from streamdiffusion import StreamDiffusionWrapper
- import torch
final_config = {**config, **overrides}
wrapper_params = _extract_wrapper_params(final_config)
wrapper = StreamDiffusionWrapper(**wrapper_params)
-
+
prepare_params = _extract_prepare_params(final_config)
# Handle prompt configuration with clear precedence
- if 'prompt_blending' in final_config:
+ if "prompt_blending" in final_config:
# Use prompt blending (new interface) - ignore legacy 'prompt' field
- blend_config = final_config['prompt_blending']
-
+ blend_config = final_config["prompt_blending"]
+
# Prepare with prompt blending directly using unified interface
- prepare_params_with_blending = {k: v for k, v in prepare_params.items()
- if k not in ['prompt_blending', 'seed_blending']}
- prepare_params_with_blending['prompt'] = blend_config.get('prompt_list', [])
- prepare_params_with_blending['prompt_interpolation_method'] = blend_config.get('interpolation_method', 'slerp')
-
+ prepare_params_with_blending = {
+ k: v for k, v in prepare_params.items() if k not in ["prompt_blending", "seed_blending"]
+ }
+ prepare_params_with_blending["prompt"] = blend_config.get("prompt_list", [])
+ prepare_params_with_blending["prompt_interpolation_method"] = blend_config.get("interpolation_method", "slerp")
+
# Add seed blending if configured
- if 'seed_blending' in final_config:
- seed_blend_config = final_config['seed_blending']
- prepare_params_with_blending['seed_list'] = seed_blend_config.get('seed_list', [])
- prepare_params_with_blending['seed_interpolation_method'] = seed_blend_config.get('interpolation_method', 'linear')
-
+ if "seed_blending" in final_config:
+ seed_blend_config = final_config["seed_blending"]
+ prepare_params_with_blending["seed_list"] = seed_blend_config.get("seed_list", [])
+ prepare_params_with_blending["seed_interpolation_method"] = seed_blend_config.get(
+ "interpolation_method", "linear"
+ )
+
wrapper.prepare(**prepare_params_with_blending)
- elif prepare_params.get('prompt'):
+ elif prepare_params.get("prompt"):
# Use legacy single prompt interface
- clean_prepare_params = {k: v for k, v in prepare_params.items()
- if k not in ['prompt_blending', 'seed_blending']}
+ clean_prepare_params = {
+ k: v for k, v in prepare_params.items() if k not in ["prompt_blending", "seed_blending"]
+ }
wrapper.prepare(**clean_prepare_params)
# Apply seed blending if configured and not already handled in prepare
- if 'seed_blending' in final_config and 'prompt_blending' not in final_config:
- seed_blend_config = final_config['seed_blending']
+ if "seed_blending" in final_config and "prompt_blending" not in final_config:
+ seed_blend_config = final_config["seed_blending"]
wrapper.update_stream_params(
- seed_list=seed_blend_config.get('seed_list', []),
- interpolation_method=seed_blend_config.get('interpolation_method', 'linear')
+ seed_list=seed_blend_config.get("seed_list", []),
+ interpolation_method=seed_blend_config.get("interpolation_method", "linear"),
)
return wrapper
@@ -93,256 +98,269 @@ def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any:
def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]:
"""Extract parameters for StreamDiffusionWrapper.__init__() from config"""
- import torch
param_map = {
- 'model_id_or_path': config.get('model_id', 'stabilityai/sd-turbo'),
- 't_index_list': config.get('t_index_list', [0, 16, 32, 45]),
- 'lora_dict': config.get('lora_dict'),
- 'mode': config.get('mode', 'img2img'),
- 'output_type': config.get('output_type', 'pil'),
- 'vae_id': config.get('vae_id'),
- 'device': config.get('device', 'cuda'),
- 'dtype': _parse_dtype(config.get('dtype', 'float16')),
- 'frame_buffer_size': config.get('frame_buffer_size', 1),
- 'width': config.get('width', 512),
- 'height': config.get('height', 512),
- 'warmup': config.get('warmup', 10),
- 'acceleration': config.get('acceleration', 'tensorrt'),
- 'do_add_noise': config.get('do_add_noise', True),
- 'device_ids': config.get('device_ids'),
- 'use_lcm_lora': config.get('use_lcm_lora'), # Backwards compatibility
- 'use_tiny_vae': config.get('use_tiny_vae', True),
- 'enable_similar_image_filter': config.get('enable_similar_image_filter', False),
- 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98),
- 'similar_image_filter_max_skip_frame': config.get('similar_image_filter_max_skip_frame', 10),
- 'similar_filter_sleep_fraction': config.get('similar_filter_sleep_fraction', 0.025),
- 'use_denoising_batch': config.get('use_denoising_batch', True),
- 'cfg_type': config.get('cfg_type', 'self'),
- 'seed': config.get('seed', 2),
- 'use_safety_checker': config.get('use_safety_checker', False),
- 'skip_diffusion': config.get('skip_diffusion', False),
- 'engine_dir': config.get('engine_dir', 'engines'),
- 'normalize_prompt_weights': config.get('normalize_prompt_weights', True),
- 'normalize_seed_weights': config.get('normalize_seed_weights', True),
- 'scheduler': config.get('scheduler', 'lcm'),
- 'sampler': config.get('sampler', 'normal'),
- 'compile_engines_only': config.get('compile_engines_only', False),
- 'static_shapes': config.get('static_shapes', False),
- 'fp8': config.get('fp8', False),
- 'builder_optimization_level': config.get('builder_optimization_level'),
- 'build_engines_if_missing': config.get('build_engines_if_missing', True),
- 'fp8_allow_fp16_fallback': config.get('fp8_allow_fp16_fallback', False),
+ "model_id_or_path": config.get("model_id", "stabilityai/sd-turbo"),
+ "t_index_list": config.get("t_index_list", [0, 16, 32, 45]),
+ "lora_dict": config.get("lora_dict"),
+ "mode": config.get("mode", "img2img"),
+ "output_type": config.get("output_type", "pil"),
+ "vae_id": config.get("vae_id"),
+ "device": config.get("device", "cuda"),
+ "dtype": _parse_dtype(config.get("dtype", "float16")),
+ "frame_buffer_size": config.get("frame_buffer_size", 1),
+ "width": config.get("width", 512),
+ "height": config.get("height", 512),
+ "warmup": config.get("warmup", 10),
+ "acceleration": config.get("acceleration", "tensorrt"),
+ "do_add_noise": config.get("do_add_noise", True),
+ "device_ids": config.get("device_ids"),
+ "use_lcm_lora": config.get("use_lcm_lora"), # Backwards compatibility
+ "use_tiny_vae": config.get("use_tiny_vae", True),
+ "enable_similar_image_filter": config.get("enable_similar_image_filter", False),
+ "similar_image_filter_threshold": config.get("similar_image_filter_threshold", 0.98),
+ "similar_image_filter_max_skip_frame": config.get("similar_image_filter_max_skip_frame", 10),
+ "similar_filter_sleep_fraction": config.get("similar_filter_sleep_fraction", 0.025),
+ "use_denoising_batch": config.get("use_denoising_batch", True),
+ "cfg_type": config.get("cfg_type", "self"),
+ "seed": config.get("seed", 2),
+ "use_safety_checker": config.get("use_safety_checker", False),
+ "skip_diffusion": config.get("skip_diffusion", False),
+ "engine_dir": config.get("engine_dir", "engines"),
+ "normalize_prompt_weights": config.get("normalize_prompt_weights", True),
+ "normalize_seed_weights": config.get("normalize_seed_weights", True),
+ "scheduler": config.get("scheduler", "lcm"),
+ "sampler": config.get("sampler", "normal"),
+ "compile_engines_only": config.get("compile_engines_only", False),
+ "static_shapes": config.get("static_shapes", False),
+ "fp8": config.get("fp8", False),
+ "builder_optimization_level": config.get("builder_optimization_level"),
+ "build_engines_if_missing": config.get("build_engines_if_missing", True),
+ "fp8_allow_fp16_fallback": config.get("fp8_allow_fp16_fallback", False),
}
- if 'controlnets' in config and config['controlnets']:
- param_map['use_controlnet'] = True
- param_map['controlnet_config'] = _prepare_controlnet_configs(config)
+ if "controlnets" in config and config["controlnets"]:
+ param_map["use_controlnet"] = True
+ param_map["controlnet_config"] = _prepare_controlnet_configs(config)
else:
- param_map['use_controlnet'] = config.get('use_controlnet', False)
- param_map['controlnet_config'] = config.get('controlnet_config')
-
+ param_map["use_controlnet"] = config.get("use_controlnet", False)
+ param_map["controlnet_config"] = config.get("controlnet_config")
+
# Set IPAdapter usage if IPAdapters are configured
- if 'ipadapters' in config and config['ipadapters']:
- param_map['use_ipadapter'] = True
- param_map['ipadapter_config'] = _prepare_ipadapter_configs(config)
+ if "ipadapters" in config and config["ipadapters"]:
+ param_map["use_ipadapter"] = True
+ param_map["ipadapter_config"] = _prepare_ipadapter_configs(config)
else:
- param_map['use_ipadapter'] = config.get('use_ipadapter', False)
- param_map['ipadapter_config'] = config.get('ipadapter_config')
-
- param_map['use_cached_attn'] = config.get('use_cached_attn', False)
-
- param_map['cache_maxframes'] = config.get('cache_maxframes', 1)
- param_map['cache_interval'] = config.get('cache_interval', 1)
-
+ param_map["use_ipadapter"] = config.get("use_ipadapter", False)
+ param_map["ipadapter_config"] = config.get("ipadapter_config")
+
+ param_map["use_cached_attn"] = config.get("use_cached_attn", False)
+
+ param_map["cache_maxframes"] = config.get("cache_maxframes", 1)
+ param_map["cache_interval"] = config.get("cache_interval", 1)
+
+ # CUDA IPC output (SDβTD zero-copy GPU transport via cuda-link)
+ param_map["use_cuda_ipc_output"] = config.get("use_cuda_ipc_output", False)
+ param_map["cuda_ipc_shm_name"] = config.get("cuda_ipc_shm_name")
+ param_map["cuda_ipc_num_slots"] = config.get("cuda_ipc_num_slots", 2)
+
# Pipeline hook configurations (Phase 4: Configuration Integration)
hook_configs = _prepare_pipeline_hook_configs(config)
param_map.update(hook_configs)
-
+
return {k: v for k, v in param_map.items() if v is not None}
def _extract_prepare_params(config: Dict[str, Any]) -> Dict[str, Any]:
"""Extract parameters for wrapper.prepare() from config"""
prepare_params = {
- 'prompt': config.get('prompt', ''),
- 'negative_prompt': config.get('negative_prompt', ''),
- 'num_inference_steps': config.get('num_inference_steps', 50),
- 'guidance_scale': config.get('guidance_scale', 1.2),
- 'delta': config.get('delta', 1.0),
+ "prompt": config.get("prompt", ""),
+ "negative_prompt": config.get("negative_prompt", ""),
+ "num_inference_steps": config.get("num_inference_steps", 50),
+ "guidance_scale": config.get("guidance_scale", 1.2),
+ "delta": config.get("delta", 1.0),
}
-
+
# Handle prompt blending configuration
- if 'prompt_blending' in config:
- blend_config = config['prompt_blending']
- prepare_params['prompt_blending'] = {
- 'prompt_list': blend_config.get('prompt_list', []),
- 'interpolation_method': blend_config.get('interpolation_method', 'slerp'),
- 'enable_caching': blend_config.get('enable_caching', True)
+ if "prompt_blending" in config:
+ blend_config = config["prompt_blending"]
+ prepare_params["prompt_blending"] = {
+ "prompt_list": blend_config.get("prompt_list", []),
+ "interpolation_method": blend_config.get("interpolation_method", "slerp"),
+ "enable_caching": blend_config.get("enable_caching", True),
}
-
+
# Handle seed blending configuration
- if 'seed_blending' in config:
- seed_blend_config = config['seed_blending']
- prepare_params['seed_blending'] = {
- 'seed_list': seed_blend_config.get('seed_list', []),
- 'interpolation_method': seed_blend_config.get('interpolation_method', 'linear'),
- 'enable_caching': seed_blend_config.get('enable_caching', True)
+ if "seed_blending" in config:
+ seed_blend_config = config["seed_blending"]
+ prepare_params["seed_blending"] = {
+ "seed_list": seed_blend_config.get("seed_list", []),
+ "interpolation_method": seed_blend_config.get("interpolation_method", "linear"),
+ "enable_caching": seed_blend_config.get("enable_caching", True),
}
-
+
return prepare_params
+
def _prepare_controlnet_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Prepare ControlNet configurations for wrapper"""
controlnet_configs = []
- pipeline_type = config.get('pipeline_type', 'sd1.5')
- for cn_config in config['controlnets']:
+ pipeline_type = config.get("pipeline_type", "sd1.5")
+ for cn_config in config["controlnets"]:
controlnet_config = {
- 'model_id': cn_config['model_id'],
- 'preprocessor': cn_config.get('preprocessor', 'passthrough'),
- 'conditioning_scale': cn_config.get('conditioning_scale', 1.0),
- 'enabled': cn_config.get('enabled', True),
- 'preprocessor_params': cn_config.get('preprocessor_params'),
- 'conditioning_channels': cn_config.get('conditioning_channels'),
- 'pipeline_type': pipeline_type,
- 'control_guidance_start': cn_config.get('control_guidance_start', 0.0),
- 'control_guidance_end': cn_config.get('control_guidance_end', 1.0),
+ "model_id": cn_config["model_id"],
+ "preprocessor": cn_config.get("preprocessor", "passthrough"),
+ "conditioning_scale": cn_config.get("conditioning_scale", 1.0),
+ "enabled": cn_config.get("enabled", True),
+ "preprocessor_params": cn_config.get("preprocessor_params"),
+ "conditioning_channels": cn_config.get("conditioning_channels"),
+ "pipeline_type": pipeline_type,
+ "control_guidance_start": cn_config.get("control_guidance_start", 0.0),
+ "control_guidance_end": cn_config.get("control_guidance_end", 1.0),
}
controlnet_configs.append(controlnet_config)
-
+
return controlnet_configs
def _prepare_ipadapter_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Prepare IPAdapter configurations for wrapper"""
ipadapter_configs = []
-
- for ip_config in config['ipadapters']:
+
+ for ip_config in config["ipadapters"]:
ipadapter_config = {
- 'ipadapter_model_path': ip_config['ipadapter_model_path'],
- 'image_encoder_path': ip_config['image_encoder_path'],
- 'style_image': ip_config.get('style_image'),
- 'scale': ip_config.get('scale', 1.0),
- 'enabled': ip_config.get('enabled', True),
+ "ipadapter_model_path": ip_config["ipadapter_model_path"],
+ "image_encoder_path": ip_config["image_encoder_path"],
+ "style_image": ip_config.get("style_image"),
+ "scale": ip_config.get("scale", 1.0),
+ "enabled": ip_config.get("enabled", True),
# Preserve FaceID options from config for downstream wrapper/module handling
- 'type': ip_config.get('type', 'regular'),
- 'insightface_model_name': ip_config.get('insightface_model_name'),
+ "type": ip_config.get("type", "regular"),
+ "insightface_model_name": ip_config.get("insightface_model_name"),
}
ipadapter_configs.append(ipadapter_config)
-
+
return ipadapter_configs
def _prepare_pipeline_hook_configs(config: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare pipeline hook configurations for wrapper following ControlNet/IPAdapter pattern"""
hook_configs = {}
-
+
# Image preprocessing hooks
- if 'image_preprocessing' in config and config['image_preprocessing']:
- if config['image_preprocessing'].get('enabled', True):
- hook_configs['image_preprocessing_config'] = _prepare_single_hook_config(
- config['image_preprocessing'], 'image_preprocessing'
+ if "image_preprocessing" in config and config["image_preprocessing"]:
+ if config["image_preprocessing"].get("enabled", True):
+ hook_configs["image_preprocessing_config"] = _prepare_single_hook_config(
+ config["image_preprocessing"], "image_preprocessing"
)
-
- # Image postprocessing hooks
- if 'image_postprocessing' in config and config['image_postprocessing']:
- if config['image_postprocessing'].get('enabled', True):
- hook_configs['image_postprocessing_config'] = _prepare_single_hook_config(
- config['image_postprocessing'], 'image_postprocessing'
+
+ # Image postprocessing hooks
+ if "image_postprocessing" in config and config["image_postprocessing"]:
+ if config["image_postprocessing"].get("enabled", True):
+ hook_configs["image_postprocessing_config"] = _prepare_single_hook_config(
+ config["image_postprocessing"], "image_postprocessing"
)
-
+
# Latent preprocessing hooks
- if 'latent_preprocessing' in config and config['latent_preprocessing']:
- if config['latent_preprocessing'].get('enabled', True):
- hook_configs['latent_preprocessing_config'] = _prepare_single_hook_config(
- config['latent_preprocessing'], 'latent_preprocessing'
+ if "latent_preprocessing" in config and config["latent_preprocessing"]:
+ if config["latent_preprocessing"].get("enabled", True):
+ hook_configs["latent_preprocessing_config"] = _prepare_single_hook_config(
+ config["latent_preprocessing"], "latent_preprocessing"
)
-
+
# Latent postprocessing hooks
- if 'latent_postprocessing' in config and config['latent_postprocessing']:
- if config['latent_postprocessing'].get('enabled', True):
- hook_configs['latent_postprocessing_config'] = _prepare_single_hook_config(
- config['latent_postprocessing'], 'latent_postprocessing'
+ if "latent_postprocessing" in config and config["latent_postprocessing"]:
+ if config["latent_postprocessing"].get("enabled", True):
+ hook_configs["latent_postprocessing_config"] = _prepare_single_hook_config(
+ config["latent_postprocessing"], "latent_postprocessing"
)
-
+
return hook_configs
def _prepare_single_hook_config(hook_config: Dict[str, Any], hook_type: str) -> Dict[str, Any]:
"""Prepare configuration for a single hook type"""
return {
- 'enabled': hook_config.get('enabled', True),
- 'processors': hook_config.get('processors', []),
- 'hook_type': hook_type,
+ "enabled": hook_config.get("enabled", True),
+ "processors": hook_config.get("processors", []),
+ "hook_type": hook_type,
}
def _validate_pipeline_hook_configs(config: Dict[str, Any]) -> None:
"""Validate pipeline hook configurations following ControlNet/IPAdapter validation pattern"""
- hook_types = ['image_preprocessing', 'image_postprocessing', 'latent_preprocessing', 'latent_postprocessing']
-
+ hook_types = ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"]
+
for hook_type in hook_types:
if hook_type in config:
hook_config = config[hook_type]
if not isinstance(hook_config, dict):
raise ValueError(f"_validate_config: '{hook_type}' must be a dictionary")
-
+
# Validate enabled field
- if 'enabled' in hook_config:
- enabled = hook_config['enabled']
+ if "enabled" in hook_config:
+ enabled = hook_config["enabled"]
if not isinstance(enabled, bool):
raise ValueError(f"_validate_config: '{hook_type}.enabled' must be a boolean")
-
+
# Validate processors field
- if 'processors' in hook_config:
- processors = hook_config['processors']
+ if "processors" in hook_config:
+ processors = hook_config["processors"]
if not isinstance(processors, list):
raise ValueError(f"_validate_config: '{hook_type}.processors' must be a list")
-
+
for i, processor in enumerate(processors):
if not isinstance(processor, dict):
raise ValueError(f"_validate_config: '{hook_type}.processors[{i}]' must be a dictionary")
-
+
# Validate processor type (required)
- if 'type' not in processor:
- raise ValueError(f"_validate_config: '{hook_type}.processors[{i}]' missing required 'type' field")
-
- if not isinstance(processor['type'], str):
+ if "type" not in processor:
+ raise ValueError(
+ f"_validate_config: '{hook_type}.processors[{i}]' missing required 'type' field"
+ )
+
+ if not isinstance(processor["type"], str):
raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].type' must be a string")
-
+
# Validate enabled field (optional, defaults to True)
- if 'enabled' in processor:
- enabled = processor['enabled']
+ if "enabled" in processor:
+ enabled = processor["enabled"]
if not isinstance(enabled, bool):
- raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].enabled' must be a boolean")
-
+ raise ValueError(
+ f"_validate_config: '{hook_type}.processors[{i}].enabled' must be a boolean"
+ )
+
# Validate order field (optional)
- if 'order' in processor:
- order = processor['order']
+ if "order" in processor:
+ order = processor["order"]
if not isinstance(order, int):
- raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].order' must be an integer")
-
+ raise ValueError(
+ f"_validate_config: '{hook_type}.processors[{i}].order' must be an integer"
+ )
+
# Validate params field (optional, coerce None to empty dict)
- if 'params' in processor:
- if processor['params'] is None:
- processor['params'] = {}
- elif not isinstance(processor['params'], dict):
- raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].params' must be a dictionary")
+ if "params" in processor:
+ if processor["params"] is None:
+ processor["params"] = {}
+ elif not isinstance(processor["params"], dict):
+ raise ValueError(
+ f"_validate_config: '{hook_type}.processors[{i}].params' must be a dictionary"
+ )
def create_prompt_blending_config(
base_config: Dict[str, Any],
prompt_list: List[Tuple[str, float]],
prompt_interpolation_method: str = "slerp",
- enable_caching: bool = True
+ enable_caching: bool = True,
) -> Dict[str, Any]:
"""Create a configuration with prompt blending settings"""
config = base_config.copy()
-
- config['prompt_blending'] = {
- 'prompt_list': prompt_list,
- 'interpolation_method': prompt_interpolation_method,
- 'enable_caching': enable_caching
+
+ config["prompt_blending"] = {
+ "prompt_list": prompt_list,
+ "interpolation_method": prompt_interpolation_method,
+ "enable_caching": enable_caching,
}
-
+
return config
@@ -350,150 +368,152 @@ def create_seed_blending_config(
base_config: Dict[str, Any],
seed_list: List[Tuple[int, float]],
interpolation_method: str = "linear",
- enable_caching: bool = True
+ enable_caching: bool = True,
) -> Dict[str, Any]:
"""Create a configuration with seed blending settings"""
config = base_config.copy()
-
- config['seed_blending'] = {
- 'seed_list': seed_list,
- 'interpolation_method': interpolation_method,
- 'enable_caching': enable_caching
+
+ config["seed_blending"] = {
+ "seed_list": seed_list,
+ "interpolation_method": interpolation_method,
+ "enable_caching": enable_caching,
}
-
+
return config
def set_normalize_weights_config(
- base_config: Dict[str, Any],
- normalize_prompt_weights: bool = True,
- normalize_seed_weights: bool = True
+ base_config: Dict[str, Any], normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True
) -> Dict[str, Any]:
"""Create a configuration with separate normalize weight settings"""
config = base_config.copy()
-
- config['normalize_prompt_weights'] = normalize_prompt_weights
- config['normalize_seed_weights'] = normalize_seed_weights
-
+
+ config["normalize_prompt_weights"] = normalize_prompt_weights
+ config["normalize_seed_weights"] = normalize_seed_weights
+
return config
+
def _parse_dtype(dtype_str: str) -> Any:
"""Parse dtype string to torch dtype"""
import torch
-
+
dtype_map = {
- 'float16': torch.float16,
- 'float32': torch.float32,
- 'half': torch.float16,
- 'float': torch.float32,
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "half": torch.float16,
+ "float": torch.float32,
}
-
+
if isinstance(dtype_str, str):
return dtype_map.get(dtype_str.lower(), torch.float16)
return dtype_str # Assume it's already a torch dtype
+
+
def _validate_config(config: Dict[str, Any]) -> None:
"""Basic validation of configuration dictionary"""
if not isinstance(config, dict):
raise ValueError("_validate_config: Configuration must be a dictionary")
-
- if 'model_id' not in config:
+
+ if "model_id" not in config:
raise ValueError("_validate_config: Missing required field: model_id")
-
- if 'controlnets' in config:
- if not isinstance(config['controlnets'], list):
+
+ if "controlnets" in config:
+ if not isinstance(config["controlnets"], list):
raise ValueError("_validate_config: 'controlnets' must be a list")
-
- for i, controlnet in enumerate(config['controlnets']):
+
+ for i, controlnet in enumerate(config["controlnets"]):
if not isinstance(controlnet, dict):
raise ValueError(f"_validate_config: ControlNet {i} must be a dictionary")
-
- if 'model_id' not in controlnet:
+
+ if "model_id" not in controlnet:
raise ValueError(f"_validate_config: ControlNet {i} missing required 'model_id'")
-
+
# Validate conditioning_channels if present
- if 'conditioning_channels' in controlnet:
- channels = controlnet['conditioning_channels']
+ if "conditioning_channels" in controlnet:
+ channels = controlnet["conditioning_channels"]
if not isinstance(channels, int) or channels <= 0:
- raise ValueError(f"_validate_config: ControlNet {i} 'conditioning_channels' must be a positive integer, got {channels}")
-
+ raise ValueError(
+ f"_validate_config: ControlNet {i} 'conditioning_channels' must be a positive integer, got {channels}"
+ )
+
# Validate ipadapters if present
- if 'ipadapters' in config:
- if not isinstance(config['ipadapters'], list):
+ if "ipadapters" in config:
+ if not isinstance(config["ipadapters"], list):
raise ValueError("_validate_config: 'ipadapters' must be a list")
-
- for i, ipadapter in enumerate(config['ipadapters']):
+
+ for i, ipadapter in enumerate(config["ipadapters"]):
if not isinstance(ipadapter, dict):
raise ValueError(f"_validate_config: IPAdapter {i} must be a dictionary")
-
- if 'ipadapter_model_path' not in ipadapter:
+
+ if "ipadapter_model_path" not in ipadapter:
raise ValueError(f"_validate_config: IPAdapter {i} missing required 'ipadapter_model_path'")
-
- if 'image_encoder_path' not in ipadapter:
+
+ if "image_encoder_path" not in ipadapter:
raise ValueError(f"_validate_config: IPAdapter {i} missing required 'image_encoder_path'")
# Validate prompt blending configuration if present
- if 'prompt_blending' in config:
- blend_config = config['prompt_blending']
+ if "prompt_blending" in config:
+ blend_config = config["prompt_blending"]
if not isinstance(blend_config, dict):
raise ValueError("_validate_config: 'prompt_blending' must be a dictionary")
-
- if 'prompt_list' in blend_config:
- prompt_list = blend_config['prompt_list']
+
+ if "prompt_list" in blend_config:
+ prompt_list = blend_config["prompt_list"]
if not isinstance(prompt_list, list):
raise ValueError("_validate_config: 'prompt_list' must be a list")
-
+
for i, prompt_item in enumerate(prompt_list):
if not isinstance(prompt_item, (list, tuple)) or len(prompt_item) != 2:
raise ValueError(f"_validate_config: Prompt item {i} must be [text, weight] pair")
-
+
text, weight = prompt_item
if not isinstance(text, str):
raise ValueError(f"_validate_config: Prompt text {i} must be a string")
-
+
if not isinstance(weight, (int, float)) or weight < 0:
raise ValueError(f"_validate_config: Prompt weight {i} must be a non-negative number")
-
- interpolation_method = blend_config.get('interpolation_method', 'slerp')
- if interpolation_method not in ['linear', 'slerp']:
+
+ interpolation_method = blend_config.get("interpolation_method", "slerp")
+ if interpolation_method not in ["linear", "slerp"]:
raise ValueError("_validate_config: interpolation_method must be 'linear' or 'slerp'")
# Validate seed blending configuration if present
- if 'seed_blending' in config:
- seed_blend_config = config['seed_blending']
+ if "seed_blending" in config:
+ seed_blend_config = config["seed_blending"]
if not isinstance(seed_blend_config, dict):
raise ValueError("_validate_config: 'seed_blending' must be a dictionary")
-
- if 'seed_list' in seed_blend_config:
- seed_list = seed_blend_config['seed_list']
+
+ if "seed_list" in seed_blend_config:
+ seed_list = seed_blend_config["seed_list"]
if not isinstance(seed_list, list):
raise ValueError("_validate_config: 'seed_list' must be a list")
-
+
for i, seed_item in enumerate(seed_list):
if not isinstance(seed_item, (list, tuple)) or len(seed_item) != 2:
raise ValueError(f"_validate_config: Seed item {i} must be [seed, weight] pair")
-
+
seed_value, weight = seed_item
if not isinstance(seed_value, int) or seed_value < 0:
raise ValueError(f"_validate_config: Seed value {i} must be a non-negative integer")
-
+
if not isinstance(weight, (int, float)) or weight < 0:
raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number")
-
- interpolation_method = seed_blend_config.get('interpolation_method', 'linear')
- if interpolation_method not in ['linear', 'slerp']:
+
+ interpolation_method = seed_blend_config.get("interpolation_method", "linear")
+ if interpolation_method not in ["linear", "slerp"]:
raise ValueError("_validate_config: seed blending interpolation_method must be 'linear' or 'slerp'")
# Validate pipeline hook configurations if present (Phase 4: Configuration Integration)
_validate_pipeline_hook_configs(config)
# Validate separate normalize settings if present
- if 'normalize_prompt_weights' in config:
- normalize_prompt_weights = config['normalize_prompt_weights']
+ if "normalize_prompt_weights" in config:
+ normalize_prompt_weights = config["normalize_prompt_weights"]
if not isinstance(normalize_prompt_weights, bool):
raise ValueError("_validate_config: 'normalize_prompt_weights' must be a boolean value")
-
- if 'normalize_seed_weights' in config:
- normalize_seed_weights = config['normalize_seed_weights']
+
+ if "normalize_seed_weights" in config:
+ normalize_seed_weights = config["normalize_seed_weights"]
if not isinstance(normalize_seed_weights, bool):
raise ValueError("_validate_config: 'normalize_seed_weights' must be a boolean value")
-
diff --git a/src/streamdiffusion/hooks.py b/src/streamdiffusion/hooks.py
index ec5db4155..02f10270d 100644
--- a/src/streamdiffusion/hooks.py
+++ b/src/streamdiffusion/hooks.py
@@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
+
import torch
@@ -13,6 +14,7 @@ class EmbedsCtx:
- prompt_embeds: [batch, seq_len, dim]
- negative_prompt_embeds: optional [batch, seq_len, dim]
"""
+
prompt_embeds: torch.Tensor
negative_prompt_embeds: Optional[torch.Tensor] = None
@@ -28,6 +30,7 @@ class StepCtx:
- guidance_mode: one of {"none","full","self","initialize"}
- sdxl_cond: optional dict with SDXL micro-cond tensors
"""
+
x_t_latent: torch.Tensor
t_list: torch.Tensor
step_index: Optional[int]
@@ -38,6 +41,7 @@ class StepCtx:
@dataclass
class UnetKwargsDelta:
"""Delta produced by UNet hooks to augment UNet call kwargs."""
+
down_block_additional_residuals: Optional[List[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
@@ -48,37 +52,37 @@ class UnetKwargsDelta:
@dataclass
class ImageCtx:
"""Context passed to image processing hooks.
-
+
Fields:
- image: [B, C, H, W] tensor in image space
- width: image width
- - height: image height
+ - height: image height
- step_index: optional step index for multi-step processing
"""
+
image: torch.Tensor
width: int
height: int
step_index: Optional[int] = None
-@dataclass
+@dataclass
class LatentCtx:
"""Context passed to latent processing hooks.
-
+
Fields:
- latent: [B, C, H/8, W/8] tensor in latent space
- timestep: optional timestep tensor for diffusion context
- step_index: optional step index for multi-step processing
"""
+
latent: torch.Tensor
timestep: Optional[torch.Tensor] = None
step_index: Optional[int] = None
-
# Type aliases for clarity
EmbeddingHook = Callable[[EmbedsCtx], EmbedsCtx]
UnetHook = Callable[[StepCtx], UnetKwargsDelta]
ImageHook = Callable[[ImageCtx], ImageCtx]
LatentHook = Callable[[LatentCtx], LatentCtx]
-
diff --git a/src/streamdiffusion/image_filter.py b/src/streamdiffusion/image_filter.py
index 5523c8869..e975567a0 100644
--- a/src/streamdiffusion/image_filter.py
+++ b/src/streamdiffusion/image_filter.py
@@ -1,5 +1,5 @@
-from typing import Optional
import random
+from typing import Optional
import torch
import torch.nn.functional as F
diff --git a/src/streamdiffusion/image_utils.py b/src/streamdiffusion/image_utils.py
index 200295b37..77d7275c2 100644
--- a/src/streamdiffusion/image_utils.py
+++ b/src/streamdiffusion/image_utils.py
@@ -30,9 +30,7 @@ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
- pil_images = [
- PIL.Image.fromarray(image.squeeze(), mode="L") for image in images
- ]
+ pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [PIL.Image.fromarray(image) for image in images]
@@ -56,12 +54,7 @@ def postprocess_image(
if do_denormalize is None:
do_denormalize = [do_normalize_flg] * image.shape[0]
- image = torch.stack(
- [
- denormalize(image[i]) if do_denormalize[i] else image[i]
- for i in range(image.shape[0])
- ]
- )
+ image = torch.stack([denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])])
if output_type == "pt":
return image
@@ -91,8 +84,6 @@ def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor:
img, _ = process_image(image_pil)
imgs.append(img)
imgs = torch.vstack(imgs)
- images = torch.nn.functional.interpolate(
- imgs, size=(height, width), mode="bilinear"
- )
+ images = torch.nn.functional.interpolate(imgs, size=(height, width), mode="bilinear")
image_tensors = images.to(torch.float16)
return image_tensors
diff --git a/src/streamdiffusion/model_detection.py b/src/streamdiffusion/model_detection.py
index e9eef252c..fbd28933e 100644
--- a/src/streamdiffusion/model_detection.py
+++ b/src/streamdiffusion/model_detection.py
@@ -1,13 +1,15 @@
"""Comprehensive model detection for TensorRT and pipeline support"""
-from typing import Dict, Tuple, Optional, Any, List
+from typing import Any, Dict, Optional
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+
# Gracefully import the SD3 model class; it might not exist in older diffusers versions.
try:
from diffusers.models.transformers.mm_dit import MMDiTTransformer2DModel
+
HAS_MMDIT = True
except ImportError:
# Create a dummy class if the import fails to prevent runtime errors.
@@ -15,6 +17,8 @@
HAS_MMDIT = False
import logging
+
+
logger = logging.getLogger(__name__)
@@ -23,7 +27,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
Comprehensive and robust model detection using definitive architectural features.
This function replaces heuristic-based analysis with a deterministic,
- rule-based approach by first inspecting the model's class and then its key
+ rule-based approach by first inspecting the model's class and then its key
configuration parameters that define the architecture.
Args:
@@ -50,9 +54,9 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
confidence = 1.0
# Differentiating SD3 vs. SD3-Turbo from the MMDiT config alone is currently
# speculative. A check on the pipeline's scheduler is a reasonable proxy.
- if pipe and hasattr(pipe, 'scheduler'):
- scheduler_name = getattr(pipe.scheduler.config, '_class_name', '').lower()
- if 'lcm' in scheduler_name or 'turbo' in scheduler_name:
+ if pipe and hasattr(pipe, "scheduler"):
+ scheduler_name = getattr(pipe.scheduler.config, "_class_name", "").lower()
+ if "lcm" in scheduler_name or "turbo" in scheduler_name:
is_turbo = True
model_type = "SD3-Turbo"
else:
@@ -62,7 +66,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
# 2. UNet-based Model Detection (SDXL, SD2.1, SD1.5)
elif isinstance(model, UNet2DConditionModel):
config = model.config
-
+
# 2a. SDXL vs. non-SDXL
# The `addition_embed_type` is the clearest indicator for the SDXL architecture.
if config.get("addition_embed_type") is not None:
@@ -73,7 +77,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
# Base SDXL has `time_cond_proj_dim` (e.g., 256), while Turbo has it set to `None`.
if config.get("time_cond_proj_dim") is None:
is_turbo = True
-
+
# 2b. SD2.1 vs. SD1.5 (if not SDXL)
# Differentiate based on the text encoder's projection dimension.
else:
@@ -90,10 +94,10 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
confidence = 0.7
# 3. ControlNet Model Detection (detect underlying architecture)
- elif hasattr(model, 'config') and hasattr(model.config, 'cross_attention_dim'):
+ elif hasattr(model, "config") and hasattr(model.config, "cross_attention_dim"):
# ControlNet models have UNet-like configs, detect their base architecture
config = model.config
-
+
# Apply same detection logic as UNet models
if config.get("addition_embed_type") is not None:
model_type = "SDXL"
@@ -107,12 +111,12 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
model_type = "SD2.1"
confidence = 0.95
elif cross_attention_dim == 768:
- model_type = "SD1.5"
+ model_type = "SD1.5"
confidence = 0.95
else:
model_type = "SD-finetune"
confidence = 0.7
-
+
else:
# The model is not a known UNet or MMDiT class.
confidence = 0.0
@@ -120,44 +124,46 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
# Populate architecture and compatibility details (can be expanded as needed)
architecture_details = {
- 'model_class': model.__class__.__name__,
- 'in_channels': getattr(model.config, 'in_channels', 'N/A'),
- 'cross_attention_dim': getattr(model.config, 'cross_attention_dim', 'N/A'),
- 'block_out_channels': getattr(model.config, 'block_out_channels', 'N/A'),
+ "model_class": model.__class__.__name__,
+ "in_channels": getattr(model.config, "in_channels", "N/A"),
+ "cross_attention_dim": getattr(model.config, "cross_attention_dim", "N/A"),
+ "block_out_channels": getattr(model.config, "block_out_channels", "N/A"),
}
-
+
# For UNet models, add detailed characteristics that SDXL code expects
if isinstance(model, UNet2DConditionModel):
unet_chars = detect_unet_characteristics(model)
- architecture_details.update({
- 'has_time_conditioning': unet_chars['has_time_cond'],
- 'has_addition_embeds': unet_chars['has_addition_embed'],
- })
-
+ architecture_details.update(
+ {
+ "has_time_conditioning": unet_chars["has_time_cond"],
+ "has_addition_embeds": unet_chars["has_addition_embed"],
+ }
+ )
+
# For ControlNet models, add similar characteristics
- elif hasattr(model, 'config') and hasattr(model.config, 'cross_attention_dim'):
+ elif hasattr(model, "config") and hasattr(model.config, "cross_attention_dim"):
# ControlNet models have similar config structure to UNet
config = model.config
has_addition_embed = config.get("addition_embed_type") is not None
- has_time_cond = hasattr(config, 'time_cond_proj_dim') and config.time_cond_proj_dim is not None
-
- architecture_details.update({
- 'has_time_conditioning': has_time_cond,
- 'has_addition_embeds': has_addition_embed,
- })
-
- compatibility_info = {
- 'notes': f"Detected as {model_type} with {confidence:.2f} confidence based on architecture."
- }
+ has_time_cond = hasattr(config, "time_cond_proj_dim") and config.time_cond_proj_dim is not None
+
+ architecture_details.update(
+ {
+ "has_time_conditioning": has_time_cond,
+ "has_addition_embeds": has_addition_embed,
+ }
+ )
+
+ compatibility_info = {"notes": f"Detected as {model_type} with {confidence:.2f} confidence based on architecture."}
result = {
- 'model_type': model_type,
- 'is_turbo': is_turbo,
- 'is_sdxl': is_sdxl,
- 'is_sd3': is_sd3,
- 'confidence': confidence,
- 'architecture_details': architecture_details,
- 'compatibility_info': compatibility_info,
+ "model_type": model_type,
+ "is_turbo": is_turbo,
+ "is_sdxl": is_sdxl,
+ "is_sd3": is_sd3,
+ "confidence": confidence,
+ "architecture_details": architecture_details,
+ "compatibility_info": compatibility_info,
}
return result
@@ -166,13 +172,13 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]:
"""Detect detailed UNet characteristics including SDXL-specific features"""
config = unet.config
-
+
# Get cross attention dimensions to detect model type
- cross_attention_dim = getattr(config, 'cross_attention_dim', None)
-
+ cross_attention_dim = getattr(config, "cross_attention_dim", None)
+
# Detect SDXL by multiple indicators
is_sdxl = False
-
+
# Check cross attention dimension
if isinstance(cross_attention_dim, (list, tuple)):
# SDXL typically has [1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280]
@@ -180,73 +186,74 @@ def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]:
elif isinstance(cross_attention_dim, int):
# Single value - SDXL uses 2048 for concatenated embeddings, or 1280+ for individual encoders
is_sdxl = cross_attention_dim >= 1280
-
+
# Check addition_embed_type for SDXL detection (strong indicator)
- addition_embed_type = getattr(config, 'addition_embed_type', None)
+ addition_embed_type = getattr(config, "addition_embed_type", None)
has_addition_embed = addition_embed_type is not None
-
- if addition_embed_type in ['text_time', 'text_time_guidance']:
+
+ if addition_embed_type in ["text_time", "text_time_guidance"]:
is_sdxl = True # This is a definitive SDXL indicator
-
+
# Check if model has time conditioning projection (SDXL feature)
- has_time_cond = hasattr(config, 'time_cond_proj_dim') and config.time_cond_proj_dim is not None
-
+ has_time_cond = hasattr(config, "time_cond_proj_dim") and config.time_cond_proj_dim is not None
+
# Additional SDXL detection checks
- if hasattr(config, 'num_class_embeds') and config.num_class_embeds is not None:
+ if hasattr(config, "num_class_embeds") and config.num_class_embeds is not None:
is_sdxl = True # SDXL often has class embeddings
-
+
# Check sample size (SDXL typically uses 128 vs 64 for SD1.5)
- sample_size = getattr(config, 'sample_size', 64)
+ sample_size = getattr(config, "sample_size", 64)
if sample_size >= 128:
is_sdxl = True
-
+
return {
- 'is_sdxl': is_sdxl,
- 'has_time_cond': has_time_cond,
- 'has_addition_embed': has_addition_embed,
- 'cross_attention_dim': cross_attention_dim,
- 'addition_embed_type': addition_embed_type,
- 'in_channels': getattr(config, 'in_channels', 4),
- 'sample_size': getattr(config, 'sample_size', 64 if not is_sdxl else 128),
- 'block_out_channels': tuple(getattr(config, 'block_out_channels', [])),
- 'attention_head_dim': getattr(config, 'attention_head_dim', None)
+ "is_sdxl": is_sdxl,
+ "has_time_cond": has_time_cond,
+ "has_addition_embed": has_addition_embed,
+ "cross_attention_dim": cross_attention_dim,
+ "addition_embed_type": addition_embed_type,
+ "in_channels": getattr(config, "in_channels", 4),
+ "sample_size": getattr(config, "sample_size", 64 if not is_sdxl else 128),
+ "block_out_channels": tuple(getattr(config, "block_out_channels", [])),
+ "attention_head_dim": getattr(config, "attention_head_dim", None),
}
+
# This is used for controlnet/ipadapter model detection - can be deprecated (along with detect_unet_characteristics)
def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str:
"""Detect model type from diffusers UNet configuration"""
characteristics = detect_unet_characteristics(unet)
-
- in_channels = characteristics['in_channels']
- block_out_channels = characteristics['block_out_channels']
- cross_attention_dim = characteristics['cross_attention_dim']
- is_sdxl = characteristics['is_sdxl']
-
+
+ in_channels = characteristics["in_channels"]
+ block_out_channels = characteristics["block_out_channels"]
+ cross_attention_dim = characteristics["cross_attention_dim"]
+ is_sdxl = characteristics["is_sdxl"]
+
# Use enhanced SDXL detection
if is_sdxl:
return "SDXL"
-
+
# Original detection logic for other models
- if (cross_attention_dim == 768 and
- block_out_channels == (320, 640, 1280, 1280) and
- in_channels == 4):
+ if cross_attention_dim == 768 and block_out_channels == (320, 640, 1280, 1280) and in_channels == 4:
return "SD15"
-
- elif (cross_attention_dim == 1024 and
- block_out_channels == (320, 640, 1280, 1280) and
- in_channels == 4):
+
+ elif cross_attention_dim == 1024 and block_out_channels == (320, 640, 1280, 1280) and in_channels == 4:
return "SD21"
-
+
elif cross_attention_dim == 768 and in_channels == 4:
return "SD15"
elif cross_attention_dim == 1024 and in_channels == 4:
return "SD21"
-
+
if cross_attention_dim == 768:
- print(f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15")
+ print(
+ f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15"
+ )
return "SD15"
elif cross_attention_dim == 1024:
- print(f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21")
+ print(
+ f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21"
+ )
return "SD21"
else:
raise ValueError(
@@ -260,58 +267,58 @@ def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str:
def extract_unet_architecture(unet: UNet2DConditionModel) -> Dict[str, Any]:
"""
Extract UNet architecture details needed for TensorRT engine building.
-
+
This function provides the essential architecture information needed
for TensorRT engine compilation in a clean, structured format.
-
+
Args:
unet: The UNet model to analyze
-
+
Returns:
Dict with architecture parameters for TensorRT engine building
"""
config = unet.config
-
+
# Basic model parameters
model_channels = config.block_out_channels[0] if config.block_out_channels else 320
block_out_channels = tuple(config.block_out_channels)
channel_mult = tuple(ch // model_channels for ch in block_out_channels)
-
+
# Resolution blocks
- if hasattr(config, 'layers_per_block'):
+ if hasattr(config, "layers_per_block"):
if isinstance(config.layers_per_block, (list, tuple)):
num_res_blocks = tuple(config.layers_per_block)
else:
num_res_blocks = tuple([config.layers_per_block] * len(block_out_channels))
else:
num_res_blocks = tuple([2] * len(block_out_channels))
-
+
# Attention and context dimensions
context_dim = config.cross_attention_dim
in_channels = config.in_channels
-
+
# Attention head configuration
- attention_head_dim = getattr(config, 'attention_head_dim', 8)
+ attention_head_dim = getattr(config, "attention_head_dim", 8)
if isinstance(attention_head_dim, (list, tuple)):
attention_head_dim = attention_head_dim[0]
-
+
# Transformer depth
- transformer_depth = getattr(config, 'transformer_layers_per_block', 1)
+ transformer_depth = getattr(config, "transformer_layers_per_block", 1)
if isinstance(transformer_depth, (list, tuple)):
transformer_depth = tuple(transformer_depth)
else:
transformer_depth = tuple([transformer_depth] * len(block_out_channels))
-
+
# Time embedding
- time_embed_dim = getattr(config, 'time_embedding_dim', None)
+ time_embed_dim = getattr(config, "time_embedding_dim", None)
if time_embed_dim is None:
time_embed_dim = model_channels * 4
-
+
# Build architecture dictionary
architecture_dict = {
"model_channels": model_channels,
"in_channels": in_channels,
- "out_channels": getattr(config, 'out_channels', in_channels),
+ "out_channels": getattr(config, "out_channels", in_channels),
"num_res_blocks": num_res_blocks,
"channel_mult": channel_mult,
"context_dim": context_dim,
@@ -319,48 +326,50 @@ def extract_unet_architecture(unet: UNet2DConditionModel) -> Dict[str, Any]:
"transformer_depth": transformer_depth,
"time_embed_dim": time_embed_dim,
"block_out_channels": block_out_channels,
-
# Additional configuration
- "use_linear_in_transformer": getattr(config, 'use_linear_in_transformer', False),
- "conv_in_kernel": getattr(config, 'conv_in_kernel', 3),
- "conv_out_kernel": getattr(config, 'conv_out_kernel', 3),
- "resnet_time_scale_shift": getattr(config, 'resnet_time_scale_shift', 'default'),
- "class_embed_type": getattr(config, 'class_embed_type', None),
- "num_class_embeds": getattr(config, 'num_class_embeds', None),
-
+ "use_linear_in_transformer": getattr(config, "use_linear_in_transformer", False),
+ "conv_in_kernel": getattr(config, "conv_in_kernel", 3),
+ "conv_out_kernel": getattr(config, "conv_out_kernel", 3),
+ "resnet_time_scale_shift": getattr(config, "resnet_time_scale_shift", "default"),
+ "class_embed_type": getattr(config, "class_embed_type", None),
+ "num_class_embeds": getattr(config, "num_class_embeds", None),
# Block types
- "down_block_types": getattr(config, 'down_block_types', []),
- "up_block_types": getattr(config, 'up_block_types', []),
+ "down_block_types": getattr(config, "down_block_types", []),
+ "up_block_types": getattr(config, "up_block_types", []),
}
-
+
return architecture_dict
def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[str, Any]:
"""
Validate and fix architecture dictionary using model type presets.
-
+
Ensures that all required architecture parameters are present and
have reasonable values for the specified model type.
-
+
Args:
arch_dict: Architecture dictionary to validate
model_type: Expected model type for validation
-
+
Returns:
Validated and corrected architecture dictionary
"""
-
+
# Check for required keys
required_keys = [
- "model_channels", "channel_mult", "num_res_blocks",
- "context_dim", "in_channels", "block_out_channels"
+ "model_channels",
+ "channel_mult",
+ "num_res_blocks",
+ "context_dim",
+ "in_channels",
+ "block_out_channels",
]
-
+
for key in required_keys:
if key not in arch_dict:
raise ValueError(f"Missing required architecture parameter: {key}")
-
+
# Ensure tuple format for sequence parameters
for key in ["channel_mult", "num_res_blocks", "transformer_depth", "block_out_channels"]:
if key in arch_dict and not isinstance(arch_dict[key], tuple):
@@ -371,12 +380,11 @@ def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[st
arch_dict[key] = tuple(arch_dict[key])
else:
arch_dict[key] = preset[key]
-
+
# Validate sequence lengths match
expected_levels = len(arch_dict["channel_mult"])
for key in ["num_res_blocks", "transformer_depth"]:
if key in arch_dict and len(arch_dict[key]) != expected_levels:
arch_dict[key] = preset[key]
-
- return arch_dict
+ return arch_dict
diff --git a/src/streamdiffusion/modules/__init__.py b/src/streamdiffusion/modules/__init__.py
index 54954961a..f3242ca59 100644
--- a/src/streamdiffusion/modules/__init__.py
+++ b/src/streamdiffusion/modules/__init__.py
@@ -1,22 +1,21 @@
# StreamDiffusion Modules Package
from .controlnet_module import ControlNetModule
+from .image_processing_module import ImagePostprocessingModule, ImagePreprocessingModule, ImageProcessingModule
from .ipadapter_module import IPAdapterModule
-from .image_processing_module import ImageProcessingModule, ImagePreprocessingModule, ImagePostprocessingModule
-from .latent_processing_module import LatentProcessingModule, LatentPreprocessingModule, LatentPostprocessingModule
+from .latent_processing_module import LatentPostprocessingModule, LatentPreprocessingModule, LatentProcessingModule
+
__all__ = [
# Existing modules
- 'ControlNetModule',
- 'IPAdapterModule',
-
+ "ControlNetModule",
+ "IPAdapterModule",
# Pipeline processing base classes
- 'ImageProcessingModule',
- 'LatentProcessingModule',
-
+ "ImageProcessingModule",
+ "LatentProcessingModule",
# Pipeline processing timing-specific modules
- 'ImagePreprocessingModule',
- 'ImagePostprocessingModule',
- 'LatentPreprocessingModule',
- 'LatentPostprocessingModule',
+ "ImagePreprocessingModule",
+ "ImagePostprocessingModule",
+ "LatentPreprocessingModule",
+ "LatentPostprocessingModule",
]
diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py
index 6c8655b13..e0a57f3b6 100644
--- a/src/streamdiffusion/modules/controlnet_module.py
+++ b/src/streamdiffusion/modules/controlnet_module.py
@@ -1,18 +1,18 @@
from __future__ import annotations
+import logging
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch
from diffusers.models import ControlNetModel
-import logging
-from streamdiffusion.hooks import StepCtx, UnetKwargsDelta, UnetHook
+from streamdiffusion.hooks import StepCtx, UnetHook, UnetKwargsDelta
+from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser
from streamdiffusion.preprocessing.preprocessing_orchestrator import (
PreprocessingOrchestrator,
)
-from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser
@dataclass
@@ -55,17 +55,17 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
self._prepared_dtype: Optional[torch.dtype] = None
self._prepared_batch: Optional[int] = None
self._images_version: int = 0
-
+
# Cache expensive lookups to avoid repeated hasattr/getattr calls
self._engines_by_id: Dict[str, Any] = {}
self._engines_cache_valid: bool = False
self._is_sdxl: Optional[bool] = None
self._expected_text_len: int = 77
-
+
# SDXL-specific caching for performance optimization
self._sdxl_conditioning_cache: Optional[Dict[str, torch.Tensor]] = None
self._sdxl_conditioning_valid: bool = False
-
+
# Cache engine type detection to avoid repeated hasattr calls
self._engine_type_cache: Dict[str, bool] = {}
@@ -78,9 +78,9 @@ def install(self, stream) -> None:
# Register UNet hook
stream.unet_hooks.append(self.build_unet_hook())
# Expose controlnet collections so existing updater can find them
- setattr(stream, 'controlnets', self.controlnets)
- setattr(stream, 'controlnet_scales', self.controlnet_scales)
- setattr(stream, 'preprocessors', self.preprocessors)
+ setattr(stream, "controlnets", self.controlnets)
+ setattr(stream, "controlnet_scales", self.controlnet_scales)
+ setattr(stream, "preprocessors", self.preprocessors)
# Reset prepared tensors on install
self._prepared_tensors = []
self._prepared_device = None
@@ -92,18 +92,26 @@ def install(self, stream) -> None:
self._sdxl_conditioning_valid = False
self._engine_type_cache.clear()
- def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None:
+ def add_controlnet(
+ self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None
+ ) -> None:
model = self._load_pytorch_controlnet_model(cfg.model_id, cfg.conditioning_channels)
preproc = None
if cfg.preprocessor:
from streamdiffusion.preprocessing.processors import get_preprocessor
- preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream, normalization_context='controlnet', params=cfg.preprocessor_params)
+
+ preproc = get_preprocessor(
+ cfg.preprocessor,
+ pipeline_ref=self._stream,
+ normalization_context="controlnet",
+ params=cfg.preprocessor_params,
+ )
# Apply provided parameters to the preprocessor instance
if cfg.preprocessor_params:
params = cfg.preprocessor_params or {}
# If the preprocessor exposes a 'params' dict, update it
- if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict):
+ if hasattr(preproc, "params") and isinstance(getattr(preproc, "params"), dict):
preproc.params.update(params)
# Also set attributes directly when they exist
for name, value in params.items():
@@ -113,16 +121,15 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
except Exception:
pass
-
# Align preprocessor target size with stream resolution once (avoid double-resize later)
try:
- if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict):
- preproc.params['image_width'] = int(self._stream.width)
- preproc.params['image_height'] = int(self._stream.height)
- if hasattr(preproc, 'image_width'):
- setattr(preproc, 'image_width', int(self._stream.width))
- if hasattr(preproc, 'image_height'):
- setattr(preproc, 'image_height', int(self._stream.height))
+ if hasattr(preproc, "params") and isinstance(getattr(preproc, "params"), dict):
+ preproc.params["image_width"] = int(self._stream.width)
+ preproc.params["image_height"] = int(self._stream.height)
+ if hasattr(preproc, "image_width"):
+ setattr(preproc, "image_width", int(self._stream.width))
+ if hasattr(preproc, "image_height"):
+ setattr(preproc, "image_height", int(self._stream.height))
except Exception:
pass
@@ -142,7 +149,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
# Invalidate SDXL conditioning cache when ControlNet configuration changes
self._sdxl_conditioning_valid = False
- def update_control_image_efficient(self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None) -> None:
+ def update_control_image_efficient(
+ self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None
+ ) -> None:
if self._preprocessing_orchestrator is None:
return
with self._collections_lock:
@@ -150,23 +159,15 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
return
total = len(self.controlnets)
# Build active scales, respecting enabled_list if present
- scales = [
- (self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0)
- for i in range(total)
- ]
- if hasattr(self, 'enabled_list') and self.enabled_list and len(self.enabled_list) == total:
+ scales = [(self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0) for i in range(total)]
+ if hasattr(self, "enabled_list") and self.enabled_list and len(self.enabled_list) == total:
scales = [sc if bool(self.enabled_list[i]) else 0.0 for i, sc in enumerate(scales)]
preprocessors = [self.preprocessors[i] if i < len(self.preprocessors) else None for i in range(total)]
# Single-index fast path
if index is not None:
results = self._preprocessing_orchestrator.process_sync(
- control_image,
- preprocessors,
- scales,
- self._stream.width,
- self._stream.height,
- index
+ control_image, preprocessors, scales, self._stream.width, self._stream.height, index
)
processed = results[index] if results and len(results) > index else None
with self._collections_lock:
@@ -182,11 +183,7 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
# Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
processed_images = self._preprocessing_orchestrator.process_pipelined(
- control_image,
- preprocessors,
- scales,
- self._stream.width,
- self._stream.height
+ control_image, preprocessors, scales, self._stream.width, self._stream.height
)
# If orchestrator returns empty list, it indicates no update needed for this frame
@@ -243,7 +240,7 @@ def reorder_controlnets_by_model_ids(self, desired_model_ids: List[str]) -> None
# Build current mapping from model_id to index
current_ids: List[str] = []
for i, cn in enumerate(self.controlnets):
- model_id = getattr(cn, 'model_id', f'controlnet_{i}')
+ model_id = getattr(cn, "model_id", f"controlnet_{i}")
current_ids.append(model_id)
# Compute new index order
@@ -275,23 +272,29 @@ def get_current_config(self) -> List[Dict[str, Any]]:
cfg: List[Dict[str, Any]] = []
with self._collections_lock:
for i, cn in enumerate(self.controlnets):
- model_id = getattr(cn, 'model_id', f'controlnet_{i}')
+ model_id = getattr(cn, "model_id", f"controlnet_{i}")
scale = self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0
- preproc_params = getattr(self.preprocessors[i], 'params', {}) if i < len(self.preprocessors) and self.preprocessors[i] else {}
- cfg.append({
- 'model_id': model_id,
- 'conditioning_scale': scale,
- 'preprocessor_params': preproc_params,
- 'enabled': (self.enabled_list[i] if i < len(self.enabled_list) else True),
- })
+ preproc_params = (
+ getattr(self.preprocessors[i], "params", {})
+ if i < len(self.preprocessors) and self.preprocessors[i]
+ else {}
+ )
+ cfg.append(
+ {
+ "model_id": model_id,
+ "conditioning_scale": scale,
+ "preprocessor_params": preproc_params,
+ "enabled": (self.enabled_list[i] if i < len(self.enabled_list) else True),
+ }
+ )
return cfg
def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_size: int) -> None:
"""Prepare control image tensors for the current frame.
-
+
This method is called once per frame to prepare all control images with the correct
device, dtype, and batch size. This avoids redundant operations during each denoising step.
-
+
Args:
device: Target device for tensors
dtype: Target dtype for tensors
@@ -300,22 +303,22 @@ def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_
with self._collections_lock:
# Check if we need to re-prepare tensors
cache_valid = (
- self._prepared_device == device and
- self._prepared_dtype == dtype and
- self._prepared_batch == batch_size and
- len(self._prepared_tensors) == len(self.controlnet_images)
+ self._prepared_device == device
+ and self._prepared_dtype == dtype
+ and self._prepared_batch == batch_size
+ and len(self._prepared_tensors) == len(self.controlnet_images)
)
-
+
if cache_valid:
return
-
+
# Prepare tensors for current frame
self._prepared_tensors = []
for img in self.controlnet_images:
if img is None:
self._prepared_tensors.append(None)
continue
-
+
# Prepare tensor with correct batch size
prepared = img
if prepared.dim() == 4 and prepared.shape[0] != batch_size:
@@ -324,63 +327,62 @@ def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_
else:
repeat_factor = max(1, batch_size // prepared.shape[0])
prepared = prepared.repeat(repeat_factor, 1, 1, 1)[:batch_size]
-
+
# Move to correct device and dtype
prepared = prepared.to(device=device, dtype=dtype)
self._prepared_tensors.append(prepared)
-
+
# Update cache state
self._prepared_device = device
self._prepared_dtype = dtype
self._prepared_batch = batch_size
- def _get_cached_sdxl_conditioning(self, ctx: 'StepCtx') -> Optional[Dict[str, torch.Tensor]]:
+ def _get_cached_sdxl_conditioning(self, ctx: "StepCtx") -> Optional[Dict[str, torch.Tensor]]:
"""Get cached SDXL conditioning to avoid repeated preparation"""
if not self._is_sdxl or ctx.sdxl_cond is None:
return None
-
+
# Check if cache is valid
if self._sdxl_conditioning_valid and self._sdxl_conditioning_cache is not None:
cached = self._sdxl_conditioning_cache
# Verify batch size matches current context
- if ('text_embeds' in cached and
- cached['text_embeds'].shape[0] == ctx.x_t_latent.shape[0]):
+ if "text_embeds" in cached and cached["text_embeds"].shape[0] == ctx.x_t_latent.shape[0]:
return cached
-
+
# Cache miss or invalid - prepare new conditioning
try:
conditioning = {}
- if 'text_embeds' in ctx.sdxl_cond:
- text_embeds = ctx.sdxl_cond['text_embeds']
+ if "text_embeds" in ctx.sdxl_cond:
+ text_embeds = ctx.sdxl_cond["text_embeds"]
batch_size = ctx.x_t_latent.shape[0]
-
+
# Optimize batch expansion for SDXL text embeddings
if text_embeds.shape[0] != batch_size:
if text_embeds.shape[0] == 1:
- conditioning['text_embeds'] = text_embeds.repeat(batch_size, 1)
+ conditioning["text_embeds"] = text_embeds.repeat(batch_size, 1)
else:
- conditioning['text_embeds'] = text_embeds[:batch_size]
+ conditioning["text_embeds"] = text_embeds[:batch_size]
else:
- conditioning['text_embeds'] = text_embeds
-
- if 'time_ids' in ctx.sdxl_cond:
- time_ids = ctx.sdxl_cond['time_ids']
+ conditioning["text_embeds"] = text_embeds
+
+ if "time_ids" in ctx.sdxl_cond:
+ time_ids = ctx.sdxl_cond["time_ids"]
batch_size = ctx.x_t_latent.shape[0]
-
+
# Optimize batch expansion for SDXL time IDs
if time_ids.shape[0] != batch_size:
if time_ids.shape[0] == 1:
- conditioning['time_ids'] = time_ids.repeat(batch_size, 1)
+ conditioning["time_ids"] = time_ids.repeat(batch_size, 1)
else:
- conditioning['time_ids'] = time_ids[:batch_size]
+ conditioning["time_ids"] = time_ids[:batch_size]
else:
- conditioning['time_ids'] = time_ids
-
+ conditioning["time_ids"] = time_ids
+
# Cache the prepared conditioning
self._sdxl_conditioning_cache = conditioning
self._sdxl_conditioning_valid = True
return conditioning
-
+
except Exception:
# Fallback to original conditioning on any error
return ctx.sdxl_cond
@@ -399,8 +401,10 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
# Single pass to collect active ControlNet data
active_data = []
enabled_flags = self.enabled_list if len(self.enabled_list) == len(self.controlnets) else None
-
- for i, (cn, img, scale) in enumerate(zip(self.controlnets, self.controlnet_images, self.controlnet_scales)):
+
+ for i, (cn, img, scale) in enumerate(
+ zip(self.controlnets, self.controlnet_images, self.controlnet_scales)
+ ):
if cn is not None and img is not None and scale > 0:
enabled = enabled_flags[i] if enabled_flags else True
if enabled:
@@ -413,9 +417,11 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
if not self._engines_cache_valid:
self._engines_by_id.clear()
try:
- if hasattr(self._stream, 'controlnet_engines') and isinstance(self._stream.controlnet_engines, list):
+ if hasattr(self._stream, "controlnet_engines") and isinstance(
+ self._stream.controlnet_engines, list
+ ):
for eng in self._stream.controlnet_engines:
- mid = getattr(eng, 'model_id', None)
+ mid = getattr(eng, "model_id", None)
if mid:
self._engines_by_id[mid] = eng
self._engines_cache_valid = True
@@ -425,17 +431,17 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
# Cache SDXL detection to avoid repeated hasattr calls
if self._is_sdxl is None:
try:
- self._is_sdxl = getattr(self._stream, 'is_sdxl', False)
+ self._is_sdxl = getattr(self._stream, "is_sdxl", False)
except Exception:
self._is_sdxl = False
- encoder_hidden_states = self._stream.prompt_embeds[:, :self._expected_text_len, :]
+ encoder_hidden_states = self._stream.prompt_embeds[:, : self._expected_text_len, :]
base_kwargs: Dict[str, Any] = {
- 'sample': x_t,
- 'timestep': t_list,
- 'encoder_hidden_states': encoder_hidden_states,
- 'return_dict': False,
+ "sample": x_t,
+ "timestep": t_list,
+ "encoder_hidden_states": encoder_hidden_states,
+ "return_dict": False,
}
down_samples_list: List[List[torch.Tensor]] = []
@@ -443,20 +449,22 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
# Ensure tensors are prepared for this frame
# This should have been called earlier, but we call it here as a safety net
- if (self._prepared_device != x_t.device or
- self._prepared_dtype != x_t.dtype or
- self._prepared_batch != x_t.shape[0]):
+ if (
+ self._prepared_device != x_t.device
+ or self._prepared_dtype != x_t.dtype
+ or self._prepared_batch != x_t.shape[0]
+ ):
self.prepare_frame_tensors(x_t.device, x_t.dtype, x_t.shape[0])
-
+
# Use pre-prepared tensors
prepared_images = self._prepared_tensors
for cn, img, scale, idx_i in active_data:
# Swap to TRT engine if available for this model_id (use cached lookup)
- model_id = getattr(cn, 'model_id', None)
+ model_id = getattr(cn, "model_id", None)
if model_id and model_id in self._engines_by_id:
cn = self._engines_by_id[model_id]
-
+
# Use pre-prepared tensor
current_img = prepared_images[idx_i] if idx_i < len(prepared_images) else img
if current_img is None:
@@ -467,12 +475,12 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
if cache_key in self._engine_type_cache:
is_trt_engine = self._engine_type_cache[cache_key]
else:
- is_trt_engine = hasattr(cn, 'engine') and hasattr(cn, 'stream')
+ is_trt_engine = hasattr(cn, "engine") and hasattr(cn, "stream")
self._engine_type_cache[cache_key] = is_trt_engine
-
+
# Get optimized SDXL conditioning (uses caching to avoid repeated tensor operations)
added_cond_kwargs = self._get_cached_sdxl_conditioning(ctx)
-
+
try:
if is_trt_engine:
# TensorRT engine path
@@ -483,7 +491,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=current_img,
conditioning_scale=float(scale),
- **added_cond_kwargs
+ **added_cond_kwargs,
)
else:
down_samples, mid_sample = cn(
@@ -491,7 +499,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
timestep=t_list,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=current_img,
- conditioning_scale=float(scale)
+ conditioning_scale=float(scale),
)
else:
# PyTorch ControlNet path
@@ -503,7 +511,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
controlnet_cond=current_img,
conditioning_scale=float(scale),
return_dict=False,
- added_cond_kwargs=added_cond_kwargs
+ added_cond_kwargs=added_cond_kwargs,
)
else:
down_samples, mid_sample = cn(
@@ -512,21 +520,30 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=current_img,
conditioning_scale=float(scale),
- return_dict=False
+ return_dict=False,
)
except Exception as e:
import traceback
- __import__('logging').getLogger(__name__).error("ControlNetModule: controlnet forward failed: %s", e)
+
+ __import__("logging").getLogger(__name__).error(
+ "ControlNetModule: controlnet forward failed: %s", e
+ )
try:
- __import__('logging').getLogger(__name__).error("ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s",
- (tuple(encoder_hidden_states.shape) if isinstance(encoder_hidden_states, torch.Tensor) else None),
- (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None),
- scale,
- self._is_sdxl,
- is_trt_engine)
+ __import__("logging").getLogger(__name__).error(
+ "ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s",
+ (
+ tuple(encoder_hidden_states.shape)
+ if isinstance(encoder_hidden_states, torch.Tensor)
+ else None
+ ),
+ (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None),
+ scale,
+ self._is_sdxl,
+ is_trt_engine,
+ )
except Exception:
pass
- __import__('logging').getLogger(__name__).error(traceback.format_exc())
+ __import__("logging").getLogger(__name__).error(traceback.format_exc())
continue
down_samples_list.append(down_samples)
mid_samples_list.append(mid_sample)
@@ -555,48 +572,53 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
return _unet_hook
- def _prepare_control_image(self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any]) -> torch.Tensor:
+ def _prepare_control_image(
+ self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any]
+ ) -> torch.Tensor:
if self._preprocessing_orchestrator is None:
raise RuntimeError("ControlNetModule: preprocessing orchestrator is not initialized")
# Reuse orchestrator API used by BaseControlNetPipeline
images = self._preprocessing_orchestrator.process_sync(
- control_image,
- [preprocessor],
- [1.0],
- self._stream.width,
- self._stream.height,
- 0
+ control_image, [preprocessor], [1.0], self._stream.width, self._stream.height, 0
)
# API returns a list; pick first if present
return images[0] if images else None
- #FIXME: more robust model management is needed in general.
- def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: Optional[int] = None) -> ControlNetModel:
- from pathlib import Path
- import logging
+ # FIXME: more robust model management is needed in general.
+ def _load_pytorch_controlnet_model(
+ self, model_id: str, conditioning_channels: Optional[int] = None
+ ) -> ControlNetModel:
import os
+ from pathlib import Path
+
logger = logging.getLogger(__name__)
-
+
try:
# Prepare loading kwargs
load_kwargs = {"torch_dtype": self.dtype}
if conditioning_channels is not None:
load_kwargs["conditioning_channels"] = conditioning_channels
-
+
# Check if offline mode is enabled via environment variables
- is_offline = os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1"
-
+ is_offline = (
+ os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1"
+ )
+
if Path(model_id).exists():
model_path = Path(model_id)
-
+
# Check if it's a direct file path to a safetensors/ckpt file
- if model_path.is_file() and model_path.suffix in ['.safetensors', '.ckpt', '.bin']:
- logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Loading ControlNet from single file: {model_path} (channels={conditioning_channels})")
+ if model_path.is_file() and model_path.suffix in [".safetensors", ".ckpt", ".bin"]:
+ logger.info(
+ f"ControlNetModule._load_pytorch_controlnet_model: Loading ControlNet from single file: {model_path} (channels={conditioning_channels})"
+ )
# Try loading from single file (works for most ControlNet models)
try:
controlnet = ControlNetModel.from_single_file(str(model_path), **load_kwargs)
except Exception as e:
- logger.warning(f"ControlNetModule._load_pytorch_controlnet_model: Single file loading failed: {e}")
+ logger.warning(
+ f"ControlNetModule._load_pytorch_controlnet_model: Single file loading failed: {e}"
+ )
# Fallback: try pretrained loading in case it's in a proper directory structure
load_kwargs["local_files_only"] = True
controlnet = ControlNetModel.from_pretrained(str(model_path.parent), **load_kwargs)
@@ -608,29 +630,27 @@ def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: O
# Loading from HuggingFace Hub - respect offline mode
if is_offline:
load_kwargs["local_files_only"] = True
- logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only")
-
+ logger.info(
+ f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only"
+ )
+
if "/" in model_id and model_id.count("/") > 1:
parts = model_id.split("/")
repo_id = "/".join(parts[:2])
subfolder = "/".join(parts[2:])
- controlnet = ControlNetModel.from_pretrained(
- repo_id, subfolder=subfolder, **load_kwargs
- )
+ controlnet = ControlNetModel.from_pretrained(repo_id, subfolder=subfolder, **load_kwargs)
else:
controlnet = ControlNetModel.from_pretrained(model_id, **load_kwargs)
controlnet = controlnet.to(device=self.device, dtype=self.dtype)
# Track model_id for updater diffing
try:
- setattr(controlnet, 'model_id', model_id)
+ setattr(controlnet, "model_id", model_id)
except Exception:
pass
return controlnet
except Exception as e:
import traceback
+
logger.error(f"ControlNetModule: failed to load model '{model_id}': {e}")
logger.error(traceback.format_exc())
raise
-
-
-
diff --git a/src/streamdiffusion/modules/image_processing_module.py b/src/streamdiffusion/modules/image_processing_module.py
index ffea6e5fe..b96f0c0b6 100644
--- a/src/streamdiffusion/modules/image_processing_module.py
+++ b/src/streamdiffusion/modules/image_processing_module.py
@@ -1,55 +1,57 @@
-from typing import List, Optional, Any, Dict
+from typing import Any, Dict, List
+
import torch
-from ..preprocessing.orchestrator_user import OrchestratorUser
-from ..preprocessing.pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator
from ..hooks import ImageCtx, ImageHook
+from ..preprocessing.orchestrator_user import OrchestratorUser
class ImageProcessingModule(OrchestratorUser):
"""
Shared base class for image domain processing modules.
-
+
Handles sequential chain execution for both preprocessing and postprocessing
timing variants. Processing domain is always image tensors.
"""
-
+
def __init__(self):
"""Initialize image processing module."""
self.processors = []
-
+
def _process_image_chain(self, input_image: torch.Tensor) -> torch.Tensor:
"""Execute sequential chain of processors in image domain.
-
+
Uses the shared orchestrator's sequential chain processing.
"""
if not self.processors:
return input_image
-
+
ordered_processors = self._get_ordered_processors()
return self._preprocessing_orchestrator.execute_pipeline_chain(
input_image, ordered_processors, processing_domain="image"
)
-
+
def add_processor(self, proc_config: Dict[str, Any]) -> None:
"""Add a processor using the existing registry, following ControlNet pattern."""
from streamdiffusion.preprocessing.processors import get_preprocessor
-
- processor_type = proc_config.get('type')
+
+ processor_type = proc_config.get("type")
if not processor_type:
raise ValueError("Processor config missing 'type' field")
-
+
# Check if processor is enabled (default to True, same as ControlNet)
- enabled = proc_config.get('enabled', True)
-
+ enabled = proc_config.get("enabled", True)
+
# Create processor using existing registry (same as ControlNet)
# ImageProcessingModule uses 'pipeline' normalization context
- processor = get_preprocessor(processor_type, pipeline_ref=getattr(self, '_stream', None), normalization_context='pipeline')
-
+ processor = get_preprocessor(
+ processor_type, pipeline_ref=getattr(self, "_stream", None), normalization_context="pipeline"
+ )
+
# Apply parameters (same pattern as ControlNet)
- processor_params = proc_config.get('params', {})
+ processor_params = proc_config.get("params", {})
if processor_params:
- if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict):
+ if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict):
processor.params.update(processor_params)
for name, value in processor_params.items():
try:
@@ -57,109 +59,109 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None:
setattr(processor, name, value)
except Exception:
pass
-
+
# Set order for sequential execution
- order = proc_config.get('order', len(self.processors))
- setattr(processor, 'order', order)
-
+ order = proc_config.get("order", len(self.processors))
+ setattr(processor, "order", order)
+
# Set enabled state
- setattr(processor, 'enabled', enabled)
-
+ setattr(processor, "enabled", enabled)
+
# Align preprocessor target size with stream resolution (same as ControlNet)
- if hasattr(self, '_stream'):
+ if hasattr(self, "_stream"):
try:
- if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict):
- processor.params['image_width'] = int(self._stream.width)
- processor.params['image_height'] = int(self._stream.height)
- if hasattr(processor, 'image_width'):
- setattr(processor, 'image_width', int(self._stream.width))
- if hasattr(processor, 'image_height'):
- setattr(processor, 'image_height', int(self._stream.height))
+ if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict):
+ processor.params["image_width"] = int(self._stream.width)
+ processor.params["image_height"] = int(self._stream.height)
+ if hasattr(processor, "image_width"):
+ setattr(processor, "image_width", int(self._stream.width))
+ if hasattr(processor, "image_height"):
+ setattr(processor, "image_height", int(self._stream.height))
except Exception:
pass
-
+
self.processors.append(processor)
-
+
def _get_ordered_processors(self) -> List[Any]:
"""Return enabled processors in execution order based on their order attribute."""
# Filter for enabled processors first, then sort by order
- enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)]
- return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0))
+ enabled_processors = [p for p in self.processors if getattr(p, "enabled", True)]
+ return sorted(enabled_processors, key=lambda p: getattr(p, "order", 0))
class ImagePreprocessingModule(ImageProcessingModule):
"""
Image domain preprocessing module - executes before VAE encoding.
-
+
Timing: After image_processor.preprocess(), before similar_image_filter
Uses pipelined processing for performance optimization.
"""
-
+
def install(self, stream) -> None:
"""Install module by registering hook with stream and attaching orchestrators."""
self._stream = stream # Store stream reference for dimension access
self.attach_orchestrator(stream) # For sequential chain processing (fallback)
self.attach_pipeline_preprocessing_orchestrator(stream) # For pipelined processing
stream.image_preprocessing_hooks.append(self.build_image_hook())
-
+
def build_image_hook(self) -> ImageHook:
"""Build hook function that processes image context with pipelined processing."""
+
def hook(ctx: ImageCtx) -> ImageCtx:
ctx.image = self._process_image_pipelined(ctx.image)
return ctx
+
return hook
-
+
def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor:
"""Execute pipelined processing of preprocessors for performance.
-
+
Uses PipelinePreprocessingOrchestrator for Frame N-1 results while starting Frame N processing.
Falls back to synchronous processing when needed.
"""
if not self.processors:
return input_image
-
+
ordered_processors = self._get_ordered_processors()
-
+
# Use pipelined pipeline preprocessing orchestrator for performance
- return self._pipeline_preprocessing_orchestrator.process_pipelined(
- input_image, ordered_processors
- )
+ return self._pipeline_preprocessing_orchestrator.process_pipelined(input_image, ordered_processors)
class ImagePostprocessingModule(ImageProcessingModule):
"""
Image domain postprocessing module - executes after VAE decoding.
-
+
Timing: After decode_image(), before returning final output
Uses pipelined processing for performance optimization.
"""
-
+
def install(self, stream) -> None:
"""Install module by registering hook with stream and attaching orchestrators."""
self._stream = stream # Store stream reference for dimension access
self.attach_preprocessing_orchestrator(stream) # For sequential chain processing (fallback)
self.attach_postprocessing_orchestrator(stream) # For pipelined processing
stream.image_postprocessing_hooks.append(self.build_image_hook())
-
+
def build_image_hook(self) -> ImageHook:
"""Build hook function that processes image context with pipelined processing."""
+
def hook(ctx: ImageCtx) -> ImageCtx:
ctx.image = self._process_image_pipelined(ctx.image)
return ctx
+
return hook
-
+
def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor:
"""Execute pipelined processing of postprocessors for performance.
-
+
Uses PostprocessingOrchestrator for Frame N-1 results while starting Frame N processing.
Falls back to synchronous processing when needed.
"""
if not self.processors:
return input_image
-
+
ordered_processors = self._get_ordered_processors()
-
+
# Use pipelined postprocessing orchestrator for performance
- return self._postprocessing_orchestrator.process_pipelined(
- input_image, ordered_processors
- )
+ return self._postprocessing_orchestrator.process_pipelined(input_image, ordered_processors)
diff --git a/src/streamdiffusion/modules/ipadapter_module.py b/src/streamdiffusion/modules/ipadapter_module.py
index b283799f3..4e3b95ead 100644
--- a/src/streamdiffusion/modules/ipadapter_module.py
+++ b/src/streamdiffusion/modules/ipadapter_module.py
@@ -1,16 +1,18 @@
from __future__ import annotations
+import logging
+import os
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Any
from enum import Enum
+from typing import Any, Dict, Optional, Tuple
+
import torch
-from streamdiffusion.hooks import EmbedsCtx, EmbeddingHook, StepCtx, UnetKwargsDelta, UnetHook
-import os
+from streamdiffusion.hooks import EmbeddingHook, EmbedsCtx, StepCtx, UnetHook, UnetKwargsDelta
from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser
-import logging
from streamdiffusion.utils.reporting import report_error
+
logger = logging.getLogger(__name__)
@@ -27,6 +29,7 @@ class IPAdapterConfig:
This module focuses only on embedding composition (step 2 of migration).
Runtime installation and wrapper wiring will come in later steps.
"""
+
style_image_key: Optional[str] = None
num_image_tokens: int = 4 # e.g., 4 for standard, 16 for plus
ipadapter_model_path: Optional[str] = None
@@ -59,7 +62,7 @@ class IPAdapterConfig:
"image_encoder_path": "h94/IP-Adapter/models/image_encoder",
},
("SD2.1", IPAdapterType.REGULAR): None, # not available from h94 (ip-adapter_sd21.bin was never released)
- ("SD2.1", IPAdapterType.PLUS): None, # not available from h94
+ ("SD2.1", IPAdapterType.PLUS): None, # not available from h94
("SD2.1", IPAdapterType.FACEID): None, # not available from h94
("SDXL", IPAdapterType.REGULAR): {
"model_path": "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.bin",
@@ -78,15 +81,15 @@ class IPAdapterConfig:
# Set of all known HF model paths β used to distinguish known vs custom paths.
# Custom/local paths are never overridden.
_KNOWN_IPADAPTER_PATHS: frozenset = frozenset(
- entry["model_path"]
- for entry in IPADAPTER_MODEL_MAP.values()
- if entry is not None
+ entry["model_path"] for entry in IPADAPTER_MODEL_MAP.values() if entry is not None
)
-_KNOWN_ENCODER_PATHS: frozenset = frozenset({
- "h94/IP-Adapter/models/image_encoder",
- "h94/IP-Adapter/sdxl_models/image_encoder",
-})
+_KNOWN_ENCODER_PATHS: frozenset = frozenset(
+ {
+ "h94/IP-Adapter/models/image_encoder",
+ "h94/IP-Adapter/sdxl_models/image_encoder",
+ }
+)
def _normalize_model_type(detected_model_type: str, is_sdxl: bool) -> Optional[str]:
@@ -183,10 +186,7 @@ def resolve_ipadapter_paths(
# Resolve encoder path (only if it's a known HF encoder β custom encoders untouched)
if current_encoder_path in _KNOWN_ENCODER_PATHS and current_encoder_path != correct_encoder_path:
- logger.info(
- f"IP-Adapter: resolving image encoder "
- f"'{current_encoder_path}' β '{correct_encoder_path}'."
- )
+ logger.info(f"IP-Adapter: resolving image encoder '{current_encoder_path}' β '{correct_encoder_path}'.")
cfg["image_encoder_path"] = correct_encoder_path
return cfg
@@ -209,7 +209,9 @@ def build_embedding_hook(self, stream) -> EmbeddingHook:
def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx:
# Fetch cached image token embeddings (prompt, negative)
- cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings(style_key)
+ cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings(
+ style_key
+ )
image_prompt_tokens: Optional[torch.Tensor] = None
image_negative_tokens: Optional[torch.Tensor] = None
if cached is not None:
@@ -220,7 +222,9 @@ def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx:
batch_size = ctx.prompt_embeds.shape[0]
if image_prompt_tokens is None:
image_prompt_tokens = torch.zeros(
- (batch_size, num_tokens, hidden_dim), dtype=ctx.prompt_embeds.dtype, device=ctx.prompt_embeds.device
+ (batch_size, num_tokens, hidden_dim),
+ dtype=ctx.prompt_embeds.dtype,
+ device=ctx.prompt_embeds.device,
)
else:
if image_prompt_tokens.shape[1] != num_tokens:
@@ -242,7 +246,9 @@ def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx:
if neg_with_image is not None:
if image_negative_tokens is None:
image_negative_tokens = torch.zeros(
- (neg_with_image.shape[0], num_tokens, hidden_dim), dtype=neg_with_image.dtype, device=neg_with_image.device
+ (neg_with_image.shape[0], num_tokens, hidden_dim),
+ dtype=neg_with_image.dtype,
+ device=neg_with_image.device,
)
else:
if image_negative_tokens.shape[0] != neg_with_image.shape[0]:
@@ -291,14 +297,14 @@ def install(self, stream) -> None:
# Create IP-Adapter and install processors into UNet (FaceID-aware)
ip_kwargs = {
- 'pipe': stream.pipe,
- 'ipadapter_ckpt_path': resolved_ip_path,
- 'image_encoder_path': resolved_encoder_path,
- 'device': stream.device,
- 'dtype': stream.dtype,
+ "pipe": stream.pipe,
+ "ipadapter_ckpt_path": resolved_ip_path,
+ "image_encoder_path": resolved_encoder_path,
+ "device": stream.device,
+ "dtype": stream.dtype,
}
if self.config.type == IPAdapterType.FACEID and self.config.insightface_model_name:
- ip_kwargs['insightface_model_name'] = self.config.insightface_model_name
+ ip_kwargs["insightface_model_name"] = self.config.insightface_model_name
print(
f"IPAdapterModule.install: Initializing FaceID IP-Adapter with InsightFace model: {self.config.insightface_model_name}"
)
@@ -311,6 +317,7 @@ def install(self, stream) -> None:
# AttnProcessor2_0 which accepts kvo_cache and returns (hidden_states, kvo_cache).
try:
from diffusers.models.attention_processor import AttnProcessor2_0 as NativeAttnProcessor2_0
+
attn_procs = stream.pipe.unet.attn_processors
for name in attn_procs:
if name.endswith("attn1.processor"):
@@ -324,6 +331,7 @@ def install(self, stream) -> None:
if self.config.type == IPAdapterType.FACEID:
try:
from streamdiffusion.preprocessing.processors.faceid_embedding import FaceIDEmbeddingPreprocessor
+
embedding_preprocessor = FaceIDEmbeddingPreprocessor(
ipadapter=ipadapter,
device=stream.device,
@@ -357,11 +365,11 @@ def install(self, stream) -> None:
# Expose IPAdapter instance as single source of truth
try:
- setattr(stream, 'ipadapter', ipadapter)
+ setattr(stream, "ipadapter", ipadapter)
# Extend IPAdapter with our custom attributes since diffusers IPAdapter doesn't expose current state
- setattr(ipadapter, 'weight_type', self.config.weight_type) # For build_layer_weights
- setattr(ipadapter, 'scale', float(self.config.scale)) # Track current scale
- setattr(ipadapter, 'enabled', bool(self.config.enabled)) # Track enabled state
+ setattr(ipadapter, "weight_type", self.config.weight_type) # For build_layer_weights
+ setattr(ipadapter, "scale", float(self.config.scale)) # Track current scale
+ setattr(ipadapter, "enabled", bool(self.config.enabled)) # Track enabled state
except Exception:
pass
@@ -389,7 +397,10 @@ def _resolve_model_path(self, model_path: Optional[str]) -> str:
from huggingface_hub import hf_hub_download, snapshot_download
except Exception as e:
import logging
- logging.getLogger(__name__).error(f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}")
+
+ logging.getLogger(__name__).error(
+ f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}"
+ )
raise
parts = model_path.split("/")
@@ -419,28 +430,28 @@ def build_unet_hook(self, stream) -> UnetHook:
- For PyTorch UNet with installed IP processors, modulate per-layer processor scale by time factor
"""
_last_enabled_state = None # Track previous enabled state to avoid redundant updates
-
+
def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
# If no IP-Adapter installed, do nothing
- if not hasattr(stream, 'ipadapter') or stream.ipadapter is None:
+ if not hasattr(stream, "ipadapter") or stream.ipadapter is None:
return UnetKwargsDelta()
# Check if IPAdapter is enabled
- enabled = getattr(stream.ipadapter, 'enabled', True)
+ enabled = getattr(stream.ipadapter, "enabled", True)
# Read base weight and weight type from IPAdapter instance
try:
- base_weight = float(getattr(stream.ipadapter, 'scale', 1.0)) if enabled else 0.0
+ base_weight = float(getattr(stream.ipadapter, "scale", 1.0)) if enabled else 0.0
except Exception:
base_weight = 0.0 if not enabled else 1.0
- weight_type = getattr(stream.ipadapter, 'weight_type', None)
+ weight_type = getattr(stream.ipadapter, "weight_type", None)
# Determine total steps and current step index for time scheduling
total_steps = None
try:
- if hasattr(stream, 'denoising_steps_num') and isinstance(stream.denoising_steps_num, int):
+ if hasattr(stream, "denoising_steps_num") and isinstance(stream.denoising_steps_num, int):
total_steps = int(stream.denoising_steps_num)
- elif hasattr(stream, 't_list') and stream.t_list is not None:
+ elif hasattr(stream, "t_list") and stream.t_list is not None:
total_steps = len(stream.t_list)
except Exception:
total_steps = None
@@ -449,6 +460,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
if total_steps is not None and ctx.step_index is not None:
try:
from diffusers_ipadapter.ip_adapter.attention_processor import build_time_weight_factor
+
time_factor = float(build_time_weight_factor(weight_type, int(ctx.step_index), int(total_steps)))
except Exception:
# Do not add fallback mechanisms
@@ -456,18 +468,20 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
# TensorRT engine path: supply ipadapter_scale vector via extra kwargs
try:
- is_trt_unet = hasattr(stream, 'unet') and hasattr(stream.unet, 'engine') and hasattr(stream.unet, 'stream')
+ is_trt_unet = (
+ hasattr(stream, "unet") and hasattr(stream.unet, "engine") and hasattr(stream.unet, "stream")
+ )
except Exception:
is_trt_unet = False
- if is_trt_unet and getattr(stream.unet, 'use_ipadapter', False):
+ if is_trt_unet and getattr(stream.unet, "use_ipadapter", False):
try:
from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights
except Exception:
# If helper unavailable, do not construct weights here
build_layer_weights = None # type: ignore
- num_ip_layers = getattr(stream.unet, 'num_ip_layers', None)
+ num_ip_layers = getattr(stream.unet, "num_ip_layers", None)
if isinstance(num_ip_layers, int) and num_ip_layers > 0:
weights_tensor = None
try:
@@ -476,24 +490,26 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
except Exception:
weights_tensor = None
if weights_tensor is None:
- weights_tensor = torch.full((num_ip_layers,), float(base_weight), dtype=torch.float32, device=stream.device)
+ weights_tensor = torch.full(
+ (num_ip_layers,), float(base_weight), dtype=torch.float32, device=stream.device
+ )
# Apply per-step time factor
try:
weights_tensor = weights_tensor * float(time_factor)
except Exception:
pass
- return UnetKwargsDelta(extra_unet_kwargs={'ipadapter_scale': weights_tensor})
+ return UnetKwargsDelta(extra_unet_kwargs={"ipadapter_scale": weights_tensor})
# PyTorch UNet path: modulate installed processor scales by time factor and enabled state
try:
nonlocal _last_enabled_state
# Only process if we need to make changes (time scaling or state transition)
- needs_update = (time_factor != 1.0 or enabled != _last_enabled_state)
- if needs_update and hasattr(stream.pipe, 'unet') and hasattr(stream.pipe.unet, 'attn_processors'):
+ needs_update = time_factor != 1.0 or enabled != _last_enabled_state
+ if needs_update and hasattr(stream.pipe, "unet") and hasattr(stream.pipe.unet, "attn_processors"):
_last_enabled_state = enabled
for proc in stream.pipe.unet.attn_processors.values():
- if hasattr(proc, 'scale') and hasattr(proc, '_ip_layer_index'):
- base_val = getattr(proc, '_base_scale', proc.scale)
+ if hasattr(proc, "scale") and hasattr(proc, "_ip_layer_index"):
+ base_val = getattr(proc, "_base_scale", proc.scale)
# Apply both enabled state and time factor
final_scale = float(base_val) * float(time_factor) if enabled else 0.0
proc.scale = final_scale
@@ -503,4 +519,3 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
return UnetKwargsDelta()
return _unet_hook
-
diff --git a/src/streamdiffusion/modules/latent_processing_module.py b/src/streamdiffusion/modules/latent_processing_module.py
index 256c66f0e..78edf1b37 100644
--- a/src/streamdiffusion/modules/latent_processing_module.py
+++ b/src/streamdiffusion/modules/latent_processing_module.py
@@ -1,54 +1,55 @@
-from typing import List, Optional, Any, Dict
+from typing import Any, Dict, List
+
import torch
-from ..preprocessing.orchestrator_user import OrchestratorUser
from ..hooks import LatentCtx, LatentHook
+from ..preprocessing.orchestrator_user import OrchestratorUser
class LatentProcessingModule(OrchestratorUser):
"""
Shared base class for latent domain processing modules.
-
+
Handles sequential chain execution for both preprocessing and postprocessing
timing variants. Processing domain is always latent tensors.
"""
-
+
def __init__(self):
"""Initialize latent processing module."""
self.processors = []
-
+
def _process_latent_chain(self, input_latent: torch.Tensor) -> torch.Tensor:
"""Execute sequential chain of processors in latent domain.
-
+
Uses the shared orchestrator's sequential chain processing.
"""
if not self.processors:
return input_latent
-
+
ordered_processors = self._get_ordered_processors()
return self._preprocessing_orchestrator.execute_pipeline_chain(
input_latent, ordered_processors, processing_domain="latent"
)
-
+
def add_processor(self, proc_config: Dict[str, Any]) -> None:
"""Add a processor using the existing registry, following ControlNet pattern."""
from streamdiffusion.preprocessing.processors import get_preprocessor
-
- processor_type = proc_config.get('type')
+
+ processor_type = proc_config.get("type")
if not processor_type:
raise ValueError("Processor config missing 'type' field")
-
+
# Check if processor is enabled (default to True, same as ControlNet)
- enabled = proc_config.get('enabled', True)
-
+ enabled = proc_config.get("enabled", True)
+
# Create processor using existing registry (same as ControlNet)
# LatentProcessingModule uses 'latent' normalization context (works in latent space)
- processor = get_preprocessor(processor_type, pipeline_ref=self._stream, normalization_context='latent')
-
+ processor = get_preprocessor(processor_type, pipeline_ref=self._stream, normalization_context="latent")
+
# Apply parameters (same pattern as ControlNet)
- processor_params = proc_config.get('params', {})
+ processor_params = proc_config.get("params", {})
if processor_params:
- if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict):
+ if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict):
processor.params.update(processor_params)
for name, value in processor_params.items():
try:
@@ -56,62 +57,66 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None:
setattr(processor, name, value)
except Exception:
pass
-
+
# Set order for sequential execution
- order = proc_config.get('order', len(self.processors))
- setattr(processor, 'order', order)
-
+ order = proc_config.get("order", len(self.processors))
+ setattr(processor, "order", order)
+
# Set enabled state
- setattr(processor, 'enabled', enabled)
-
+ setattr(processor, "enabled", enabled)
+
# Pipeline reference is now automatically handled by the factory function
-
+
self.processors.append(processor)
-
+
def _get_ordered_processors(self) -> List[Any]:
"""Return enabled processors in execution order based on their order attribute."""
# Filter for enabled processors first, then sort by order
- enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)]
- return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0))
+ enabled_processors = [p for p in self.processors if getattr(p, "enabled", True)]
+ return sorted(enabled_processors, key=lambda p: getattr(p, "order", 0))
class LatentPreprocessingModule(LatentProcessingModule):
"""
Latent domain preprocessing module - executes after VAE encoding, before diffusion.
-
+
Timing: After encode_image(), before predict_x0_batch()
"""
-
+
def install(self, stream) -> None:
"""Install module by registering hook with stream and attaching orchestrator."""
self.attach_orchestrator(stream)
self._stream = stream # Store stream reference like ControlNet module does
stream.latent_preprocessing_hooks.append(self.build_latent_hook())
-
+
def build_latent_hook(self) -> LatentHook:
"""Build hook function that processes latent context."""
+
def hook(ctx: LatentCtx) -> LatentCtx:
ctx.latent = self._process_latent_chain(ctx.latent)
return ctx
+
return hook
class LatentPostprocessingModule(LatentProcessingModule):
"""
Latent domain postprocessing module - executes after diffusion, before VAE decoding.
-
+
Timing: After predict_x0_batch(), before decode_image()
"""
-
+
def install(self, stream) -> None:
"""Install module by registering hook with stream and attaching orchestrator."""
self.attach_orchestrator(stream)
self._stream = stream # Store stream reference like ControlNet module does
stream.latent_postprocessing_hooks.append(self.build_latent_hook())
-
+
def build_latent_hook(self) -> LatentHook:
"""Build hook function that processes latent context."""
+
def hook(ctx: LatentCtx) -> LatentCtx:
ctx.latent = self._process_latent_chain(ctx.latent)
return ctx
+
return hook
diff --git a/src/streamdiffusion/pip_utils.py b/src/streamdiffusion/pip_utils.py
index 6ae3f11cd..4a28c0a0e 100644
--- a/src/streamdiffusion/pip_utils.py
+++ b/src/streamdiffusion/pip_utils.py
@@ -27,13 +27,16 @@ def _check_torch_installed():
raise RuntimeError(msg)
if not torch.version.cuda:
- raise RuntimeError("Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package.")
+ raise RuntimeError(
+ "Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package."
+ )
def get_cuda_version() -> str | None:
_check_torch_installed()
import torch
+
return torch.version.cuda
@@ -66,7 +69,7 @@ def is_installed(package: str) -> bool:
def run_python(command: str, env: Dict[str, str] | None = None) -> str:
run_kwargs = {
- "args": f"\"{python}\" {command}",
+ "args": f'"{python}" {command}',
"shell": True,
"env": os.environ if env is None else env,
"encoding": "utf8",
diff --git a/src/streamdiffusion/preprocessing/__init__.py b/src/streamdiffusion/preprocessing/__init__.py
index 4228ee69b..c52a8a2ed 100644
--- a/src/streamdiffusion/preprocessing/__init__.py
+++ b/src/streamdiffusion/preprocessing/__init__.py
@@ -1,13 +1,14 @@
-from .preprocessing_orchestrator import PreprocessingOrchestrator
-from .postprocessing_orchestrator import PostprocessingOrchestrator
-from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator
from .base_orchestrator import BaseOrchestrator
from .orchestrator_user import OrchestratorUser
+from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator
+from .postprocessing_orchestrator import PostprocessingOrchestrator
+from .preprocessing_orchestrator import PreprocessingOrchestrator
+
__all__ = [
"PreprocessingOrchestrator",
"PostprocessingOrchestrator",
"PipelinePreprocessingOrchestrator",
"BaseOrchestrator",
- "OrchestratorUser"
+ "OrchestratorUser",
]
diff --git a/src/streamdiffusion/preprocessing/base_orchestrator.py b/src/streamdiffusion/preprocessing/base_orchestrator.py
index d6d86bf2b..e5f6c1b22 100644
--- a/src/streamdiffusion/preprocessing/base_orchestrator.py
+++ b/src/streamdiffusion/preprocessing/base_orchestrator.py
@@ -1,144 +1,148 @@
-import torch
-from typing import List, Optional, Union, Dict, Any, Tuple, Callable, TypeVar, Generic
-from abc import ABC, abstractmethod
-import numpy as np
import concurrent.futures
import logging
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Generic, Optional, TypeVar
+
+import torch
+
logger = logging.getLogger(__name__)
# Type variables for generic orchestrator
-T = TypeVar('T') # Input type (e.g., ControlImage for preprocessing)
-R = TypeVar('R') # Result type (e.g., List[torch.Tensor] for preprocessing)
+T = TypeVar("T") # Input type (e.g., ControlImage for preprocessing)
+R = TypeVar("R") # Result type (e.g., List[torch.Tensor] for preprocessing)
class BaseOrchestrator(Generic[T, R], ABC):
"""
Generic base orchestrator for parallelized and pipelined processing.
-
+
Handles thread pool management, pipeline state, and inter-frame pipelining
while leaving domain-specific processing logic to subclasses.
-
+
Type Parameters:
T: Input type for processing operations
R: Result type returned from processing operations
"""
-
- def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, timeout_ms: float = 10.0, pipeline_ref: Optional[Any] = None):
+
+ def __init__(
+ self,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+ max_workers: int = 4,
+ timeout_ms: float = 10.0,
+ pipeline_ref: Optional[Any] = None,
+ ):
self.device = device
self.dtype = dtype
self.timeout_ms = timeout_ms
self.pipeline_ref = pipeline_ref
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
-
+
# Pipeline state for pipelined processing
self._next_frame_future = None
self._next_frame_result = None
-
+
# CUDA stream for background processing to avoid GPU contention
self._background_stream = None
device_str = str(device)
if device_str.startswith("cuda") and torch.cuda.is_available():
self._background_stream = torch.cuda.Stream()
-
-
def cleanup(self) -> None:
"""Cleanup thread pool and CUDA stream resources"""
- if hasattr(self, '_executor'):
+ if hasattr(self, "_executor"):
self._executor.shutdown(wait=True)
-
+
# Cleanup CUDA stream if it exists
- if hasattr(self, '_background_stream') and self._background_stream is not None:
+ if hasattr(self, "_background_stream") and self._background_stream is not None:
# Synchronize the stream before cleanup
torch.cuda.synchronize()
self._background_stream = None
-
+
def __del__(self):
"""Cleanup on destruction"""
try:
self.cleanup()
except:
pass
-
+
@abstractmethod
def _should_use_sync_processing(self, *args, **kwargs) -> bool:
"""
Determine if synchronous processing should be used instead of pipelined.
-
+
Subclasses implement domain-specific logic (e.g., feedback preprocessor detection).
-
+
Returns:
True if sync processing should be used, False for pipelined processing
"""
pass
-
+
@abstractmethod
def _process_frame_background(self, *args, **kwargs) -> Dict[str, Any]:
"""
Process a frame in the background thread.
-
+
Subclasses implement their specific processing logic here.
-
+
Returns:
Dictionary containing processing results and status
"""
pass
-
+
def process_pipelined(self, input_data: T, *args, **kwargs) -> R:
"""
Process input with intelligent pipelining.
-
+
Automatically falls back to sync processing when required by domain logic,
otherwise uses pipelined processing for performance.
-
+
Args:
input_data: Input data to process
*args, **kwargs: Additional arguments passed to processing methods
-
+
Returns:
Processing results
"""
# Check if sync processing is required (domain-specific logic)
if self._should_use_sync_processing(*args, **kwargs):
return self.process_sync(input_data, *args, **kwargs)
-
+
# Use pipelined processing
# Wait for previous frame processing; non-blocking with short timeout
self._wait_for_previous_processing()
-
+
# Start next frame processing in background
self._start_next_frame_processing(input_data, *args, **kwargs)
-
+
# Apply current frame processing results if available; otherwise signal no update
return self._apply_current_frame_processing(*args, **kwargs)
-
+
@abstractmethod
def process_sync(self, input_data: T, *args, **kwargs) -> R:
"""
Process input synchronously.
-
+
Subclasses implement their specific synchronous processing logic.
-
+
Args:
input_data: Input data to process
*args, **kwargs: Additional arguments passed to processing methods
-
+
Returns:
Processing results
"""
pass
-
+
def _start_next_frame_processing(self, input_data: T, *args, **kwargs) -> None:
"""Start processing for next frame in background thread"""
# Submit background processing
- self._next_frame_future = self._executor.submit(
- self._process_frame_background, input_data, *args, **kwargs
- )
-
+ self._next_frame_future = self._executor.submit(self._process_frame_background, input_data, *args, **kwargs)
+
def _wait_for_previous_processing(self) -> None:
"""Wait for previous frame processing with configurable timeout"""
- if hasattr(self, '_next_frame_future') and self._next_frame_future is not None:
+ if hasattr(self, "_next_frame_future") and self._next_frame_future is not None:
try:
# Use configurable timeout based on orchestrator type
self._next_frame_result = self._next_frame_future.result(timeout=self.timeout_ms / 1000.0)
@@ -150,52 +154,52 @@ def _wait_for_previous_processing(self) -> None:
self._next_frame_result = None
else:
self._next_frame_result = None
-
+
def _apply_current_frame_processing(self, processors=None, *args, **kwargs) -> R:
"""
Apply processing results from previous iteration.
-
+
Default implementation provides common fallback logic for tensor-to-tensor orchestrators.
Subclasses can override this method for specialized behavior.
-
+
Args:
processors: List of processors/postprocessors to apply (parameter name varies by subclass)
*args, **kwargs: Additional arguments
-
+
Returns:
Processing results, or processed current input if no results available
"""
- if not hasattr(self, '_next_frame_result') or self._next_frame_result is None:
+ if not hasattr(self, "_next_frame_result") or self._next_frame_result is None:
# First frame or no background results - process current input synchronously
- if hasattr(self, '_current_input_tensor') and self._current_input_tensor is not None:
+ if hasattr(self, "_current_input_tensor") and self._current_input_tensor is not None:
if processors:
return self.process_sync(self._current_input_tensor, processors)
else:
return self._current_input_tensor
-
+
# If we don't have current input stored, we have an issue
class_name = self.__class__.__name__
logger.error(f"{class_name}: No background results and no current input tensor available")
raise RuntimeError(f"{class_name}: No processing results available")
-
+
result = self._next_frame_result
- if result['status'] != 'success':
+ if result["status"] != "success":
class_name = self.__class__.__name__
logger.warning(f"{class_name}: Background processing failed: {result.get('error', 'Unknown error')}")
# Process current input synchronously on error
- if hasattr(self, '_current_input_tensor') and self._current_input_tensor is not None:
+ if hasattr(self, "_current_input_tensor") and self._current_input_tensor is not None:
if processors:
return self.process_sync(self._current_input_tensor, processors)
else:
return self._current_input_tensor
raise RuntimeError(f"{class_name}: Background processing failed and no fallback available")
-
- return result['result']
-
+
+ return result["result"]
+
def _set_background_stream_context(self):
"""
Set CUDA stream context for background processing.
-
+
Returns:
The original stream to restore later, or None if no background stream
"""
@@ -204,11 +208,11 @@ def _set_background_stream_context(self):
torch.cuda.set_stream(self._background_stream)
return original_stream
return None
-
+
def _restore_stream_context(self, original_stream):
"""
Restore the original CUDA stream context.
-
+
Args:
original_stream: The stream to restore, or None to do nothing
"""
diff --git a/src/streamdiffusion/preprocessing/orchestrator_user.py b/src/streamdiffusion/preprocessing/orchestrator_user.py
index 2503c14e3..e731540b9 100644
--- a/src/streamdiffusion/preprocessing/orchestrator_user.py
+++ b/src/streamdiffusion/preprocessing/orchestrator_user.py
@@ -2,9 +2,9 @@
from typing import Optional
-from .preprocessing_orchestrator import PreprocessingOrchestrator
-from .postprocessing_orchestrator import PostprocessingOrchestrator
from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator
+from .postprocessing_orchestrator import PostprocessingOrchestrator
+from .preprocessing_orchestrator import PreprocessingOrchestrator
class OrchestratorUser:
@@ -20,32 +20,36 @@ class OrchestratorUser:
def attach_orchestrator(self, stream) -> None:
"""Attach preprocessing orchestrator (backward compatibility)."""
self.attach_preprocessing_orchestrator(stream)
-
+
def attach_preprocessing_orchestrator(self, stream) -> None:
"""Attach shared preprocessing orchestrator from stream."""
- orchestrator = getattr(stream, 'preprocessing_orchestrator', None)
+ orchestrator = getattr(stream, "preprocessing_orchestrator", None)
if orchestrator is None:
# Lazy-create on stream once, on first user that needs it
- orchestrator = PreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream)
- setattr(stream, 'preprocessing_orchestrator', orchestrator)
+ orchestrator = PreprocessingOrchestrator(
+ device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream
+ )
+ setattr(stream, "preprocessing_orchestrator", orchestrator)
self._preprocessing_orchestrator = orchestrator
-
+
def attach_postprocessing_orchestrator(self, stream) -> None:
"""Attach shared postprocessing orchestrator from stream."""
- orchestrator = getattr(stream, 'postprocessing_orchestrator', None)
+ orchestrator = getattr(stream, "postprocessing_orchestrator", None)
if orchestrator is None:
# Lazy-create on stream once, on first user that needs it
- orchestrator = PostprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream)
- setattr(stream, 'postprocessing_orchestrator', orchestrator)
+ orchestrator = PostprocessingOrchestrator(
+ device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream
+ )
+ setattr(stream, "postprocessing_orchestrator", orchestrator)
self._postprocessing_orchestrator = orchestrator
-
+
def attach_pipeline_preprocessing_orchestrator(self, stream) -> None:
"""Attach shared pipeline preprocessing orchestrator from stream."""
- orchestrator = getattr(stream, 'pipeline_preprocessing_orchestrator', None)
+ orchestrator = getattr(stream, "pipeline_preprocessing_orchestrator", None)
if orchestrator is None:
# Lazy-create on stream once, on first user that needs it
- orchestrator = PipelinePreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream)
- setattr(stream, 'pipeline_preprocessing_orchestrator', orchestrator)
+ orchestrator = PipelinePreprocessingOrchestrator(
+ device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream
+ )
+ setattr(stream, "pipeline_preprocessing_orchestrator", orchestrator)
self._pipeline_preprocessing_orchestrator = orchestrator
-
-
diff --git a/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py
index 8cf4e7171..382e874fb 100644
--- a/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py
+++ b/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py
@@ -1,32 +1,42 @@
-import torch
-from typing import List, Dict, Any, Optional
import logging
+from typing import Any, Dict, List, Optional
+
+import torch
+
from .base_orchestrator import BaseOrchestrator
+
logger = logging.getLogger(__name__)
+
class PipelinePreprocessingOrchestrator(BaseOrchestrator[torch.Tensor, torch.Tensor]):
"""
Orchestrates pipeline input preprocessing with parallelization and pipelining.
-
+
Handles preprocessing of input tensors before they enter the diffusion pipeline.
-
+
Tensor ranges:
- Input: Receives [-1, 1] tensors from image_processor.preprocess()
- Processors: Work in [-1, 1] space when normalization_context='pipeline'
- Output: Returns [-1, 1] tensors for pipeline processing
-
+
Note: Processors created with normalization_context='pipeline' expect and preserve
[-1, 1] range. No automatic conversion happens in this orchestrator.
"""
-
- def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, pipeline_ref: Optional[Any] = None):
+
+ def __init__(
+ self,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+ max_workers: int = 4,
+ pipeline_ref: Optional[Any] = None,
+ ):
# Pipeline preprocessing: 10ms timeout for responsive processing
super().__init__(device, dtype, max_workers, timeout_ms=10.0, pipeline_ref=pipeline_ref)
-
+
# Pipeline preprocessing specific state
self._current_input_tensor = None # For BaseOrchestrator fallback logic
-
+
def _should_use_sync_processing(self, *args, **kwargs) -> bool:
"""
Determine if synchronous processing should be used instead of pipelined.
@@ -44,123 +54,102 @@ def _should_use_sync_processing(self, *args, **kwargs) -> bool:
if not processors:
return False
for proc in processors:
- if proc is not None and getattr(proc, 'requires_sync_processing', False):
+ if proc is not None and getattr(proc, "requires_sync_processing", False):
return True
return False
-
- def process_pipelined(self,
- input_tensor: torch.Tensor,
- processors: List[Any],
- *args, **kwargs) -> torch.Tensor:
+
+ def process_pipelined(self, input_tensor: torch.Tensor, processors: List[Any], *args, **kwargs) -> torch.Tensor:
"""
Process input with intelligent pipelining.
-
+
Overrides base method to store current input tensor for fallback logic.
"""
# Store current input for fallback logic
self._current_input_tensor = input_tensor
-
+
# RACE CONDITION FIX: Check if there are actually enabled processors
# Filter to only enabled processors (same logic as _get_ordered_processors)
- enabled_processors = [p for p in processors if getattr(p, 'enabled', True)] if processors else []
-
+ enabled_processors = [p for p in processors if getattr(p, "enabled", True)] if processors else []
+
if not enabled_processors:
return input_tensor
-
+
# Call parent implementation
return super().process_pipelined(input_tensor, processors, *args, **kwargs)
-
- def process_sync(self,
- input_tensor: torch.Tensor,
- processors: List[Any]) -> torch.Tensor:
+
+ def process_sync(self, input_tensor: torch.Tensor, processors: List[Any]) -> torch.Tensor:
"""
Process pipeline input tensor synchronously through preprocessors.
-
+
Implementation of BaseOrchestrator.process_sync for pipeline preprocessing.
-
+
Args:
input_tensor: Input tensor to preprocess (already normalized)
processors: List of preprocessor instances
-
+
Returns:
Preprocessed tensor ready for pipeline processing
"""
if not processors:
return input_tensor
-
+
# Sequential application of processors
current_tensor = input_tensor
for processor in processors:
if processor is not None:
current_tensor = self._apply_single_processor(current_tensor, processor)
-
+
return current_tensor
-
- def _process_frame_background(self,
- input_tensor: torch.Tensor,
- processors: List[Any]) -> Dict[str, Any]:
+
+ def _process_frame_background(self, input_tensor: torch.Tensor, processors: List[Any]) -> Dict[str, Any]:
"""
Process a frame in the background thread.
-
+
Implementation of BaseOrchestrator._process_frame_background for pipeline preprocessing.
-
+
Returns:
Dictionary containing processing results and status
"""
try:
# Set CUDA stream for background processing
original_stream = self._set_background_stream_context()
-
+
if not processors:
- return {
- 'result': input_tensor,
- 'status': 'success'
- }
-
+ return {"result": input_tensor, "status": "success"}
+
# Process processors sequentially (most pipeline preprocessing is dependent)
current_tensor = input_tensor
for processor in processors:
if processor is not None:
current_tensor = self._apply_single_processor(current_tensor, processor)
-
- return {
- 'result': current_tensor,
- 'status': 'success'
- }
-
+
+ return {"result": current_tensor, "status": "success"}
+
except Exception as e:
logger.error(f"PipelinePreprocessingOrchestrator: Background processing failed: {e}")
# Return original input tensor on error
- return {
- 'result': input_tensor,
- 'error': str(e),
- 'status': 'error'
- }
+ return {"result": input_tensor, "error": str(e), "status": "error"}
finally:
# Restore original CUDA stream
self._restore_stream_context(original_stream)
-
-
-
- def _apply_single_processor(self,
- input_tensor: torch.Tensor,
- processor: Any) -> torch.Tensor:
+
+ def _apply_single_processor(self, input_tensor: torch.Tensor, processor: Any) -> torch.Tensor:
"""
Apply a single processor to the input tensor.
-
+
Args:
input_tensor: Input tensor to process
processor: Processor instance
-
+
Returns:
Processed tensor
"""
try:
# Apply processor
- if hasattr(processor, 'process_tensor'):
+ if hasattr(processor, "process_tensor"):
# Prefer tensor processing method
result = processor.process_tensor(input_tensor)
- elif hasattr(processor, 'process'):
+ elif hasattr(processor, "process"):
# Use general process method
result = processor.process(input_tensor)
elif callable(processor):
@@ -169,18 +158,18 @@ def _apply_single_processor(self,
else:
logger.warning(f"PipelinePreprocessingOrchestrator: Unknown processor type: {type(processor)}")
return input_tensor
-
+
# Ensure result is a tensor
if isinstance(result, torch.Tensor):
return result
else:
logger.warning(f"PipelinePreprocessingOrchestrator: Processor returned non-tensor: {type(result)}")
return input_tensor
-
+
except Exception as e:
logger.error(f"PipelinePreprocessingOrchestrator: Processor failed: {e}")
return input_tensor # Return original on error
-
+
def clear_cache(self) -> None:
"""Clear preprocessing cache"""
pass
diff --git a/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py
index 742a80e60..ef5dceb32 100644
--- a/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py
+++ b/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py
@@ -73,7 +73,7 @@ def _should_use_sync_processing(self, *args, **kwargs) -> bool:
if not processors:
return False
for proc in processors:
- if proc is not None and getattr(proc, 'requires_sync_processing', False):
+ if proc is not None and getattr(proc, "requires_sync_processing", False):
return True
return False
diff --git a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py
index cf05fbfaa..fd554ac6f 100644
--- a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py
+++ b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py
@@ -1,12 +1,15 @@
-import torch
-from typing import List, Optional, Union, Dict, Any, Tuple
-from PIL import Image
-import numpy as np
import logging
-from diffusers.utils import load_image
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
import torchvision.transforms as transforms
+from diffusers.utils import load_image
+from PIL import Image
+
from .base_orchestrator import BaseOrchestrator
+
logger = logging.getLogger(__name__)
# Type alias for control image input
@@ -16,40 +19,45 @@
class PreprocessingOrchestrator(BaseOrchestrator[ControlImage, List[Optional[torch.Tensor]]]):
"""
Orchestrates module preprocessing with typical orchestrator pipelining, but with additional intraframe parallelization, caching, and optimization.
- Modules (IPAdapter, Controlnet) share intraframe parallelism.
+ Modules (IPAdapter, Controlnet) share intraframe parallelism.
Handles image format conversion (while most are GPU native,some preprocessors are CPU only), preprocessor execution, and result caching.
"""
-
- def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, pipeline_ref: Optional[Any] = None):
+
+ def __init__(
+ self,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.float16,
+ max_workers: int = 4,
+ pipeline_ref: Optional[Any] = None,
+ ):
# Preprocessing: 10ms timeout for fast frame-skipping behavior
super().__init__(device, dtype, max_workers, timeout_ms=10.0, pipeline_ref=pipeline_ref)
-
+
# Caching
self._preprocessed_cache: Dict[str, torch.Tensor] = {}
self._last_input_frame = None
-
+
# Optimized transforms
self._cached_transform = transforms.ToTensor()
-
+
# Cache pipelining decision to avoid hot path checks
self._preprocessors_cache_key = None
self._has_feedback_cache = False
-
-
-
-
- #Abstract method implementations
- def process_sync(self,
- control_image: ControlImage,
- preprocessors: List[Optional[Any]],
- scales: List[float] = None,
- stream_width: int = None,
- stream_height: int = None,
- index: Optional[int] = None,
- processing_type: str = "controlnet") -> Union[List[Optional[torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor]]]:
+
+ # Abstract method implementations
+ def process_sync(
+ self,
+ control_image: ControlImage,
+ preprocessors: List[Optional[Any]],
+ scales: List[float] = None,
+ stream_width: int = None,
+ stream_height: int = None,
+ index: Optional[int] = None,
+ processing_type: str = "controlnet",
+ ) -> Union[List[Optional[torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Process images synchronously for ControlNet or IPAdapter preprocessing.
-
+
Args:
control_image: Input image to process
preprocessors: List of preprocessor instances
@@ -58,7 +66,7 @@ def process_sync(self,
stream_height: Target height for processing
index: If specified, only process this single ControlNet index (ControlNet only)
processing_type: "controlnet" or "ipadapter" to specify processing mode
-
+
Returns:
ControlNet: List of processed tensors for each ControlNet
IPAdapter: List of (positive_embeds, negative_embeds) tuples
@@ -79,410 +87,390 @@ def process_sync(self,
)
else:
raise ValueError(f"Invalid processing_type: {processing_type}. Must be 'controlnet' or 'ipadapter'")
-
+
def _should_use_sync_processing(self, *args, **kwargs) -> bool:
"""
Check for pipeline-aware preprocessors that require sync processing.
-
- Pipeline-aware preprocessors (feedback, temporal, etc.) need synchronous processing
+
+ Pipeline-aware preprocessors (feedback, temporal, etc.) need synchronous processing
to avoid temporal artifacts and ensure access to previous pipeline outputs.
-
+
Args:
*args: Arguments from process_pipelined call (preprocessors, scales, stream_width, stream_height)
**kwargs: Keyword arguments
-
+
Returns:
True if pipeline-aware preprocessors detected, False otherwise
"""
# Extract preprocessors from args - they're the first argument after control_image
if len(args) < 1:
return False
-
+
preprocessors = args[0] # preprocessors is first arg after control_image
return self._check_pipeline_aware_cached(preprocessors)
- def _process_frame_background(self,
- control_image: ControlImage,
- *args, **kwargs) -> Dict[str, Any]:
+ def _process_frame_background(self, control_image: ControlImage, *args, **kwargs) -> Dict[str, Any]:
"""
Process a frame in the background thread.
-
+
Implementation of BaseOrchestrator._process_frame_background for ControlNet preprocessing.
Automatically detects processing mode based on current state.
-
+
Returns:
Dictionary containing processing results and status
"""
try:
# Set CUDA stream for background processing
original_stream = self._set_background_stream_context()
-
+
# Check if last argument is "ipadapter" processing type
if args and len(args) >= 5 and args[4] == "ipadapter":
# Handle embedding preprocessing
embedding_preprocessors = args[0]
- stream_width = args[2]
+ stream_width = args[2]
stream_height = args[3]
-
+
# Prepare processing data
control_variants = self._prepare_input_variants(control_image, thread_safe=True)
-
+
# Process using existing IPAdapter logic
try:
results = self._process_ipadapter_preprocessors_parallel(
embedding_preprocessors, control_variants, stream_width, stream_height
)
- return {
- 'results': results,
- 'status': 'success'
- }
+ return {"results": results, "status": "success"}
except Exception as e:
import traceback
+
traceback.print_exc()
- return {
- 'error': str(e),
- 'status': 'error'
- }
- elif hasattr(self, '_current_processing_mode') and self._current_processing_mode == "embedding":
+ return {"error": str(e), "status": "error"}
+ elif hasattr(self, "_current_processing_mode") and self._current_processing_mode == "embedding":
# Handle embedding preprocessing (legacy path)
embedding_preprocessors = args[0]
- stream_width = args[2]
+ stream_width = args[2]
stream_height = args[3]
-
+
# Prepare processing data
control_variants = self._prepare_input_variants(control_image, thread_safe=True)
-
+
# Process using existing IPAdapter logic
try:
results = self._process_ipadapter_preprocessors_parallel(
embedding_preprocessors, control_variants, stream_width, stream_height
)
- return {
- 'results': results,
- 'status': 'success'
- }
+ return {"results": results, "status": "success"}
except Exception as e:
import traceback
+
traceback.print_exc()
- return {
- 'error': str(e),
- 'status': 'error'
- }
+ return {"error": str(e), "status": "error"}
else:
# Handle ControlNet preprocessing (default mode)
preprocessors = args[0]
scales = args[1]
stream_width = args[2]
stream_height = args[3]
-
+
# Check if any processing is needed
if not any(scale > 0 for scale in scales):
- return {'status': 'success', 'results': [None] * len(preprocessors)}
- #TODO: can we reuse similarity filter here?
- if (self._last_input_frame is not None and
- isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) and
- control_image is self._last_input_frame):
- return {'status': 'success', 'results': []} # Signal no update needed
-
+ return {"status": "success", "results": [None] * len(preprocessors)}
+ # TODO: can we reuse similarity filter here?
+ if (
+ self._last_input_frame is not None
+ and isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image))
+ and control_image is self._last_input_frame
+ ):
+ return {"status": "success", "results": []} # Signal no update needed
+
self._last_input_frame = control_image
-
+
# Prepare processing data
preprocessor_groups = self._group_preprocessors(preprocessors, scales)
active_indices = [i for i, scale in enumerate(scales) if scale > 0]
-
+
if not active_indices:
- return {'status': 'success', 'results': [None] * len(preprocessors)}
-
+ return {"status": "success", "results": [None] * len(preprocessors)}
+
# Optimize input preparation
control_variants = self._prepare_input_variants(control_image, thread_safe=True)
-
+
# Process using unified parallel logic
processed_images = self._process_controlnet_preprocessors_parallel(
preprocessor_groups, control_variants, stream_width, stream_height, preprocessors
)
-
- return {
- 'results': processed_images,
- 'status': 'success'
- }
-
+
+ return {"results": processed_images, "status": "success"}
+
except Exception as e:
logger.error(f"PreprocessingOrchestrator: Background processing failed: {e}")
- return {
- 'error': str(e),
- 'status': 'error'
- }
+ return {"error": str(e), "status": "error"}
finally:
# Restore original CUDA stream
self._restore_stream_context(original_stream)
-
- def _apply_current_frame_processing(self,
- preprocessors: List[Optional[Any]] = None,
- scales: List[float] = None,
- *args, **kwargs) -> List[Optional[torch.Tensor]]:
+
+ def _apply_current_frame_processing(
+ self, preprocessors: List[Optional[Any]] = None, scales: List[float] = None, *args, **kwargs
+ ) -> List[Optional[torch.Tensor]]:
"""
Apply processing results from previous iteration.
-
+
Overrides BaseOrchestrator._apply_current_frame_processing for module preprocessing.
-
+
Returns:
List of processed tensors, or empty list to signal no update needed
"""
- if not hasattr(self, '_next_frame_result') or self._next_frame_result is None:
+ if not hasattr(self, "_next_frame_result") or self._next_frame_result is None:
# Return empty list to signal no update needed
return []
-
+
# Handle case where preprocessors is None
if preprocessors is None:
return []
-
+
processed_images = [None] * len(preprocessors)
-
+
result = self._next_frame_result
- if result['status'] != 'success':
+ if result["status"] != "success":
# Return empty list to signal no update needed on error
return []
-
+
# Handle case where no update is needed (cached input)
- if 'results' in result and len(result['results']) == 0:
+ if "results" in result and len(result["results"]) == 0:
return []
-
+
# Get the processed results directly
- processed_images = result.get('results', [])
+ processed_images = result.get("results", [])
if not processed_images:
return []
-
+
return processed_images
-
- #Controlnet methods
- def prepare_control_image(self,
- control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
- preprocessor: Optional[Any],
- target_width: int,
- target_height: int) -> torch.Tensor:
+
+ # Controlnet methods
+ def prepare_control_image(
+ self,
+ control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
+ preprocessor: Optional[Any],
+ target_width: int,
+ target_height: int,
+ ) -> torch.Tensor:
"""
Prepare a single control image for ControlNet input with format conversion and preprocessing.
-
+
Args:
control_image: Input image in various formats
preprocessor: Optional preprocessor to apply
target_width: Target width for the output tensor
target_height: Target height for the output tensor
-
+
Returns:
Processed tensor ready for ControlNet
"""
# Load image if path
if isinstance(control_image, str):
control_image = load_image(control_image)
-
+
# Fast tensor processing path
if isinstance(control_image, torch.Tensor):
return self._process_tensor_input(control_image, preprocessor, target_width, target_height)
-
+
# Apply preprocessor to non-tensor inputs
if preprocessor is not None:
control_image = preprocessor.process(control_image)
-
+
# Convert to tensor
return self._convert_to_tensor(control_image, target_width, target_height)
-
- def _process_multiple_controlnets_sync(self,
- control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
- preprocessors: List[Optional[Any]],
- scales: List[float],
- stream_width: int,
- stream_height: int) -> List[Optional[torch.Tensor]]:
+
+ def _process_multiple_controlnets_sync(
+ self,
+ control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
+ preprocessors: List[Optional[Any]],
+ scales: List[float],
+ stream_width: int,
+ stream_height: int,
+ ) -> List[Optional[torch.Tensor]]:
"""Process multiple ControlNets synchronously with parallel execution"""
# Check if any processing is needed
if not any(scale > 0 for scale in scales):
return [None] * len(preprocessors)
-
- #TODO: can we reuse similarity filter here?
+
+ # TODO: can we reuse similarity filter here?
# Check cache for same input - return early without changing anything
- if (self._last_input_frame is not None and
- isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) and
- control_image is self._last_input_frame):
+ if (
+ self._last_input_frame is not None
+ and isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image))
+ and control_image is self._last_input_frame
+ ):
# Return empty list to signal no update needed
return []
-
+
self._last_input_frame = control_image
self.clear_cache()
-
+
# Prepare input variants for optimal processing
control_variants = self._prepare_input_variants(control_image, stream_width, stream_height)
-
+
# Group preprocessors to avoid duplicate work
preprocessor_groups = self._group_preprocessors(preprocessors, scales)
-
+
if not preprocessor_groups:
return [None] * len(preprocessors)
-
+
# Process groups using parallel logic (efficient for 1 or many items)
return self._process_controlnet_preprocessors_parallel(
preprocessor_groups, control_variants, stream_width, stream_height, preprocessors
)
-
- def _process_single_controlnet(self,
- control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
- preprocessors: List[Optional[Any]],
- scales: List[float],
- stream_width: int,
- stream_height: int,
- index: int) -> List[Optional[torch.Tensor]]:
+
+ def _process_single_controlnet(
+ self,
+ control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
+ preprocessors: List[Optional[Any]],
+ scales: List[float],
+ stream_width: int,
+ stream_height: int,
+ index: int,
+ ) -> List[Optional[torch.Tensor]]:
"""Process a single ControlNet by index"""
if not (0 <= index < len(preprocessors)):
raise IndexError(f"ControlNet index {index} out of range")
-
+
if scales[index] == 0:
return [None] * len(preprocessors)
-
+
processed_images = [None] * len(preprocessors)
- processed_image = self.prepare_control_image(
- control_image, preprocessors[index], stream_width, stream_height
- )
+ processed_image = self.prepare_control_image(control_image, preprocessors[index], stream_width, stream_height)
processed_images[index] = processed_image
-
+
return processed_images
-
- def _process_controlnet_preprocessors_parallel(self,
- preprocessor_groups: Dict[str, Dict[str, Any]],
- control_variants: Dict[str, Any],
- stream_width: int,
- stream_height: int,
- preprocessors: List[Optional[Any]]) -> List[Optional[torch.Tensor]]:
+
+ def _process_controlnet_preprocessors_parallel(
+ self,
+ preprocessor_groups: Dict[str, Dict[str, Any]],
+ control_variants: Dict[str, Any],
+ stream_width: int,
+ stream_height: int,
+ preprocessors: List[Optional[Any]],
+ ) -> List[Optional[torch.Tensor]]:
"""Process ControlNet preprocessor groups in parallel"""
futures = [
self._executor.submit(
- self._process_single_preprocessor_group,
- prep_key, group, control_variants, stream_width, stream_height
+ self._process_single_preprocessor_group, prep_key, group, control_variants, stream_width, stream_height
)
for prep_key, group in preprocessor_groups.items()
]
-
+
processed_images = [None] * len(preprocessors)
-
+
for future in futures:
result = future.result()
- if result and result['processed_image'] is not None:
- prep_key = result['prep_key']
- processed_image = result['processed_image']
- indices = result['indices']
-
+ if result and result["processed_image"] is not None:
+ prep_key = result["prep_key"]
+ processed_image = result["processed_image"]
+ indices = result["indices"]
+
# Cache and assign
cache_key = f"prep_{prep_key}"
self._preprocessed_cache[cache_key] = processed_image
for index in indices:
processed_images[index] = processed_image
-
+
return processed_images
-
- #IPAdapter methods
- def _process_multiple_ipadapters_sync(self,
- control_image: ControlImage,
- preprocessors: List[Optional[Any]],
- stream_width: int,
- stream_height: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
+
+ # IPAdapter methods
+ def _process_multiple_ipadapters_sync(
+ self, control_image: ControlImage, preprocessors: List[Optional[Any]], stream_width: int, stream_height: int
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Process IPAdapter preprocessors synchronously.
-
+
This is the implementation that was previously in process_ipadapter_preprocessors().
"""
if not preprocessors:
return []
-
+
# For IPAdapter preprocessing, we don't skip on cache hits - we need the actual embeddings
# (Unlike spatial preprocessing where empty list means "no update needed")
-
+
# Prepare input variants for processing
control_variants = self._prepare_input_variants(control_image, stream_width, stream_height)
-
+
# Process using parallel logic (efficient for 1 or many items)
results = self._process_ipadapter_preprocessors_parallel(
preprocessors, control_variants, stream_width, stream_height
)
-
+
return results
-
- def _process_ipadapter_preprocessors_parallel(self,
- ipadapter_preprocessors: List[Any],
- control_variants: Dict[str, Any],
- stream_width: int,
- stream_height: int) -> List[Tuple[torch.Tensor, torch.Tensor]]:
+
+ def _process_ipadapter_preprocessors_parallel(
+ self,
+ ipadapter_preprocessors: List[Any],
+ control_variants: Dict[str, Any],
+ stream_width: int,
+ stream_height: int,
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Process multiple IPAdapter preprocessors in parallel"""
futures = [
self._executor.submit(
- self._process_single_ipadapter,
- i, preprocessor, control_variants, stream_width, stream_height
+ self._process_single_ipadapter, i, preprocessor, control_variants, stream_width, stream_height
)
for i, preprocessor in enumerate(ipadapter_preprocessors)
]
-
+
results = [None] * len(ipadapter_preprocessors)
-
+
for future in futures:
result = future.result()
- if result and result['embeddings'] is not None:
- index = result['index']
- embeddings = result['embeddings']
+ if result and result["embeddings"] is not None:
+ index = result["index"]
+ embeddings = result["embeddings"]
results[index] = embeddings
-
+
return results
-
- def _process_single_ipadapter(self,
- index: int,
- preprocessor: Any,
- control_variants: Dict[str, Any],
- stream_width: int,
- stream_height: int) -> Optional[Dict[str, Any]]:
+
+ def _process_single_ipadapter(
+ self, index: int, preprocessor: Any, control_variants: Dict[str, Any], stream_width: int, stream_height: int
+ ) -> Optional[Dict[str, Any]]:
"""Process a single IPAdapter preprocessor"""
try:
# Use tensor processing if available and input is tensor
- if (hasattr(preprocessor, 'process_tensor') and
- control_variants['tensor'] is not None):
- embeddings = preprocessor.process_tensor(control_variants['tensor'])
- return {
- 'index': index,
- 'embeddings': embeddings
- }
-
+ if hasattr(preprocessor, "process_tensor") and control_variants["tensor"] is not None:
+ embeddings = preprocessor.process_tensor(control_variants["tensor"])
+ return {"index": index, "embeddings": embeddings}
+
# Use PIL processing for non-tensor inputs
- if control_variants['image'] is not None:
- embeddings = preprocessor.process(control_variants['image'])
- return {
- 'index': index,
- 'embeddings': embeddings
- }
-
+ if control_variants["image"] is not None:
+ embeddings = preprocessor.process(control_variants["image"])
+ return {"index": index, "embeddings": embeddings}
+
return None
-
- except Exception as e:
+
+ except Exception:
import traceback
+
traceback.print_exc()
return None
- #Helper methods
+ # Helper methods
def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bool:
"""
Efficiently check for pipeline-aware preprocessors using caching
-
+
Only performs expensive isinstance checks when preprocessor list actually changes.
"""
# Create cache key from preprocessor identities
cache_key = tuple(id(p) for p in preprocessors)
-
+
# Return cached result if preprocessors haven't changed
if cache_key == self._preprocessors_cache_key:
return self._has_feedback_cache # Reuse cache variable for backward compatibility
-
+
# Preprocessors changed - recompute and cache
self._preprocessors_cache_key = cache_key
self._has_feedback_cache = False
-
+
try:
# Check for the mixin or class attribute first
for prep in preprocessors:
- if prep is not None and getattr(prep, 'requires_sync_processing', False):
+ if prep is not None and getattr(prep, "requires_sync_processing", False):
self._has_feedback_cache = True
break
except Exception:
@@ -490,6 +478,7 @@ def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bo
try:
from .processors.feedback import FeedbackPreprocessor
from .processors.temporal_net import TemporalNetPreprocessor
+
for prep in preprocessors:
if isinstance(prep, (FeedbackPreprocessor, TemporalNetPreprocessor)):
self._has_feedback_cache = True
@@ -499,44 +488,43 @@ def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bo
for prep in preprocessors:
if prep is not None:
class_name = prep.__class__.__name__.lower()
- if any(name in class_name for name in ['feedback', 'temporal']):
+ if any(name in class_name for name in ["feedback", "temporal"]):
self._has_feedback_cache = True
break
-
+
return self._has_feedback_cache
def clear_cache(self) -> None:
"""Clear preprocessing cache"""
self._preprocessed_cache.clear()
self._last_input_frame = None
-
+
# =========================================================================
# Pipeline Chain Processing Methods (For Hook System Compatibility)
# =========================================================================
-
- def execute_pipeline_chain(self,
- input_data: torch.Tensor,
- processors: List[Any],
- processing_domain: str = "image") -> torch.Tensor:
+
+ def execute_pipeline_chain(
+ self, input_data: torch.Tensor, processors: List[Any], processing_domain: str = "image"
+ ) -> torch.Tensor:
"""Execute ordered sequential chain of processors for pipeline hooks.
-
+
This method provides compatibility with the hook system modules that
expect sequential processor execution rather than pipelined processing.
-
+
Args:
input_data: Input tensor (image or latent domain)
processors: List of processor instances to execute in sequence
processing_domain: "image" or "latent" to determine processing path
-
+
Returns:
Processed tensor in same domain as input
"""
if not processors:
return input_data
-
+
result = input_data
ordered_processors = self._order_processors(processors)
-
+
for processor in ordered_processors:
try:
if processing_domain == "image":
@@ -549,58 +537,58 @@ def execute_pipeline_chain(self,
logger.error(f"execute_pipeline_chain: Processor {type(processor).__name__} failed: {e}")
# Continue with next processor rather than failing entire chain
continue
-
+
return result
-
+
def _order_processors(self, processors: List[Any]) -> List[Any]:
"""Order processors based on their configuration.
-
+
Processors can define an 'order' attribute to control execution sequence.
"""
- return sorted(processors, key=lambda p: getattr(p, 'order', 0))
-
+ return sorted(processors, key=lambda p: getattr(p, "order", 0))
+
def _process_image_processor_chain(self, image_tensor: torch.Tensor, processor: Any) -> torch.Tensor:
"""Process single image processor in chain, handling tensor<->PIL conversion.
-
+
Leverages existing format conversion and processing logic.
"""
# Convert tensor to PIL for processor (reuse existing conversion logic)
try:
# Use existing tensor to PIL conversion from prepare_control_image logic
pil_image = self._tensor_to_pil_safe(image_tensor)
-
+
# Process using existing processor execution pattern
- if hasattr(processor, 'process'):
+ if hasattr(processor, "process"):
processed_pil = processor.process(pil_image)
else:
processed_pil = processor(pil_image)
-
+
# Convert back to tensor (reuse existing PIL to tensor logic)
result_tensor = self._pil_to_tensor_safe(processed_pil, image_tensor.device, image_tensor.dtype)
return result_tensor
-
+
except Exception as e:
logger.error(f"_process_image_processor_chain: Failed processing {type(processor).__name__}: {e}")
return image_tensor # Return input unchanged on failure
-
+
def _process_latent_processor_chain(self, latent_tensor: torch.Tensor, processor: Any) -> torch.Tensor:
"""Process single latent processor in chain.
-
+
Direct tensor processing - no format conversion needed for latent domain.
"""
try:
# Latent processors work directly on tensors
- if hasattr(processor, 'process_tensor'):
+ if hasattr(processor, "process_tensor"):
return processor.process_tensor(latent_tensor)
- elif hasattr(processor, 'process'):
+ elif hasattr(processor, "process"):
return processor.process(latent_tensor)
else:
return processor(latent_tensor)
-
+
except Exception as e:
logger.error(f"_process_latent_processor_chain: Failed processing {type(processor).__name__}: {e}")
return latent_tensor # Return input unchanged on failure
-
+
def _tensor_to_pil_safe(self, tensor: torch.Tensor) -> Image.Image:
"""Convert tensor to PIL Image safely (reuse existing conversion logic)."""
# Leverage existing tensor conversion from prepare_control_image
@@ -609,50 +597,48 @@ def _tensor_to_pil_safe(self, tensor: torch.Tensor) -> Image.Image:
if tensor.dim() == 3 and tensor.shape[0] == 3:
# Convert from CHW to HWC
tensor = tensor.permute(1, 2, 0)
-
+
# CRITICAL FIX: Handle VAE output range [-1, 1] -> [0, 1] -> [0, 255]
# VAE decode_image() outputs in [-1, 1] range, need to convert to [0, 1] first
if tensor.min() < 0:
- logger.debug(f"_tensor_to_pil_safe: Converting from VAE range [-1, 1] to [0, 1]")
+ logger.debug("_tensor_to_pil_safe: Converting from VAE range [-1, 1] to [0, 1]")
tensor = (tensor / 2.0 + 0.5).clamp(0, 1) # Convert [-1, 1] -> [0, 1]
-
+
# Ensure proper range [0, 1] -> [0, 255]
if tensor.max() <= 1.0:
tensor = tensor * 255.0
-
+
# Convert to numpy and then PIL
numpy_image = tensor.detach().cpu().numpy().astype(np.uint8)
return Image.fromarray(numpy_image)
-
+
def _pil_to_tensor_safe(self, pil_image: Image.Image, device: str, dtype: torch.dtype) -> torch.Tensor:
"""Convert PIL Image to tensor safely (reuse existing conversion logic)."""
# Convert PIL to numpy
numpy_image = np.array(pil_image)
-
+
# Convert to tensor and normalize to [0, 1]
tensor = torch.from_numpy(numpy_image).float() / 255.0
-
+
# Convert HWC to CHW
if tensor.dim() == 3:
tensor = tensor.permute(2, 0, 1)
-
+
# Add batch dimension and move to device
tensor = tensor.unsqueeze(0).to(device=device, dtype=dtype)
-
+
# CRITICAL: Convert back to VAE input range [-1, 1] for postprocessing
# VAE expects inputs in [-1, 1] range, so convert [0, 1] -> [-1, 1]
tensor = (tensor - 0.5) * 2.0 # Convert [0, 1] -> [-1, 1]
-
+
return tensor
-
- def _process_tensor_input(self,
- control_tensor: torch.Tensor,
- preprocessor: Optional[Any],
- target_width: int,
- target_height: int) -> torch.Tensor:
+
+ def _process_tensor_input(
+ self, control_tensor: torch.Tensor, preprocessor: Optional[Any], target_width: int, target_height: int
+ ) -> torch.Tensor:
"""Process tensor input with GPU acceleration when possible"""
# Fast path for tensor input with GPU preprocessor
- if preprocessor is not None and hasattr(preprocessor, 'process_tensor'):
+ if preprocessor is not None and hasattr(preprocessor, "process_tensor"):
try:
processed_tensor = preprocessor.process_tensor(control_tensor)
# Ensure NCHW shape
@@ -661,155 +647,139 @@ def _process_tensor_input(self,
return processed_tensor.to(device=self.device, dtype=self.dtype)
except Exception:
pass # Fall through to standard processing
-
+
# Direct tensor passthrough (no preprocessor) - preprocessors handle their own sizing
if preprocessor is None:
# For passthrough, we still need basic format handling
if control_tensor.dim() == 3:
control_tensor = control_tensor.unsqueeze(0)
return control_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Convert to PIL for preprocessor, then back to tensor
if control_tensor.dim() == 4:
control_tensor = control_tensor[0]
if control_tensor.dim() == 3 and control_tensor.shape[0] in [1, 3]:
control_tensor = control_tensor.permute(1, 2, 0)
-
+
if control_tensor.is_cuda:
control_tensor = control_tensor.cpu()
-
+
control_array = control_tensor.numpy()
if control_array.max() <= 1.0:
control_array = (control_array * 255).astype(np.uint8)
-
+
control_image = Image.fromarray(control_array.astype(np.uint8))
return self.prepare_control_image(control_image, preprocessor, target_width, target_height)
-
- def _convert_to_tensor(self,
- control_image: Union[Image.Image, np.ndarray],
- target_width: int,
- target_height: int) -> torch.Tensor:
+
+ def _convert_to_tensor(
+ self, control_image: Union[Image.Image, np.ndarray], target_width: int, target_height: int
+ ) -> torch.Tensor:
"""Convert PIL Image or numpy array to tensor - preprocessors handle their own sizing"""
# Handle PIL Images - no resizing here, preprocessors handle their target size
if isinstance(control_image, Image.Image):
control_tensor = self._cached_transform(control_image).unsqueeze(0)
return control_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Handle numpy arrays
if isinstance(control_image, np.ndarray):
if control_image.max() <= 1.0:
control_image = (control_image * 255).astype(np.uint8)
control_image = Image.fromarray(control_image)
return self._convert_to_tensor(control_image, target_width, target_height)
-
+
raise ValueError(f"Unsupported control image type: {type(control_image)}")
-
+
def _to_tensor_safe(self, image: Image.Image) -> torch.Tensor:
"""Thread-safe tensor conversion from PIL Image"""
return self._cached_transform(image).unsqueeze(0).to(device=self.device, dtype=self.dtype)
-
- def _prepare_input_variants(self,
- control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
- stream_width: int = None,
- stream_height: int = None,
- thread_safe: bool = False) -> Dict[str, Any]:
+
+ def _prepare_input_variants(
+ self,
+ control_image: Union[str, Image.Image, np.ndarray, torch.Tensor],
+ stream_width: int = None,
+ stream_height: int = None,
+ thread_safe: bool = False,
+ ) -> Dict[str, Any]:
"""Prepare optimized input variants for different processing paths
-
+
Args:
control_image: Input image in various formats
stream_width: Target width (unused, kept for backward compatibility)
stream_height: Target height (unused, kept for backward compatibility)
thread_safe: If True, use thread-safe key naming for background processing
-
+
Returns:
Dictionary with 'tensor' and 'image'/'image_safe' keys
"""
- image_key = 'image_safe' if thread_safe else 'image'
-
+ image_key = "image_safe" if thread_safe else "image"
+
if isinstance(control_image, torch.Tensor):
return {
- 'tensor': control_image,
- image_key: None # Will create if needed
+ "tensor": control_image,
+ image_key: None, # Will create if needed
}
elif isinstance(control_image, Image.Image):
image_copy = control_image.copy()
- return {
- image_key: image_copy,
- 'tensor': self._to_tensor_safe(image_copy)
- }
+ return {image_key: image_copy, "tensor": self._to_tensor_safe(image_copy)}
elif isinstance(control_image, str):
image_loaded = load_image(control_image)
- return {
- image_key: image_loaded,
- 'tensor': self._to_tensor_safe(image_loaded)
- }
+ return {image_key: image_loaded, "tensor": self._to_tensor_safe(image_loaded)}
else:
- return {
- image_key: control_image,
- 'tensor': None
- }
-
- def _group_preprocessors(self,
- preprocessors: List[Optional[Any]],
- scales: List[float]) -> Dict[str, Dict[str, Any]]:
+ return {image_key: control_image, "tensor": None}
+
+ def _group_preprocessors(
+ self, preprocessors: List[Optional[Any]], scales: List[float]
+ ) -> Dict[str, Dict[str, Any]]:
"""Group preprocessors by type to avoid duplicate processing"""
preprocessor_groups = {}
-
+
for i, scale in enumerate(scales):
if scale > 0:
preprocessor = preprocessors[i]
- preprocessor_key = id(preprocessor) if preprocessor is not None else 'passthrough'
-
+ preprocessor_key = id(preprocessor) if preprocessor is not None else "passthrough"
+
if preprocessor_key not in preprocessor_groups:
- preprocessor_groups[preprocessor_key] = {
- 'preprocessor': preprocessor,
- 'indices': []
- }
- preprocessor_groups[preprocessor_key]['indices'].append(i)
-
+ preprocessor_groups[preprocessor_key] = {"preprocessor": preprocessor, "indices": []}
+ preprocessor_groups[preprocessor_key]["indices"].append(i)
+
return preprocessor_groups
- def _process_single_preprocessor_group(self,
- prep_key: str,
- group: Dict[str, Any],
- control_variants: Dict[str, Any],
- stream_width: int,
- stream_height: int) -> Optional[Dict[str, Any]]:
+ def _process_single_preprocessor_group(
+ self,
+ prep_key: str,
+ group: Dict[str, Any],
+ control_variants: Dict[str, Any],
+ stream_width: int,
+ stream_height: int,
+ ) -> Optional[Dict[str, Any]]:
"""Process a single preprocessor group with optimal input selection"""
try:
- preprocessor = group['preprocessor']
- indices = group['indices']
-
+ preprocessor = group["preprocessor"]
+ indices = group["indices"]
+
# Try tensor processing first (fastest path)
- if (preprocessor is not None and
- hasattr(preprocessor, 'process_tensor') and
- control_variants['tensor'] is not None):
+ if (
+ preprocessor is not None
+ and hasattr(preprocessor, "process_tensor")
+ and control_variants["tensor"] is not None
+ ):
try:
processed_image = self.prepare_control_image(
- control_variants['tensor'], preprocessor, stream_width, stream_height
+ control_variants["tensor"], preprocessor, stream_width, stream_height
)
- return {
- 'prep_key': prep_key,
- 'indices': indices,
- 'processed_image': processed_image
- }
+ return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image}
except Exception:
pass # Fall through to PIL processing
-
+
# PIL processing fallback
- if control_variants['image'] is not None:
+ if control_variants["image"] is not None:
processed_image = self.prepare_control_image(
- control_variants['image'], preprocessor, stream_width, stream_height
+ control_variants["image"], preprocessor, stream_width, stream_height
)
- return {
- 'prep_key': prep_key,
- 'indices': indices,
- 'processed_image': processed_image
- }
-
+ return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image}
+
return None
-
+
except Exception as e:
logger.error(f"PreprocessingOrchestrator: Preprocessor {prep_key} failed: {e}")
return None
-
diff --git a/src/streamdiffusion/preprocessing/processors/__init__.py b/src/streamdiffusion/preprocessing/processors/__init__.py
index 5674ab4af..26f00e732 100644
--- a/src/streamdiffusion/preprocessing/processors/__init__.py
+++ b/src/streamdiffusion/preprocessing/processors/__init__.py
@@ -1,26 +1,29 @@
-from .base import BasePreprocessor, PipelineAwareProcessor
from typing import Any
+
+from .base import BasePreprocessor, PipelineAwareProcessor
+from .blur import BlurPreprocessor
from .canny import CannyPreprocessor
from .depth import DepthPreprocessor
-from .openpose import OpenPosePreprocessor
-from .lineart import LineartPreprocessor
-from .standard_lineart import StandardLineartPreprocessor
-from .passthrough import PassthroughPreprocessor
from .external import ExternalPreprocessor
-from .soft_edge import SoftEdgePreprocessor
-from .hed import HEDPreprocessor
-from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor
from .faceid_embedding import FaceIDEmbeddingPreprocessor
from .feedback import FeedbackPreprocessor
+from .hed import HEDPreprocessor
+from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor
from .latent_feedback import LatentFeedbackPreprocessor
+from .lineart import LineartPreprocessor
+from .openpose import OpenPosePreprocessor
+from .passthrough import PassthroughPreprocessor
+from .realesrgan_trt import RealESRGANProcessor
from .sharpen import SharpenPreprocessor
+from .soft_edge import SoftEdgePreprocessor
+from .standard_lineart import StandardLineartPreprocessor
from .upscale import UpscalePreprocessor
-from .blur import BlurPreprocessor
-from .realesrgan_trt import RealESRGANProcessor
+
# Try to import TensorRT preprocessors - might not be available on all systems
try:
from .depth_tensorrt import DepthAnythingTensorrtPreprocessor
+
DEPTH_TENSORRT_AVAILABLE = True
except ImportError:
DepthAnythingTensorrtPreprocessor = None
@@ -28,6 +31,7 @@
try:
from .pose_tensorrt import YoloNasPoseTensorrtPreprocessor
+
POSE_TENSORRT_AVAILABLE = True
except ImportError:
YoloNasPoseTensorrtPreprocessor = None
@@ -35,6 +39,7 @@
try:
from .temporal_net_tensorrt import TemporalNetTensorRTPreprocessor
+
TEMPORAL_NET_TENSORRT_AVAILABLE = True
except ImportError:
TemporalNetTensorRTPreprocessor = None
@@ -42,6 +47,7 @@
try:
from .mediapipe_pose import MediaPipePosePreprocessor
+
MEDIAPIPE_POSE_AVAILABLE = True
except ImportError:
MediaPipePosePreprocessor = None
@@ -49,6 +55,7 @@
try:
from .mediapipe_segmentation import MediaPipeSegmentationPreprocessor
+
MEDIAPIPE_SEGMENTATION_AVAILABLE = True
except ImportError:
MediaPipeSegmentationPreprocessor = None
@@ -71,7 +78,7 @@
"upscale": UpscalePreprocessor,
"blur": BlurPreprocessor,
"realesrgan_trt": RealESRGANProcessor,
-}
+}
# Add TensorRT preprocessors if available
if DEPTH_TENSORRT_AVAILABLE:
@@ -94,27 +101,29 @@
def get_preprocessor_class(name: str) -> type:
"""
Get a preprocessor class by name
-
+
Args:
name: Name of the preprocessor
-
+
Returns:
Preprocessor class
-
+
Raises:
ValueError: If preprocessor name is not found
"""
if name not in _preprocessor_registry:
available = ", ".join(_preprocessor_registry.keys())
raise ValueError(f"Unknown preprocessor '{name}'. Available: {available}")
-
+
return _preprocessor_registry[name]
-def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: str = 'controlnet', params: Any = None) -> BasePreprocessor:
+def get_preprocessor(
+ name: str, pipeline_ref: Any = None, normalization_context: str = "controlnet", params: Any = None
+) -> BasePreprocessor:
"""
Get a preprocessor by name
-
+
Args:
name: Name of the preprocessor
pipeline_ref: Pipeline reference for pipeline-aware processors (required for some processors)
@@ -122,20 +131,25 @@ def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context:
- 'controlnet': Expects/produces [0,1] range for ControlNet conditioning
- 'pipeline': Expects/produces [-1,1] range for pipeline image processing
- 'latent': Works in latent space (no normalization needed)
-
+
Returns:
Preprocessor instance
-
+
Raises:
ValueError: If preprocessor name is not found or pipeline_ref missing for pipeline-aware processor
"""
processor_class = get_preprocessor_class(name)
-
+
# Check if this is a pipeline-aware processor
- if hasattr(processor_class, 'requires_sync_processing') and processor_class.requires_sync_processing:
+ if hasattr(processor_class, "requires_sync_processing") and processor_class.requires_sync_processing:
if pipeline_ref is None:
raise ValueError(f"Processor '{name}' requires a pipeline_ref")
- return processor_class(pipeline_ref=pipeline_ref, normalization_context=normalization_context, _registry_name=name, **(params or {}))
+ return processor_class(
+ pipeline_ref=pipeline_ref,
+ normalization_context=normalization_context,
+ _registry_name=name,
+ **(params or {}),
+ )
else:
return processor_class(normalization_context=normalization_context, _registry_name=name, **(params or {}))
@@ -143,7 +157,7 @@ def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context:
def register_preprocessor(name: str, preprocessor_class):
"""
Register a new preprocessor
-
+
Args:
name: Name to register under
preprocessor_class: Preprocessor class
@@ -160,7 +174,7 @@ def list_preprocessors():
"BasePreprocessor",
"PipelineAwareProcessor",
"CannyPreprocessor",
- "DepthPreprocessor",
+ "DepthPreprocessor",
"OpenPosePreprocessor",
"LineartPreprocessor",
"StandardLineartPreprocessor",
@@ -195,14 +209,16 @@ def list_preprocessors():
# region Custom Processor Discovery
-import logging
-import os
import importlib.util
import inspect
+import logging
+import os
from pathlib import Path
+
_logger = logging.getLogger(__name__)
+
def _discover_custom_processors():
"""Auto-discover custom processors from repo_root/custom_processors/ folder."""
if os.getenv("STREAMDIFFUSION_DISABLE_CUSTOM_PROCESSORS") == "1":
@@ -216,7 +232,7 @@ def _discover_custom_processors():
return
_logger.info("Scanning custom_processors/ for custom processors...")
for item in custom_dir.iterdir():
- if not item.is_dir() or item.name.startswith(('.', '_')):
+ if not item.is_dir() or item.name.startswith((".", "_")):
continue
manifest_file = item / "processors.yaml"
if manifest_file.exists():
@@ -226,20 +242,22 @@ def _discover_custom_processors():
except Exception as e:
_logger.error(f"Custom processor discovery failed: {e}")
+
def _load_processor_collection(collection_dir, manifest_file):
"""Load processors from a collection with processors.yaml manifest."""
import yaml
+
try:
- with open(manifest_file, 'r') as f:
+ with open(manifest_file, "r") as f:
manifest = yaml.safe_load(f)
- processor_files = manifest.get('processors', [])
+ processor_files = manifest.get("processors", [])
if not processor_files:
_logger.warning(f"Collection '{collection_dir.name}' has empty processors list")
return
_logger.info(f"Loading collection '{collection_dir.name}' ({len(processor_files)} processors)")
for proc_file in processor_files:
if isinstance(proc_file, dict):
- filename, enabled = proc_file.get('file'), proc_file.get('enabled', True)
+ filename, enabled = proc_file.get("file"), proc_file.get("enabled", True)
if not enabled:
continue
else:
@@ -252,24 +270,28 @@ def _load_processor_collection(collection_dir, manifest_file):
except Exception as e:
_logger.error(f"Failed to load collection {collection_dir.name}: {e}")
+
def _load_processor_folder_auto(folder):
"""Auto-discover processors by scanning for .py files (no manifest)."""
_logger.info(f"Auto-scanning folder: {folder.name}")
for py_file in folder.glob("*.py"):
- if py_file.name.startswith('_') or py_file.name in ['base.py', 'setup.py']:
+ if py_file.name.startswith("_") or py_file.name in ["base.py", "setup.py"]:
continue
_load_processor_from_file(py_file, py_file.stem)
+
def _load_processor_from_file(file_path, proc_name):
"""Load and register a processor class from a Python file."""
try:
spec = importlib.util.spec_from_file_location(
- f"custom_processors.{file_path.parent.name}.{file_path.stem}", file_path)
+ f"custom_processors.{file_path.parent.name}.{file_path.stem}", file_path
+ )
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
found_classes = [
- (name, obj) for name, obj in inspect.getmembers(module, inspect.isclass)
+ (name, obj)
+ for name, obj in inspect.getmembers(module, inspect.isclass)
if issubclass(obj, (BasePreprocessor, PipelineAwareProcessor))
and obj not in [BasePreprocessor, PipelineAwareProcessor]
]
@@ -284,5 +306,6 @@ def _load_processor_from_file(file_path, proc_name):
except Exception as e:
_logger.error(f" Failed to load {file_path.name}: {e}")
+
_discover_custom_processors()
-# endregion
\ No newline at end of file
+# endregion
diff --git a/src/streamdiffusion/preprocessing/processors/base.py b/src/streamdiffusion/preprocessing/processors/base.py
index 218a459f5..155dfddde 100644
--- a/src/streamdiffusion/preprocessing/processors/base.py
+++ b/src/streamdiffusion/preprocessing/processors/base.py
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
-from typing import Union, Dict, Any, Tuple, Optional
+from typing import Any, Dict, Tuple, Union
+
+import numpy as np
import torch
import torch.nn.functional as F
-import numpy as np
from PIL import Image
@@ -10,12 +11,11 @@ class BasePreprocessor(ABC):
"""
Base class for ControlNet preprocessors with template method pattern
"""
-
-
- def __init__(self, normalization_context: str = 'controlnet', **kwargs):
+
+ def __init__(self, normalization_context: str = "controlnet", **kwargs):
"""
Initialize the preprocessor
-
+
Args:
normalization_context: Context for normalization handling.
- 'controlnet': Expects/produces [0,1] range for ControlNet conditioning
@@ -25,15 +25,15 @@ def __init__(self, normalization_context: str = 'controlnet', **kwargs):
"""
self.params = kwargs
self.normalization_context = normalization_context
- self.device = kwargs.get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
- self.dtype = kwargs.get('dtype', torch.float16)
-
+ self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
+ self.dtype = kwargs.get("dtype", torch.float16)
+
@classmethod
def get_preprocessor_metadata(cls) -> Dict[str, Any]:
"""
Get comprehensive metadata for this preprocessor.
Subclasses should override this to define their specific metadata.
-
+
Returns:
Dictionary containing:
- display_name: Human-readable name
@@ -45,9 +45,9 @@ def get_preprocessor_metadata(cls) -> Dict[str, Any]:
"display_name": cls.__name__.replace("Preprocessor", ""),
"description": f"Preprocessor for {cls.__name__.replace('Preprocessor', '').lower()}",
"parameters": {},
- "use_cases": []
+ "use_cases": [],
}
-
+
def process(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image.Image:
"""
Template method - handles all common operations
@@ -55,7 +55,7 @@ def process(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image.
image = self.validate_input(image)
processed = self._process_core(image)
return self._ensure_target_size(processed)
-
+
def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Template method for GPU tensor processing
@@ -63,14 +63,14 @@ def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor:
tensor = self.validate_tensor_input(image_tensor)
processed = self._process_tensor_core(tensor)
return self._ensure_target_size_tensor(processed)
-
+
@abstractmethod
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Subclasses implement ONLY their specific algorithm
"""
pass
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Optional GPU processing (fallback to PIL if not overridden)
@@ -78,7 +78,7 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
pil_image = self.tensor_to_pil(tensor)
processed_pil = self._process_core(pil_image)
return self.pil_to_tensor(processed_pil)
-
+
def _ensure_target_size(self, image: Image.Image) -> Image.Image:
"""
Centralized PIL resize logic
@@ -87,7 +87,7 @@ def _ensure_target_size(self, image: Image.Image) -> Image.Image:
if image.size != (target_width, target_height):
return image.resize((target_width, target_height), Image.LANCZOS)
return image
-
+
def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Centralized tensor resize logic
@@ -95,54 +95,54 @@ def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
target_width, target_height = self.get_target_dimensions()
current_size = tensor.shape[-2:]
target_size = (target_height, target_width)
-
+
if current_size != target_size:
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
- tensor = F.interpolate(tensor, size=target_size, mode='bilinear', align_corners=False)
+ tensor = F.interpolate(tensor, size=target_size, mode="bilinear", align_corners=False)
if tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
return tensor
-
+
def validate_tensor_input(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Validate and normalize tensor input for processing
-
+
Args:
image_tensor: Input tensor
-
+
Returns:
Tensor in CHW format, on correct device
Range: [0,1] if input was [0,255], otherwise preserves input range
-
+
Note: This preserves [-1,1] tensors (from pipeline) since max() <= 1.0
"""
# Handle batch dimension
if image_tensor.dim() == 4:
image_tensor = image_tensor[0] # Take first image from batch
-
+
# Convert to CHW format if needed
if image_tensor.dim() == 3 and image_tensor.shape[0] not in [1, 3]:
# Likely HWC format, convert to CHW
image_tensor = image_tensor.permute(2, 0, 1)
-
+
# Ensure correct device and dtype
image_tensor = image_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Normalize to [0,1] range only if tensor is in [0,255] uint8 range
# Preserves [-1,1] and [0,1] ranges (max <= 1.0)
if image_tensor.max() > 1.0:
image_tensor = image_tensor / 255.0
-
+
return image_tensor
-
+
def tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
"""
Convert tensor to PIL Image (minimize CPU transfers)
-
+
Args:
tensor: Input tensor
-
+
Returns:
PIL Image
"""
@@ -151,39 +151,39 @@ def tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image:
tensor = tensor[0]
if tensor.dim() == 3 and tensor.shape[0] in [1, 3]:
tensor = tensor.permute(1, 2, 0)
-
+
# Convert to numpy (unavoidable for PIL)
if tensor.is_cuda:
tensor = tensor.cpu()
-
+
# Convert to uint8
if tensor.max() <= 1.0:
tensor = (tensor * 255).clamp(0, 255).to(torch.uint8)
else:
tensor = tensor.clamp(0, 255).to(torch.uint8)
-
+
array = tensor.numpy()
-
+
if array.shape[-1] == 3:
- return Image.fromarray(array, 'RGB')
+ return Image.fromarray(array, "RGB")
elif array.shape[-1] == 1:
- return Image.fromarray(array.squeeze(-1), 'L')
+ return Image.fromarray(array.squeeze(-1), "L")
else:
return Image.fromarray(array)
-
+
def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
"""
Convert PIL Image to tensor on GPU
-
+
Args:
image: PIL Image
-
+
Returns:
Tensor on correct device
"""
# Convert to numpy first
array = np.array(image)
-
+
# Convert to tensor
if len(array.shape) == 2: # Grayscale
tensor = torch.from_numpy(array).float() / 255.0
@@ -191,25 +191,25 @@ def pil_to_tensor(self, image: Image.Image) -> torch.Tensor:
else: # RGB
tensor = torch.from_numpy(array).float() / 255.0
tensor = tensor.permute(2, 0, 1) # HWC to CHW
-
+
# Move to device
tensor = tensor.to(device=self.device, dtype=self.dtype)
return tensor.unsqueeze(0) # Add batch dimension
-
+
def validate_input(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image.Image:
"""
Convert input to PIL Image for processing
-
+
Args:
image: Input image in various formats
-
+
Returns:
PIL Image
"""
if isinstance(image, torch.Tensor):
# Use tensor_to_pil method for better handling
return self.tensor_to_pil(image)
-
+
if isinstance(image, np.ndarray):
# Ensure uint8 format
if image.dtype != np.uint8:
@@ -217,83 +217,83 @@ def validate_input(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) ->
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
-
+
# Convert to PIL Image
if len(image.shape) == 3:
- image = Image.fromarray(image, 'RGB')
+ image = Image.fromarray(image, "RGB")
else:
- image = Image.fromarray(image, 'L')
-
+ image = Image.fromarray(image, "L")
+
if not isinstance(image, Image.Image):
raise ValueError(f"Unsupported image type: {type(image)}")
-
+
return image
-
+
def get_target_dimensions(self) -> Tuple[int, int]:
"""
Get target output dimensions (width, height)
"""
# Check for explicit width/height parameters first
- width = self.params.get('image_width')
- height = self.params.get('image_height')
-
+ width = self.params.get("image_width")
+ height = self.params.get("image_height")
+
if width is not None and height is not None:
return (width, height)
-
+
# Fallback to square resolution for backwards compatibility
- resolution = self.params.get('image_resolution', 512)
+ resolution = self.params.get("image_resolution", 512)
return (resolution, resolution)
-
+
def __call__(self, image: Union[Image.Image, np.ndarray, torch.Tensor], **kwargs) -> Image.Image:
"""
Process an image (convenience method)
-
+
Args:
image: Input image
**kwargs: Additional parameters to override defaults
-
+
Returns:
Processed PIL Image
"""
# Update parameters for this call
params = {**self.params, **kwargs}
-
+
# Store original params and update
original_params = self.params
self.params = params
-
+
try:
result = self.process(image)
finally:
# Restore original params
self.params = original_params
-
+
return result
class PipelineAwareProcessor(BasePreprocessor):
"""
Abstract base class for processors that need access to pipeline state (previous outputs).
-
- This base class marks processors as requiring synchronous processing to avoid
+
+ This base class marks processors as requiring synchronous processing to avoid
temporal artifacts and ensures they have access to pipeline references.
-
+
Usage:
class MyProcessor(PipelineAwareProcessor):
pass
-
+
Examples:
- FeedbackPreprocessor: Needs previous diffusion output
- TemporalNetPreprocessor: Needs previous frame for optical flow
"""
-
+
# Class attribute to mark processors as requiring sync processing
requires_sync_processing = True
-
- def __init__(self, pipeline_ref: Any, normalization_context: str = 'controlnet', **kwargs):
+
+ def __init__(self, pipeline_ref: Any, normalization_context: str = "controlnet", **kwargs):
"""
Initialize pipeline-aware functionality
-
+
Args:
pipeline_ref: Reference to the StreamDiffusion pipeline instance (required)
normalization_context: Context for normalization handling
@@ -302,4 +302,4 @@ def __init__(self, pipeline_ref: Any, normalization_context: str = 'controlnet',
if pipeline_ref is None:
raise ValueError(f"{self.__class__.__name__} requires a pipeline_ref")
super().__init__(normalization_context=normalization_context, **kwargs)
- self.pipeline_ref = pipeline_ref
\ No newline at end of file
+ self.pipeline_ref = pipeline_ref
diff --git a/src/streamdiffusion/preprocessing/processors/blur.py b/src/streamdiffusion/preprocessing/processors/blur.py
index 2e12c8e6a..694d9eb06 100644
--- a/src/streamdiffusion/preprocessing/processors/blur.py
+++ b/src/streamdiffusion/preprocessing/processors/blur.py
@@ -1,19 +1,18 @@
import torch
import torch.nn.functional as F
-import numpy as np
from PIL import Image
-from typing import Union
+
from .base import BasePreprocessor
class BlurPreprocessor(BasePreprocessor):
"""
Gaussian blur preprocessor for ControlNet
-
+
Applies Gaussian blur to the input image using GPU-accelerated operations.
Useful for creating soft, dreamy effects or reducing image detail.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -24,22 +23,22 @@ def get_preprocessor_metadata(cls):
"type": "float",
"default": 2.0,
"range": [0.1, 10.0],
- "description": "Intensity of the blur effect. Higher values create stronger blur."
+ "description": "Intensity of the blur effect. Higher values create stronger blur.",
},
"kernel_size": {
"type": "int",
"default": 15,
"range": [3, 51],
- "description": "Size of the blur kernel. Must be odd. Larger values create smoother blur."
- }
+ "description": "Size of the blur kernel. Must be odd. Larger values create smoother blur.",
+ },
},
- "use_cases": ["Soft focus effects", "Background blur", "Artistic rendering", "Detail reduction"]
+ "use_cases": ["Soft focus effects", "Background blur", "Artistic rendering", "Detail reduction"],
}
-
+
def __init__(self, blur_intensity: float = 2.0, kernel_size: int = 15, **kwargs):
"""
Initialize Blur preprocessor
-
+
Args:
blur_intensity: Standard deviation for Gaussian kernel (higher = more blur)
kernel_size: Size of the blur kernel (must be odd)
@@ -48,58 +47,55 @@ def __init__(self, blur_intensity: float = 2.0, kernel_size: int = 15, **kwargs)
# Ensure kernel_size is odd
if kernel_size % 2 == 0:
kernel_size += 1
-
- super().__init__(
- blur_intensity=blur_intensity,
- kernel_size=kernel_size,
- **kwargs
- )
-
+
+ super().__init__(blur_intensity=blur_intensity, kernel_size=kernel_size, **kwargs)
+
# Cache the Gaussian kernel for efficiency
self._cached_kernel = None
self._cached_kernel_size = None
self._cached_intensity = None
-
+
def _create_gaussian_kernel(self, kernel_size: int, intensity: float) -> torch.Tensor:
"""
Create a 2D Gaussian kernel for blurring
-
+
Args:
kernel_size: Size of the kernel (must be odd)
intensity: Standard deviation of the Gaussian
-
+
Returns:
2D Gaussian kernel tensor
"""
# Create coordinate grids
coords = torch.arange(kernel_size, dtype=self.dtype, device=self.device)
coords = coords - (kernel_size - 1) / 2
-
+
# Create 2D coordinate grids
- y_grid, x_grid = torch.meshgrid(coords, coords, indexing='ij')
-
+ y_grid, x_grid = torch.meshgrid(coords, coords, indexing="ij")
+
# Calculate Gaussian values
gaussian = torch.exp(-(x_grid**2 + y_grid**2) / (2 * intensity**2))
-
+
# Normalize to sum to 1
gaussian = gaussian / gaussian.sum()
-
+
return gaussian
-
+
def _get_gaussian_kernel(self, kernel_size: int, intensity: float) -> torch.Tensor:
"""
Get cached Gaussian kernel or create new one
"""
- if (self._cached_kernel is None or
- self._cached_kernel_size != kernel_size or
- self._cached_intensity != intensity):
-
+ if (
+ self._cached_kernel is None
+ or self._cached_kernel_size != kernel_size
+ or self._cached_intensity != intensity
+ ):
self._cached_kernel = self._create_gaussian_kernel(kernel_size, intensity)
self._cached_kernel_size = kernel_size
self._cached_intensity = intensity
-
+
return self._cached_kernel
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply Gaussian blur to the input image using PIL/numpy fallback
@@ -107,46 +103,46 @@ def _process_core(self, image: Image.Image) -> Image.Image:
# Convert to tensor for processing
tensor = self.pil_to_tensor(image)
tensor = tensor.squeeze(0) # Remove batch dimension
-
+
# Process on GPU
blurred = self._process_tensor_core(tensor)
-
+
# Convert back to PIL
return self.tensor_to_pil(blurred)
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU for Gaussian blur
"""
- blur_intensity = self.params.get('blur_intensity', 2.0)
- kernel_size = self.params.get('kernel_size', 15)
-
+ blur_intensity = self.params.get("blur_intensity", 2.0)
+ kernel_size = self.params.get("kernel_size", 15)
+
# Ensure kernel_size is odd
if kernel_size % 2 == 0:
kernel_size += 1
-
+
# Get the Gaussian kernel
kernel = self._get_gaussian_kernel(kernel_size, blur_intensity)
-
+
# Ensure tensor has batch dimension
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
-
+
# Ensure tensor is on the correct device and dtype
image_tensor = image_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Reshape kernel for conv2d: (out_channels, in_channels/groups, H, W)
# We'll apply the same kernel to each channel separately
num_channels = image_tensor.shape[1]
kernel_conv = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1)
-
+
# Apply Gaussian blur using conv2d with groups=num_channels for per-channel convolution
padding = kernel_size // 2
blurred = F.conv2d(
image_tensor,
kernel_conv,
padding=padding,
- groups=num_channels # Apply kernel separately to each channel
+ groups=num_channels, # Apply kernel separately to each channel
)
-
+
return blurred
diff --git a/src/streamdiffusion/preprocessing/processors/canny.py b/src/streamdiffusion/preprocessing/processors/canny.py
index 7c25e9ab4..90e47241e 100644
--- a/src/streamdiffusion/preprocessing/processors/canny.py
+++ b/src/streamdiffusion/preprocessing/processors/canny.py
@@ -1,18 +1,19 @@
import cv2
import numpy as np
-from PIL import Image
import torch
-from typing import Union
+from PIL import Image
+
from .base import BasePreprocessor
-#TODO provide gpu native edge detection
+
+# TODO provide gpu native edge detection
class CannyPreprocessor(BasePreprocessor):
"""
Canny edge detection preprocessor for ControlNet
-
+
Detects edges in the input image using the Canny edge detection algorithm.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -23,52 +24,48 @@ def get_preprocessor_metadata(cls):
"type": "int",
"default": 100,
"range": [1, 255],
- "description": "Lower threshold for edge detection. Lower values detect more edges."
+ "description": "Lower threshold for edge detection. Lower values detect more edges.",
},
"high_threshold": {
- "type": "int",
+ "type": "int",
"default": 200,
"range": [1, 255],
- "description": "Upper threshold for edge detection. Higher values are more selective."
- }
+ "description": "Upper threshold for edge detection. Higher values are more selective.",
+ },
},
- "use_cases": ["Line art", "Architecture", "Technical drawings", "Clean edge detection"]
+ "use_cases": ["Line art", "Architecture", "Technical drawings", "Clean edge detection"],
}
-
+
def __init__(self, low_threshold: int = 100, high_threshold: int = 200, **kwargs):
"""
Initialize Canny preprocessor
-
+
Args:
low_threshold: Lower threshold for edge detection
high_threshold: Upper threshold for edge detection
**kwargs: Additional parameters
"""
- super().__init__(
- low_threshold=low_threshold,
- high_threshold=high_threshold,
- **kwargs
- )
-
+ super().__init__(low_threshold=low_threshold, high_threshold=high_threshold, **kwargs)
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply Canny edge detection to the input image
"""
image_np = np.array(image)
-
+
if len(image_np.shape) == 3:
gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
else:
gray = image_np
-
- low_threshold = self.params.get('low_threshold', 100)
- high_threshold = self.params.get('high_threshold', 200)
-
+
+ low_threshold = self.params.get("low_threshold", 100)
+ high_threshold = self.params.get("high_threshold", 200)
+
edges = cv2.Canny(gray, low_threshold, high_threshold)
edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
-
+
return Image.fromarray(edges_rgb)
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU for Canny edge detection
@@ -77,18 +74,18 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
gray_tensor = 0.299 * image_tensor[0] + 0.587 * image_tensor[1] + 0.114 * image_tensor[2]
else:
gray_tensor = image_tensor[0] if image_tensor.shape[0] == 1 else image_tensor
-
+
gray_cpu = gray_tensor.cpu()
gray_np = (gray_cpu * 255).clamp(0, 255).to(torch.uint8).numpy()
-
- low_threshold = self.params.get('low_threshold', 100)
- high_threshold = self.params.get('high_threshold', 200)
-
+
+ low_threshold = self.params.get("low_threshold", 100)
+ high_threshold = self.params.get("high_threshold", 200)
+
edges = cv2.Canny(gray_np, low_threshold, high_threshold)
-
+
edges_tensor = torch.from_numpy(edges).float() / 255.0
edges_tensor = edges_tensor.to(device=self.device, dtype=self.dtype)
-
+
edges_rgb = edges_tensor.unsqueeze(0).repeat(3, 1, 1)
-
- return edges_rgb
\ No newline at end of file
+
+ return edges_rgb
diff --git a/src/streamdiffusion/preprocessing/processors/depth.py b/src/streamdiffusion/preprocessing/processors/depth.py
index fbf57dc83..b7287ffa0 100644
--- a/src/streamdiffusion/preprocessing/processors/depth.py
+++ b/src/streamdiffusion/preprocessing/processors/depth.py
@@ -1,12 +1,14 @@
import numpy as np
-from PIL import Image
import torch
-from typing import Union, Optional
+from PIL import Image
+
from .base import BasePreprocessor
+
try:
import torch
from transformers import pipeline
+
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
@@ -15,29 +17,25 @@
class DepthPreprocessor(BasePreprocessor):
"""
Depth estimation preprocessor for ControlNet using MiDaS
-
+
Estimates depth maps from input images using the MiDaS depth estimation model.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "Depth Estimation",
"description": "Estimates depth from the input image using MiDaS. Good for adding depth-based control to generation.",
- "parameters": {
-
- },
- "use_cases": ["3D-aware generation", "Depth preservation", "Scene understanding"]
+ "parameters": {},
+ "use_cases": ["3D-aware generation", "Depth preservation", "Scene understanding"],
}
-
- def __init__(self,
- model_name: str = "Intel/dpt-large",
- detect_resolution: int = 512,
- image_resolution: int = 512,
- **kwargs):
+
+ def __init__(
+ self, model_name: str = "Intel/dpt-large", detect_resolution: int = 512, image_resolution: int = 512, **kwargs
+ ):
"""
Initialize depth preprocessor
-
+
Args:
model_name: Name of the depth estimation model to use
detect_resolution: Resolution for depth detection
@@ -46,102 +44,94 @@ def __init__(self,
"""
if not TRANSFORMERS_AVAILABLE:
raise ImportError(
- "transformers library is required for depth preprocessing. "
- "Install it with: pip install transformers"
+ "transformers library is required for depth preprocessing. Install it with: pip install transformers"
)
-
+
super().__init__(
- model_name=model_name,
- detect_resolution=detect_resolution,
- image_resolution=image_resolution,
- **kwargs
+ model_name=model_name, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs
)
-
+
self._depth_estimator = None
-
+
@property
def depth_estimator(self):
"""Lazy loading of the depth estimation model"""
if self._depth_estimator is None:
- model_name = self.params.get('model_name', 'Intel/dpt-large')
+ model_name = self.params.get("model_name", "Intel/dpt-large")
print(f"Loading depth estimation model: {model_name}")
self._depth_estimator = pipeline(
- 'depth-estimation',
- model=model_name,
- device=0 if torch.cuda.is_available() else -1
+ "depth-estimation", model=model_name, device=0 if torch.cuda.is_available() else -1
)
return self._depth_estimator
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply depth estimation to the input image
"""
- detect_resolution = self.params.get('detect_resolution', 512)
+ detect_resolution = self.params.get("detect_resolution", 512)
image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS)
-
+
depth_result = self.depth_estimator(image_resized)
- depth_map = depth_result['depth']
-
- if hasattr(depth_map, 'cpu'):
+ depth_map = depth_result["depth"]
+
+ if hasattr(depth_map, "cpu"):
depth_np = depth_map.cpu().numpy()
else:
depth_np = np.array(depth_map)
-
+
depth_min = depth_np.min()
depth_max = depth_np.max()
if depth_max > depth_min:
depth_normalized = ((depth_np - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8)
else:
depth_normalized = np.zeros_like(depth_np, dtype=np.uint8)
-
+
depth_rgb = np.stack([depth_normalized] * 3, axis=-1)
return Image.fromarray(depth_rgb)
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU for depth estimation
"""
- detect_resolution = self.params.get('detect_resolution', 512)
+ detect_resolution = self.params.get("detect_resolution", 512)
current_size = image_tensor.shape[-2:]
-
+
if current_size != (detect_resolution, detect_resolution):
import torch.nn.functional as F
+
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
-
+
resized_tensor = F.interpolate(
- image_tensor,
- size=(detect_resolution, detect_resolution),
- mode='bilinear',
- align_corners=False
+ image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False
)
-
+
if image_tensor.shape[0] == 1:
resized_tensor = resized_tensor.squeeze(0)
else:
resized_tensor = image_tensor
-
+
pil_image = self.tensor_to_pil(resized_tensor)
-
+
depth_result = self.depth_estimator(pil_image)
- depth_map = depth_result['depth']
-
- if hasattr(depth_map, 'to'):
+ depth_map = depth_result["depth"]
+
+ if hasattr(depth_map, "to"):
depth_tensor = depth_map.to(device=self.device, dtype=self.dtype)
else:
depth_np = np.array(depth_map)
depth_tensor = torch.from_numpy(depth_np).to(device=self.device, dtype=self.dtype)
-
+
depth_min = depth_tensor.min()
depth_max = depth_tensor.max()
if depth_max > depth_min:
depth_normalized = (depth_tensor - depth_min) / (depth_max - depth_min)
else:
depth_normalized = torch.zeros_like(depth_tensor)
-
+
if depth_normalized.dim() == 2:
depth_rgb = depth_normalized.unsqueeze(0).repeat(3, 1, 1)
else:
depth_rgb = depth_normalized
-
- return depth_rgb
\ No newline at end of file
+
+ return depth_rgb
diff --git a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py
index 993ee242d..70ce9dbce 100644
--- a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py
+++ b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py
@@ -1,19 +1,23 @@
-#NOTE: ported from https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt
+# NOTE: ported from https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt
import os
+
+import cv2
import numpy as np
import torch
import torch.nn.functional as F
-import cv2
from PIL import Image
-from typing import Union, Optional
+
from .base import BasePreprocessor
+
try:
+ from collections import OrderedDict
+
import tensorrt as trt
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import engine_from_bytes
- from collections import OrderedDict
+
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
@@ -40,7 +44,7 @@
class TensorRTEngine:
"""Simplified TensorRT engine wrapper for depth estimation inference (optimized)"""
-
+
def __init__(self, engine_path):
self.engine_path = engine_path
self.engine = None
@@ -65,13 +69,11 @@ def allocate_buffers(self, device="cuda"):
name = self.engine.get_tensor_name(idx)
shape = self.context.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
-
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
-
- tensor = torch.empty(
- tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]
- ).to(device=device)
+
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[name] = tensor
def infer(self, feed_dict, stream=None):
@@ -79,7 +81,7 @@ def infer(self, feed_dict, stream=None):
# Use cached stream if none provided
if stream is None:
stream = self._cuda_stream
-
+
# Copy input data to tensors
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
@@ -87,39 +89,35 @@ def infer(self, feed_dict, stream=None):
# Set tensor addresses
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
-
+
# Execute inference
success = self.context.execute_async_v3(stream)
if not success:
raise ValueError("ERROR: TensorRT inference failed.")
-
+
return self.tensors
class DepthAnythingTensorrtPreprocessor(BasePreprocessor):
"""
Depth Anything TensorRT preprocessor for ControlNet
-
+
Uses TensorRT-optimized Depth Anything model for fast depth estimation.
"""
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "Depth Estimation (TensorRT)",
"description": "Fast TensorRT-optimized depth estimation using Depth Anything model. Significantly faster than standard depth estimation.",
- "parameters": {
-
- },
- "use_cases": ["High-performance depth estimation", "Real-time applications", "3D-aware generation"]
+ "parameters": {},
+ "use_cases": ["High-performance depth estimation", "Real-time applications", "3D-aware generation"],
}
- def __init__(self,
- engine_path: str = None,
- detect_resolution: int = 518,
- image_resolution: int = 512,
- **kwargs):
+
+ def __init__(self, engine_path: str = None, detect_resolution: int = 518, image_resolution: int = 512, **kwargs):
"""
Initialize TensorRT depth preprocessor
-
+
Args:
engine_path: Path to TensorRT engine file
detect_resolution: Resolution for depth detection (should match engine input)
@@ -131,74 +129,68 @@ def __init__(self,
"TensorRT and polygraphy libraries are required for TensorRT depth preprocessing. "
"Install them with: pip install tensorrt polygraphy"
)
-
+
super().__init__(
- engine_path=engine_path,
- detect_resolution=detect_resolution,
- image_resolution=image_resolution,
- **kwargs
+ engine_path=engine_path, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs
)
-
+
self._engine = None
-
+
@property
def engine(self):
"""Lazy loading of the TensorRT engine"""
if self._engine is None:
- engine_path = self.params.get('engine_path')
+ engine_path = self.params.get("engine_path")
if engine_path is None:
raise ValueError(
"engine_path is required for TensorRT depth preprocessing. "
"Please provide it in the preprocessor_params config."
)
-
+
if not os.path.exists(engine_path):
raise FileNotFoundError(f"TensorRT engine not found: {engine_path}")
-
+
print(f"Loading TensorRT depth estimation engine: {engine_path}")
-
+
self._engine = TensorRTEngine(engine_path)
self._engine.load()
self._engine.activate()
self._engine.allocate_buffers()
-
+
return self._engine
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply TensorRT depth estimation to the input image
"""
- detect_resolution = self.params.get('detect_resolution', 518)
-
+ detect_resolution = self.params.get("detect_resolution", 518)
+
image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
-
+
image_resized = F.interpolate(
- image_tensor,
- size=(detect_resolution, detect_resolution),
- mode='bilinear',
- align_corners=False
+ image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False
)
-
+
if torch.cuda.is_available():
image_resized = image_resized.cuda()
-
+
cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.engine.infer({"input": image_resized}, cuda_stream)
- depth = result['output']
-
+ depth = result["output"]
+
depth = np.reshape(depth.cpu().numpy(), (detect_resolution, detect_resolution))
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
-
+
original_size = image.size
depth = cv2.resize(depth, original_size)
-
+
depth_rgb = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB)
result = Image.fromarray(depth_rgb)
-
+
return result
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU to avoid CPU transfers
@@ -207,20 +199,19 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
image_tensor = image_tensor.unsqueeze(0)
if not image_tensor.is_cuda:
image_tensor = image_tensor.cuda()
-
- detect_resolution = self.params.get('detect_resolution', 518)
-
+
+ detect_resolution = self.params.get("detect_resolution", 518)
+
image_resized = torch.nn.functional.interpolate(
- image_tensor, size=(detect_resolution, detect_resolution),
- mode='bilinear', align_corners=False
+ image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False
)
-
+
cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.engine.infer({"input": image_resized}, cuda_stream)
- depth_tensor = result['output']
-
+ depth_tensor = result["output"]
+
depth_tensor = depth_tensor.squeeze() if depth_tensor.dim() > 2 else depth_tensor
depth_min, depth_max = depth_tensor.min(), depth_tensor.max()
depth_normalized = (depth_tensor - depth_min) / (depth_max - depth_min)
-
- return depth_normalized.repeat(3, 1, 1).unsqueeze(0)
\ No newline at end of file
+
+ return depth_normalized.repeat(3, 1, 1).unsqueeze(0)
diff --git a/src/streamdiffusion/preprocessing/processors/external.py b/src/streamdiffusion/preprocessing/processors/external.py
index 80bd7fe8d..3a205b132 100644
--- a/src/streamdiffusion/preprocessing/processors/external.py
+++ b/src/streamdiffusion/preprocessing/processors/external.py
@@ -1,95 +1,87 @@
+from typing import Union
+
import numpy as np
import torch
from PIL import Image
-from typing import Union, Optional, Dict, Any
+
from .base import BasePreprocessor
class ExternalPreprocessor(BasePreprocessor):
"""
External source preprocessor for client-processed control data
-
+
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "External",
"description": "Allows using external preprocessing tools or custom processing pipelines.",
- "parameters": {
-
- },
- "use_cases": ["Custom processing", "Third-party tools integration", "Pre-processed control images"]
+ "parameters": {},
+ "use_cases": ["Custom processing", "Third-party tools integration", "Pre-processed control images"],
}
-
- def __init__(self,
- image_resolution: int = 512,
- validate_input: bool = True,
- **kwargs):
+
+ def __init__(self, image_resolution: int = 512, validate_input: bool = True, **kwargs):
"""
Initialize external source preprocessor
-
+
Args:
image_resolution: Target output resolution
validate_input: Whether to validate the control image format
**kwargs: Additional parameters
"""
- super().__init__(
- image_resolution=image_resolution,
- validate_input=validate_input,
- **kwargs
- )
-
+ super().__init__(image_resolution=image_resolution, validate_input=validate_input, **kwargs)
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Process client-preprocessed control image
-
+
Applies minimal server-side validation to control images
that have already been processed by external sources.
"""
# Optional validation of control image format
- if self.params.get('validate_input', True):
+ if self.params.get("validate_input", True):
image = self._validate_control_image(image)
-
+
return image
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly (optimized path for external sources)
-
+
For external sources, tensor input likely comes from client WebGL/Canvas
processing, so minimal processing needed.
"""
return tensor
-
+
def _validate_control_image(self, image: Image.Image) -> Image.Image:
"""
Validate that the control image is in proper format
"""
# Convert to RGB if needed
- if image.mode != 'RGB':
- image = image.convert('RGB')
-
+ if image.mode != "RGB":
+ image = image.convert("RGB")
+
# Basic validation - check if image has content
# (not completely black, which might indicate processing failure)
img_array = np.array(image)
brightness = np.mean(img_array)
-
+
if brightness < 1.0: # Very dark image, might be processing error
print("ExternalPreprocessor._validate_control_image: Warning - control image appears very dark")
-
+
return image
-
-
+
def __call__(self, image: Union[Image.Image, np.ndarray, torch.Tensor], **kwargs) -> Image.Image:
"""
Process control image (convenience method)
"""
# Store any client metadata if provided
- client_metadata = kwargs.get('client_metadata', {})
+ client_metadata = kwargs.get("client_metadata", {})
if client_metadata:
- source = client_metadata.get('source', 'unknown')
- control_type = client_metadata.get('type', 'unknown')
+ source = client_metadata.get("source", "unknown")
+ control_type = client_metadata.get("type", "unknown")
print(f"ExternalPreprocessor: Received {control_type} control from {source}")
-
- return super().__call__(image, **kwargs)
\ No newline at end of file
+
+ return super().__call__(image, **kwargs)
diff --git a/src/streamdiffusion/preprocessing/processors/faceid_embedding.py b/src/streamdiffusion/preprocessing/processors/faceid_embedding.py
index c6421897a..31bb042ae 100644
--- a/src/streamdiffusion/preprocessing/processors/faceid_embedding.py
+++ b/src/streamdiffusion/preprocessing/processors/faceid_embedding.py
@@ -1,9 +1,12 @@
-from typing import Tuple, Any
+from typing import Any, Tuple
+
import torch
from PIL import Image
-from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor
+
from streamdiffusion.utils.reporting import report_error
+from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor
+
class FaceIDEmbeddingPreprocessor(IPAdapterEmbeddingPreprocessor):
"""
@@ -45,9 +48,7 @@ def __init__(self, ipadapter: Any, faceid_v2_weight: float = 1.0, **kwargs):
self.faceid_v2_weight = float(faceid_v2_weight)
if not hasattr(ipadapter, "insightface_model") or ipadapter.insightface_model is None:
- raise ValueError(
- "FaceIDEmbeddingPreprocessor: ipadapter must have an initialized InsightFace model"
- )
+ raise ValueError("FaceIDEmbeddingPreprocessor: ipadapter must have an initialized InsightFace model")
def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -78,8 +79,4 @@ def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]
def update_faceid_v2_weight(self, weight: float) -> None:
self.faceid_v2_weight = float(weight)
- print(
- f"FaceIDEmbeddingPreprocessor.update_faceid_v2_weight: Updated weight to {self.faceid_v2_weight}"
- )
-
-
+ print(f"FaceIDEmbeddingPreprocessor.update_faceid_v2_weight: Updated weight to {self.faceid_v2_weight}")
diff --git a/src/streamdiffusion/preprocessing/processors/feedback.py b/src/streamdiffusion/preprocessing/processors/feedback.py
index 72a37a7f0..05cce45fa 100644
--- a/src/streamdiffusion/preprocessing/processors/feedback.py
+++ b/src/streamdiffusion/preprocessing/processors/feedback.py
@@ -1,28 +1,30 @@
+from typing import Any
+
import torch
from PIL import Image
-from typing import Union, Optional, Any
+
from .base import PipelineAwareProcessor
class FeedbackPreprocessor(PipelineAwareProcessor):
"""
Feedback preprocessor for ControlNet
-
+
Creates a configurable blend between the current input image and the previous frame's diffusion output.
This creates a feedback loop where each generated frame influences the next generation,
while allowing control over the blend strength for stability and creative effects.
-
+
Formula: output = (1 - feedback_strength) * input_image + feedback_strength * previous_output
-
+
Examples:
- feedback_strength = 0.0: Pure passthrough (input only)
- feedback_strength = 0.5: 50/50 blend (default)
- feedback_strength = 1.0: Pure feedback (previous output only)
-
+
The preprocessor accesses the pipeline's prev_image_result to get the previous output.
For the first frame (when no previous output exists), it falls back to the input image.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -34,21 +36,29 @@ def get_preprocessor_metadata(cls):
"default": 0.5,
"range": [0.0, 1.0],
"step": 0.01,
- "description": "Strength of feedback blend (0.0 = pure input, 1.0 = pure feedback)"
+ "description": "Strength of feedback blend (0.0 = pure input, 1.0 = pure feedback)",
}
},
- "use_cases": ["Temporal consistency", "Video-like generation", "Smooth transitions", "Deforum", "Blast off"]
+ "use_cases": [
+ "Temporal consistency",
+ "Video-like generation",
+ "Smooth transitions",
+ "Deforum",
+ "Blast off",
+ ],
}
-
- def __init__(self,
- pipeline_ref: Any,
- normalization_context: str = 'controlnet',
- image_resolution: int = 512,
- feedback_strength: float = 0.5,
- **kwargs):
+
+ def __init__(
+ self,
+ pipeline_ref: Any,
+ normalization_context: str = "controlnet",
+ image_resolution: int = 512,
+ feedback_strength: float = 0.5,
+ **kwargs,
+ ):
"""
Initialize feedback preprocessor
-
+
Args:
pipeline_ref: Reference to the StreamDiffusion pipeline instance (required)
normalization_context: Context for normalization handling
@@ -61,41 +71,42 @@ def __init__(self,
normalization_context=normalization_context,
image_resolution=image_resolution,
feedback_strength=feedback_strength,
- **kwargs
+ **kwargs,
)
self.feedback_strength = max(0.0, min(1.0, feedback_strength)) # Clamp to [0, 1]
self._first_frame = True
-
+
def reset(self):
"""Reset the processor state (useful for new sequences)"""
self._first_frame = True
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Process using configurable blend of input image + previous frame output
-
+
Args:
image: Current input image
-
+
Returns:
Blended PIL Image (blend strength controlled by feedback_strength), or input image for first frame
"""
# Check if we have a pipeline reference and previous output
- if (self.pipeline_ref is not None and
- hasattr(self.pipeline_ref, 'prev_image_result') and
- self.pipeline_ref.prev_image_result is not None and
- not self._first_frame):
-
+ if (
+ self.pipeline_ref is not None
+ and hasattr(self.pipeline_ref, "prev_image_result")
+ and self.pipeline_ref.prev_image_result is not None
+ and not self._first_frame
+ ):
prev_output_tensor = self.pipeline_ref.prev_image_result
# Convert previous output tensor to PIL Image
if prev_output_tensor.dim() == 4:
prev_output_tensor = prev_output_tensor[0] # Remove batch dimension
-
+
# Context-aware normalization handling
- if self.normalization_context == 'controlnet':
+ if self.normalization_context == "controlnet":
# ControlNet context: Convert from [-1, 1] (VAE output) to [0, 1] (ControlNet input)
prev_output_tensor = (prev_output_tensor / 2.0 + 0.5).clamp(0, 1)
- elif self.normalization_context == 'pipeline':
+ elif self.normalization_context == "pipeline":
# Pipeline context: prev_output is already [-1, 1], but pil_to_tensor produces [0, 1]
# So we need to convert input to [-1, 1] to match prev_output
# Convert prev_output to [0, 1] for blending in standard image space
@@ -103,15 +114,15 @@ def _process_core(self, image: Image.Image) -> Image.Image:
else:
# Unknown context - assume controlnet for backward compatibility
prev_output_tensor = (prev_output_tensor / 2.0 + 0.5).clamp(0, 1)
-
+
# Convert both to tensors for blending
prev_output_pil = self.tensor_to_pil(prev_output_tensor)
input_tensor = self.pil_to_tensor(image).squeeze(0) # Remove batch dim for blending
prev_tensor = self.pil_to_tensor(prev_output_pil).squeeze(0)
-
+
# Blend with configurable strength (both tensors now in [0, 1] range)
blended_tensor = (1 - self.feedback_strength) * input_tensor + self.feedback_strength * prev_tensor
-
+
# Convert back to PIL
blended_pil = self.tensor_to_pil(blended_tensor)
return blended_pil
@@ -119,35 +130,36 @@ def _process_core(self, image: Image.Image) -> Image.Image:
# First frame, no pipeline ref, or no previous output available - use input image
self._first_frame = False
return image
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Process using configurable blend of input tensor + previous frame output (GPU-optimized path)
-
+
Args:
tensor: Current input tensor
-
+
Returns:
Blended tensor (blend strength controlled by feedback_strength), or input tensor for first frame
"""
# Check if we have a pipeline reference and previous output
- if (self.pipeline_ref is not None and
- hasattr(self.pipeline_ref, 'prev_image_result') and
- self.pipeline_ref.prev_image_result is not None and
- not self._first_frame):
-
+ if (
+ self.pipeline_ref is not None
+ and hasattr(self.pipeline_ref, "prev_image_result")
+ and self.pipeline_ref.prev_image_result is not None
+ and not self._first_frame
+ ):
prev_output = self.pipeline_ref.prev_image_result
input_tensor = tensor
-
+
# Context-aware normalization handling
- if self.normalization_context == 'controlnet':
+ if self.normalization_context == "controlnet":
# ControlNet context: prev_output is [-1, 1] from VAE, input is [0, 1]
# Convert prev_output from [-1, 1] to [0, 1] to match input
prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1)
# Normalize input tensor to [0, 1] if needed
if input_tensor.max() > 1.0:
input_tensor = input_tensor / 255.0
- elif self.normalization_context == 'pipeline':
+ elif self.normalization_context == "pipeline":
# Pipeline context: both prev_output and input_tensor are in [-1, 1] range
# - prev_output comes from VAE decode (always [-1, 1])
# - input_tensor arrives as [-1, 1] from image_processor.preprocess()
@@ -157,17 +169,20 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
else:
# Unknown context - log warning and assume controlnet behavior for backward compatibility
import logging
- logging.warning(f"FeedbackPreprocessor: Unknown normalization_context '{self.normalization_context}', using controlnet behavior")
+
+ logging.warning(
+ f"FeedbackPreprocessor: Unknown normalization_context '{self.normalization_context}', using controlnet behavior"
+ )
prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1)
if input_tensor.max() > 1.0:
input_tensor = input_tensor / 255.0
-
+
# Ensure both tensors have same format for blending
if prev_output.dim() == 4 and prev_output.shape[0] == 1:
prev_output = prev_output[0] # Remove batch dimension
if input_tensor.dim() == 4 and input_tensor.shape[0] == 1:
input_tensor = input_tensor[0] # Remove batch dimension
-
+
# Resize if dimensions don't match
if prev_output.shape != input_tensor.shape:
# Use the input tensor's shape as target
@@ -176,18 +191,18 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
if prev_output.dim() == 3:
prev_output = prev_output.unsqueeze(0)
prev_output = torch.nn.functional.interpolate(
- prev_output, size=target_size, mode='bilinear', align_corners=False
+ prev_output, size=target_size, mode="bilinear", align_corners=False
)
if prev_output.shape[0] == 1:
prev_output = prev_output.squeeze(0)
-
+
# Blend with configurable strength
blended_tensor = (1 - self.feedback_strength) * input_tensor + self.feedback_strength * prev_output
-
+
# Ensure correct output format
if blended_tensor.dim() == 3:
blended_tensor = blended_tensor.unsqueeze(0) # Add batch dimension back
-
+
# Ensure correct device and dtype
blended_tensor = blended_tensor.to(device=self.device, dtype=self.dtype)
return blended_tensor
diff --git a/src/streamdiffusion/preprocessing/processors/hed.py b/src/streamdiffusion/preprocessing/processors/hed.py
index 78c770878..0c5ef9a6b 100644
--- a/src/streamdiffusion/preprocessing/processors/hed.py
+++ b/src/streamdiffusion/preprocessing/processors/hed.py
@@ -1,11 +1,13 @@
-import torch
import numpy as np
+import torch
from PIL import Image
-from typing import Union, Optional
+
from .base import BasePreprocessor
+
try:
from controlnet_aux import HEDdetector
+
CONTROLNET_AUX_AVAILABLE = True
except ImportError:
CONTROLNET_AUX_AVAILABLE = False
@@ -15,83 +17,81 @@
class HEDPreprocessor(BasePreprocessor):
"""
HED (Holistically-Nested Edge Detection) preprocessor
-
+
Uses controlnet_aux HEDdetector for high-quality edge detection.
"""
-
+
_model_cache = {}
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "HED Edge Detection",
"description": "Holistically-Nested Edge Detection for clean, structured edge maps.",
"parameters": {
- "safe": {
- "type": "bool",
- "default": True,
- "description": "Whether to use safe mode for edge detection"
- }
+ "safe": {"type": "bool", "default": True, "description": "Whether to use safe mode for edge detection"}
},
- "use_cases": ["Structured edge detection", "Clean architectural edges", "Line art generation"]
+ "use_cases": ["Structured edge detection", "Clean architectural edges", "Line art generation"],
}
-
+
def __init__(self, safe: bool = True, **kwargs):
if not CONTROLNET_AUX_AVAILABLE:
- raise ImportError("controlnet_aux is required for HED preprocessor. Install with: pip install controlnet_aux")
-
+ raise ImportError(
+ "controlnet_aux is required for HED preprocessor. Install with: pip install controlnet_aux"
+ )
+
super().__init__(**kwargs)
self.safe = safe
self.model = None
self._load_model()
-
+
def _load_model(self):
"""Load controlnet_aux HED model with caching"""
cache_key = f"hed_{self.device}"
-
+
if cache_key in self._model_cache:
self.model = self._model_cache[cache_key]
return
-
+
print("HEDPreprocessor: Loading controlnet_aux HED model")
try:
# Initialize HED detector
self.model = HEDdetector.from_pretrained("lllyasviel/Annotators")
- if hasattr(self.model, 'to'):
+ if hasattr(self.model, "to"):
self.model = self.model.to(self.device)
-
+
# Cache the model
self._model_cache[cache_key] = self.model
print(f"HEDPreprocessor: Successfully loaded model on {self.device}")
-
+
except Exception as e:
raise RuntimeError(f"Failed to load HED model: {e}")
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""Apply HED edge detection to the input image"""
# Get target dimensions
target_width, target_height = self.get_target_dimensions()
-
+
# Process with controlnet_aux
result = self.model(image, output_type="pil")
-
+
# Ensure result is PIL Image
if not isinstance(result, Image.Image):
if isinstance(result, np.ndarray):
result = Image.fromarray(result)
else:
raise ValueError(f"Unexpected result type: {type(result)}")
-
+
# Resize to target size if needed
if result.size != (target_width, target_height):
result = result.resize((target_width, target_height), Image.LANCZOS)
-
+
return result
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
GPU-optimized HED processing using tensors
-
+
Note: controlnet_aux doesn't support direct tensor input, so we convert to PIL and back.
This is still reasonably fast due to optimized conversions in the base class.
"""
@@ -99,24 +99,18 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
pil_image = self.tensor_to_pil(image_tensor)
processed_pil = self._process_core(pil_image)
return self.pil_to_tensor(processed_pil)
-
-
+
@classmethod
- def create_optimized(cls, device: str = 'cuda', dtype: torch.dtype = torch.float16, **kwargs):
+ def create_optimized(cls, device: str = "cuda", dtype: torch.dtype = torch.float16, **kwargs):
"""
Create an optimized HED preprocessor
-
+
Args:
device: Target device ('cuda' or 'cpu')
dtype: Data type for inference
**kwargs: Additional parameters
-
+
Returns:
Optimized HEDPreprocessor instance
"""
- return cls(
- device=device,
- dtype=dtype,
- safe=True,
- **kwargs
- )
\ No newline at end of file
+ return cls(device=device, dtype=dtype, safe=True, **kwargs)
diff --git a/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py b/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py
index 8b7d28a08..398119bfb 100644
--- a/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py
+++ b/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py
@@ -1,6 +1,8 @@
-from typing import Union, Tuple, Optional, Any
+from typing import Any, Tuple, Union
+
import torch
from PIL import Image
+
from .base import BasePreprocessor
@@ -9,53 +11,53 @@ class IPAdapterEmbeddingPreprocessor(BasePreprocessor):
Preprocessor that generates IPAdapter embeddings instead of spatial conditioning.
Leverages existing preprocessing infrastructure for parallel IPAdapter embedding generation.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "IPAdapter Embedding",
"description": "Generates IPAdapter embeddings for style transfer and image conditioning instead of spatial control maps.",
"parameters": {},
- "use_cases": ["Style transfer", "Image conditioning", "Semantic control", "Content-aware generation"]
+ "use_cases": ["Style transfer", "Image conditioning", "Semantic control", "Content-aware generation"],
}
-
+
def __init__(self, ipadapter: Any, **kwargs):
super().__init__(**kwargs)
self.ipadapter = ipadapter
# Verify the ipadapter has the required method
- if not hasattr(ipadapter, 'get_image_embeds'):
+ if not hasattr(ipadapter, "get_image_embeds"):
raise ValueError("IPAdapterEmbeddingPreprocessor: ipadapter must have 'get_image_embeds' method")
-
+
# Create dedicated CUDA stream for IPAdapter processing to avoid TensorRT conflicts
self._ipadapter_stream = torch.cuda.Stream() if torch.cuda.is_available() else None
-
+
def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns (positive_embeds, negative_embeds) instead of processed image"""
if self._ipadapter_stream is not None:
# Use dedicated stream to avoid TensorRT stream capture conflicts
with torch.cuda.stream(self._ipadapter_stream):
image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image])
-
+
# Wait for stream completion and move tensors to default stream
self._ipadapter_stream.synchronize()
-
+
# Ensure tensors are accessible from default stream
- if hasattr(image_embeds, 'record_stream'):
+ if hasattr(image_embeds, "record_stream"):
image_embeds.record_stream(torch.cuda.current_stream())
- if hasattr(negative_embeds, 'record_stream'):
+ if hasattr(negative_embeds, "record_stream"):
negative_embeds.record_stream(torch.cuda.current_stream())
else:
# Fallback for non-CUDA environments
image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image])
-
+
return image_embeds, negative_embeds
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""GPU-optimized path for tensor inputs"""
# Convert tensor to PIL for IPAdapter processing
pil_image = self.tensor_to_pil(tensor)
return self._process_core(pil_image)
-
+
def process(self, image: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Override base process to return embeddings tuple instead of PIL Image"""
if isinstance(image, torch.Tensor):
@@ -63,9 +65,9 @@ def process(self, image: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor
else:
image = self.validate_input(image)
result = self._process_core(image)
-
+
return result
-
+
def process_tensor(self, image_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Override base process_tensor to return embeddings tuple"""
tensor = self.validate_tensor_input(image_tensor)
diff --git a/src/streamdiffusion/preprocessing/processors/latent_feedback.py b/src/streamdiffusion/preprocessing/processors/latent_feedback.py
index e5b8f8b2c..38d70961d 100644
--- a/src/streamdiffusion/preprocessing/processors/latent_feedback.py
+++ b/src/streamdiffusion/preprocessing/processors/latent_feedback.py
@@ -1,27 +1,29 @@
+from typing import Any
+
import torch
-from typing import Optional, Any
+
from .base import PipelineAwareProcessor
class LatentFeedbackPreprocessor(PipelineAwareProcessor):
"""
Latent domain feedback preprocessor
-
+
Creates a configurable blend between the current input latent and the previous frame's latent output.
This creates a feedback loop in latent space where each generated latent influences the next generation,
providing temporal consistency without the overhead of VAE encoding/decoding.
-
+
Formula: output = (1 - feedback_strength) * input_latent + feedback_strength * previous_latent
-
+
Examples:
- feedback_strength = 0.0: Pure passthrough (input only)
- feedback_strength = 0.15: Default safe blend
- feedback_strength = 0.40: Maximum safe feedback (values > 0.4 produce garbage)
-
+
The preprocessor accesses the pipeline's prev_latent_result to get the previous latent output.
For the first frame (when no previous output exists), it falls back to the input latent.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -33,20 +35,24 @@ def get_preprocessor_metadata(cls):
"default": 0.15,
"range": [0.0, 0.40],
"step": 0.01,
- "description": "Strength of latent feedback blend (0.0 = pure input, .40 = more feedback)"
+ "description": "Strength of latent feedback blend (0.0 = pure input, .40 = more feedback)",
}
},
- "use_cases": ["Latent temporal consistency", "Latent space transitions", "Efficient feedback", "Latent preprocessing", "Temporal stability"]
+ "use_cases": [
+ "Latent temporal consistency",
+ "Latent space transitions",
+ "Efficient feedback",
+ "Latent preprocessing",
+ "Temporal stability",
+ ],
}
-
- def __init__(self,
- pipeline_ref: Any,
- normalization_context: str = 'latent',
- feedback_strength: float = 0.15,
- **kwargs):
+
+ def __init__(
+ self, pipeline_ref: Any, normalization_context: str = "latent", feedback_strength: float = 0.15, **kwargs
+ ):
"""
Initialize latent feedback preprocessor
-
+
Args:
pipeline_ref: Reference to the StreamDiffusion pipeline instance (required)
normalization_context: Context for normalization handling (latent space doesn't need normalization)
@@ -57,29 +63,31 @@ def __init__(self,
pipeline_ref=pipeline_ref,
normalization_context=normalization_context,
feedback_strength=feedback_strength,
- **kwargs
+ **kwargs,
)
- self.feedback_strength = max(0.0, min(0.40, feedback_strength)) # Clamp to [0, 0.40] - values > 0.4 produce garbage
+ self.feedback_strength = max(
+ 0.0, min(0.40, feedback_strength)
+ ) # Clamp to [0, 0.40] - values > 0.4 produce garbage
self._first_frame = True
-
+
def _get_previous_data(self):
"""Get previous frame latent data from pipeline"""
if self.pipeline_ref is not None:
# Get previous OUTPUT latent (after diffusion), not input latent
# Check for prev_latent_result (the actual attribute name used by the pipeline)
- if hasattr(self.pipeline_ref, 'prev_latent_result'):
+ if hasattr(self.pipeline_ref, "prev_latent_result"):
if self.pipeline_ref.prev_latent_result is not None and not self._first_frame:
return self.pipeline_ref.prev_latent_result
return None
-
- #TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class
+
+ # TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class
def validate_tensor_input(self, latent_tensor: torch.Tensor) -> torch.Tensor:
"""
Validate latent tensor input - preserve batch dimensions for latent processing
-
+
Args:
latent_tensor: Input latent tensor in format [B, C, H/8, W/8]
-
+
Returns:
Validated latent tensor with preserved batch dimension
"""
@@ -87,18 +95,18 @@ def validate_tensor_input(self, latent_tensor: torch.Tensor) -> torch.Tensor:
# Only ensure correct device and dtype
latent_tensor = latent_tensor.to(device=self.device, dtype=self.dtype)
return latent_tensor
-
- #TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class
+
+ # TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class
def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Override base class resize logic - latent tensors should NOT be resized to image dimensions
-
+
For latent domain processing, we want to preserve the latent space dimensions,
not resize to image target dimensions like image-domain processors.
"""
# For latent feedback, just return the tensor as-is without any resizing
return tensor
-
+
def _process_core(self, image):
"""
For latent feedback, we don't process PIL images directly.
@@ -108,23 +116,23 @@ def _process_core(self, image):
"LatentFeedbackPreprocessor is designed for latent domain processing. "
"Use _process_tensor_core or process_tensor for latent tensors."
)
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Process latent tensor with feedback blending
-
+
Args:
tensor: Current input latent tensor in format [B, C, H/8, W/8]
-
+
Returns:
Blended latent tensor (blend strength controlled by feedback_strength), or input tensor for first frame
"""
# Get previous frame latent data using mixin method
prev_latent = self._get_previous_data()
-
+
if prev_latent is not None:
input_latent = tensor
-
+
# Ensure both tensors have the same batch size for element-wise blending
# If batch sizes differ, expand the smaller one to match
if prev_latent.shape[0] != input_latent.shape[0]:
@@ -139,22 +147,22 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
min_batch = min(prev_latent.shape[0], input_latent.shape[0])
prev_latent = prev_latent[:min_batch]
input_latent = input_latent[:min_batch]
-
+
# Resize spatial dimensions if they don't match (though this should be rare in latent space)
if prev_latent.shape[2:] != input_latent.shape[2:]:
target_size = input_latent.shape[2:] # Get H, W from input
prev_latent = torch.nn.functional.interpolate(
- prev_latent, size=target_size, mode='bilinear', align_corners=False
+ prev_latent, size=target_size, mode="bilinear", align_corners=False
)
-
+
# Blend current latent with previous latent for temporal consistency
# Higher feedback_strength = more influence from previous frame
blended_latent = (1 - self.feedback_strength) * input_latent + self.feedback_strength * prev_latent
-
+
# Add safety measures for latent values to prevent extreme outputs
# Clamp to reasonable range based on typical latent distributions
blended_latent = torch.clamp(blended_latent, min=-10.0, max=10.0)
-
+
# Ensure correct device and dtype
blended_latent = blended_latent.to(device=self.device, dtype=self.dtype)
return blended_latent
diff --git a/src/streamdiffusion/preprocessing/processors/lineart.py b/src/streamdiffusion/preprocessing/processors/lineart.py
index 4f0bafa81..030043c4d 100644
--- a/src/streamdiffusion/preprocessing/processors/lineart.py
+++ b/src/streamdiffusion/preprocessing/processors/lineart.py
@@ -1,29 +1,34 @@
import logging
-import numpy as np
-from PIL import Image
-from typing import Union, Optional
import time
+
+from PIL import Image
+
from .base import BasePreprocessor
+
logger = logging.getLogger(__name__)
try:
- from controlnet_aux import LineartDetector, LineartAnimeDetector
+ from controlnet_aux import LineartAnimeDetector, LineartDetector
+
CONTROLNET_AUX_AVAILABLE = True
except ImportError:
CONTROLNET_AUX_AVAILABLE = False
- raise ImportError("LineartPreprocessor: controlnet_aux is required for real-time optimization. Install with: pip install controlnet_aux")
+ raise ImportError(
+ "LineartPreprocessor: controlnet_aux is required for real-time optimization. Install with: pip install controlnet_aux"
+ )
+
-#TODO provide gpu native lineart detection
+# TODO provide gpu native lineart detection
class LineartPreprocessor(BasePreprocessor):
"""
Real-time optimized Lineart detection preprocessor for ControlNet
-
+
Extracts line art from input images using controlnet_aux line art detection models.
Supports both realistic and anime-style line art extraction.
Optimized for real-time performance - no fallbacks.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -33,26 +38,28 @@ def get_preprocessor_metadata(cls):
"coarse": {
"type": "bool",
"default": True,
- "description": "Whether to use coarse line art detection (faster but less detailed)"
+ "description": "Whether to use coarse line art detection (faster but less detailed)",
},
"anime_style": {
"type": "bool",
"default": False,
- "description": "Whether to use anime-style line art detection"
- }
+ "description": "Whether to use anime-style line art detection",
+ },
},
- "use_cases": ["Sketch to image", "Line art generation", "Clean line extraction"]
+ "use_cases": ["Sketch to image", "Line art generation", "Clean line extraction"],
}
-
- def __init__(self,
- detect_resolution: int = 512,
- image_resolution: int = 512,
- coarse: bool = True,
- anime_style: bool = False,
- **kwargs):
+
+ def __init__(
+ self,
+ detect_resolution: int = 512,
+ image_resolution: int = 512,
+ coarse: bool = True,
+ anime_style: bool = False,
+ **kwargs,
+ ):
"""
Initialize Lineart preprocessor
-
+
Args:
detect_resolution: Resolution for line art detection
image_resolution: Output image resolution
@@ -65,34 +72,34 @@ def __init__(self,
image_resolution=image_resolution,
coarse=coarse,
anime_style=anime_style,
- **kwargs
+ **kwargs,
)
self._detector = None
-
+
@property
def detector(self):
"""Lazy loading of the line art detector - controlnet_aux only"""
if self._detector is None:
start_time = time.time()
- anime_style = self.params.get('anime_style', False)
-
+ anime_style = self.params.get("anime_style", False)
+
if anime_style:
- self._detector = LineartAnimeDetector.from_pretrained('lllyasviel/Annotators')
+ self._detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
else:
- self._detector = LineartDetector.from_pretrained('lllyasviel/Annotators')
+ self._detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
load_time = time.time() - start_time
logger.info(f"Lineart detector loaded in {load_time:.3f}s")
-
+
return self._detector
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply line art detection to the input image
"""
- detect_resolution = self.params.get('detect_resolution', 512)
- coarse = self.params.get('coarse', False)
+ detect_resolution = self.params.get("detect_resolution", 512)
+ coarse = self.params.get("coarse", False)
if image.size != (detect_resolution, detect_resolution):
image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS)
@@ -100,10 +107,7 @@ def _process_core(self, image: Image.Image) -> Image.Image:
image_resized = image
lineart_image = self.detector(
- image_resized,
- detect_resolution=detect_resolution,
- image_resolution=detect_resolution,
- coarse=coarse
+ image_resized, detect_resolution=detect_resolution, image_resolution=detect_resolution, coarse=coarse
)
- return lineart_image
\ No newline at end of file
+ return lineart_image
diff --git a/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py b/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py
index 7a09a8d31..3d7c72093 100644
--- a/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py
+++ b/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py
@@ -1,12 +1,16 @@
+from typing import List
+
+import cv2
import numpy as np
import torch
-import cv2
-from PIL import Image, ImageDraw
-from typing import Union, Optional, List, Tuple, Dict
+from PIL import Image
+
from .base import BasePreprocessor
+
try:
import mediapipe as mp
+
MEDIAPIPE_AVAILABLE = True
except ImportError:
MEDIAPIPE_AVAILABLE = False
@@ -21,49 +25,87 @@
# 10: RKnee, 11: RAnkle, 12: LHip, 13: LKnee, 14: LAnkle,
# 15: REye, 16: LEye, 17: REar, 18: LEar, 19: LBigToe,
# 20: LSmallToe, 21: LHeel, 22: RBigToe, 23: RSmallToe, 24: RHeel
-
- 0: 0, # Nose -> Nose
- 1: None, # Neck (calculated from shoulders)
+ 0: 0, # Nose -> Nose
+ 1: None, # Neck (calculated from shoulders)
2: 12, # RShoulder -> RightShoulder
- 3: 14, # RElbow -> RightElbow
+ 3: 14, # RElbow -> RightElbow
4: 16, # RWrist -> RightWrist
5: 11, # LShoulder -> LeftShoulder
6: 13, # LElbow -> LeftElbow
7: 15, # LWrist -> LeftWrist
- 8: None, # MidHip (calculated from hips)
+ 8: None, # MidHip (calculated from hips)
9: 24, # RHip -> RightHip
- 10: 26, # RKnee -> RightKnee
- 11: 28, # RAnkle -> RightAnkle
- 12: 23, # LHip -> LeftHip
- 13: 25, # LKnee -> LeftKnee
- 14: 27, # LAnkle -> LeftAnkle
+ 10: 26, # RKnee -> RightKnee
+ 11: 28, # RAnkle -> RightAnkle
+ 12: 23, # LHip -> LeftHip
+ 13: 25, # LKnee -> LeftKnee
+ 14: 27, # LAnkle -> LeftAnkle
15: 5, # REye -> RightEye
16: 2, # LEye -> LeftEye
17: 8, # REar -> RightEar
18: 7, # LEar -> LeftEar
- 19: 31, # LBigToe -> LeftFootIndex
- 20: 31, # LSmallToe -> LeftFootIndex (approximation)
- 21: 29, # LHeel -> LeftHeel
- 22: 32, # RBigToe -> RightFootIndex
- 23: 32, # RSmallToe -> RightFootIndex (approximation)
- 24: 30 # RHeel -> RightHeel
+ 19: 31, # LBigToe -> LeftFootIndex
+ 20: 31, # LSmallToe -> LeftFootIndex (approximation)
+ 21: 29, # LHeel -> LeftHeel
+ 22: 32, # RBigToe -> RightFootIndex
+ 23: 32, # RSmallToe -> RightFootIndex (approximation)
+ 24: 30, # RHeel -> RightHeel
}
# OpenPose connections for proper skeleton rendering
OPENPOSE_LIMB_SEQUENCE = [
- [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7],
- [1, 8], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13],
- [13, 14], [1, 0], [0, 15], [15, 17], [0, 16], [16, 18],
- [14, 19], [19, 20], [14, 21], [11, 22], [22, 23], [11, 24]
+ [1, 2],
+ [1, 5],
+ [2, 3],
+ [3, 4],
+ [5, 6],
+ [6, 7],
+ [1, 8],
+ [8, 9],
+ [9, 10],
+ [10, 11],
+ [8, 12],
+ [12, 13],
+ [13, 14],
+ [1, 0],
+ [0, 15],
+ [15, 17],
+ [0, 16],
+ [16, 18],
+ [14, 19],
+ [19, 20],
+ [14, 21],
+ [11, 22],
+ [22, 23],
+ [11, 24],
]
# Standard OpenPose colors (BGR format) - matching actual OpenPose output
OPENPOSE_COLORS = [
- [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
- [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
- [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
- [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0], [255, 85, 0],
- [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0]
+ [255, 0, 0],
+ [255, 85, 0],
+ [255, 170, 0],
+ [255, 255, 0],
+ [170, 255, 0],
+ [85, 255, 0],
+ [0, 255, 0],
+ [0, 255, 85],
+ [0, 255, 170],
+ [0, 255, 255],
+ [0, 170, 255],
+ [0, 85, 255],
+ [0, 0, 255],
+ [85, 0, 255],
+ [170, 0, 255],
+ [255, 0, 255],
+ [255, 0, 170],
+ [255, 0, 85],
+ [255, 0, 0],
+ [255, 85, 0],
+ [255, 170, 0],
+ [255, 255, 0],
+ [170, 255, 0],
+ [85, 255, 0],
]
# OPTIMIZATION: Vectorized mapping for MediaPipe to OpenPose conversion
@@ -76,19 +118,20 @@
OPENPOSE_COLORS_ARRAY = np.array(OPENPOSE_COLORS, dtype=np.uint8)
LIMB_SEQUENCE_ARRAY = np.array(OPENPOSE_LIMB_SEQUENCE, dtype=np.int32)
+
class MediaPipePosePreprocessor(BasePreprocessor):
"""
MediaPipe-based pose preprocessor for ControlNet that outputs OpenPose-style annotations
-
+
Converts MediaPipe's 33 keypoints to OpenPose's 25 keypoints format and renders
them in the standard OpenPose style for ControlNet compatibility.
-
+
Improvements inspired by TouchDesigner MediaPipe plugin:
- Better confidence filtering
- Temporal smoothing for jitter reduction
- Improved multi-pose support preparation
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -100,89 +143,88 @@ def get_preprocessor_metadata(cls):
"default": 0.5,
"range": [0.0, 1.0],
"step": 0.01,
- "description": "Minimum confidence for pose detection"
+ "description": "Minimum confidence for pose detection",
},
"min_tracking_confidence": {
"type": "float",
"default": 0.5,
"range": [0.0, 1.0],
"step": 0.01,
- "description": "Minimum confidence for pose tracking"
+ "description": "Minimum confidence for pose tracking",
},
"model_complexity": {
"type": "int",
"default": 1,
"range": [0, 2],
- "description": "MediaPipe model complexity (0=fastest, 2=most accurate)"
+ "description": "MediaPipe model complexity (0=fastest, 2=most accurate)",
},
"static_image_mode": {
"type": "bool",
"default": False,
- "description": "Use static image mode (slower but more accurate per frame)"
- },
- "draw_hands": {
- "type": "bool",
- "default": True,
- "description": "Whether to draw hand poses"
- },
- "draw_face": {
- "type": "bool",
- "default": False,
- "description": "Whether to draw face landmarks"
+ "description": "Use static image mode (slower but more accurate per frame)",
},
+ "draw_hands": {"type": "bool", "default": True, "description": "Whether to draw hand poses"},
+ "draw_face": {"type": "bool", "default": False, "description": "Whether to draw face landmarks"},
"line_thickness": {
"type": "int",
"default": 2,
"range": [1, 10],
- "description": "Thickness of skeleton lines"
+ "description": "Thickness of skeleton lines",
},
"circle_radius": {
"type": "int",
"default": 4,
"range": [1, 10],
- "description": "Radius of joint circles"
+ "description": "Radius of joint circles",
},
"confidence_threshold": {
"type": "float",
"default": 0.3,
"range": [0.0, 1.0],
"step": 0.01,
- "description": "Minimum confidence for rendering keypoints"
+ "description": "Minimum confidence for rendering keypoints",
},
"enable_smoothing": {
"type": "bool",
"default": True,
- "description": "Enable temporal smoothing to reduce jitter"
+ "description": "Enable temporal smoothing to reduce jitter",
},
"smoothing_factor": {
"type": "float",
"default": 0.7,
"range": [0.0, 1.0],
"step": 0.01,
- "description": "Smoothing strength (higher = more smoothing)"
- }
+ "description": "Smoothing strength (higher = more smoothing)",
+ },
},
- "use_cases": ["Detailed pose control", "Hand and face detection", "Real-time pose tracking", "Custom confidence tuning"]
+ "use_cases": [
+ "Detailed pose control",
+ "Hand and face detection",
+ "Real-time pose tracking",
+ "Custom confidence tuning",
+ ],
}
-
- def __init__(self,
- detect_resolution: int = 256, # OPTIMIZATION: Reduced from 512 for 4x speedup
- image_resolution: int = 512,
- min_detection_confidence: float = 0.5,
- min_tracking_confidence: float = 0.5,
- model_complexity: int = 1,
- static_image_mode: bool = False, # OPTIMIZATION: Video mode for tracking (3-5x faster)
- draw_hands: bool = True,
- draw_face: bool = False, # Simplified - disable face by default
- line_thickness: int = 2,
- circle_radius: int = 4,
- confidence_threshold: float = 0.3, # TouchDesigner-style confidence filtering
- enable_smoothing: bool = True, # TouchDesigner-inspired smoothing
- smoothing_factor: float = 0.7, # Smoothing strength
- **kwargs):
+
+ def __init__(
+ self,
+ detect_resolution: int = 256, # OPTIMIZATION: Reduced from 512 for 4x speedup
+ image_resolution: int = 512,
+ min_detection_confidence: float = 0.5,
+ min_tracking_confidence: float = 0.5,
+ model_complexity: int = 1,
+ static_image_mode: bool = False, # OPTIMIZATION: Video mode for tracking (3-5x faster)
+ draw_hands: bool = True,
+ draw_face: bool = False, # Simplified - disable face by default
+ line_thickness: int = 2,
+ circle_radius: int = 4,
+ confidence_threshold: float = 0.3, # TouchDesigner-style confidence filtering
+ enable_smoothing: bool = True, # TouchDesigner-inspired smoothing
+ smoothing_factor: float = 0.7, # Smoothing strength
+ **kwargs,
+ ):
"""
Initialize MediaPipe pose preprocessor with TouchDesigner-inspired improvements
-
+
Args:
detect_resolution: Resolution for pose detection
image_resolution: Output image resolution
@@ -201,10 +243,9 @@ def __init__(self,
"""
if not MEDIAPIPE_AVAILABLE:
raise ImportError(
- "MediaPipe is required for MediaPipe pose preprocessing. "
- "Install it with: pip install mediapipe"
+ "MediaPipe is required for MediaPipe pose preprocessing. Install it with: pip install mediapipe"
)
-
+
super().__init__(
detect_resolution=detect_resolution,
image_resolution=image_resolution,
@@ -219,312 +260,335 @@ def __init__(self,
confidence_threshold=confidence_threshold,
enable_smoothing=enable_smoothing,
smoothing_factor=smoothing_factor,
- **kwargs
+ **kwargs,
)
-
+
self._detector = None
self._current_options = None
# TouchDesigner-style smoothing buffers
self._smoothing_buffers = {}
-
+
@property
def detector(self):
"""Lazy loading of the MediaPipe Holistic detector with GPU optimization"""
new_options = {
- 'min_detection_confidence': self.params.get('min_detection_confidence', 0.5),
- 'min_tracking_confidence': self.params.get('min_tracking_confidence', 0.5),
- 'model_complexity': self.params.get('model_complexity', 1),
- 'static_image_mode': self.params.get('static_image_mode', False), # Video mode default
+ "min_detection_confidence": self.params.get("min_detection_confidence", 0.5),
+ "min_tracking_confidence": self.params.get("min_tracking_confidence", 0.5),
+ "model_complexity": self.params.get("model_complexity", 1),
+ "static_image_mode": self.params.get("static_image_mode", False), # Video mode default
}
-
+
# Initialize or update detector if needed
if self._detector is None or self._current_options != new_options:
if self._detector is not None:
self._detector.close()
-
+
# OPTIMIZATION: Try GPU delegate first, fallback to CPU
try:
print("MediaPipePosePreprocessor.detector: Attempting GPU delegate initialization")
-
+
# Try to create base options with GPU delegate
try:
- base_options = mp.tasks.BaseOptions(
- delegate=mp.tasks.BaseOptions.Delegate.GPU
- )
+ base_options = mp.tasks.BaseOptions(delegate=mp.tasks.BaseOptions.Delegate.GPU)
print("MediaPipePosePreprocessor.detector: GPU delegate available")
except Exception as gpu_error:
print(f"MediaPipePosePreprocessor.detector: GPU delegate failed ({gpu_error}), using CPU")
- base_options = mp.tasks.BaseOptions(
- delegate=mp.tasks.BaseOptions.Delegate.CPU
- )
-
+ base_options = mp.tasks.BaseOptions(delegate=mp.tasks.BaseOptions.Delegate.CPU)
+
# Create detector with optimized settings
- print(f"MediaPipePosePreprocessor.detector: Initializing MediaPipe Holistic (video_mode={not new_options['static_image_mode']})")
+ print(
+ f"MediaPipePosePreprocessor.detector: Initializing MediaPipe Holistic (video_mode={not new_options['static_image_mode']})"
+ )
self._detector = mp.solutions.holistic.Holistic(
- static_image_mode=new_options['static_image_mode'],
- model_complexity=new_options['model_complexity'],
+ static_image_mode=new_options["static_image_mode"],
+ model_complexity=new_options["model_complexity"],
enable_segmentation=False,
refine_face_landmarks=False, # Keep simple for speed
- min_detection_confidence=new_options['min_detection_confidence'],
- min_tracking_confidence=new_options['min_tracking_confidence'],
+ min_detection_confidence=new_options["min_detection_confidence"],
+ min_tracking_confidence=new_options["min_tracking_confidence"],
)
-
+
except Exception as e:
print(f"MediaPipePosePreprocessor.detector: Advanced options failed ({e}), using basic setup")
# Fallback to basic setup
self._detector = mp.solutions.holistic.Holistic(
- static_image_mode=new_options['static_image_mode'],
- model_complexity=new_options['model_complexity'],
+ static_image_mode=new_options["static_image_mode"],
+ model_complexity=new_options["model_complexity"],
enable_segmentation=False,
refine_face_landmarks=False,
- min_detection_confidence=new_options['min_detection_confidence'],
- min_tracking_confidence=new_options['min_tracking_confidence'],
+ min_detection_confidence=new_options["min_detection_confidence"],
+ min_tracking_confidence=new_options["min_tracking_confidence"],
)
-
+
self._current_options = new_options
-
+
return self._detector
-
+
def _apply_smoothing(self, keypoints: List[List[float]], pose_id: str = "default") -> List[List[float]]:
"""
Apply TouchDesigner-inspired temporal smoothing - VECTORIZED
-
+
Args:
keypoints: Current frame keypoints
pose_id: Unique identifier for this pose
-
+
Returns:
Smoothed keypoints
"""
- if not self.params.get('enable_smoothing', True) or not keypoints:
+ if not self.params.get("enable_smoothing", True) or not keypoints:
return keypoints
-
- smoothing_factor = self.params.get('smoothing_factor', 0.7)
-
+
+ smoothing_factor = self.params.get("smoothing_factor", 0.7)
+
# Initialize buffer for this pose if needed
if pose_id not in self._smoothing_buffers:
self._smoothing_buffers[pose_id] = keypoints.copy()
return keypoints
-
+
# OPTIMIZATION: Vectorized exponential smoothing
current_array = np.array(keypoints, dtype=np.float32)
previous_array = np.array(self._smoothing_buffers[pose_id], dtype=np.float32)
-
+
# Create confidence mask for selective smoothing
confidence_mask = current_array[:, 2] > 0.1
-
+
# Vectorized smoothing calculation
smoothed_array = previous_array.copy()
# Apply smoothing only where confidence is good
- smoothed_array[confidence_mask, :2] = (
- previous_array[confidence_mask, :2] * smoothing_factor +
- current_array[confidence_mask, :2] * (1 - smoothing_factor)
- )
+ smoothed_array[confidence_mask, :2] = previous_array[confidence_mask, :2] * smoothing_factor + current_array[
+ confidence_mask, :2
+ ] * (1 - smoothing_factor)
# Always use current confidence values
smoothed_array[:, 2] = current_array[:, 2]
-
+
# Update buffer and return
smoothed_list = smoothed_array.tolist()
self._smoothing_buffers[pose_id] = smoothed_list
return smoothed_list
-
- def _mediapipe_to_openpose(self, mediapipe_landmarks: List, image_width: int, image_height: int) -> List[List[float]]:
+
+ def _mediapipe_to_openpose(
+ self, mediapipe_landmarks: List, image_width: int, image_height: int
+ ) -> List[List[float]]:
"""
Convert MediaPipe landmarks to OpenPose format - VECTORIZED
-
+
Args:
mediapipe_landmarks: MediaPipe pose landmarks
image_width: Image width
image_height: Image height
-
+
Returns:
OpenPose keypoints in [x, y, confidence] format
"""
if not mediapipe_landmarks:
return []
-
+
# OPTIMIZATION: Vectorized landmark conversion
# Extract all coordinates and confidences in one go
- landmarks_data = np.array([
- [lm.x * image_width, lm.y * image_height,
- lm.visibility if hasattr(lm, 'visibility') else 1.0]
- for lm in mediapipe_landmarks
- ], dtype=np.float32)
-
+ landmarks_data = np.array(
+ [
+ [lm.x * image_width, lm.y * image_height, lm.visibility if hasattr(lm, "visibility") else 1.0]
+ for lm in mediapipe_landmarks
+ ],
+ dtype=np.float32,
+ )
+
# Initialize OpenPose keypoints array (25 points x 3 values)
openpose_keypoints = np.zeros((25, 3), dtype=np.float32)
-
+
# OPTIMIZATION: Vectorized mapping using advanced indexing
# Only map valid indices that exist in landmarks_data
valid_mask = MEDIAPIPE_INDICES < len(landmarks_data)
valid_mp_indices = MEDIAPIPE_INDICES[valid_mask]
valid_op_indices = OPENPOSE_INDICES[valid_mask]
-
+
# Vectorized assignment
openpose_keypoints[valid_op_indices] = landmarks_data[valid_mp_indices]
-
+
# OPTIMIZATION: Vectorized derived point calculations
- confidence_threshold = self.params.get('confidence_threshold', 0.3)
-
+ confidence_threshold = self.params.get("confidence_threshold", 0.3)
+
# Neck (1): midpoint between shoulders (indices 11, 12)
- if (len(landmarks_data) > 12 and
- landmarks_data[11, 2] > confidence_threshold and
- landmarks_data[12, 2] > confidence_threshold):
+ if (
+ len(landmarks_data) > 12
+ and landmarks_data[11, 2] > confidence_threshold
+ and landmarks_data[12, 2] > confidence_threshold
+ ):
# Vectorized midpoint calculation
neck_point = np.mean(landmarks_data[[11, 12]], axis=0)
neck_point[2] = np.min(landmarks_data[[11, 12], 2]) # Min confidence
openpose_keypoints[1] = neck_point
-
+
# MidHip (8): midpoint between hips (indices 23, 24)
- if (len(landmarks_data) > 24 and
- landmarks_data[23, 2] > confidence_threshold and
- landmarks_data[24, 2] > confidence_threshold):
+ if (
+ len(landmarks_data) > 24
+ and landmarks_data[23, 2] > confidence_threshold
+ and landmarks_data[24, 2] > confidence_threshold
+ ):
# Vectorized midpoint calculation
midhip_point = np.mean(landmarks_data[[23, 24]], axis=0)
midhip_point[2] = np.min(landmarks_data[[23, 24], 2]) # Min confidence
openpose_keypoints[8] = midhip_point
-
+
# Convert back to list format for compatibility
return openpose_keypoints.tolist()
-
+
def _draw_openpose_skeleton(self, image: np.ndarray, keypoints: List[List[float]]) -> np.ndarray:
"""
Draw OpenPose-style skeleton on image
-
+
Args:
image: Input image
keypoints: OpenPose keypoints
-
+
Returns:
Image with skeleton drawn
"""
if not keypoints or len(keypoints) != 25:
return image
-
+
h, w = image.shape[:2]
- line_thickness = self.params.get('line_thickness', 2)
- circle_radius = self.params.get('circle_radius', 4)
- confidence_threshold = self.params.get('confidence_threshold', 0.3)
-
+ line_thickness = self.params.get("line_thickness", 2)
+ circle_radius = self.params.get("circle_radius", 4)
+ confidence_threshold = self.params.get("confidence_threshold", 0.3)
+
# OPTIMIZATION: Vectorized limb drawing with confidence filtering
keypoints_array = np.array(keypoints, dtype=np.float32)
-
+
# Draw limbs
for i, (start_idx, end_idx) in enumerate(LIMB_SEQUENCE_ARRAY):
- if (start_idx < len(keypoints_array) and end_idx < len(keypoints_array) and
- keypoints_array[start_idx, 2] > confidence_threshold and keypoints_array[end_idx, 2] > confidence_threshold):
-
+ if (
+ start_idx < len(keypoints_array)
+ and end_idx < len(keypoints_array)
+ and keypoints_array[start_idx, 2] > confidence_threshold
+ and keypoints_array[end_idx, 2] > confidence_threshold
+ ):
start_point = (int(keypoints_array[start_idx, 0]), int(keypoints_array[start_idx, 1]))
end_point = (int(keypoints_array[end_idx, 0]), int(keypoints_array[end_idx, 1]))
-
+
# Use vectorized color array
color = OPENPOSE_COLORS_ARRAY[i % len(OPENPOSE_COLORS_ARRAY)].tolist()
-
+
cv2.line(image, start_point, end_point, color, line_thickness)
-
+
# OPTIMIZATION: Vectorized keypoint drawing with confidence filtering
confidence_mask = keypoints_array[:, 2] > confidence_threshold
valid_indices = np.where(confidence_mask)[0]
-
+
for i in valid_indices:
center = (int(keypoints_array[i, 0]), int(keypoints_array[i, 1]))
color = OPENPOSE_COLORS_ARRAY[i % len(OPENPOSE_COLORS_ARRAY)].tolist()
cv2.circle(image, center, circle_radius, color, -1)
-
+
return image
-
+
def _draw_hand_keypoints(self, image: np.ndarray, hand_landmarks: List, is_left_hand: bool = True) -> np.ndarray:
"""
Draw hand keypoints in OpenPose style - FIXED coordinate mapping
-
+
Args:
image: Input image
hand_landmarks: MediaPipe hand landmarks
is_left_hand: Whether this is the left hand
-
+
Returns:
Image with hand keypoints drawn
"""
if not hand_landmarks:
return image
-
+
h, w = image.shape[:2]
- confidence_threshold = self.params.get('confidence_threshold', 0.3)
-
+ confidence_threshold = self.params.get("confidence_threshold", 0.3)
+
# Standard hand connections (21 landmarks per hand)
hand_connections = [
# Thumb
- (0, 1), (1, 2), (2, 3), (3, 4),
- # Index finger
- (0, 5), (5, 6), (6, 7), (7, 8),
+ (0, 1),
+ (1, 2),
+ (2, 3),
+ (3, 4),
+ # Index finger
+ (0, 5),
+ (5, 6),
+ (6, 7),
+ (7, 8),
# Middle finger
- (0, 9), (9, 10), (10, 11), (11, 12),
+ (0, 9),
+ (9, 10),
+ (10, 11),
+ (11, 12),
# Ring finger
- (0, 13), (13, 14), (14, 15), (15, 16),
+ (0, 13),
+ (13, 14),
+ (14, 15),
+ (15, 16),
# Pinky
- (0, 17), (17, 18), (18, 19), (19, 20),
+ (0, 17),
+ (17, 18),
+ (18, 19),
+ (19, 20),
# Palm connections
- (5, 9), (9, 13), (13, 17),
+ (5, 9),
+ (9, 13),
+ (13, 17),
]
-
+
# OPTIMIZATION: Vectorized hand coordinate conversion
landmarks_array = np.array([[lm.x * w, lm.y * h] for lm in hand_landmarks], dtype=np.int32)
hand_points = [(int(pt[0]), int(pt[1])) for pt in landmarks_array]
-
+
# Standard hand colors
hand_color = [255, 128, 0] if is_left_hand else [0, 255, 255] # Orange for left, cyan for right
-
+
# Draw connections
for start_idx, end_idx in hand_connections:
if start_idx < len(hand_points) and end_idx < len(hand_points):
cv2.line(image, hand_points[start_idx], hand_points[end_idx], hand_color, 2)
-
+
# Draw keypoints
for point in hand_points:
cv2.circle(image, point, 3, hand_color, -1)
-
+
return image
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply MediaPipe pose detection and create OpenPose-style annotation
"""
- detect_resolution = self.params.get('detect_resolution', 512)
+ detect_resolution = self.params.get("detect_resolution", 512)
image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS)
-
+
rgb_image = cv2.cvtColor(np.array(image_resized), cv2.COLOR_BGR2RGB)
-
+
results = self.detector.process(rgb_image)
-
+
pose_image = np.zeros((detect_resolution, detect_resolution, 3), dtype=np.uint8)
-
+
if results.pose_landmarks:
openpose_keypoints = self._mediapipe_to_openpose(
- results.pose_landmarks.landmark,
- detect_resolution,
- detect_resolution
+ results.pose_landmarks.landmark, detect_resolution, detect_resolution
)
-
+
openpose_keypoints = self._apply_smoothing(openpose_keypoints, "main_pose")
-
+
pose_image = self._draw_openpose_skeleton(pose_image, openpose_keypoints)
-
- draw_hands = self.params.get('draw_hands', True)
+
+ draw_hands = self.params.get("draw_hands", True)
if draw_hands:
if results.left_hand_landmarks:
pose_image = self._draw_hand_keypoints(
pose_image, results.left_hand_landmarks.landmark, is_left_hand=True
)
-
+
if results.right_hand_landmarks:
pose_image = self._draw_hand_keypoints(
pose_image, results.right_hand_landmarks.landmark, is_left_hand=False
)
-
+
pose_pil = Image.fromarray(cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB))
-
+
return pose_pil
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU to avoid unnecessary CPU transfers
@@ -532,23 +596,23 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
pil_image = self.tensor_to_pil(image_tensor)
processed_pil = self._process_core(pil_image)
return self.pil_to_tensor(processed_pil)
-
+
def reset_smoothing_buffers(self):
"""Reset smoothing buffers (useful for new sequences)"""
print("MediaPipePosePreprocessor.reset_smoothing_buffers: Clearing smoothing buffers")
self._smoothing_buffers.clear()
-
+
def reset_tracking(self):
"""Reset MediaPipe tracking for new video sequences (when using video mode)"""
print("MediaPipePosePreprocessor.reset_tracking: Resetting MediaPipe tracking state")
- if hasattr(self, '_detector') and self._detector is not None:
+ if hasattr(self, "_detector") and self._detector is not None:
# Force detector recreation to reset tracking state
self._detector.close()
self._detector = None
self._current_options = None
self.reset_smoothing_buffers()
-
+
def __del__(self):
"""Cleanup MediaPipe detector"""
- if hasattr(self, '_detector') and self._detector is not None:
- self._detector.close()
\ No newline at end of file
+ if hasattr(self, "_detector") and self._detector is not None:
+ self._detector.close()
diff --git a/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py b/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py
index 004b250c5..0f4893f23 100644
--- a/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py
+++ b/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py
@@ -1,12 +1,16 @@
+from typing import Tuple
+
+import cv2
import numpy as np
import torch
-import cv2
from PIL import Image
-from typing import Union, Optional, List, Tuple
+
from .base import BasePreprocessor
+
try:
import mediapipe as mp
+
MEDIAPIPE_AVAILABLE = True
except ImportError:
MEDIAPIPE_AVAILABLE = False
@@ -15,11 +19,11 @@
class MediaPipeSegmentationPreprocessor(BasePreprocessor):
"""
MediaPipe-based segmentation preprocessor for ControlNet
-
+
Uses MediaPipe's Selfie Segmentation model to create accurate person segmentation masks.
Outputs binary masks suitable for ControlNet conditioning.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -30,43 +34,50 @@ def get_preprocessor_metadata(cls):
"type": "int",
"default": 1,
"range": [0, 1],
- "description": "Model type (0=general/faster, 1=landscape/better quality)"
+ "description": "Model type (0=general/faster, 1=landscape/better quality)",
},
"threshold": {
"type": "float",
"default": 0.5,
"range": [0.0, 1.0],
"step": 0.01,
- "description": "Confidence threshold for segmentation"
+ "description": "Confidence threshold for segmentation",
},
"blur_radius": {
"type": "int",
"default": 0,
"range": [0, 20],
- "description": "Blur radius for mask smoothing (0=no blur)"
+ "description": "Blur radius for mask smoothing (0=no blur)",
},
"invert_mask": {
"type": "bool",
"default": False,
- "description": "Whether to invert the segmentation mask"
- }
+ "description": "Whether to invert the segmentation mask",
+ },
},
- "use_cases": ["Precise object control", "Background replacement", "Person segmentation", "Mask generation"]
+ "use_cases": [
+ "Precise object control",
+ "Background replacement",
+ "Person segmentation",
+ "Mask generation",
+ ],
}
-
- def __init__(self,
- detect_resolution: int = 512,
- image_resolution: int = 512,
- model_selection: int = 1, # 0 for general model, 1 for landscape model
- threshold: float = 0.5,
- blur_radius: int = 0,
- invert_mask: bool = False,
- output_mode: str = "binary", # "binary", "alpha", "background"
- background_color: Tuple[int, int, int] = (0, 0, 0),
- **kwargs):
+
+ def __init__(
+ self,
+ detect_resolution: int = 512,
+ image_resolution: int = 512,
+ model_selection: int = 1, # 0 for general model, 1 for landscape model
+ threshold: float = 0.5,
+ blur_radius: int = 0,
+ invert_mask: bool = False,
+ output_mode: str = "binary", # "binary", "alpha", "background"
+ background_color: Tuple[int, int, int] = (0, 0, 0),
+ **kwargs,
+ ):
"""
Initialize MediaPipe segmentation preprocessor
-
+
Args:
detect_resolution: Resolution for segmentation processing
image_resolution: Output image resolution
@@ -83,7 +94,7 @@ def __init__(self,
"MediaPipe is required for MediaPipe segmentation preprocessing. "
"Install it with: pip install mediapipe"
)
-
+
super().__init__(
detect_resolution=detect_resolution,
image_resolution=image_resolution,
@@ -93,145 +104,145 @@ def __init__(self,
invert_mask=invert_mask,
output_mode=output_mode,
background_color=background_color,
- **kwargs
+ **kwargs,
)
-
+
self._segmentor = None
self._current_options = None
-
+
@property
def segmentor(self):
"""Lazy loading of the MediaPipe Selfie Segmentation model"""
new_options = {
- 'model_selection': self.params.get('model_selection', 1),
+ "model_selection": self.params.get("model_selection", 1),
}
-
+
# Initialize or update segmentor if needed
if self._segmentor is None or self._current_options != new_options:
if self._segmentor is not None:
self._segmentor.close()
-
- print(f"MediaPipeSegmentationPreprocessor.segmentor: Initializing MediaPipe Selfie Segmentation model")
+
+ print("MediaPipeSegmentationPreprocessor.segmentor: Initializing MediaPipe Selfie Segmentation model")
self._segmentor = mp.solutions.selfie_segmentation.SelfieSegmentation(
- model_selection=new_options['model_selection']
+ model_selection=new_options["model_selection"]
)
self._current_options = new_options
-
+
return self._segmentor
-
+
def _apply_mask_smoothing(self, mask: np.ndarray) -> np.ndarray:
"""
Apply smoothing to the segmentation mask
-
+
Args:
mask: Input segmentation mask
-
+
Returns:
Smoothed mask
"""
- blur_radius = self.params.get('blur_radius', 0)
-
+ blur_radius = self.params.get("blur_radius", 0)
+
if blur_radius > 0:
# Apply Gaussian blur for smoother edges
kernel_size = blur_radius * 2 + 1
mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
-
+
return mask
-
+
def _threshold_mask(self, mask: np.ndarray) -> np.ndarray:
"""
Apply threshold to segmentation mask
-
+
Args:
mask: Input segmentation mask (0.0-1.0)
-
+
Returns:
Binary mask
"""
- threshold = self.params.get('threshold', 0.5)
- invert_mask = self.params.get('invert_mask', False)
-
+ threshold = self.params.get("threshold", 0.5)
+ invert_mask = self.params.get("invert_mask", False)
+
# Apply threshold
binary_mask = (mask > threshold).astype(np.uint8)
-
+
# Invert if requested
if invert_mask:
binary_mask = 1 - binary_mask
-
+
return binary_mask
-
+
def _create_output_image(self, original_image: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
Create final output image based on output mode
-
+
Args:
original_image: Original input image
mask: Segmentation mask
-
+
Returns:
Output image
"""
- output_mode = self.params.get('output_mode', 'binary')
-
- if output_mode == 'binary':
+ output_mode = self.params.get("output_mode", "binary")
+
+ if output_mode == "binary":
# Create binary black/white mask
binary_mask = self._threshold_mask(mask)
output = np.stack([binary_mask * 255] * 3, axis=-1)
-
- elif output_mode == 'alpha':
+
+ elif output_mode == "alpha":
# Create RGBA output with alpha channel
if len(original_image.shape) == 3:
alpha = (mask * 255).astype(np.uint8)
output = np.concatenate([original_image, alpha[..., np.newaxis]], axis=-1)
else:
output = original_image
-
- elif output_mode == 'background':
+
+ elif output_mode == "background":
# Replace background with solid color
- background_color = self.params.get('background_color', (0, 0, 0))
+ background_color = self.params.get("background_color", (0, 0, 0))
binary_mask = self._threshold_mask(mask)
-
+
output = original_image.copy()
# Apply background where mask is 0
for i in range(3):
output[..., i] = np.where(binary_mask, output[..., i], background_color[i])
-
+
else:
raise ValueError(f"Unknown output_mode: {output_mode}")
-
+
return output
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply MediaPipe segmentation to the input image
"""
- detect_resolution = self.params.get('detect_resolution', 512)
+ detect_resolution = self.params.get("detect_resolution", 512)
image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS)
-
+
rgb_image = cv2.cvtColor(np.array(image_resized), cv2.COLOR_BGR2RGB)
-
+
results = self.segmentor.process(rgb_image)
-
+
if results.segmentation_mask is not None:
mask = results.segmentation_mask
-
+
mask = self._apply_mask_smoothing(mask)
-
+
output_image = self._create_output_image(rgb_image, mask)
else:
- output_mode = self.params.get('output_mode', 'binary')
- if output_mode == 'binary':
+ output_mode = self.params.get("output_mode", "binary")
+ if output_mode == "binary":
output_image = np.zeros((detect_resolution, detect_resolution, 3), dtype=np.uint8)
else:
output_image = rgb_image
-
+
if output_image.shape[-1] == 4:
- result_pil = Image.fromarray(output_image, 'RGBA')
+ result_pil = Image.fromarray(output_image, "RGBA")
else:
- result_pil = Image.fromarray(output_image, 'RGB')
-
+ result_pil = Image.fromarray(output_image, "RGB")
+
return result_pil
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU to avoid unnecessary CPU transfers
@@ -239,8 +250,8 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
pil_image = self.tensor_to_pil(image_tensor)
processed_pil = self._process_core(pil_image)
return self.pil_to_tensor(processed_pil)
-
+
def __del__(self):
"""Cleanup MediaPipe segmentor"""
- if hasattr(self, '_segmentor') and self._segmentor is not None:
- self._segmentor.close()
\ No newline at end of file
+ if hasattr(self, "_segmentor") and self._segmentor is not None:
+ self._segmentor.close()
diff --git a/src/streamdiffusion/preprocessing/processors/openpose.py b/src/streamdiffusion/preprocessing/processors/openpose.py
index a7ba15498..53ac8afe2 100644
--- a/src/streamdiffusion/preprocessing/processors/openpose.py
+++ b/src/streamdiffusion/preprocessing/processors/openpose.py
@@ -1,16 +1,18 @@
-import numpy as np
from PIL import Image, ImageDraw
-from typing import Union, Optional, List, Tuple
+
from .base import BasePreprocessor
+
try:
import cv2
+
OPENCV_AVAILABLE = True
except ImportError:
OPENCV_AVAILABLE = False
try:
from controlnet_aux import OpenposeDetector
+
CONTROLNET_AUX_AVAILABLE = True
except ImportError:
CONTROLNET_AUX_AVAILABLE = False
@@ -19,10 +21,10 @@
class OpenPosePreprocessor(BasePreprocessor):
"""
OpenPose human pose detection preprocessor for ControlNet
-
+
Detects human poses and creates stick figure representations.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -32,26 +34,28 @@ def get_preprocessor_metadata(cls):
"include_hands": {
"type": "bool",
"default": False,
- "description": "Whether to include hand keypoints in detection"
+ "description": "Whether to include hand keypoints in detection",
},
"include_face": {
"type": "bool",
"default": False,
- "description": "Whether to include face keypoints in detection"
- }
+ "description": "Whether to include face keypoints in detection",
+ },
},
- "use_cases": ["Human pose control", "Dance movements", "Character poses"]
+ "use_cases": ["Human pose control", "Dance movements", "Character poses"],
}
-
- def __init__(self,
- detect_resolution: int = 512,
- image_resolution: int = 512,
- include_hands: bool = False,
- include_face: bool = False,
- **kwargs):
+
+ def __init__(
+ self,
+ detect_resolution: int = 512,
+ image_resolution: int = 512,
+ include_hands: bool = False,
+ include_face: bool = False,
+ **kwargs,
+ ):
"""
Initialize OpenPose preprocessor
-
+
Args:
detect_resolution: Resolution for pose detection
image_resolution: Output image resolution
@@ -64,81 +68,93 @@ def __init__(self,
image_resolution=image_resolution,
include_hands=include_hands,
include_face=include_face,
- **kwargs
+ **kwargs,
)
-
+
self._detector = None
-
+
@property
def detector(self):
"""Lazy loading of the OpenPose detector"""
if self._detector is None:
if CONTROLNET_AUX_AVAILABLE:
print("Loading OpenPose detector from controlnet_aux")
- self._detector = OpenposeDetector.from_pretrained('lllyasviel/Annotators')
+ self._detector = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
else:
print("Warning: controlnet_aux not available, using fallback OpenPose implementation")
self._detector = self._create_fallback_detector()
return self._detector
-
+
def _create_fallback_detector(self):
"""Create a simple fallback detector if controlnet_aux is not available"""
+
class FallbackDetector:
def __call__(self, image, include_hands=False, include_face=False):
# Simple fallback: return a blank image with some basic pose lines
width, height = image.size
- pose_image = Image.new('RGB', (width, height), (0, 0, 0))
+ pose_image = Image.new("RGB", (width, height), (0, 0, 0))
draw = ImageDraw.Draw(pose_image)
-
+
# Draw a basic stick figure in the center
center_x, center_y = width // 2, height // 2
-
+
# Head
head_radius = min(width, height) // 20
- draw.ellipse([
- center_x - head_radius, center_y - height // 4 - head_radius,
- center_x + head_radius, center_y - height // 4 + head_radius
- ], outline=(255, 255, 255), width=2)
-
+ draw.ellipse(
+ [
+ center_x - head_radius,
+ center_y - height // 4 - head_radius,
+ center_x + head_radius,
+ center_y - height // 4 + head_radius,
+ ],
+ outline=(255, 255, 255),
+ width=2,
+ )
+
# Body
body_top = center_y - height // 4 + head_radius
body_bottom = center_y + height // 6
draw.line([center_x, body_top, center_x, body_bottom], fill=(255, 255, 255), width=2)
-
+
# Arms
arm_length = width // 6
arm_y = body_top + (body_bottom - body_top) // 3
draw.line([center_x - arm_length, arm_y, center_x + arm_length, arm_y], fill=(255, 255, 255), width=2)
-
+
# Legs
leg_length = height // 8
- draw.line([center_x, body_bottom, center_x - leg_length//2, body_bottom + leg_length], fill=(255, 255, 255), width=2)
- draw.line([center_x, body_bottom, center_x + leg_length//2, body_bottom + leg_length], fill=(255, 255, 255), width=2)
-
+ draw.line(
+ [center_x, body_bottom, center_x - leg_length // 2, body_bottom + leg_length],
+ fill=(255, 255, 255),
+ width=2,
+ )
+ draw.line(
+ [center_x, body_bottom, center_x + leg_length // 2, body_bottom + leg_length],
+ fill=(255, 255, 255),
+ width=2,
+ )
+
return pose_image
-
+
return FallbackDetector()
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply OpenPose detection to the input image
"""
- detect_resolution = self.params.get('detect_resolution', 512)
+ detect_resolution = self.params.get("detect_resolution", 512)
image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS)
-
- include_hands = self.params.get('include_hands', False)
- include_face = self.params.get('include_face', False)
-
- if CONTROLNET_AUX_AVAILABLE and hasattr(self.detector, '__call__'):
+
+ include_hands = self.params.get("include_hands", False)
+ include_face = self.params.get("include_face", False)
+
+ if CONTROLNET_AUX_AVAILABLE and hasattr(self.detector, "__call__"):
try:
- pose_image = self.detector(
- image_resized,
- hand_and_face=include_hands or include_face
- )
+ pose_image = self.detector(image_resized, hand_and_face=include_hands or include_face)
except Exception as e:
print(f"Warning: OpenPose detection failed, using fallback: {e}")
pose_image = self._create_fallback_detector()(image_resized, include_hands, include_face)
else:
pose_image = self.detector(image_resized, include_hands, include_face)
-
- return pose_image
\ No newline at end of file
+
+ return pose_image
diff --git a/src/streamdiffusion/preprocessing/processors/passthrough.py b/src/streamdiffusion/preprocessing/processors/passthrough.py
index e4d1125fe..ea81de6e4 100644
--- a/src/streamdiffusion/preprocessing/processors/passthrough.py
+++ b/src/streamdiffusion/preprocessing/processors/passthrough.py
@@ -1,55 +1,47 @@
-import numpy as np
-from PIL import Image
import torch
-from typing import Union, Optional
+from PIL import Image
+
from .base import BasePreprocessor
class PassthroughPreprocessor(BasePreprocessor):
"""
Passthrough preprocessor for ControlNet
-
+
Simply passes the input image through without any processing.
Useful for ControlNets that expect the raw input image, such as:
- Tile ControlNet
- Reference ControlNet
- Custom ControlNets that don't need preprocessing
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "Passthrough",
"description": "Passes the input image through with minimal processing. Used for tile ControlNet or when you want to use the input image directly.",
- "parameters": {
-
- },
- "use_cases": ["Tile ControlNet", "Image-to-image with structure preservation", "Upscaling with control"]
+ "parameters": {},
+ "use_cases": ["Tile ControlNet", "Image-to-image with structure preservation", "Upscaling with control"],
}
-
- def __init__(self,
- image_resolution: int = 512,
- **kwargs):
+
+ def __init__(self, image_resolution: int = 512, **kwargs):
"""
Initialize passthrough preprocessor
-
+
Args:
image_resolution: Output image resolution
**kwargs: Additional parameters (ignored for passthrough)
"""
- super().__init__(
- image_resolution=image_resolution,
- **kwargs
- )
-
+ super().__init__(image_resolution=image_resolution, **kwargs)
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Pass through the input image with no processing
"""
return image
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Pass through tensor with no processing
"""
- return tensor
\ No newline at end of file
+ return tensor
diff --git a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py
index 7662c37cc..58bea9479 100644
--- a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py
+++ b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py
@@ -1,19 +1,23 @@
-#NOTE: ported from https://github.com/yuvraj108c/ComfyUI-YoloNasPose-Tensorrt
+# NOTE: ported from https://github.com/yuvraj108c/ComfyUI-YoloNasPose-Tensorrt
import os
+
+import cv2
import numpy as np
import torch
import torch.nn.functional as F
-import cv2
from PIL import Image
-from typing import Union, Optional, List, Tuple
+
from .base import BasePreprocessor
+
try:
+ from collections import OrderedDict
+
import tensorrt as trt
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import engine_from_bytes
- from collections import OrderedDict
+
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
@@ -40,7 +44,7 @@
class TensorRTEngine:
"""Simplified TensorRT engine wrapper for pose estimation inference (optimized)"""
-
+
def __init__(self, engine_path):
self.engine_path = engine_path
self.engine = None
@@ -64,13 +68,11 @@ def allocate_buffers(self, device="cuda"):
name = self.engine.get_tensor_name(idx)
shape = self.context.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
-
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
self.context.set_input_shape(name, shape)
-
- tensor = torch.empty(
- tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]
- ).to(device=device)
+
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
self.tensors[name] = tensor
def infer(self, feed_dict, stream=None):
@@ -78,7 +80,7 @@ def infer(self, feed_dict, stream=None):
# Use cached stream if none provided
if stream is None:
stream = self._cuda_stream
-
+
# Copy input data to tensors
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
@@ -86,23 +88,23 @@ def infer(self, feed_dict, stream=None):
# Set tensor addresses
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
-
+
# Execute inference
success = self.context.execute_async_v3(stream)
if not success:
raise ValueError("TensorRT inference failed.")
-
+
return self.tensors
class PoseVisualization:
"""Pose drawing utilities ported from ComfyUI YoloNasPose node"""
-
+
@staticmethod
def draw_skeleton(image, keypoints, edge_links, edge_colors, joint_thickness=10, keypoint_radius=10):
"""Draw pose skeleton on image"""
overlay = image.copy()
-
+
# Draw edges/links between keypoints
for (kp1, kp2), color in zip(edge_links, edge_colors):
if kp1 < len(keypoints) and kp2 < len(keypoints):
@@ -113,32 +115,32 @@ def draw_skeleton(image, keypoints, edge_links, edge_colors, joint_thickness=10,
p1 = (int(keypoints[kp1][0]), int(keypoints[kp1][1]))
p2 = (int(keypoints[kp2][0]), int(keypoints[kp2][1]))
cv2.line(overlay, p1, p2, color=color, thickness=joint_thickness, lineType=cv2.LINE_AA)
-
+
# Draw keypoints
for keypoint in keypoints:
if len(keypoint) >= 3 and keypoint[2] > 0.5: # confidence threshold
x, y = int(keypoint[0]), int(keypoint[1])
cv2.circle(overlay, (x, y), keypoint_radius, (0, 255, 0), -1, cv2.LINE_AA)
-
+
return cv2.addWeighted(overlay, 0.75, image, 0.25, 0)
@staticmethod
def draw_poses(image, poses, edge_links, edge_colors, joint_thickness=10, keypoint_radius=10):
"""Draw multiple poses on image"""
result = image.copy()
-
+
for pose in poses:
result = PoseVisualization.draw_skeleton(
result, pose, edge_links, edge_colors, joint_thickness, keypoint_radius
)
-
+
return result
def iterate_over_batch_predictions(predictions, batch_size):
"""Process batch predictions from TensorRT output"""
num_detections, batch_boxes, batch_scores, batch_joints = predictions
-
+
for image_index in range(batch_size):
num_detection_in_image = int(num_detections[image_index, 0])
@@ -150,35 +152,62 @@ def iterate_over_batch_predictions(predictions, batch_size):
else:
pred_scores = batch_scores[image_index, :num_detection_in_image]
pred_boxes = batch_boxes[image_index, :num_detection_in_image]
- pred_joints = batch_joints[image_index, :num_detection_in_image].reshape(
- (num_detection_in_image, -1, 3))
+ pred_joints = batch_joints[image_index, :num_detection_in_image].reshape((num_detection_in_image, -1, 3))
yield image_index, pred_boxes, pred_scores, pred_joints
+
# precompute edge links define skeleton connections (COCO format)
-edge_links = [[0, 17], [13, 15], [14, 16], [12, 14], [12, 17], [5, 6],
- [11, 13], [7, 9], [5, 7], [17, 11], [6, 8], [8, 10],
- [1, 3], [0, 1], [0, 2], [2, 4]]
+edge_links = [
+ [0, 17],
+ [13, 15],
+ [14, 16],
+ [12, 14],
+ [12, 17],
+ [5, 6],
+ [11, 13],
+ [7, 9],
+ [5, 7],
+ [17, 11],
+ [6, 8],
+ [8, 10],
+ [1, 3],
+ [0, 1],
+ [0, 2],
+ [2, 4],
+]
edge_colors = [
- [255, 0, 0], [255, 85, 0], [170, 255, 0], [85, 255, 0], [85, 255, 0],
- [85, 0, 255], [255, 170, 0], [0, 177, 58], [0, 179, 119], [179, 179, 0],
- [0, 119, 179], [0, 179, 179], [119, 0, 179], [179, 0, 179], [178, 0, 118], [178, 0, 118]
+ [255, 0, 0],
+ [255, 85, 0],
+ [170, 255, 0],
+ [85, 255, 0],
+ [85, 255, 0],
+ [85, 0, 255],
+ [255, 170, 0],
+ [0, 177, 58],
+ [0, 179, 119],
+ [179, 179, 0],
+ [0, 119, 179],
+ [0, 179, 179],
+ [119, 0, 179],
+ [179, 0, 179],
+ [178, 0, 118],
+ [178, 0, 118],
]
+
+
def show_predictions_from_batch_format(predictions):
"""Convert predictions to pose visualization format"""
try:
- image_index, pred_boxes, pred_scores, pred_joints = next(
- iter(iterate_over_batch_predictions(predictions, 1)))
+ image_index, pred_boxes, pred_scores, pred_joints = next(iter(iterate_over_batch_predictions(predictions, 1)))
except Exception as e:
raise RuntimeError(f"show_predictions_from_batch_format: Error in iterate_over_batch_predictions: {e}")
-
-
# Handle case where no poses are detected
if pred_joints.shape[0] == 0:
return np.zeros((640, 640, 3))
-
+
# Add middle joint between shoulders (keypoints 5 and 6)
try:
# Calculate middle joints for all poses at once
@@ -187,49 +216,50 @@ def show_predictions_from_batch_format(predictions):
new_pred_joints = np.concatenate([pred_joints, middle_joints[:, np.newaxis]], axis=1)
except Exception as e:
raise RuntimeError(f"show_predictions_from_batch_format: Error processing poses: {e}")
-
+
# Create black background for pose visualization
black_image = np.zeros((640, 640, 3))
-
+
try:
image = PoseVisualization.draw_poses(
- image=black_image,
- poses=new_pred_joints,
- edge_links=edge_links,
- edge_colors=edge_colors,
+ image=black_image,
+ poses=new_pred_joints,
+ edge_links=edge_links,
+ edge_colors=edge_colors,
joint_thickness=10,
- keypoint_radius=10
+ keypoint_radius=10,
)
except Exception as e:
raise RuntimeError(f"show_predictions_from_batch_format: Error in pose drawing: {e}")
-
+
return image
class YoloNasPoseTensorrtPreprocessor(BasePreprocessor):
"""
YoloNas Pose TensorRT preprocessor for ControlNet
-
+
Uses TensorRT-optimized YoloNas Pose model for fast pose estimation.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "Pose Detection (TensorRT)",
"description": "Fast TensorRT-optimized pose detection using YOLO-NAS Pose model. Detects human pose keypoints with high performance.",
"parameters": {},
- "use_cases": ["Human pose control", "Character animation", "Pose-guided generation", "Real-time pose detection"]
+ "use_cases": [
+ "Human pose control",
+ "Character animation",
+ "Pose-guided generation",
+ "Real-time pose detection",
+ ],
}
-
- def __init__(self,
- engine_path: str = None,
- detect_resolution: int = 640,
- image_resolution: int = 512,
- **kwargs):
+
+ def __init__(self, engine_path: str = None, detect_resolution: int = 640, image_resolution: int = 512, **kwargs):
"""
Initialize TensorRT pose preprocessor
-
+
Args:
engine_path: Path to TensorRT engine file
detect_resolution: Resolution for pose detection (should match engine input)
@@ -241,78 +271,72 @@ def __init__(self,
"TensorRT and polygraphy libraries are required for TensorRT pose preprocessing. "
"Install them with: pip install tensorrt polygraphy"
)
-
+
super().__init__(
- engine_path=engine_path,
- detect_resolution=detect_resolution,
- image_resolution=image_resolution,
- **kwargs
+ engine_path=engine_path, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs
)
-
+
self._engine = None
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._is_cuda_available = torch.cuda.is_available()
-
+
@property
def engine(self):
"""Lazy loading of the TensorRT engine"""
if self._engine is None:
- engine_path = self.params.get('engine_path')
+ engine_path = self.params.get("engine_path")
if engine_path is None:
raise ValueError(
"engine_path is required for TensorRT pose preprocessing. "
"Please provide it in the preprocessor_params config."
)
-
+
if not os.path.exists(engine_path):
raise FileNotFoundError(f"TensorRT engine not found: {engine_path}")
-
+
self._engine = TensorRTEngine(engine_path)
self._engine.load()
self._engine.activate()
self._engine.allocate_buffers()
-
+
return self._engine
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply TensorRT pose estimation to the input image
"""
- detect_resolution = self.params.get('detect_resolution', 640)
-
+ detect_resolution = self.params.get("detect_resolution", 640)
+
image_tensor = torch.from_numpy(np.array(image)).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
-
+
image_resized = F.interpolate(
- image_tensor,
- size=(detect_resolution, detect_resolution),
- mode='bilinear',
- align_corners=False
+ image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False
)
-
+
image_resized_uint8 = (image_resized * 255.0).type(torch.uint8)
-
+
if self._is_cuda_available:
image_resized_uint8 = image_resized_uint8.cuda()
-
+
cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.engine.infer({"input": image_resized_uint8}, cuda_stream)
-
- predictions = [result[key].cpu().numpy() for key in result.keys() if key != 'input']
-
+
+ predictions = [result[key].cpu().numpy() for key in result.keys() if key != "input"]
+
try:
pose_image = show_predictions_from_batch_format(predictions)
except Exception:
# Fallback to black image on error
pose_image = np.zeros((detect_resolution, detect_resolution, 3))
-
+
pose_image = pose_image.clip(0, 255).astype(np.uint8)
pose_image = cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB)
-
+
result = Image.fromarray(pose_image)
-
+
return result
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Process tensor directly on GPU to avoid CPU transfers
@@ -321,31 +345,30 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
image_tensor = image_tensor.unsqueeze(0)
if not image_tensor.is_cuda:
image_tensor = image_tensor.cuda()
-
- detect_resolution = self.params.get('detect_resolution', 640)
-
+
+ detect_resolution = self.params.get("detect_resolution", 640)
+
image_resized = torch.nn.functional.interpolate(
- image_tensor, size=(detect_resolution, detect_resolution),
- mode='bilinear', align_corners=False
+ image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False
)
-
+
image_resized_uint8 = (image_resized * 255.0).type(torch.uint8)
-
+
cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.engine.infer({"input": image_resized_uint8}, cuda_stream)
-
- predictions = [result[key].cpu().numpy() for key in result.keys() if key != 'input']
-
+
+ predictions = [result[key].cpu().numpy() for key in result.keys() if key != "input"]
+
try:
pose_image = show_predictions_from_batch_format(predictions)
pose_image = pose_image.clip(0, 255).astype(np.uint8)
pose_image = cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB)
-
+
pose_tensor = torch.from_numpy(pose_image).float() / 255.0
pose_tensor = pose_tensor.permute(2, 0, 1).unsqueeze(0).cuda()
-
+
except Exception:
# Fallback to black tensor on error
pose_tensor = torch.zeros(1, 3, detect_resolution, detect_resolution).cuda()
-
- return pose_tensor
\ No newline at end of file
+
+ return pose_tensor
diff --git a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py
index 18adfbb26..cdd548765 100644
--- a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py
+++ b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py
@@ -1,22 +1,23 @@
# NOTE: ported from https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt
-import os
-import torch
+import logging
+from collections import OrderedDict
+from pathlib import Path
+from typing import Tuple
+
import numpy as np
-from PIL import Image
-from typing import Optional, Tuple
import requests
+import torch
+from PIL import Image
from tqdm import tqdm
-import hashlib
-import logging
-from pathlib import Path
-from collections import OrderedDict
from .base import BasePreprocessor
+
# Try to import spandrel for model loading
try:
from spandrel import ModelLoader
+
SPANDREL_AVAILABLE = True
except ImportError:
SPANDREL_AVAILABLE = False
@@ -24,9 +25,11 @@
# Try to import TensorRT dependencies
try:
import tensorrt as trt
- from streamdiffusion.acceleration.tensorrt.utilities import engine_from_bytes, bytes_from_path
+
+ from streamdiffusion.acceleration.tensorrt.utilities import bytes_from_path, engine_from_bytes
+
TRT_AVAILABLE = True
-
+
# Numpy to PyTorch dtype mapping (same as depth_tensorrt.py)
numpy_to_torch_dtype_dict = {
np.uint8: torch.uint8,
@@ -40,27 +43,28 @@
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}
-
+
# Handle bool type for numpy compatibility (same as depth_tensorrt.py)
if np.version.full_version >= "1.24.0":
numpy_to_torch_dtype_dict[np.bool_] = torch.bool
else:
numpy_to_torch_dtype_dict[np.bool] = torch.bool
-
+
except ImportError:
TRT_AVAILABLE = False
class RealESRGANEngine:
"""TensorRT engine wrapper for RealESRGAN inference (following depth_tensorrt pattern)"""
-
+
def __init__(self, engine_path):
self.engine_path = engine_path
self.engine = None
self.context = None
self.tensors = OrderedDict()
-
+
import threading
+
self._inference_lock = threading.Lock()
def load(self):
@@ -79,13 +83,13 @@ def allocate_buffers(self, input_shape, device="cuda"):
# Set input shape for dynamic sizing
input_name = "input"
self.context.set_input_shape(input_name, input_shape)
-
+
# Allocate tensors for all bindings
for idx in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(idx)
shape = self.context.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
-
+
# Convert numpy dtype to torch dtype
if dtype == np.float32:
torch_dtype = torch.float32
@@ -93,7 +97,7 @@ def allocate_buffers(self, input_shape, device="cuda"):
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
-
+
tensor = torch.empty(tuple(shape), dtype=torch_dtype, device=device)
self.tensors[name] = tensor
@@ -102,7 +106,7 @@ def infer(self, feed_dict, stream=None):
# Use provided stream or current stream context
if stream is None:
stream = torch.cuda.current_stream().cuda_stream
-
+
# Copy input data to tensors
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
@@ -111,27 +115,29 @@ def infer(self, feed_dict, stream=None):
for name, tensor in self.tensors.items():
addr = tensor.data_ptr()
self.context.set_tensor_address(name, addr)
-
+
with self._inference_lock:
success = self.context.execute_async_v3(stream)
-
+
if not success:
raise RuntimeError("RealESRGANEngine: TensorRT inference failed")
-
+
torch.cuda.synchronize()
-
+
return self.tensors
+
logger = logging.getLogger(__name__)
+
class RealESRGANProcessor(BasePreprocessor):
"""
RealESRGAN 2x upscaling processor with automatic model download, ONNX export, and TensorRT acceleration.
"""
-
+
MODEL_URL = "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true"
-
- @classmethod
+
+ @classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "RealESRGAN 2x",
@@ -140,94 +146,98 @@ def get_preprocessor_metadata(cls):
"enable_tensorrt": {
"type": "bool",
"default": True,
- "description": "Use TensorRT acceleration for faster inference"
+ "description": "Use TensorRT acceleration for faster inference",
},
"force_rebuild": {
- "type": "bool",
+ "type": "bool",
"default": False,
- "description": "Force rebuild TensorRT engine even if it exists"
- }
+ "description": "Force rebuild TensorRT engine even if it exists",
+ },
},
- "use_cases": ["High-quality upscaling", "Real-time 2x enlargement", "Image enhancement"]
+ "use_cases": ["High-quality upscaling", "Real-time 2x enlargement", "Image enhancement"],
}
-
+
def __init__(self, enable_tensorrt: bool = True, force_rebuild: bool = False, **kwargs):
super().__init__(enable_tensorrt=enable_tensorrt, force_rebuild=force_rebuild, **kwargs)
self.enable_tensorrt = enable_tensorrt and TRT_AVAILABLE
self.force_rebuild = force_rebuild
self.scale_factor = 2 # RealESRGAN 2x model
-
+
# Model paths
self.models_dir = Path("models") / "realesrgan"
self.models_dir.mkdir(parents=True, exist_ok=True)
self.model_path = self.models_dir / "RealESRGAN_x2.pth"
self.onnx_path = self.models_dir / "RealESRGAN_x2.onnx"
self.engine_path = self.models_dir / f"RealESRGAN_x2_{trt.__version__ if TRT_AVAILABLE else 'notrt'}.trt"
-
+
# Model state
self.pytorch_model = None
self._engine = None # Lazy loading like depth processor
-
+
# Thread safety for engine initialization
import threading
+
self._engine_lock = threading.Lock()
-
+
# Initialize
self._ensure_model_ready()
-
+
@property
def engine(self):
"""Lazy loading of the TensorRT engine"""
if self._engine is None:
if not self.engine_path.exists():
raise FileNotFoundError(f"TensorRT engine not found: {self.engine_path}")
-
+
self._engine = RealESRGANEngine(str(self.engine_path))
self._engine.load()
self._engine.activate()
-
+
# Allocate buffers for standard input size (will be reallocated as needed)
standard_shape = (1, 3, 512, 512)
self._engine.allocate_buffers(standard_shape, device=self.device)
-
+
return self._engine
-
+
def _download_file(self, url: str, save_path: Path):
"""Download file with progress bar"""
if save_path.exists():
return
-
+
response = requests.get(url, stream=True)
response.raise_for_status()
-
- total_size = int(response.headers.get('content-length', 0))
-
- with open(save_path, 'wb') as file, tqdm(
- desc=f"Downloading {save_path.name}",
- total=total_size,
- unit='iB',
- unit_scale=True,
- unit_divisor=1024,
- colour='green'
- ) as progress_bar:
+
+ total_size = int(response.headers.get("content-length", 0))
+
+ with (
+ open(save_path, "wb") as file,
+ tqdm(
+ desc=f"Downloading {save_path.name}",
+ total=total_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ colour="green",
+ ) as progress_bar,
+ ):
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
-
+
def _ensure_model_ready(self):
"""Ensure PyTorch model is downloaded and loaded"""
# Download model if needed
if not self.model_path.exists():
self._download_file(self.MODEL_URL, self.model_path)
-
+
# Load PyTorch model
if self.pytorch_model is None:
self._load_pytorch_model()
-
+
# Setup TensorRT if enabled
if self.enable_tensorrt:
self._setup_tensorrt()
-
+
def _load_pytorch_model(self):
"""Load PyTorch model from file"""
if not SPANDREL_AVAILABLE:
@@ -235,92 +245,92 @@ def _load_pytorch_model(self):
state_dict = torch.load(self.model_path, map_location=self.device)
# This is a simplified approach - real implementation would need model architecture
return
-
+
model_descriptor = ModelLoader().load_from_file(str(self.model_path))
# Don't force dtype conversion as it can cause type mismatches
# Let the model keep its native dtype and convert inputs as needed
self.pytorch_model = model_descriptor.model.eval().to(device=self.device)
model_dtype = next(self.pytorch_model.parameters()).dtype
-
+
def _export_to_onnx(self):
"""Export PyTorch model to ONNX format"""
if self.onnx_path.exists() and not self.force_rebuild:
return
-
+
if self.pytorch_model is None:
self._load_pytorch_model()
-
+
if self.pytorch_model is None:
return
-
+
# Test with small input for export
test_input = torch.randn(1, 3, 256, 256).to(self.device)
-
+
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "height", 3: "width"},
}
-
+
with torch.no_grad():
torch.onnx.export(
self.pytorch_model,
test_input,
str(self.onnx_path),
verbose=False,
- input_names=['input'],
- output_names=['output'],
+ input_names=["input"],
+ output_names=["output"],
opset_version=17,
export_params=True,
dynamic_axes=dynamic_axes,
)
-
+
def _setup_tensorrt(self):
"""Setup TensorRT engine"""
if not TRT_AVAILABLE:
return
-
+
# Export to ONNX first if needed
if not self.onnx_path.exists():
self._export_to_onnx()
-
+
# Build/load TensorRT engine
self._load_tensorrt_engine()
-
+
def _load_tensorrt_engine(self):
"""Load or build TensorRT engine"""
if self.engine_path.exists() and not self.force_rebuild:
self._load_existing_engine()
else:
self._build_tensorrt_engine()
-
+
def _load_existing_engine(self):
"""Load existing TensorRT engine (now handled by lazy loading property)"""
# Engine loading is now handled by the lazy loading 'engine' property
# This method is kept for compatibility but does nothing
pass
-
+
def _build_tensorrt_engine(self):
"""Build TensorRT engine from ONNX model"""
if not self.onnx_path.exists():
return
-
+
try:
# Create builder and network
builder = trt.Builder(trt.Logger(trt.Logger.WARNING))
network = builder.create_network() # EXPLICIT_BATCH deprecated/ignored in TRT 10.x
parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING))
-
+
# Parse ONNX model
- with open(self.onnx_path, 'rb') as model:
+ with open(self.onnx_path, "rb") as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
pass
return
-
+
# Configure builder
config = builder.create_builder_config()
config.set_flag(trt.BuilderFlag.FP16) # Enable FP16 for better performance
-
+
# Set optimization profile for dynamic shapes
profile = builder.create_optimization_profile()
min_shape = (1, 3, 256, 256)
@@ -328,87 +338,87 @@ def _build_tensorrt_engine(self):
max_shape = (1, 3, 1024, 1024)
profile.set_shape("input", min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
-
+
# Build engine
engine = builder.build_serialized_network(network, config)
-
+
if engine is None:
return
-
+
# Save engine
- with open(self.engine_path, 'wb') as f:
+ with open(self.engine_path, "wb") as f:
f.write(engine)
-
+
# Load the built engine
self._load_existing_engine()
-
- except Exception as e:
+
+ except Exception:
pass
-
+
def _process_with_tensorrt(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process tensor using TensorRT engine (following depth_tensorrt pattern)"""
batch_size, channels, height, width = tensor.shape
input_shape = (batch_size, channels, height, width)
-
+
# Ensure buffers are allocated for this input shape
- if not hasattr(self.engine, 'tensors') or len(self.engine.tensors) == 0:
+ if not hasattr(self.engine, "tensors") or len(self.engine.tensors) == 0:
self.engine.allocate_buffers(input_shape, device=self.device)
else:
# Check if we need to reallocate for different input shape
input_tensor_shape = self.engine.tensors.get("input", torch.empty(0)).shape
if input_tensor_shape != input_shape:
self.engine.allocate_buffers(input_shape, device=self.device)
-
+
# Prepare input tensor
input_tensor = tensor.contiguous()
if input_tensor.dtype != self.engine.tensors["input"].dtype:
input_tensor = input_tensor.to(dtype=self.engine.tensors["input"].dtype)
-
+
# Use engine inference with current stream context for proper synchronization
cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.engine.infer({"input": input_tensor}, cuda_stream)
- output_tensor = result['output']
-
+ output_tensor = result["output"]
+
# Ensure output is properly clamped to [0, 1] range for RealESRGAN
output_tensor = torch.clamp(output_tensor, 0.0, 1.0)
-
+
return output_tensor.clone()
-
+
def _process_with_pytorch(self, tensor: torch.Tensor) -> torch.Tensor:
"""Process tensor using PyTorch model"""
if self.pytorch_model is None:
raise RuntimeError("_process_with_pytorch: PyTorch model not loaded")
-
+
# Ensure model and input tensor have compatible dtypes
model_dtype = next(self.pytorch_model.parameters()).dtype
original_dtype = tensor.dtype
if tensor.dtype != model_dtype:
tensor = tensor.to(dtype=model_dtype)
-
+
with torch.no_grad():
result = self.pytorch_model(tensor)
-
+
# Ensure output is properly clamped to [0, 1] range for RealESRGAN
result = torch.clamp(result, 0.0, 1.0)
-
+
# Convert result to the desired output dtype (self.dtype)
if result.dtype != self.dtype:
result = result.to(dtype=self.dtype)
-
+
return result
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""Core processing using PIL Image"""
# Convert to tensor for processing
tensor = self.pil_to_tensor(image)
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
-
+
# Process with available backend
if self.enable_tensorrt and TRT_AVAILABLE and self.engine_path.exists():
try:
output_tensor = self._process_with_tensorrt(tensor)
- except Exception as e:
+ except Exception:
output_tensor = self._process_with_pytorch(tensor)
elif self.pytorch_model is not None:
output_tensor = self._process_with_pytorch(tensor)
@@ -416,29 +426,29 @@ def _process_core(self, image: Image.Image) -> Image.Image:
# Fallback to simple upscaling if no model is available
target_width, target_height = self.get_target_dimensions()
return image.resize((target_width, target_height), Image.LANCZOS)
-
+
# Convert back to PIL
if output_tensor.dim() == 4:
output_tensor = output_tensor.squeeze(0)
-
+
result_image = self.tensor_to_pil(output_tensor)
-
+
return result_image
-
+
def _ensure_target_size(self, image: Image.Image) -> Image.Image:
"""
Override base class method - for upscaling, we want to keep the upscaled size
Don't resize back to original dimensions
"""
return image
-
+
def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Override base class method - for upscaling, we want to keep the upscaled size
Don't resize back to original dimensions
"""
return tensor
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""Core tensor processing"""
if tensor.dim() == 3:
@@ -446,49 +456,46 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
squeeze_output = True
else:
squeeze_output = False
-
+
# Process with available backend
if self.enable_tensorrt and TRT_AVAILABLE and self.engine_path.exists():
try:
output_tensor = self._process_with_tensorrt(tensor)
- except Exception as e:
+ except Exception:
output_tensor = self._process_with_pytorch(tensor)
elif self.pytorch_model is not None:
output_tensor = self._process_with_pytorch(tensor)
else:
# Fallback using interpolation
output_tensor = torch.nn.functional.interpolate(
- tensor,
- scale_factor=self.scale_factor,
- mode='bicubic',
- align_corners=False
+ tensor, scale_factor=self.scale_factor, mode="bicubic", align_corners=False
)
-
+
if squeeze_output:
output_tensor = output_tensor.squeeze(0)
-
+
return output_tensor
-
+
def get_target_dimensions(self) -> Tuple[int, int]:
"""Get target output dimensions (width, height) - 2x upscaled"""
- width = self.params.get('image_width')
- height = self.params.get('image_height')
-
+ width = self.params.get("image_width")
+ height = self.params.get("image_height")
+
if width is not None and height is not None:
target_dims = (width * self.scale_factor, height * self.scale_factor)
return target_dims
-
+
# Fallback to square resolution
- resolution = self.params.get('image_resolution', 512)
+ resolution = self.params.get("image_resolution", 512)
target_resolution = resolution * self.scale_factor
target_dims = (target_resolution, target_resolution)
return target_dims
-
+
def __del__(self):
"""Cleanup resources"""
- if hasattr(self, '_engine') and self._engine is not None:
+ if hasattr(self, "_engine") and self._engine is not None:
# Cleanup dedicated stream if it exists
- if hasattr(self._engine, '_dedicated_stream'):
+ if hasattr(self._engine, "_dedicated_stream"):
torch.cuda.synchronize()
del self._engine._dedicated_stream
del self._engine
diff --git a/src/streamdiffusion/preprocessing/processors/sharpen.py b/src/streamdiffusion/preprocessing/processors/sharpen.py
index 9660e1cee..05d36fc2d 100644
--- a/src/streamdiffusion/preprocessing/processors/sharpen.py
+++ b/src/streamdiffusion/preprocessing/processors/sharpen.py
@@ -1,22 +1,21 @@
import torch
import torch.nn.functional as F
-import numpy as np
from PIL import Image
-from typing import Union
+
from .base import BasePreprocessor
class SharpenPreprocessor(BasePreprocessor):
"""
GPU-heavy image sharpening preprocessor using unsharp masking and edge enhancement
-
+
Applies sophisticated sharpening using multiple Gaussian operations:
- Multi-scale unsharp masking
- Edge-preserving enhancement
- Laplacian-based detail enhancement
- All operations performed on GPU for maximum performance
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -27,52 +26,54 @@ def get_preprocessor_metadata(cls):
"type": "float",
"default": 1.5,
"range": [0.1, 5.0],
- "description": "Overall sharpening intensity. Higher values create stronger effects."
+ "description": "Overall sharpening intensity. Higher values create stronger effects.",
},
"unsharp_radius": {
"type": "float",
"default": 1.0,
"range": [0.1, 5.0],
- "description": "Radius for unsharp masking blur. Affects detail scale."
+ "description": "Radius for unsharp masking blur. Affects detail scale.",
},
"edge_enhancement": {
"type": "float",
"default": 0.5,
"range": [0.0, 2.0],
- "description": "Edge enhancement factor. Emphasizes image boundaries."
+ "description": "Edge enhancement factor. Emphasizes image boundaries.",
},
"detail_boost": {
"type": "float",
"default": 0.3,
"range": [0.0, 1.0],
- "description": "Fine detail enhancement using Laplacian filtering."
+ "description": "Fine detail enhancement using Laplacian filtering.",
},
"noise_reduction": {
"type": "float",
"default": 0.1,
"range": [0.0, 0.5],
- "description": "Mild noise reduction to prevent amplification."
+ "description": "Mild noise reduction to prevent amplification.",
},
"multi_scale": {
"type": "bool",
"default": True,
- "description": "Use multi-scale processing for better quality (more GPU intensive)."
- }
+ "description": "Use multi-scale processing for better quality (more GPU intensive).",
+ },
},
- "use_cases": ["Detail enhancement", "Photo sharpening", "Edge definition", "Clarity improvement"]
+ "use_cases": ["Detail enhancement", "Photo sharpening", "Edge definition", "Clarity improvement"],
}
-
- def __init__(self,
- sharpen_intensity: float = 1.5,
- unsharp_radius: float = 1.0,
- edge_enhancement: float = 0.5,
- detail_boost: float = 0.3,
- noise_reduction: float = 0.1,
- multi_scale: bool = True,
- **kwargs):
+
+ def __init__(
+ self,
+ sharpen_intensity: float = 1.5,
+ unsharp_radius: float = 1.0,
+ edge_enhancement: float = 0.5,
+ detail_boost: float = 0.3,
+ noise_reduction: float = 0.1,
+ multi_scale: bool = True,
+ **kwargs,
+ ):
"""
Initialize Sharpen preprocessor
-
+
Args:
sharpen_intensity: Overall sharpening strength
unsharp_radius: Blur radius for unsharp masking
@@ -89,194 +90,182 @@ def __init__(self,
detail_boost=detail_boost,
noise_reduction=noise_reduction,
multi_scale=multi_scale,
- **kwargs
+ **kwargs,
)
-
+
# Cache kernels for efficiency
self._cached_gaussian_kernels = {}
self._cached_laplacian_kernel = None
self._cached_edge_kernels = None
-
+
def _create_gaussian_kernel(self, size: int, sigma: float) -> torch.Tensor:
"""Create 2D Gaussian kernel"""
coords = torch.arange(size, dtype=self.dtype, device=self.device)
coords = coords - (size - 1) / 2
- y_grid, x_grid = torch.meshgrid(coords, coords, indexing='ij')
+ y_grid, x_grid = torch.meshgrid(coords, coords, indexing="ij")
gaussian = torch.exp(-(x_grid**2 + y_grid**2) / (2 * sigma**2))
return gaussian / gaussian.sum()
-
+
def _get_gaussian_kernel(self, sigma: float) -> torch.Tensor:
"""Get cached Gaussian kernel"""
# Calculate appropriate kernel size (6 sigma rule)
size = max(3, int(6 * sigma + 1))
if size % 2 == 0:
size += 1
-
+
key = (size, sigma)
if key not in self._cached_gaussian_kernels:
self._cached_gaussian_kernels[key] = self._create_gaussian_kernel(size, sigma)
-
+
return self._cached_gaussian_kernels[key]
-
+
def _create_laplacian_kernel(self) -> torch.Tensor:
"""Create Laplacian kernel for edge detection"""
- kernel = torch.tensor([
- [0, -1, 0],
- [-1, 4, -1],
- [0, -1, 0]
- ], dtype=self.dtype, device=self.device)
+ kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=self.dtype, device=self.device)
return kernel
-
+
def _get_laplacian_kernel(self) -> torch.Tensor:
"""Get cached Laplacian kernel"""
if self._cached_laplacian_kernel is None:
self._cached_laplacian_kernel = self._create_laplacian_kernel()
return self._cached_laplacian_kernel
-
+
def _create_edge_kernels(self) -> tuple:
"""Create Sobel edge detection kernels"""
- sobel_x = torch.tensor([
- [-1, 0, 1],
- [-2, 0, 2],
- [-1, 0, 1]
- ], dtype=self.dtype, device=self.device)
-
- sobel_y = torch.tensor([
- [-1, -2, -1],
- [0, 0, 0],
- [1, 2, 1]
- ], dtype=self.dtype, device=self.device)
-
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=self.dtype, device=self.device)
+
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=self.dtype, device=self.device)
+
return sobel_x, sobel_y
-
+
def _get_edge_kernels(self) -> tuple:
"""Get cached edge kernels"""
if self._cached_edge_kernels is None:
self._cached_edge_kernels = self._create_edge_kernels()
return self._cached_edge_kernels
-
+
def _apply_kernel(self, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
"""Apply convolution kernel to image"""
num_channels = image.shape[1]
padding = kernel.shape[-1] // 2
-
+
# Expand kernel for all channels
kernel_conv = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1)
-
+
return F.conv2d(image, kernel_conv, padding=padding, groups=num_channels)
-
+
def _gaussian_blur(self, image: torch.Tensor, sigma: float) -> torch.Tensor:
"""Apply Gaussian blur"""
kernel = self._get_gaussian_kernel(sigma)
return self._apply_kernel(image, kernel)
-
+
def _unsharp_mask(self, image: torch.Tensor, radius: float, intensity: float) -> torch.Tensor:
"""Apply unsharp masking"""
# Create blurred version
blurred = self._gaussian_blur(image, radius)
-
+
# Create mask (original - blurred)
mask = image - blurred
-
+
# Apply sharpening
sharpened = image + intensity * mask
-
+
return torch.clamp(sharpened, 0, 1)
-
+
def _edge_enhancement(self, image: torch.Tensor, strength: float) -> torch.Tensor:
"""Enhance edges using Sobel operators"""
sobel_x, sobel_y = self._get_edge_kernels()
-
+
# Calculate gradients
grad_x = self._apply_kernel(image, sobel_x)
grad_y = self._apply_kernel(image, sobel_y)
-
+
# Calculate edge magnitude
edge_magnitude = torch.sqrt(grad_x**2 + grad_y**2)
-
+
# Enhance edges
enhanced = image + strength * edge_magnitude
-
+
return torch.clamp(enhanced, 0, 1)
-
+
def _detail_enhancement(self, image: torch.Tensor, strength: float) -> torch.Tensor:
"""Enhance fine details using Laplacian"""
laplacian = self._get_laplacian_kernel()
-
+
# Apply Laplacian filter
details = self._apply_kernel(image, laplacian)
-
+
# Add details back to image
enhanced = image + strength * details
-
+
return torch.clamp(enhanced, 0, 1)
-
+
def _noise_reduction_light(self, image: torch.Tensor, strength: float) -> torch.Tensor:
"""Light noise reduction using small Gaussian blur"""
if strength <= 0:
return image
-
+
# Very light blur to reduce noise
noise_reduced = self._gaussian_blur(image, 0.3)
-
+
# Blend with original
return (1 - strength) * image + strength * noise_reduced
-
+
def _multi_scale_sharpen(self, image: torch.Tensor) -> torch.Tensor:
"""Apply multi-scale sharpening for better quality"""
- sharpen_intensity = self.params.get('sharpen_intensity', 1.5)
- unsharp_radius = self.params.get('unsharp_radius', 1.0)
-
+ sharpen_intensity = self.params.get("sharpen_intensity", 1.5)
+ unsharp_radius = self.params.get("unsharp_radius", 1.0)
+
# Multiple scales for better quality
scales = [unsharp_radius * 0.5, unsharp_radius, unsharp_radius * 2.0]
weights = [0.3, 0.5, 0.2]
-
+
result = image.clone()
-
+
for scale, weight in zip(scales, weights):
# Apply unsharp mask at this scale
sharpened_scale = self._unsharp_mask(image, scale, sharpen_intensity * weight)
-
+
# Blend with result
result = result + weight * (sharpened_scale - image)
-
+
return torch.clamp(result, 0, 1)
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""Apply sharpening using PIL/numpy fallback"""
# Convert to tensor for GPU processing
tensor = self.pil_to_tensor(image)
tensor = tensor.squeeze(0) # Remove batch dimension
-
+
# Process on GPU
sharpened = self._process_tensor_core(tensor)
-
+
# Convert back to PIL
return self.tensor_to_pil(sharpened)
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""GPU-intensive sharpening processing"""
# Ensure batch dimension
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
-
+
# Ensure correct device and dtype
image_tensor = image_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Get parameters
- sharpen_intensity = self.params.get('sharpen_intensity', 1.5)
- unsharp_radius = self.params.get('unsharp_radius', 1.0)
- edge_enhancement = self.params.get('edge_enhancement', 0.5)
- detail_boost = self.params.get('detail_boost', 0.3)
- noise_reduction = self.params.get('noise_reduction', 0.1)
- multi_scale = self.params.get('multi_scale', True)
-
+ sharpen_intensity = self.params.get("sharpen_intensity", 1.5)
+ unsharp_radius = self.params.get("unsharp_radius", 1.0)
+ edge_enhancement = self.params.get("edge_enhancement", 0.5)
+ detail_boost = self.params.get("detail_boost", 0.3)
+ noise_reduction = self.params.get("noise_reduction", 0.1)
+ multi_scale = self.params.get("multi_scale", True)
+
result = image_tensor.clone()
-
+
# Step 1: Light noise reduction (prevent amplification)
if noise_reduction > 0:
result = self._noise_reduction_light(result, noise_reduction)
-
+
# Step 2: Main sharpening
if multi_scale:
# Multi-scale processing (more GPU intensive)
@@ -284,16 +273,16 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
else:
# Single-scale unsharp masking
result = self._unsharp_mask(result, unsharp_radius, sharpen_intensity)
-
+
# Step 3: Edge enhancement
if edge_enhancement > 0:
result = self._edge_enhancement(result, edge_enhancement)
-
+
# Step 4: Fine detail enhancement
if detail_boost > 0:
result = self._detail_enhancement(result, detail_boost)
-
+
# Final clamp to ensure valid range
result = torch.clamp(result, 0, 1)
-
+
return result
diff --git a/src/streamdiffusion/preprocessing/processors/soft_edge.py b/src/streamdiffusion/preprocessing/processors/soft_edge.py
index 67537982b..abbf37b88 100644
--- a/src/streamdiffusion/preprocessing/processors/soft_edge.py
+++ b/src/streamdiffusion/preprocessing/processors/soft_edge.py
@@ -1,9 +1,7 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
-import numpy as np
from PIL import Image
-from typing import Union, Optional
+
from .base import BasePreprocessor
@@ -12,95 +10,100 @@ class MultiScaleSobelOperator(nn.Module):
Real-time multi-scale Sobel edge detector optimized for soft HED-like edges
Based on the existing SobelOperator but enhanced for soft edge detection
"""
-
+
def __init__(self, device="cuda", dtype=torch.float16):
super(MultiScaleSobelOperator, self).__init__()
self.device = device
self.dtype = dtype
-
+
# Multi-scale edge detection (3 scales)
self.edge_conv_x_1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(device)
self.edge_conv_y_1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(device)
-
+
self.edge_conv_x_2 = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False).to(device)
self.edge_conv_y_2 = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False).to(device)
-
+
self.edge_conv_x_3 = nn.Conv2d(1, 1, kernel_size=7, padding=3, bias=False).to(device)
self.edge_conv_y_3 = nn.Conv2d(1, 1, kernel_size=7, padding=3, bias=False).to(device)
-
+
# Gaussian blur for soft edges
self.blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False).to(device)
-
+
self._setup_kernels()
-
+
def _setup_kernels(self):
"""Setup Sobel kernels for different scales"""
# Scale 1: Standard 3x3 Sobel
- sobel_x_3 = torch.tensor([
- [-1.0, 0.0, 1.0],
- [-2.0, 0.0, 2.0],
- [-1.0, 0.0, 1.0]
- ], device=self.device, dtype=self.dtype)
-
- sobel_y_3 = torch.tensor([
- [-1.0, -2.0, -1.0],
- [0.0, 0.0, 0.0],
- [1.0, 2.0, 1.0]
- ], device=self.device, dtype=self.dtype)
-
+ sobel_x_3 = torch.tensor(
+ [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=self.device, dtype=self.dtype
+ )
+
+ sobel_y_3 = torch.tensor(
+ [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], device=self.device, dtype=self.dtype
+ )
+
# Scale 2: 5x5 Sobel
- sobel_x_5 = torch.tensor([
- [-1, -2, 0, 2, 1],
- [-2, -3, 0, 3, 2],
- [-3, -5, 0, 5, 3],
- [-2, -3, 0, 3, 2],
- [-1, -2, 0, 2, 1]
- ], device=self.device, dtype=self.dtype) / 16.0
-
+ sobel_x_5 = (
+ torch.tensor(
+ [[-1, -2, 0, 2, 1], [-2, -3, 0, 3, 2], [-3, -5, 0, 5, 3], [-2, -3, 0, 3, 2], [-1, -2, 0, 2, 1]],
+ device=self.device,
+ dtype=self.dtype,
+ )
+ / 16.0
+ )
+
sobel_y_5 = sobel_x_5.T
-
+
# Scale 3: 7x7 Sobel (smoothed)
- sobel_x_7 = torch.tensor([
- [-1, -2, -3, 0, 3, 2, 1],
- [-2, -3, -4, 0, 4, 3, 2],
- [-3, -4, -5, 0, 5, 4, 3],
- [-4, -5, -6, 0, 6, 5, 4],
- [-3, -4, -5, 0, 5, 4, 3],
- [-2, -3, -4, 0, 4, 3, 2],
- [-1, -2, -3, 0, 3, 2, 1]
- ], device=self.device, dtype=self.dtype) / 32.0
-
+ sobel_x_7 = (
+ torch.tensor(
+ [
+ [-1, -2, -3, 0, 3, 2, 1],
+ [-2, -3, -4, 0, 4, 3, 2],
+ [-3, -4, -5, 0, 5, 4, 3],
+ [-4, -5, -6, 0, 6, 5, 4],
+ [-3, -4, -5, 0, 5, 4, 3],
+ [-2, -3, -4, 0, 4, 3, 2],
+ [-1, -2, -3, 0, 3, 2, 1],
+ ],
+ device=self.device,
+ dtype=self.dtype,
+ )
+ / 32.0
+ )
+
sobel_y_7 = sobel_x_7.T
-
+
# Gaussian kernel for smoothing
- gaussian_5 = torch.tensor([
- [1, 4, 6, 4, 1],
- [4, 16, 24, 16, 4],
- [6, 24, 36, 24, 6],
- [4, 16, 24, 16, 4],
- [1, 4, 6, 4, 1]
- ], device=self.device, dtype=self.dtype) / 256.0
-
+ gaussian_5 = (
+ torch.tensor(
+ [[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]],
+ device=self.device,
+ dtype=self.dtype,
+ )
+ / 256.0
+ )
+
# Set kernel weights
self.edge_conv_x_1.weight = nn.Parameter(sobel_x_3.view(1, 1, 3, 3))
self.edge_conv_y_1.weight = nn.Parameter(sobel_y_3.view(1, 1, 3, 3))
-
+
self.edge_conv_x_2.weight = nn.Parameter(sobel_x_5.view(1, 1, 5, 5))
self.edge_conv_y_2.weight = nn.Parameter(sobel_y_5.view(1, 1, 5, 5))
-
+
self.edge_conv_x_3.weight = nn.Parameter(sobel_x_7.view(1, 1, 7, 7))
self.edge_conv_y_3.weight = nn.Parameter(sobel_y_7.view(1, 1, 7, 7))
-
+
self.blur.weight = nn.Parameter(gaussian_5.view(1, 1, 5, 5))
@torch.no_grad()
def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Fast multi-scale soft edge detection
-
+
Args:
image_tensor: Input tensor [B, C, H, W] or [C, H, W]
-
+
Returns:
Soft edge map tensor [B, 1, H, W] or [1, H, W]
"""
@@ -109,108 +112,108 @@ def forward(self, image_tensor: torch.Tensor) -> torch.Tensor:
if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0)
squeeze_output = True
-
+
# Convert to grayscale if needed
if image_tensor.shape[1] == 3:
# RGB to grayscale
gray = 0.299 * image_tensor[:, 0:1] + 0.587 * image_tensor[:, 1:2] + 0.114 * image_tensor[:, 2:3]
else:
gray = image_tensor[:, 0:1]
-
+
# Multi-scale edge detection
# Scale 1 (fine details)
edge_x1 = self.edge_conv_x_1(gray)
edge_y1 = self.edge_conv_y_1(gray)
edge1 = torch.sqrt(edge_x1**2 + edge_y1**2)
-
+
# Scale 2 (medium details)
edge_x2 = self.edge_conv_x_2(gray)
edge_y2 = self.edge_conv_y_2(gray)
edge2 = torch.sqrt(edge_x2**2 + edge_y2**2)
-
+
# Scale 3 (coarse details)
edge_x3 = self.edge_conv_x_3(gray)
edge_y3 = self.edge_conv_y_3(gray)
edge3 = torch.sqrt(edge_x3**2 + edge_y3**2)
-
+
# Combine scales with weights (like HED side outputs)
combined_edge = 0.5 * edge1 + 0.3 * edge2 + 0.2 * edge3
-
+
# Apply Gaussian smoothing for soft edges
soft_edge = self.blur(combined_edge)
-
+
# Normalize to [0, 1] range
soft_edge = soft_edge / (soft_edge.max() + 1e-8)
-
+
# Apply soft sigmoid activation for smooth transitions
soft_edge = torch.sigmoid(soft_edge * 6.0 - 3.0) # Soft S-curve
-
+
if squeeze_output:
soft_edge = soft_edge.squeeze(0)
-
+
return soft_edge
class SoftEdgePreprocessor(BasePreprocessor):
"""
Real-time soft edge detection preprocessor - HED alternative
-
+
Uses multi-scale Sobel operations for extremely fast soft edge detection
that mimics HED output quality at 50x+ the speed.
"""
-
+
_model_cache = {}
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "Soft Edge Detection",
"description": "Real-time soft edge detection optimized for smooth, artistic edge maps using multi-scale Sobel operations.",
"parameters": {},
- "use_cases": ["Artistic edge maps", "Soft stylistic control", "Real-time edge detection"]
+ "use_cases": ["Artistic edge maps", "Soft stylistic control", "Real-time edge detection"],
}
-
+
def __init__(self, **kwargs):
"""
Initialize soft edge preprocessor
-
+
Args:
**kwargs: Additional parameters
"""
super().__init__(**kwargs)
self.model = None
self._load_model()
-
+
def _load_model(self):
"""
Load multi-scale Sobel operator with caching
"""
cache_key = f"soft_edge_{self.device}_{self.dtype}"
-
+
if cache_key in self._model_cache:
self.model = self._model_cache[cache_key]
return
-
+
print("SoftEdgePreprocessor: Loading real-time multi-scale edge detector")
self.model = MultiScaleSobelOperator(device=self.device, dtype=self.dtype)
self.model.eval()
-
+
# Cache the model
self._model_cache[cache_key] = self.model
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply soft edge detection to the input image
"""
# Convert PIL to tensor for GPU processing
image_tensor = self.pil_to_tensor(image).squeeze(0) # Remove batch dim
-
+
# Process with GPU-accelerated tensor method
processed_tensor = self._process_tensor_core(image_tensor)
-
+
# Convert back to PIL
return self.tensor_to_pil(processed_tensor)
-
+
def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
GPU-optimized soft edge processing using tensors
@@ -218,25 +221,25 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
# Ensure correct input format and device
image_tensor = image_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Normalize to [0, 1] if needed
if image_tensor.max() > 1.0:
image_tensor = image_tensor / 255.0
-
+
# Multi-scale edge detection
edge_map = self.model(image_tensor)
-
+
# Convert to 3-channel RGB format
if edge_map.dim() == 3:
edge_map = edge_map.repeat(3, 1, 1)
else:
edge_map = edge_map.repeat(1, 3, 1, 1).squeeze(0)
-
+
# Ensure output is in [0, 1] range
edge_map = torch.clamp(edge_map, 0.0, 1.0)
-
+
return edge_map
-
+
def get_model_info(self) -> dict:
"""
Get information about the loaded model
@@ -248,24 +251,20 @@ def get_model_info(self) -> dict:
"device": str(self.device),
"dtype": str(self.dtype),
"description": "Real-time multi-scale soft edge detection, HED quality at 50x+ speed",
- "expected_fps": "100+ FPS at 512x512"
+ "expected_fps": "100+ FPS at 512x512",
}
-
+
@classmethod
- def create_optimized(cls, device: str = 'cuda', dtype: torch.dtype = torch.float16, **kwargs):
+ def create_optimized(cls, device: str = "cuda", dtype: torch.dtype = torch.float16, **kwargs):
"""
Create an optimized soft edge preprocessor for real-time use
-
+
Args:
device: Target device ('cuda' or 'cpu')
dtype: Data type for inference
**kwargs: Additional parameters
-
+
Returns:
Optimized SoftEdgePreprocessor instance
"""
- return cls(
- device=device,
- dtype=dtype,
- **kwargs
- )
\ No newline at end of file
+ return cls(device=device, dtype=dtype, **kwargs)
diff --git a/src/streamdiffusion/preprocessing/processors/standard_lineart.py b/src/streamdiffusion/preprocessing/processors/standard_lineart.py
index bc732ea04..81a8ade1a 100644
--- a/src/streamdiffusion/preprocessing/processors/standard_lineart.py
+++ b/src/streamdiffusion/preprocessing/processors/standard_lineart.py
@@ -1,22 +1,22 @@
-import numpy as np
-import cv2
-from PIL import Image
-from typing import Union, Optional
import time
-from .base import BasePreprocessor
+
+import numpy as np
import torch
import torch.nn.functional as F
+from PIL import Image
+
+from .base import BasePreprocessor
class StandardLineartPreprocessor(BasePreprocessor):
"""
Real-time optimized Standard Lineart detection preprocessor for ControlNet
-
+
Extracts line art from input images using traditional computer vision techniques.
Uses Gaussian blur and intensity calculations to detect lines without requiring
pre-trained models. GPU-accelerated with PyTorch for optimal real-time performance.
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -28,27 +28,29 @@ def get_preprocessor_metadata(cls):
"default": 6.0,
"range": [1.0, 20.0],
"step": 0.1,
- "description": "Standard deviation for Gaussian blur (higher = smoother lines)"
+ "description": "Standard deviation for Gaussian blur (higher = smoother lines)",
},
"intensity_threshold": {
"type": "int",
"default": 8,
"range": [1, 50],
- "description": "Threshold for intensity calculation (lower = more sensitive)"
- }
+ "description": "Threshold for intensity calculation (lower = more sensitive)",
+ },
},
- "use_cases": ["Traditional line art", "Simple edge detection", "No AI model required"]
+ "use_cases": ["Traditional line art", "Simple edge detection", "No AI model required"],
}
-
- def __init__(self,
- detect_resolution: int = 512,
- image_resolution: int = 512,
- gaussian_sigma: float = 6.0,
- intensity_threshold: int = 8,
- **kwargs):
+
+ def __init__(
+ self,
+ detect_resolution: int = 512,
+ image_resolution: int = 512,
+ gaussian_sigma: float = 6.0,
+ intensity_threshold: int = 8,
+ **kwargs,
+ ):
"""
Initialize Standard Lineart preprocessor
-
+
Args:
detect_resolution: Resolution for line art detection
image_resolution: Output image resolution
@@ -56,39 +58,39 @@ def __init__(self,
intensity_threshold: Threshold for intensity calculation
**kwargs: Additional parameters
"""
-
+
super().__init__(
detect_resolution=detect_resolution,
image_resolution=image_resolution,
gaussian_sigma=gaussian_sigma,
intensity_threshold=intensity_threshold,
- **kwargs
+ **kwargs,
)
-
+
# Initialize GPU device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
+
def _gaussian_kernel(self, kernel_size: int, sigma: float, device=None) -> torch.Tensor:
"""Create 2D Gaussian kernel - based on existing codebase pattern"""
x, y = torch.meshgrid(
- torch.linspace(-1, 1, kernel_size, device=device),
- torch.linspace(-1, 1, kernel_size, device=device),
- indexing="ij"
+ torch.linspace(-1, 1, kernel_size, device=device),
+ torch.linspace(-1, 1, kernel_size, device=device),
+ indexing="ij",
)
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
-
+
def _gaussian_blur_torch(self, image: torch.Tensor, sigma: float) -> torch.Tensor:
"""Apply Gaussian blur using PyTorch - GPU accelerated"""
# Calculate kernel size from sigma (odd number)
kernel_size = int(2 * torch.ceil(torch.tensor(3 * sigma)) + 1)
if kernel_size % 2 == 0:
kernel_size += 1
-
+
# Create Gaussian kernel
kernel = self._gaussian_kernel(kernel_size, sigma, device=image.device)
-
+
# Handle different input shapes
if image.dim() == 3: # HWC format
H, W, C = image.shape
@@ -100,31 +102,31 @@ def _gaussian_blur_torch(self, image: torch.Tensor, sigma: float) -> torch.Tenso
needs_reshape = False
else:
raise ValueError(f"standardlineart_gaussian_blur_torch: Unsupported image shape: {image.shape}")
-
+
# Expand kernel for all channels
kernel = kernel.repeat(image.shape[1], 1, 1).unsqueeze(1)
-
+
# Apply blur with reflection padding
padding = kernel_size // 2
- padded_image = F.pad(image, (padding, padding, padding, padding), 'reflect')
+ padded_image = F.pad(image, (padding, padding, padding, padding), "reflect")
blurred = F.conv2d(padded_image, kernel, padding=0, groups=image.shape[1])
-
+
# Convert back to original format if needed
if needs_reshape:
blurred = blurred.squeeze(0).permute(1, 2, 0) # BCHW -> HWC
-
+
return blurred
-
+
def _ensure_hwc3_torch(self, x: torch.Tensor) -> torch.Tensor:
"""Ensure image has 3 channels (HWC3 format) - PyTorch version"""
if x.dim() == 2:
x = x.unsqueeze(-1) # Add channel dimension
-
+
if x.dim() != 3:
raise ValueError(f"standardlineart_ensure_hwc3_torch: Expected 2D or 3D tensor, got {x.dim()}D")
-
+
H, W, C = x.shape
-
+
if C == 3:
return x
elif C == 1:
@@ -136,90 +138,87 @@ def _ensure_hwc3_torch(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp(y, 0, 255)
else:
raise ValueError(f"standardlineart_ensure_hwc3_torch: Unsupported channel count: {C}")
-
+
def _pad64(self, x: int) -> int:
"""Pad to nearest multiple of 64"""
return int(torch.ceil(torch.tensor(float(x) / 64.0)) * 64 - x)
-
+
def _resize_image_with_pad_torch(self, input_image: torch.Tensor, resolution: int) -> tuple:
"""Resize image with padding to target resolution - PyTorch GPU accelerated"""
img = self._ensure_hwc3_torch(input_image)
H_raw, W_raw, _ = img.shape
-
+
if resolution == 0:
return img, lambda x: x
-
+
k = float(resolution) / float(min(H_raw, W_raw))
H_target = int(torch.round(torch.tensor(float(H_raw) * k)))
W_target = int(torch.round(torch.tensor(float(W_raw) * k)))
-
+
# Convert to BCHW for interpolation
img_bchw = img.permute(2, 0, 1).unsqueeze(0) # HWC -> BCHW
-
+
# Use PyTorch's interpolate for GPU-accelerated resize
- mode = 'bicubic' if k > 1 else 'area'
+ mode = "bicubic" if k > 1 else "area"
img_resized_bchw = F.interpolate(
- img_bchw,
- size=(H_target, W_target),
- mode=mode,
- align_corners=False if mode == 'bicubic' else None
+ img_bchw, size=(H_target, W_target), mode=mode, align_corners=False if mode == "bicubic" else None
)
-
+
# Convert back to HWC
img_resized = img_resized_bchw.squeeze(0).permute(1, 2, 0)
-
+
# Apply padding
H_pad, W_pad = self._pad64(H_target), self._pad64(W_target)
- img_padded = F.pad(img_resized.permute(2, 0, 1), (0, W_pad, 0, H_pad), mode='replicate').permute(1, 2, 0)
+ img_padded = F.pad(img_resized.permute(2, 0, 1), (0, W_pad, 0, H_pad), mode="replicate").permute(1, 2, 0)
def remove_pad(x):
return x[:H_target, :W_target, ...]
return img_padded, remove_pad
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Apply standard line art detection to the input image
"""
start_time = time.time()
-
+
if isinstance(image, Image.Image):
input_image_cpu = np.array(image, dtype=np.uint8)
else:
input_image_cpu = image.astype(np.uint8)
-
+
input_image = torch.from_numpy(input_image_cpu).float().to(self.device)
-
- detect_resolution = self.params.get('detect_resolution', 512)
- gaussian_sigma = self.params.get('gaussian_sigma', 6.0)
- intensity_threshold = self.params.get('intensity_threshold', 8)
-
+
+ detect_resolution = self.params.get("detect_resolution", 512)
+ gaussian_sigma = self.params.get("gaussian_sigma", 6.0)
+ intensity_threshold = self.params.get("intensity_threshold", 8)
+
input_image, remove_pad = self._resize_image_with_pad_torch(input_image, detect_resolution)
-
+
x = input_image
-
+
g = self._gaussian_blur_torch(x, gaussian_sigma)
-
+
intensity = torch.min(g - x, dim=2)[0]
intensity = torch.clamp(intensity, 0, 255)
-
+
threshold_mask = intensity > intensity_threshold
if torch.any(threshold_mask):
median_val = torch.median(intensity[threshold_mask])
normalization_factor = max(16, float(median_val))
else:
normalization_factor = 16
-
+
intensity = intensity / normalization_factor
intensity = intensity * 127
-
+
detected_map = torch.clamp(intensity, 0, 255).byte()
detected_map = detected_map.unsqueeze(-1)
detected_map = self._ensure_hwc3_torch(detected_map.float())
-
+
detected_map = remove_pad(detected_map)
-
+
detected_map_cpu = detected_map.byte().cpu().numpy()
lineart_image = Image.fromarray(detected_map_cpu)
-
- return lineart_image
\ No newline at end of file
+
+ return lineart_image
diff --git a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py
index 8151722ac..94930e1ad 100644
--- a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py
+++ b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py
@@ -1,26 +1,32 @@
-import torch
-import torch.nn.functional as F
-import numpy as np
-from PIL import Image
import logging
from pathlib import Path
from typing import Any
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
from .base import PipelineAwareProcessor
+
# Try to import TensorRT dependencies
try:
+ from collections import OrderedDict
+
import tensorrt as trt
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import engine_from_bytes
- from collections import OrderedDict
+
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
# Try to import torchvision for RAFT model
try:
- from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
+ from torchvision.models.optical_flow import Raft_Small_Weights, raft_small
from torchvision.utils import flow_to_image
+
TORCHVISION_AVAILABLE = True
except ImportError:
TORCHVISION_AVAILABLE = False
@@ -48,7 +54,7 @@
class TensorRTEngine:
"""TensorRT engine wrapper for RAFT optical flow inference"""
-
+
def __init__(self, engine_path):
self.engine_path = engine_path
self.engine = None
@@ -69,7 +75,7 @@ def activate(self):
def allocate_buffers(self, device="cuda", input_shape=None):
"""
Allocate input/output buffers
-
+
Args:
device: Device to allocate tensors on
input_shape: Shape for input tensors (B, C, H, W). Required for engines with dynamic shapes.
@@ -88,7 +94,6 @@ def allocate_buffers(self, device="cuda", input_shape=None):
else:
raise
-
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
# For dynamic shapes, use provided input_shape
if input_shape is not None and any(dim == -1 for dim in shape):
@@ -99,7 +104,7 @@ def allocate_buffers(self, device="cuda", input_shape=None):
else:
# For output tensors, get shape after input shapes are set
shape = self.context.get_tensor_shape(name)
-
+
# Verify shape has no dynamic dimensions
if any(dim == -1 for dim in shape):
raise RuntimeError(
@@ -114,7 +119,7 @@ def infer(self, feed_dict, stream=None):
"""Run inference with optional stream parameter"""
if stream is None:
stream = self._cuda_stream
-
+
# Check if we need to update tensor shapes for dynamic dimensions
need_realloc = False
for name, buf in feed_dict.items():
@@ -122,7 +127,7 @@ def infer(self, feed_dict, stream=None):
if self.tensors[name].shape != buf.shape:
need_realloc = True
break
-
+
# Reallocate buffers if input shape changed
if need_realloc:
# Update input shapes
@@ -134,18 +139,18 @@ def infer(self, feed_dict, stream=None):
except:
# Tensor name might not be in engine, skip
pass
-
+
# Reallocate all tensors with new shapes
for idx in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(idx)
shape = self.context.get_tensor_shape(name)
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
-
- tensor = torch.empty(
- tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]
- ).to(device=self.tensors[name].device)
+
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(
+ device=self.tensors[name].device
+ )
self.tensors[name] = tensor
-
+
# Copy input data to tensors
for name, buf in feed_dict.items():
self.tensors[name].copy_(buf)
@@ -153,26 +158,26 @@ def infer(self, feed_dict, stream=None):
# Set tensor addresses
for name, tensor in self.tensors.items():
self.context.set_tensor_address(name, tensor.data_ptr())
-
+
# Execute inference
success = self.context.execute_async_v3(stream)
if not success:
raise ValueError("TensorRT inference failed.")
-
+
return self.tensors
class TemporalNetTensorRTPreprocessor(PipelineAwareProcessor):
"""
TensorRT-accelerated TemporalNet preprocessor for temporal consistency using optical flow visualization.
-
+
This preprocessor uses TensorRT to accelerate RAFT optical flow computation and creates a 6-channel
control tensor by concatenating the previous input frame (RGB) with a colorized optical flow
visualization (RGB) computed between the previous and current input frames.
-
+
Output: [prev_input_RGB, flow_RGB(prev_input β current_input)]
"""
-
+
@classmethod
def get_preprocessor_metadata(cls):
return {
@@ -182,53 +187,59 @@ def get_preprocessor_metadata(cls):
"engine_path": {
"type": "str",
"default": None,
- "description": "Path to pre-built TensorRT engine file. Use compile_raft_tensorrt.py to build one."
+ "description": "Path to pre-built TensorRT engine file. Use compile_raft_tensorrt.py to build one.",
},
"flow_strength": {
"type": "float",
"default": 1.0,
"range": [0.0, 2.0],
"step": 0.1,
- "description": "Strength multiplier for optical flow visualization (1.0 = normal, higher = more pronounced flow)"
+ "description": "Strength multiplier for optical flow visualization (1.0 = normal, higher = more pronounced flow)",
},
"height": {
"type": "int",
"default": 512,
"range": [256, 1024],
"step": 64,
- "description": "Height for optical flow computation (must be within engine's height range)"
+ "description": "Height for optical flow computation (must be within engine's height range)",
},
"width": {
"type": "int",
"default": 512,
"range": [256, 1024],
"step": 64,
- "description": "Width for optical flow computation (must be within engine's width range)"
+ "description": "Width for optical flow computation (must be within engine's width range)",
},
"output_format": {
- "type": "str",
+ "type": "str",
"default": "concat",
"options": ["concat", "warped_only"],
- "description": "Output format: 'concat' for 6-channel (prev_input+flow_RGB), 'warped_only' for 3-channel flow RGB only"
- }
+ "description": "Output format: 'concat' for 6-channel (prev_input+flow_RGB), 'warped_only' for 3-channel flow RGB only",
+ },
},
- "use_cases": ["High-performance video generation", "Real-time temporal consistency", "GPU-optimized motion control"]
+ "use_cases": [
+ "High-performance video generation",
+ "Real-time temporal consistency",
+ "GPU-optimized motion control",
+ ],
}
-
- def __init__(self,
- pipeline_ref: Any,
- engine_path: str = None,
- height: int = 512,
- width: int = 512,
- flow_strength: float = 1.0,
- output_format: str = "concat",
- **kwargs):
+
+ def __init__(
+ self,
+ pipeline_ref: Any,
+ engine_path: str = None,
+ height: int = 512,
+ width: int = 512,
+ flow_strength: float = 1.0,
+ output_format: str = "concat",
+ **kwargs,
+ ):
"""
Initialize TensorRT TemporalNet preprocessor
-
+
Args:
pipeline_ref: Reference to the StreamDiffusion pipeline instance (required)
- engine_path: Path to pre-built TensorRT engine file (required).
+ engine_path: Path to pre-built TensorRT engine file (required).
Build one using: python -m streamdiffusion.tools.compile_raft_tensorrt
height: Height for optical flow computation (must be within engine's height range)
width: Width for optical flow computation (must be within engine's width range)
@@ -236,13 +247,12 @@ def __init__(self,
output_format: "concat" for 6-channel [prev_input+flow_RGB], "warped_only" for 3-channel flow RGB only
**kwargs: Additional parameters passed to BasePreprocessor
"""
-
+
if not TORCHVISION_AVAILABLE:
raise ImportError(
- "torchvision is required for TemporalNet preprocessing. "
- "Install it with: pip install torchvision"
+ "torchvision is required for TemporalNet preprocessing. Install it with: pip install torchvision"
)
-
+
if not TENSORRT_AVAILABLE:
raise ImportError(
"TensorRT and polygraphy are required for TensorRT acceleration. "
@@ -255,7 +265,7 @@ def __init__(self,
" python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution 512x512 --max_resolution 1024x1024 --output_dir ./models/temporal_net\n"
"Then pass the engine path to this preprocessor."
)
-
+
super().__init__(
pipeline_ref=pipeline_ref,
height=height,
@@ -263,17 +273,17 @@ def __init__(self,
engine_path=engine_path,
flow_strength=flow_strength,
output_format=output_format,
- **kwargs
+ **kwargs,
)
-
+
self.flow_strength = max(0.0, min(2.0, flow_strength))
self.height = height
self.width = width
self._first_frame = True
-
+
# Store previous input frame for flow computation
self.prev_input = None
-
+
# Engine path
self.engine_path = Path(engine_path)
if not self.engine_path.exists():
@@ -282,17 +292,17 @@ def __init__(self,
f"Build one using:\n"
f" python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution {height}x{width} --max_resolution {height}x{width} --output_dir {self.engine_path.parent}"
)
-
+
# Model state
self.trt_engine = None
-
+
# Cached tensors for performance
self._grid_cache = {}
self._tensor_cache = {}
-
+
# Load TensorRT engine
self._load_tensorrt_engine()
-
+
def _load_tensorrt_engine(self):
"""Load pre-built TensorRT engine"""
logger.info(f"_load_tensorrt_engine: Loading TensorRT engine: {self.engine_path}")
@@ -300,11 +310,11 @@ def _load_tensorrt_engine(self):
self.trt_engine = TensorRTEngine(str(self.engine_path))
self.trt_engine.load()
self.trt_engine.activate()
-
+
# For dynamic shapes, provide the input shape based on image dimensions
input_shape = (1, 3, self.height, self.width)
self.trt_engine.allocate_buffers(device=self.device, input_shape=input_shape)
-
+
logger.info(f"_load_tensorrt_engine: TensorRT engine loaded successfully from {self.engine_path}")
logger.info(f"_load_tensorrt_engine: Using resolution: {self.height}x{self.width}")
except Exception as e:
@@ -315,16 +325,14 @@ def _load_tensorrt_engine(self):
f"Make sure the engine was built with a resolution range that includes {self.height}x{self.width}.\n"
f"For example: python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution 512x512 --max_resolution 1024x1024"
)
-
-
def _process_core(self, image: Image.Image) -> Image.Image:
"""
Process using TensorRT-accelerated optical flow warping
-
+
Args:
image: Current input image
-
+
Returns:
Warped previous frame for temporal guidance, or fallback for first frame
"""
@@ -332,50 +340,50 @@ def _process_core(self, image: Image.Image) -> Image.Image:
tensor = self.pil_to_tensor(image)
result_tensor = self._process_tensor_core(tensor)
return self.tensor_to_pil(result_tensor)
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Process using TensorRT-accelerated optical flow computation (GPU-optimized path)
-
+
Args:
tensor: Current input tensor
-
+
Returns:
Concatenated tensor: [prev_input_RGB, flow_RGB] for temporal guidance
"""
-
+
# Normalize input tensor
input_tensor = tensor
if input_tensor.max() > 1.0:
input_tensor = input_tensor / 255.0
-
+
# Ensure consistent format
if input_tensor.dim() == 4 and input_tensor.shape[0] == 1:
input_tensor = input_tensor[0]
-
+
# Check if we have a previous input frame
if self.prev_input is not None and not self._first_frame:
try:
# Compute optical flow between prev_input -> current_input
flow_rgb_tensor = self._compute_flow_to_rgb_tensor(self.prev_input, input_tensor)
-
+
# Check output format
- output_format = self.params.get('output_format', 'concat')
+ output_format = self.params.get("output_format", "concat")
if output_format == "concat":
# Concatenate prev_input + flow_RGB for TemporalNet2 (6 channels)
result_tensor = self._concatenate_frames_tensor(self.prev_input, flow_rgb_tensor)
else:
# Return only flow RGB (3 channels)
result_tensor = flow_rgb_tensor
-
+
# Ensure correct output format
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
-
+
result = result_tensor.to(device=self.device, dtype=self.dtype)
except Exception as e:
logger.error(f"_process_tensor_core: TensorRT optical flow failed: {e}")
- output_format = self.params.get('output_format', 'concat')
+ output_format = self.params.get("output_format", "concat")
if output_format == "concat":
# Create 6-channel fallback by concatenating prev_input with itself
result_tensor = self._concatenate_frames_tensor(self.prev_input, self.prev_input)
@@ -393,19 +401,19 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
self._first_frame = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
-
+
# Handle 6-channel output for first frame
- output_format = self.params.get('output_format', 'concat')
+ output_format = self.params.get("output_format", "concat")
if output_format == "concat":
# For first frame, concatenate current frame with zeros (no flow)
if tensor.dim() == 4 and tensor.shape[0] == 1:
current_tensor = tensor[0]
else:
current_tensor = tensor
-
+
# Create zero tensor for flow (same shape as current_tensor)
zero_flow = torch.zeros_like(current_tensor, device=self.device, dtype=current_tensor.dtype)
-
+
result_tensor = self._concatenate_frames_tensor(current_tensor, zero_flow)
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
@@ -420,164 +428,148 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
result = result_tensor.to(device=self.device, dtype=self.dtype)
-
+
# Store current input as previous for next frame
self.prev_input = input_tensor.clone()
-
+
return result
-
- def _compute_flow_to_rgb_tensor(self, prev_input_tensor: torch.Tensor, current_input_tensor: torch.Tensor) -> torch.Tensor:
+
+ def _compute_flow_to_rgb_tensor(
+ self, prev_input_tensor: torch.Tensor, current_input_tensor: torch.Tensor
+ ) -> torch.Tensor:
"""
Compute optical flow between prev_input -> current_input and convert to RGB visualization
-
+
Args:
prev_input_tensor: Previous input frame tensor (CHW format, [0,1]) on GPU
current_input_tensor: Current input frame tensor (CHW format, [0,1]) on GPU
-
+
Returns:
Flow visualization as RGB tensor (CHW format, [0,1]) on GPU
"""
target_width, target_height = self.get_target_dimensions()
-
+
# Convert to float32 for TensorRT processing
prev_tensor = prev_input_tensor.to(device=self.device, dtype=torch.float32)
current_tensor = current_input_tensor.to(device=self.device, dtype=torch.float32)
-
+
# Resize for flow computation if needed (keep on GPU)
if current_tensor.shape[-1] != self.width or current_tensor.shape[-2] != self.height:
prev_resized = F.interpolate(
- prev_tensor.unsqueeze(0),
- size=(self.height, self.width),
- mode='bilinear',
- align_corners=False
+ prev_tensor.unsqueeze(0), size=(self.height, self.width), mode="bilinear", align_corners=False
).squeeze(0)
current_resized = F.interpolate(
- current_tensor.unsqueeze(0),
- size=(self.height, self.width),
- mode='bilinear',
- align_corners=False
+ current_tensor.unsqueeze(0), size=(self.height, self.width), mode="bilinear", align_corners=False
).squeeze(0)
else:
prev_resized = prev_tensor
current_resized = current_tensor
-
+
# Compute optical flow using TensorRT: prev_input -> current_input
flow = self._compute_optical_flow_tensorrt(prev_resized, current_resized)
-
+
# Apply flow strength scaling (GPU operation)
- flow_strength = self.params.get('flow_strength', 1.0)
+ flow_strength = self.params.get("flow_strength", 1.0)
if flow_strength != 1.0:
flow = flow * flow_strength
-
+
# Convert flow to RGB visualization using torchvision's flow_to_image
# flow_to_image expects (2, H, W) and returns (3, H, W) in range [0, 255]
flow_rgb = flow_to_image(flow) # Returns uint8 tensor [0, 255]
-
+
# Convert to float [0, 1] range
flow_rgb = flow_rgb.float() / 255.0
-
+
# Resize back to target resolution if needed (keep on GPU)
if flow_rgb.shape[-1] != target_width or flow_rgb.shape[-2] != target_height:
flow_rgb = F.interpolate(
- flow_rgb.unsqueeze(0),
- size=(target_height, target_width),
- mode='bilinear',
- align_corners=False
+ flow_rgb.unsqueeze(0), size=(target_height, target_width), mode="bilinear", align_corners=False
).squeeze(0)
-
+
# Convert to processor's dtype only at the very end
result = flow_rgb.to(dtype=self.dtype)
-
+
return result
-
+
def _compute_optical_flow_tensorrt(self, frame1: torch.Tensor, frame2: torch.Tensor) -> torch.Tensor:
"""
Compute optical flow between two frames using TensorRT-accelerated RAFT
-
+
Args:
frame1: First frame tensor (CHW format, [0,1])
frame2: Second frame tensor (CHW format, [0,1])
-
+
Returns:
Optical flow tensor (2HW format)
"""
-
+
if self.trt_engine is None:
raise RuntimeError("_compute_optical_flow_tensorrt: TensorRT engine not loaded")
-
+
# Prepare inputs for TensorRT
frame1_batch = frame1.unsqueeze(0)
frame2_batch = frame2.unsqueeze(0)
-
+
# Apply RAFT preprocessing if available
weights = Raft_Small_Weights.DEFAULT
- if hasattr(weights, 'transforms') and weights.transforms is not None:
+ if hasattr(weights, "transforms") and weights.transforms is not None:
transforms = weights.transforms()
frame1_batch, frame2_batch = transforms(frame1_batch, frame2_batch)
-
+
# Run TensorRT inference
- feed_dict = {
- 'frame1': frame1_batch,
- 'frame2': frame2_batch
- }
-
+ feed_dict = {"frame1": frame1_batch, "frame2": frame2_batch}
+
cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.trt_engine.infer(feed_dict, cuda_stream)
- flow = result['flow'][0] # Remove batch dimension
-
+ flow = result["flow"][0] # Remove batch dimension
+
return flow
-
-
def _warp_frame_tensor(self, frame: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
"""
Warp frame using optical flow with cached coordinate grids
-
+
Args:
frame: Frame to warp (CHW format)
flow: Optical flow (2HW format)
-
+
Returns:
Warped frame tensor
"""
H, W = frame.shape[-2:]
-
+
# Use cached grid if available
grid_key = (H, W)
if grid_key not in self._grid_cache:
grid_y, grid_x = torch.meshgrid(
torch.arange(H, device=self.device, dtype=torch.float32),
torch.arange(W, device=self.device, dtype=torch.float32),
- indexing='ij'
+ indexing="ij",
)
self._grid_cache[grid_key] = (grid_x, grid_y)
else:
grid_x, grid_y = self._grid_cache[grid_key]
-
+
# Apply flow to coordinates
new_x = grid_x + flow[0]
new_y = grid_y + flow[1]
-
+
# Normalize coordinates to [-1, 1] for grid_sample
new_x = 2.0 * new_x / (W - 1) - 1.0
new_y = 2.0 * new_y / (H - 1) - 1.0
-
+
# Create sampling grid (HW2 format for grid_sample)
grid = torch.stack([new_x, new_y], dim=-1).unsqueeze(0)
-
+
# Warp frame
warped_batch = F.grid_sample(
- frame.unsqueeze(0),
- grid,
- mode='bilinear',
- padding_mode='border',
- align_corners=True
+ frame.unsqueeze(0), grid, mode="bilinear", padding_mode="border", align_corners=True
)
-
+
result = warped_batch.squeeze(0)
-
+
return result
-
+
def _concatenate_frames(self, current_image: Image.Image, warped_image: Image.Image) -> Image.Image:
"""Concatenate current frame and warped previous frame for TemporalNet2 (6-channel input)"""
# Convert to tensors and use tensor concatenation for consistency
@@ -585,43 +577,43 @@ def _concatenate_frames(self, current_image: Image.Image, warped_image: Image.Im
warped_tensor = self.pil_to_tensor(warped_image).squeeze(0)
result_tensor = self._concatenate_frames_tensor(current_tensor, warped_tensor)
return self.tensor_to_pil(result_tensor)
-
+
def _concatenate_frames_tensor(self, current_tensor: torch.Tensor, warped_tensor: torch.Tensor) -> torch.Tensor:
"""
Concatenate current frame and warped previous frame tensors for TemporalNet2 (6-channel input)
-
+
Args:
current_tensor: Current input frame tensor (CHW format)
warped_tensor: Warped previous frame tensor (CHW format)
-
+
Returns:
Concatenated tensor (6CHW format)
"""
# Ensure same size
if current_tensor.shape != warped_tensor.shape:
target_width, target_height = self.get_target_dimensions()
-
+
if current_tensor.shape[-2:] != (target_height, target_width):
current_tensor = F.interpolate(
current_tensor.unsqueeze(0),
size=(target_height, target_width),
- mode='bilinear',
- align_corners=False
+ mode="bilinear",
+ align_corners=False,
).squeeze(0)
-
+
if warped_tensor.shape[-2:] != (target_height, target_width):
warped_tensor = F.interpolate(
warped_tensor.unsqueeze(0),
size=(target_height, target_width),
- mode='bilinear',
- align_corners=False
+ mode="bilinear",
+ align_corners=False,
).squeeze(0)
-
+
# Concatenate along channel dimension: [current_R, current_G, current_B, warped_R, warped_G, warped_B]
concatenated = torch.cat([current_tensor, warped_tensor], dim=0)
-
+
return concatenated
-
+
def reset(self):
"""
Reset the preprocessor state (useful for new sequences)
@@ -631,4 +623,4 @@ def reset(self):
# Clear caches to free memory
self._grid_cache.clear()
self._tensor_cache.clear()
- torch.cuda.empty_cache()
\ No newline at end of file
+ torch.cuda.empty_cache()
diff --git a/src/streamdiffusion/preprocessing/processors/upscale.py b/src/streamdiffusion/preprocessing/processors/upscale.py
index 82659b7b8..38a69d499 100644
--- a/src/streamdiffusion/preprocessing/processors/upscale.py
+++ b/src/streamdiffusion/preprocessing/processors/upscale.py
@@ -1,7 +1,9 @@
+from typing import Literal
+
import torch
import torch.nn.functional as F
from PIL import Image
-from typing import Literal
+
from .base import BasePreprocessor
@@ -10,8 +12,8 @@ class UpscalePreprocessor(BasePreprocessor):
Image upscaling preprocessor with multiple interpolation algorithms.
Supports bilinear, lanczos, bicubic, and nearest neighbor upscaling.
"""
-
- @classmethod
+
+ @classmethod
def get_preprocessor_metadata(cls):
return {
"display_name": "Upscale",
@@ -21,71 +23,71 @@ def get_preprocessor_metadata(cls):
"type": "float",
"default": 2.0,
"range": [1.0, 4.0],
- "description": "Upscaling factor"
+ "description": "Upscaling factor",
},
"algorithm": {
"type": "str",
"default": "bilinear",
"options": ["bilinear", "lanczos", "bicubic", "nearest"],
- "description": "Interpolation algorithm: bilinear (fast), lanczos (high quality), bicubic (balanced), nearest (pixel art)"
- }
+ "description": "Interpolation algorithm: bilinear (fast), lanczos (high quality), bicubic (balanced), nearest (pixel art)",
+ },
},
- "use_cases": ["Real-time upscaling", "Image enhancement", "Resolution conversion"]
+ "use_cases": ["Real-time upscaling", "Image enhancement", "Resolution conversion"],
}
-
- def __init__(self, scale_factor: float = 2.0, algorithm: Literal["bilinear", "lanczos", "bicubic", "nearest"] = "bilinear", **kwargs):
+
+ def __init__(
+ self,
+ scale_factor: float = 2.0,
+ algorithm: Literal["bilinear", "lanczos", "bicubic", "nearest"] = "bilinear",
+ **kwargs,
+ ):
super().__init__(scale_factor=scale_factor, algorithm=algorithm, **kwargs)
self.scale_factor = scale_factor
self.algorithm = algorithm
-
+
# Map algorithm names to PIL and PyTorch modes
self.pil_resample_map = {
"bilinear": Image.BILINEAR,
"lanczos": Image.LANCZOS,
"bicubic": Image.BICUBIC,
- "nearest": Image.NEAREST
+ "nearest": Image.NEAREST,
}
-
+
self.torch_mode_map = {
"bilinear": "bilinear",
"lanczos": "bicubic", # PyTorch doesn't have lanczos, use bicubic as closest
"bicubic": "bicubic",
- "nearest": "nearest"
+ "nearest": "nearest",
}
-
+
def _process_core(self, image: Image.Image) -> Image.Image:
"""PIL-based upscaling"""
target_width, target_height = self.get_target_dimensions()
resample_method = self.pil_resample_map.get(self.algorithm, Image.BILINEAR)
return image.resize((target_width, target_height), resample_method)
-
+
def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
"""Tensor-based upscaling"""
target_width, target_height = self.get_target_dimensions()
-
+
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
-
+
mode = self.torch_mode_map.get(self.algorithm, "bilinear")
-
+
if mode in ["bilinear", "bicubic"]:
- return F.interpolate(tensor, size=(target_height, target_width),
- mode=mode, align_corners=False)
+ return F.interpolate(tensor, size=(target_height, target_width), mode=mode, align_corners=False)
else: # nearest
- return F.interpolate(tensor, size=(target_height, target_width),
- mode=mode)
-
+ return F.interpolate(tensor, size=(target_height, target_width), mode=mode)
+
def get_target_dimensions(self):
"""Handle scale factor for dimensions"""
- width = self.params.get('image_width')
- height = self.params.get('image_height')
-
+ width = self.params.get("image_width")
+ height = self.params.get("image_height")
+
if width is not None and height is not None:
return (int(width * self.scale_factor), int(height * self.scale_factor))
-
- base_resolution = self.params.get('image_resolution', 512)
+
+ base_resolution = self.params.get("image_resolution", 512)
target_resolution = int(base_resolution * self.scale_factor)
return (target_resolution, target_resolution)
-
-
-
diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py
index 8f81b6f89..a344d3351 100644
--- a/src/streamdiffusion/stream_parameter_updater.py
+++ b/src/streamdiffusion/stream_parameter_updater.py
@@ -1,15 +1,18 @@
-from typing import List, Optional, Dict, Tuple, Literal, Any, Callable
+import logging
import threading
+from typing import Any, Dict, List, Literal, Optional, Tuple
+
import torch
import torch.nn.functional as F
-import gc
-import logging
+
logger = logging.getLogger(__name__)
from .preprocessing.orchestrator_user import OrchestratorUser
+
class CacheStats:
"""Helper class to track cache statistics"""
+
def __init__(self):
self.hits = 0
self.misses = 0
@@ -22,7 +25,13 @@ def record_miss(self):
class StreamParameterUpdater(OrchestratorUser):
- def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True):
+ def __init__(
+ self,
+ stream_diffusion,
+ wrapper=None,
+ normalize_prompt_weights: bool = True,
+ normalize_seed_weights: bool = True,
+ ):
self.stream = stream_diffusion
self.wrapper = wrapper # Reference to wrapper for accessing pipeline structure
self.normalize_prompt_weights = normalize_prompt_weights
@@ -39,8 +48,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo
self._seed_cache: Dict[int, Dict] = {}
self._current_seed_list: List[Tuple[int, float]] = []
self._seed_cache_stats = CacheStats()
-
-
+
# Attach shared orchestrator once (lazy-creates on stream if absent)
self.attach_orchestrator(self.stream)
@@ -50,6 +58,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo
self._current_style_images: Dict[str, Any] = {}
# Use the shared orchestrator attached via OrchestratorUser
self._embedding_orchestrator = self._preprocessing_orchestrator
+
def get_cache_info(self) -> Dict:
"""Get cache statistics for monitoring performance."""
total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses
@@ -68,7 +77,7 @@ def get_cache_info(self) -> Dict:
"seed_cache_hits": self._seed_cache_stats.hits,
"seed_cache_misses": self._seed_cache_stats.misses,
"seed_hit_rate": f"{seed_hit_rate:.2%}",
- "current_seeds": len(self._current_seed_list)
+ "current_seeds": len(self._current_seed_list),
}
def clear_caches(self) -> None:
@@ -81,7 +90,7 @@ def clear_caches(self) -> None:
self._seed_cache.clear()
self._current_seed_list.clear()
self._seed_cache_stats = CacheStats()
-
+
# Clear embedding caches
self._embedding_cache.clear()
self._current_style_images.clear()
@@ -93,13 +102,13 @@ def get_normalize_prompt_weights(self) -> bool:
def get_normalize_seed_weights(self) -> bool:
"""Get the current seed weight normalization setting."""
return self.normalize_seed_weights
-
+
# Deprecated enhancer registration removed; embedding composition is handled via stream.embedding_hooks
def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: str) -> None:
"""
Register an embedding preprocessor for parallel processing.
-
+
Args:
preprocessor: IPAdapterEmbeddingPreprocessor instance
style_image_key: Unique key for the style image this preprocessor handles
@@ -108,28 +117,27 @@ def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: st
# Ensure orchestrator is present
self.attach_orchestrator(self.stream)
self._embedding_orchestrator = self._preprocessing_orchestrator
-
+
self._embedding_preprocessors.append((preprocessor, style_image_key))
-
+
def unregister_embedding_preprocessor(self, style_image_key: str) -> None:
"""Unregister an embedding preprocessor by style image key."""
original_count = len(self._embedding_preprocessors)
self._embedding_preprocessors = [
- (preprocessor, key) for preprocessor, key in self._embedding_preprocessors
- if key != style_image_key
+ (preprocessor, key) for preprocessor, key in self._embedding_preprocessors if key != style_image_key
]
removed_count = original_count - len(self._embedding_preprocessors)
-
+
# Clear cached embeddings for this key
if style_image_key in self._embedding_cache:
del self._embedding_cache[style_image_key]
if style_image_key in self._current_style_images:
del self._current_style_images[style_image_key]
-
+
def update_style_image(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None:
"""
Update a style image and trigger embedding preprocessing.
-
+
Args:
style_image_key: Unique key for the style image
style_image: The style image (PIL Image, path, etc.)
@@ -138,14 +146,16 @@ def update_style_image(self, style_image_key: str, style_image: Any, is_stream:
"""
# Store the style image
self._current_style_images[style_image_key] = style_image
-
+
# Trigger preprocessing for this style image
self._preprocess_style_image_parallel(style_image_key, style_image, is_stream)
-
- def _preprocess_style_image_parallel(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None:
+
+ def _preprocess_style_image_parallel(
+ self, style_image_key: str, style_image: Any, is_stream: bool = False
+ ) -> None:
"""
Preprocessing for a specific style image with mode selection
-
+
Args:
style_image_key: Unique key for the style image
style_image: The style image to process
@@ -153,57 +163,47 @@ def _preprocess_style_image_parallel(self, style_image_key: str, style_image: An
"""
if not self._embedding_preprocessors or self._embedding_orchestrator is None:
return
-
+
# Find preprocessors for this key
relevant_preprocessors = [
- preprocessor for preprocessor, key in self._embedding_preprocessors
- if key == style_image_key
+ preprocessor for preprocessor, key in self._embedding_preprocessors if key == style_image_key
]
-
+
if not relevant_preprocessors:
return
-
+
# Choose processing mode based on is_stream parameter
try:
if is_stream:
# Pipelined processing - optimized for throughput with 1-frame lag
embedding_results = self._embedding_orchestrator.process_pipelined(
- style_image,
- relevant_preprocessors,
- None,
- self.stream.width,
- self.stream.height,
- "ipadapter"
+ style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, "ipadapter"
)
else:
# Synchronous processing - immediate results for discrete updates
embedding_results = self._embedding_orchestrator.process_sync(
- style_image,
- relevant_preprocessors,
- None,
- self.stream.width,
- self.stream.height,
- None,
- "ipadapter"
+ style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, None, "ipadapter"
)
-
+
# Cache results for this style image key
if embedding_results and embedding_results[0] is not None:
self._embedding_cache[style_image_key] = embedding_results[0]
else:
# This is an error condition - we should always have results
- raise RuntimeError(f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'")
-
- except Exception as e:
+ raise RuntimeError(
+ f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'"
+ )
+
+ except Exception:
import traceback
+
traceback.print_exc()
-
+
def get_cached_embeddings(self, style_image_key: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""Get cached embeddings for a style image key"""
cached_result = self._embedding_cache.get(style_image_key, None)
return cached_result
-
def _normalize_weights(self, weights: List[float], normalize: bool) -> torch.Tensor:
"""Generic weight normalization helper"""
weights_tensor = torch.tensor(weights, device=self.stream.device, dtype=self.stream.dtype)
@@ -218,7 +218,7 @@ def _validate_index(self, index: int, item_list: List, operation_name: str) -> b
return False
if index < 0 or index >= len(item_list):
- logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list)-1})")
+ logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list) - 1})")
return False
return True
@@ -281,28 +281,27 @@ def update_stream_params(
f"provided t_index_list (max index: {max_t_index}). Adjusting to {max_t_index + 1}."
)
num_inference_steps = max_t_index + 1
-
+
old_num_steps = len(self.stream.timesteps)
self.stream.scheduler.set_timesteps(num_inference_steps, self.stream.device)
self.stream.timesteps = self.stream.scheduler.timesteps.to(self.stream.device)
-
+
# If t_index_list wasn't explicitly provided, rescale existing t_list proportionally
if t_index_list is None and old_num_steps > 0:
# Rescale each index proportionally to the new number of steps
# e.g., if t_list = [0, 16, 32, 45] with 50 steps -> [0, 3, 6, 8] with 9 steps
scale_factor = (num_inference_steps - 1) / (old_num_steps - 1) if old_num_steps > 1 else 1.0
- t_index_list = [
- min(round(t * scale_factor), num_inference_steps - 1)
- for t in self.stream.t_list
- ]
-
+ t_index_list = [min(round(t * scale_factor), num_inference_steps - 1) for t in self.stream.t_list]
+
# Now update timestep-dependent parameters with the correct t_index_list
if t_index_list is not None:
self._recalculate_timestep_dependent_params(t_index_list)
if guidance_scale is not None:
if self.stream.cfg_type == "none" and guidance_scale > 1.0:
- logger.warning("update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect")
+ logger.warning(
+ "update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect"
+ )
self.stream.guidance_scale = guidance_scale
if delta is not None:
@@ -310,7 +309,7 @@ def update_stream_params(
if seed is not None:
self._update_seed(seed)
-
+
if normalize_prompt_weights is not None:
self.normalize_prompt_weights = normalize_prompt_weights
logger.info(f"update_stream_params: Prompt weight normalization set to {normalize_prompt_weights}")
@@ -324,44 +323,42 @@ def update_stream_params(
self._update_blended_prompts(
prompt_list=prompt_list,
negative_prompt=negative_prompt or self._current_negative_prompt,
- prompt_interpolation_method=prompt_interpolation_method
+ prompt_interpolation_method=prompt_interpolation_method,
)
# Handle seed blending if seed_list is provided
if seed_list is not None:
- self._update_blended_seeds(
- seed_list=seed_list,
- interpolation_method=seed_interpolation_method
- )
-
+ self._update_blended_seeds(seed_list=seed_list, interpolation_method=seed_interpolation_method)
# Handle ControlNet configuration updates
if controlnet_config is not None:
- #TODO: happy path for control images
+ # TODO: happy path for control images
self._update_controlnet_config(controlnet_config)
-
+
# Handle IPAdapter configuration updates
if ipadapter_config is not None:
- logger.info(f"update_stream_params: Updating IPAdapter configuration")
+ logger.info("update_stream_params: Updating IPAdapter configuration")
self._update_ipadapter_config(ipadapter_config)
-
+
# Handle Hook configuration updates
if image_preprocessing_config is not None:
- logger.info(f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors")
+ logger.info(
+ f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors"
+ )
logger.info(f"update_stream_params: image_preprocessing_config = {image_preprocessing_config}")
- self._update_hook_config('image_preprocessing', image_preprocessing_config)
-
+ self._update_hook_config("image_preprocessing", image_preprocessing_config)
+
if image_postprocessing_config is not None:
- logger.info(f"update_stream_params: Updating image postprocessing configuration")
- self._update_hook_config('image_postprocessing', image_postprocessing_config)
-
+ logger.info("update_stream_params: Updating image postprocessing configuration")
+ self._update_hook_config("image_postprocessing", image_postprocessing_config)
+
if latent_preprocessing_config is not None:
- logger.info(f"update_stream_params: Updating latent preprocessing configuration")
- self._update_hook_config('latent_preprocessing', latent_preprocessing_config)
-
+ logger.info("update_stream_params: Updating latent preprocessing configuration")
+ self._update_hook_config("latent_preprocessing", latent_preprocessing_config)
+
if latent_postprocessing_config is not None:
- logger.info(f"update_stream_params: Updating latent postprocessing configuration")
- self._update_hook_config('latent_postprocessing', latent_postprocessing_config)
+ logger.info("update_stream_params: Updating latent postprocessing configuration")
+ self._update_hook_config("latent_postprocessing", latent_postprocessing_config)
if self.stream.kvo_cache:
if cache_interval is not None:
@@ -375,9 +372,7 @@ def update_stream_params(
# runtime β resizing one-at-a-time races with TRT inference (causes "Dimensions
# with name C must be equal" errors). cache_maxframes is a logical write window.
actual_cache_size = (
- self.stream.kvo_cache[0].shape[1]
- if self.stream.kvo_cache
- else cache_maxframes
+ self.stream.kvo_cache[0].shape[1] if self.stream.kvo_cache else cache_maxframes
)
if cache_maxframes > actual_cache_size:
logger.warning(
@@ -395,9 +390,7 @@ def update_stream_params(
@torch.inference_mode()
def update_prompt_weights(
- self,
- prompt_weights: List[float],
- prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
+ self, prompt_weights: List[float], prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
) -> None:
"""Update weights for current prompt list without re-encoding prompts."""
if not self._current_prompt_list:
@@ -405,7 +398,9 @@ def update_prompt_weights(
return
if len(prompt_weights) != len(self._current_prompt_list):
- logger.warning(f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}")
+ logger.warning(
+ f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}"
+ )
return
# Update the current prompt list with new weights
@@ -420,9 +415,7 @@ def update_prompt_weights(
@torch.inference_mode()
def update_seed_weights(
- self,
- seed_weights: List[float],
- interpolation_method: Literal["linear", "slerp"] = "linear"
+ self, seed_weights: List[float], interpolation_method: Literal["linear", "slerp"] = "linear"
) -> None:
"""Update weights for current seed list without regenerating noise."""
if not self._current_seed_list:
@@ -430,7 +423,9 @@ def update_seed_weights(
return
if len(seed_weights) != len(self._current_seed_list):
- logger.warning(f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}")
+ logger.warning(
+ f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}"
+ )
return
# Update the current seed list with new weights
@@ -448,7 +443,7 @@ def _update_blended_prompts(
self,
prompt_list: List[Tuple[str, float]],
negative_prompt: str = "",
- prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
+ prompt_interpolation_method: Literal["linear", "slerp"] = "slerp",
) -> None:
"""Update prompt embeddings using multiple weighted prompts."""
# Store current state
@@ -461,14 +456,10 @@ def _update_blended_prompts(
# Apply blending
self._apply_prompt_blending(prompt_interpolation_method)
- def _cache_prompt_embeddings(
- self,
- prompt_list: List[Tuple[str, float]],
- negative_prompt: str
- ) -> None:
+ def _cache_prompt_embeddings(self, prompt_list: List[Tuple[str, float]], negative_prompt: str) -> None:
"""Cache prompt embeddings for efficient reuse."""
for idx, (prompt_text, weight) in enumerate(prompt_list):
- if idx not in self._prompt_cache or self._prompt_cache[idx]['text'] != prompt_text:
+ if idx not in self._prompt_cache or self._prompt_cache[idx]["text"] != prompt_text:
# Cache miss - encode the prompt
self._prompt_cache_stats.record_miss()
encoder_output = self.stream.pipe.encode_prompt(
@@ -482,10 +473,7 @@ def _cache_prompt_embeddings(
if len(self._prompt_cache) >= 32:
oldest_key = next(iter(self._prompt_cache))
del self._prompt_cache[oldest_key]
- self._prompt_cache[idx] = {
- 'embed': encoder_output[0],
- 'text': prompt_text
- }
+ self._prompt_cache[idx] = {"embed": encoder_output[0], "text": prompt_text}
else:
# Cache hit
self._prompt_cache_stats.record_hit()
@@ -500,7 +488,7 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear",
for idx, (prompt_text, weight) in enumerate(self._current_prompt_list):
if idx in self._prompt_cache:
- embeddings.append(self._prompt_cache[idx]['embed'])
+ embeddings.append(self._prompt_cache[idx]["embed"])
weights.append(weight)
if not embeddings:
@@ -545,13 +533,14 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear",
# No CFG, just use the blended embeddings
final_prompt_embeds = combined_embeds.repeat(self.stream.batch_size, 1, 1)
final_negative_embeds = None # Will be set by enhancers if needed
-
+
# Enhancer mechanism removed in favor of embedding_hooks
# Run embedding hooks to compose final embeddings (e.g., append IP-Adapter tokens)
try:
- if hasattr(self.stream, 'embedding_hooks') and self.stream.embedding_hooks:
+ if hasattr(self.stream, "embedding_hooks") and self.stream.embedding_hooks:
from .hooks import EmbedsCtx # local import to avoid cycles
+
embeds_ctx = EmbedsCtx(
prompt_embeds=final_prompt_embeds,
negative_prompt_embeds=final_negative_embeds,
@@ -562,8 +551,9 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear",
final_negative_embeds = embeds_ctx.negative_prompt_embeds
except Exception as e:
import logging
+
logging.getLogger(__name__).error(f"_apply_prompt_blending: embedding hook failed: {e}")
-
+
# Set final embeddings on stream
self.stream.prompt_embeds = final_prompt_embeds
if final_negative_embeds is not None:
@@ -604,9 +594,7 @@ def _slerp(self, embed1: torch.Tensor, embed2: torch.Tensor, t: float) -> torch.
@torch.inference_mode()
def _update_blended_seeds(
- self,
- seed_list: List[Tuple[int, float]],
- interpolation_method: Literal["linear", "slerp"] = "linear"
+ self, seed_list: List[Tuple[int, float]], interpolation_method: Literal["linear", "slerp"] = "linear"
) -> None:
"""Update seed tensors using multiple weighted seeds."""
# Store current state
@@ -621,7 +609,7 @@ def _update_blended_seeds(
def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None:
"""Cache seed noise tensors for efficient reuse."""
for idx, (seed_value, weight) in enumerate(seed_list):
- if idx not in self._seed_cache or self._seed_cache[idx]['seed'] != seed_value:
+ if idx not in self._seed_cache or self._seed_cache[idx]["seed"] != seed_value:
# Cache miss - generate noise for the seed
self._seed_cache_stats.record_miss()
generator = torch.Generator(device=self.stream.device)
@@ -631,13 +619,10 @@ def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None:
(self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width),
generator=generator,
device=self.stream.device,
- dtype=self.stream.dtype
+ dtype=self.stream.dtype,
)
- self._seed_cache[idx] = {
- 'noise': noise,
- 'seed': seed_value
- }
+ self._seed_cache[idx] = {"noise": noise, "seed": seed_value}
else:
# Cache hit
self._seed_cache_stats.record_hit()
@@ -652,7 +637,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"])
for idx, (seed_value, weight) in enumerate(self._current_seed_list):
if idx in self._seed_cache:
- noise_tensors.append(self._seed_cache[idx]['noise'])
+ noise_tensors.append(self._seed_cache[idx]["noise"])
weights.append(weight)
if not noise_tensors:
@@ -673,7 +658,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"])
combined_noise = torch.zeros_like(noise_tensors[0])
for noise, weight in zip(noise_tensors, weights):
combined_noise += weight * noise
-
+
# Preserve noise magnitude when weights are normalized
if self.normalize_seed_weights and len(noise_tensors) > 1:
original_magnitude = torch.mean(torch.stack([torch.norm(noise) for noise in noise_tensors]))
@@ -748,6 +733,7 @@ def _update_seed(self, seed: int) -> None:
def _get_scheduler_scalings(self, timestep):
"""Get LCM/TCD-specific scaling factors for boundary conditions."""
from diffusers import LCMScheduler
+
if isinstance(self.stream.scheduler, LCMScheduler):
c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
return c_skip, c_out
@@ -765,9 +751,7 @@ def _update_timestep_calculations(self) -> None:
for t in self.stream.t_list:
self.stream.sub_timesteps.append(self.stream.timesteps[t])
- sub_timesteps_tensor = torch.tensor(
- self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device
- )
+ sub_timesteps_tensor = torch.tensor(self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device)
self.stream.sub_timesteps_tensor = torch.repeat_interleave(
sub_timesteps_tensor,
repeats=self.stream.frame_bff_size if self.stream.use_denoising_batch else 1,
@@ -793,12 +777,8 @@ def _update_timestep_calculations(self) -> None:
)
if self.stream.use_denoising_batch:
- self.stream.c_skip = torch.repeat_interleave(
- self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0
- )
- self.stream.c_out = torch.repeat_interleave(
- self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0
- )
+ self.stream.c_skip = torch.repeat_interleave(self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0)
+ self.stream.c_out = torch.repeat_interleave(self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0)
# Update alpha_prod_t_sqrt and beta_prod_t_sqrt
alpha_prod_t_sqrt_list = []
@@ -838,29 +818,25 @@ def _update_timestep_values_only(self, t_index_list: List[int]) -> None:
def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> None:
"""Recalculate all parameters that depend on t_index_list."""
-
+
# Check if this is a structural change (length) or just value change
if len(t_index_list) == len(self.stream.t_list):
# Same length - only values changed, use lightweight update (working branch behavior)
self._update_timestep_values_only(t_index_list)
return
-
+
# Length changed - do full recalculation including batch-dependent parameters (broken branch logic - but it works for this case!)
self.stream.t_list = t_index_list
self.stream.denoising_steps_num = len(self.stream.t_list)
old_batch_size = self.stream.batch_size
-
+
if self.stream.use_denoising_batch:
self.stream.batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size
if self.stream.cfg_type == "initialize":
- self.stream.trt_unet_batch_size = (
- self.stream.denoising_steps_num + 1
- ) * self.stream.frame_bff_size
+ self.stream.trt_unet_batch_size = (self.stream.denoising_steps_num + 1) * self.stream.frame_bff_size
elif self.stream.cfg_type == "full":
- self.stream.trt_unet_batch_size = (
- 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size
- )
+ self.stream.trt_unet_batch_size = 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size
else:
self.stream.trt_unet_batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size
else:
@@ -891,27 +867,33 @@ def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> Non
# Resize kvo_cache tensors if batch size changed
if self.stream.kvo_cache and old_batch_size != self.stream.batch_size:
- logger.info(f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}")
+ logger.info(
+ f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}"
+ )
for i, cache_tensor in enumerate(self.stream.kvo_cache):
# KVO cache shape: (2, cache_maxframes, batch_size, seq_length, hidden_dim)
current_shape = cache_tensor.shape
- new_shape = (current_shape[0], current_shape[1], self.stream.batch_size, current_shape[3], current_shape[4])
- new_cache_tensor = torch.zeros(
- new_shape,
- dtype=cache_tensor.dtype,
- device=cache_tensor.device
+ new_shape = (
+ current_shape[0],
+ current_shape[1],
+ self.stream.batch_size,
+ current_shape[3],
+ current_shape[4],
)
-
+ new_cache_tensor = torch.zeros(new_shape, dtype=cache_tensor.dtype, device=cache_tensor.device)
+
# Copy over as much data as possible from old cache
min_batch = min(old_batch_size, self.stream.batch_size)
new_cache_tensor[:, :, :min_batch, :, :] = cache_tensor[:, :, :min_batch, :, :]
-
+
self.stream.kvo_cache[i] = new_cache_tensor
# Drop bucketed storage refs so update_kvo_cache falls back to
# per-layer writes against the new tensors.
self.stream._kvo_buckets = None
self.stream._kvo_outputs_by_bucket = None
- logger.info(f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}")
+ logger.info(
+ f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}"
+ )
# Update timestep-dependent calculations (shared with value-only path)
self._update_timestep_calculations()
@@ -930,10 +912,7 @@ def _recalculate_controlnet_inputs(self, width: int, height: int) -> None:
@torch.inference_mode()
def update_prompt_at_index(
- self,
- index: int,
- new_prompt: str,
- prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
+ self, index: int, new_prompt: str, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
) -> None:
"""Update a single prompt at the specified index without re-encoding others."""
if not self._validate_index(index, self._current_prompt_list, "update_prompt_at_index"):
@@ -947,11 +926,11 @@ def update_prompt_at_index(
self._cache_prompt_embeddings([(new_prompt, weight)], self._current_negative_prompt)
# Update cache index to point to the new prompt
- if index in self._prompt_cache and self._prompt_cache[index]['text'] != new_prompt:
+ if index in self._prompt_cache and self._prompt_cache[index]["text"] != new_prompt:
# Find if this prompt is already cached elsewhere
existing_cache_key = None
for cache_idx, cache_data in self._prompt_cache.items():
- if cache_data['text'] == new_prompt:
+ if cache_data["text"] == new_prompt:
existing_cache_key = cache_idx
break
@@ -969,10 +948,7 @@ def update_prompt_at_index(
do_classifier_free_guidance=False,
negative_prompt=self._current_negative_prompt,
)
- self._prompt_cache[index] = {
- 'embed': encoder_output[0],
- 'text': new_prompt
- }
+ self._prompt_cache[index] = {"embed": encoder_output[0], "text": new_prompt}
# Recompute blended embeddings with updated prompt
self._apply_prompt_blending(prompt_interpolation_method)
@@ -984,16 +960,12 @@ def get_current_prompts(self) -> List[Tuple[str, float]]:
@torch.inference_mode()
def add_prompt(
- self,
- prompt: str,
- weight: float = 1.0,
- prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
+ self, prompt: str, weight: float = 1.0, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
) -> None:
"""Add a new prompt to the current list."""
new_index = len(self._current_prompt_list)
self._current_prompt_list.append((prompt, weight))
-
# Cache the new prompt
encoder_output = self.stream.pipe.encode_prompt(
prompt=prompt,
@@ -1002,10 +974,7 @@ def add_prompt(
do_classifier_free_guidance=False,
negative_prompt=self._current_negative_prompt,
)
- self._prompt_cache[new_index] = {
- 'embed': encoder_output[0],
- 'text': prompt
- }
+ self._prompt_cache[new_index] = {"embed": encoder_output[0], "text": prompt}
self._prompt_cache_stats.record_miss()
# Recompute blended embeddings
@@ -1013,9 +982,7 @@ def add_prompt(
@torch.inference_mode()
def remove_prompt_at_index(
- self,
- index: int,
- prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
+ self, index: int, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp"
) -> None:
"""Remove a prompt at the specified index."""
if not self._validate_index(index, self._current_prompt_list, "remove_prompt_at_index"):
@@ -1040,10 +1007,7 @@ def remove_prompt_at_index(
@torch.inference_mode()
def update_seed_at_index(
- self,
- index: int,
- new_seed: int,
- interpolation_method: Literal["linear", "slerp"] = "linear"
+ self, index: int, new_seed: int, interpolation_method: Literal["linear", "slerp"] = "linear"
) -> None:
"""Update a single seed at the specified index without regenerating others."""
if not self._validate_index(index, self._current_seed_list, "update_seed_at_index"):
@@ -1053,16 +1017,15 @@ def update_seed_at_index(
old_seed, weight = self._current_seed_list[index]
self._current_seed_list[index] = (new_seed, weight)
-
# Cache the new seed noise
self._cache_seed_noise([(new_seed, weight)])
# Update cache index to point to the new seed
- if index in self._seed_cache and self._seed_cache[index]['seed'] != new_seed:
+ if index in self._seed_cache and self._seed_cache[index]["seed"] != new_seed:
# Find if this seed is already cached elsewhere
existing_cache_key = None
for cache_idx, cache_data in self._seed_cache.items():
- if cache_data['seed'] == new_seed:
+ if cache_data["seed"] == new_seed:
existing_cache_key = cache_idx
break
@@ -1080,13 +1043,10 @@ def update_seed_at_index(
(self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width),
generator=generator,
device=self.stream.device,
- dtype=self.stream.dtype
+ dtype=self.stream.dtype,
)
- self._seed_cache[index] = {
- 'noise': noise,
- 'seed': new_seed
- }
+ self._seed_cache[index] = {"noise": noise, "seed": new_seed}
# Recompute blended noise with updated seed
self._apply_seed_blending(interpolation_method)
@@ -1098,10 +1058,7 @@ def get_current_seeds(self) -> List[Tuple[int, float]]:
@torch.inference_mode()
def add_seed(
- self,
- seed: int,
- weight: float = 1.0,
- interpolation_method: Literal["linear", "slerp"] = "linear"
+ self, seed: int, weight: float = 1.0, interpolation_method: Literal["linear", "slerp"] = "linear"
) -> None:
"""Add a new seed to the current list."""
new_index = len(self._current_seed_list)
@@ -1117,24 +1074,17 @@ def add_seed(
(self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width),
generator=generator,
device=self.stream.device,
- dtype=self.stream.dtype
+ dtype=self.stream.dtype,
)
- self._seed_cache[new_index] = {
- 'noise': noise,
- 'seed': seed
- }
+ self._seed_cache[new_index] = {"noise": noise, "seed": seed}
self._seed_cache_stats.record_miss()
# Recompute blended noise
self._apply_seed_blending(interpolation_method)
@torch.inference_mode()
- def remove_seed_at_index(
- self,
- index: int,
- interpolation_method: Literal["linear", "slerp"] = "linear"
- ) -> None:
+ def remove_seed_at_index(self, index: int, interpolation_method: Literal["linear", "slerp"] = "linear") -> None:
"""Remove a seed at the specified index."""
if not self._validate_index(index, self._current_seed_list, "remove_seed_at_index"):
return
@@ -1159,7 +1109,7 @@ def remove_seed_at_index(
def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> None:
"""
Update ControlNet configuration by diffing current vs desired state.
-
+
Args:
desired_config: Complete ControlNet configuration list defining the desired state.
Each dict contains: model_id, preprocessor, conditioning_scale, enabled, etc.
@@ -1167,41 +1117,47 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non
# Find the ControlNet pipeline/module (module-aware)
controlnet_pipeline = self._get_controlnet_pipeline()
if not controlnet_pipeline:
- logger.debug("_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)")
+ logger.debug(
+ "_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)"
+ )
return
-
+
current_config = self._get_current_controlnet_config()
-
+
# Simple approach: detect what changed and apply minimal updates
- current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)}
- desired_models = {cfg['model_id']: cfg for cfg in desired_config}
-
+ current_models = {
+ i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets)
+ }
+ desired_models = {cfg["model_id"]: cfg for cfg in desired_config}
+
# Reorder to match desired order (module supports stable reordering)
try:
- desired_order = [cfg['model_id'] for cfg in desired_config if 'model_id' in cfg]
- if hasattr(controlnet_pipeline, 'reorder_controlnets_by_model_ids'):
+ desired_order = [cfg["model_id"] for cfg in desired_config if "model_id" in cfg]
+ if hasattr(controlnet_pipeline, "reorder_controlnets_by_model_ids"):
controlnet_pipeline.reorder_controlnets_by_model_ids(desired_order)
except Exception:
pass
# Recompute current models after potential reorder
- current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)}
+ current_models = {
+ i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets)
+ }
# Remove controlnets not in desired config
for i in reversed(range(len(controlnet_pipeline.controlnets))):
- model_id = current_models.get(i, f'controlnet_{i}')
+ model_id = current_models.get(i, f"controlnet_{i}")
if model_id not in desired_models:
logger.info(f"_update_controlnet_config: Removing ControlNet {model_id}")
try:
controlnet_pipeline.remove_controlnet(i)
except Exception:
raise
-
+
# Add new controlnets and update existing ones
for desired_cfg in desired_config:
- model_id = desired_cfg['model_id']
+ model_id = desired_cfg["model_id"]
existing_index = next((i for i, mid in current_models.items() if mid == model_id), None)
-
+
if existing_index is None:
# Add new controlnet
logger.info(f"_update_controlnet_config: Adding ControlNet {model_id}")
@@ -1209,15 +1165,16 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non
# Prefer module path: construct ControlNetConfig
try:
from .modules.controlnet_module import ControlNetConfig # type: ignore
+
cn_cfg = ControlNetConfig(
- model_id=desired_cfg.get('model_id'),
- preprocessor=desired_cfg.get('preprocessor'),
- conditioning_scale=desired_cfg.get('conditioning_scale', 1.0),
- enabled=desired_cfg.get('enabled', True),
- conditioning_channels=desired_cfg.get('conditioning_channels'),
- preprocessor_params=desired_cfg.get('preprocessor_params'),
+ model_id=desired_cfg.get("model_id"),
+ preprocessor=desired_cfg.get("preprocessor"),
+ conditioning_scale=desired_cfg.get("conditioning_scale", 1.0),
+ enabled=desired_cfg.get("enabled", True),
+ conditioning_channels=desired_cfg.get("conditioning_channels"),
+ preprocessor_params=desired_cfg.get("preprocessor_params"),
)
- controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get('control_image'))
+ controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get("control_image"))
except Exception:
# No fallback
raise
@@ -1225,114 +1182,136 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non
logger.error(f"_update_controlnet_config: add_controlnet failed for {model_id}: {e}")
else:
# Update existing controlnet
- if 'conditioning_scale' in desired_cfg:
- current_scale = current_config[existing_index].get('conditioning_scale', 1.0)
- desired_scale = desired_cfg['conditioning_scale']
-
+ if "conditioning_scale" in desired_cfg:
+ current_scale = current_config[existing_index].get("conditioning_scale", 1.0)
+ desired_scale = desired_cfg["conditioning_scale"]
+
if current_scale != desired_scale:
- logger.info(f"_update_controlnet_config: Updating {model_id} scale: {current_scale} β {desired_scale}")
- if hasattr(controlnet_pipeline, 'controlnet_scales') and 0 <= existing_index < len(controlnet_pipeline.controlnet_scales):
+ logger.info(
+ f"_update_controlnet_config: Updating {model_id} scale: {current_scale} β {desired_scale}"
+ )
+ if hasattr(controlnet_pipeline, "controlnet_scales") and 0 <= existing_index < len(
+ controlnet_pipeline.controlnet_scales
+ ):
controlnet_pipeline.controlnet_scales[existing_index] = float(desired_scale)
-
+
# Enable/disable toggle
- if 'enabled' in desired_cfg and hasattr(controlnet_pipeline, 'enabled_list'):
+ if "enabled" in desired_cfg and hasattr(controlnet_pipeline, "enabled_list"):
if 0 <= existing_index < len(controlnet_pipeline.enabled_list):
- controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg['enabled'])
+ controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg["enabled"])
- if 'preprocessor_params' in desired_cfg and hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[existing_index]:
+ if (
+ "preprocessor_params" in desired_cfg
+ and hasattr(controlnet_pipeline, "preprocessors")
+ and controlnet_pipeline.preprocessors[existing_index]
+ ):
preprocessor = controlnet_pipeline.preprocessors[existing_index]
- preprocessor.params.update(desired_cfg['preprocessor_params'])
- for param_name, param_value in desired_cfg['preprocessor_params'].items():
+ preprocessor.params.update(desired_cfg["preprocessor_params"])
+ for param_name, param_value in desired_cfg["preprocessor_params"].items():
if hasattr(preprocessor, param_name):
setattr(preprocessor, param_name, param_value)
-
+
# Pipeline references are now automatically managed during preprocessor creation
# No need to manually re-establish pipeline references for pipeline-aware processors
-
def _get_controlnet_pipeline(self):
"""
Get the ControlNet module or legacy pipeline from the structure (module-aware).
"""
# Module-installed path
- if hasattr(self.stream, '_controlnet_module'):
+ if hasattr(self.stream, "_controlnet_module"):
return self.stream._controlnet_module
# Legacy paths
- if hasattr(self.stream, 'controlnets'):
+ if hasattr(self.stream, "controlnets"):
return self.stream
- if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'controlnets'):
+ if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "controlnets"):
return self.stream.stream
- if self.wrapper and hasattr(self.wrapper, 'stream'):
- if hasattr(self.wrapper.stream, '_controlnet_module'):
+ if self.wrapper and hasattr(self.wrapper, "stream"):
+ if hasattr(self.wrapper.stream, "_controlnet_module"):
return self.wrapper.stream._controlnet_module
- if hasattr(self.wrapper.stream, 'controlnets'):
+ if hasattr(self.wrapper.stream, "controlnets"):
return self.wrapper.stream
- if hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'controlnets'):
+ if hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "controlnets"):
return self.wrapper.stream.stream
return None
def _get_current_controlnet_config(self) -> List[Dict[str, Any]]:
"""
Get current ControlNet configuration state.
-
+
Returns:
List of current ControlNet configurations
"""
controlnet_pipeline = self._get_controlnet_pipeline()
- if not controlnet_pipeline or not hasattr(controlnet_pipeline, 'controlnets') or not controlnet_pipeline.controlnets:
+ if (
+ not controlnet_pipeline
+ or not hasattr(controlnet_pipeline, "controlnets")
+ or not controlnet_pipeline.controlnets
+ ):
return []
-
+
current_config = []
for i, controlnet in enumerate(controlnet_pipeline.controlnets):
- model_id = getattr(controlnet, 'model_id', f'controlnet_{i}')
- scale = controlnet_pipeline.controlnet_scales[i] if hasattr(controlnet_pipeline, 'controlnet_scales') and i < len(controlnet_pipeline.controlnet_scales) else 1.0
+ model_id = getattr(controlnet, "model_id", f"controlnet_{i}")
+ scale = (
+ controlnet_pipeline.controlnet_scales[i]
+ if hasattr(controlnet_pipeline, "controlnet_scales") and i < len(controlnet_pipeline.controlnet_scales)
+ else 1.0
+ )
enabled_val = True
try:
- if hasattr(controlnet_pipeline, 'enabled_list') and i < len(controlnet_pipeline.enabled_list):
+ if hasattr(controlnet_pipeline, "enabled_list") and i < len(controlnet_pipeline.enabled_list):
enabled_val = bool(controlnet_pipeline.enabled_list[i])
except Exception:
enabled_val = True
config = {
- 'model_id': model_id,
- 'conditioning_scale': scale,
- 'preprocessor_params': getattr(controlnet_pipeline.preprocessors[i], 'params', {}) if hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[i] else {},
- 'enabled': enabled_val,
+ "model_id": model_id,
+ "conditioning_scale": scale,
+ "preprocessor_params": getattr(controlnet_pipeline.preprocessors[i], "params", {})
+ if hasattr(controlnet_pipeline, "preprocessors") and controlnet_pipeline.preprocessors[i]
+ else {},
+ "enabled": enabled_val,
}
current_config.append(config)
-
+
return current_config
def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None:
"""
Update IPAdapter configuration.
-
+
Args:
- desired_config: IPAdapter configuration dict containing:
+ desired_config: IPAdapter configuration dict containing:
ipadapter_model_path, image_encoder_path, style_image, scale, enabled, etc.
"""
# Find the IPAdapter pipeline
ipadapter_pipeline = self._get_ipadapter_pipeline()
-
+
if not ipadapter_pipeline:
- logger.warning(f"_update_ipadapter_config: No IPAdapter pipeline found")
+ logger.warning("_update_ipadapter_config: No IPAdapter pipeline found")
return
-
- if 'scale' in desired_config and desired_config['scale'] is not None:
- desired_scale = float(desired_config['scale'])
+
+ if "scale" in desired_config and desired_config["scale"] is not None:
+ desired_scale = float(desired_config["scale"])
# Get current scale from IPAdapter instance
- current_scale = getattr(self.stream.ipadapter, 'scale', 1.0) if hasattr(self.stream, 'ipadapter') else 1.0
-
+ current_scale = getattr(self.stream.ipadapter, "scale", 1.0) if hasattr(self.stream, "ipadapter") else 1.0
+
if current_scale != desired_scale:
logger.info(f"_update_ipadapter_config: Updating scale: {current_scale} β {desired_scale}")
-
+
# Get weight_type from IPAdapter instance
- weight_type = getattr(self.stream.ipadapter, 'weight_type', None) if hasattr(self.stream, 'ipadapter') else None
-
+ weight_type = (
+ getattr(self.stream.ipadapter, "weight_type", None) if hasattr(self.stream, "ipadapter") else None
+ )
+
# Apply scale with weight type consideration
- if weight_type is not None and hasattr(self.stream, 'ipadapter'):
+ if weight_type is not None and hasattr(self.stream, "ipadapter"):
try:
from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights
- ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")]
+
+ ip_procs = [
+ p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")
+ ]
num_layers = len(ip_procs)
weights = build_layer_weights(num_layers, desired_scale, weight_type)
if weights is not None:
@@ -1340,47 +1319,51 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None:
else:
self.stream.ipadapter.set_scale(desired_scale)
# Update our tracking attribute
- setattr(self.stream.ipadapter, 'scale', desired_scale)
+ setattr(self.stream.ipadapter, "scale", desired_scale)
except Exception:
# Do not add fallback mechanisms
raise
else:
# Simple uniform scale
- if hasattr(self.stream, 'ipadapter'):
+ if hasattr(self.stream, "ipadapter"):
# Tell diffusers_ipadapter to set the scale
self.stream.ipadapter.set_scale(desired_scale)
# Update our tracking attribute
- setattr(self.stream.ipadapter, 'scale', desired_scale)
-
+ setattr(self.stream.ipadapter, "scale", desired_scale)
# Update enabled state if provided
- if 'enabled' in desired_config and desired_config['enabled'] is not None:
- enabled_state = bool(desired_config['enabled'])
+ if "enabled" in desired_config and desired_config["enabled"] is not None:
+ enabled_state = bool(desired_config["enabled"])
# Update IPAdapter instance
- if hasattr(self.stream, 'ipadapter'):
- current_enabled = getattr(self.stream.ipadapter, 'enabled', True)
+ if hasattr(self.stream, "ipadapter"):
+ current_enabled = getattr(self.stream.ipadapter, "enabled", True)
if current_enabled != enabled_state:
- logger.info(f"_update_ipadapter_config: Updating enabled state: {current_enabled} β {enabled_state}")
- setattr(self.stream.ipadapter, 'enabled', enabled_state)
+ logger.info(
+ f"_update_ipadapter_config: Updating enabled state: {current_enabled} β {enabled_state}"
+ )
+ setattr(self.stream.ipadapter, "enabled", enabled_state)
# Update weight type if provided (affects per-layer distribution and/or per-step factor)
- if 'weight_type' in desired_config and desired_config['weight_type'] is not None:
- weight_type = desired_config['weight_type']
+ if "weight_type" in desired_config and desired_config["weight_type"] is not None:
+ weight_type = desired_config["weight_type"]
# Update IPAdapter instance
- if hasattr(self.stream, 'ipadapter'):
- setattr(self.stream.ipadapter, 'weight_type', weight_type)
-
+ if hasattr(self.stream, "ipadapter"):
+ setattr(self.stream.ipadapter, "weight_type", weight_type)
+
# For PyTorch UNet, immediately apply a per-layer scale vector so layers reflect selection types
try:
- is_tensorrt_engine = hasattr(self.stream.unet, 'engine') and hasattr(self.stream.unet, 'stream')
+ is_tensorrt_engine = hasattr(self.stream.unet, "engine") and hasattr(self.stream.unet, "stream")
if not is_tensorrt_engine:
# Compute per-layer vector using Diffusers_IPAdapter helper
from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights
+
# Count installed IP layers by scanning processors with _ip_layer_index
- ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")]
+ ip_procs = [
+ p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")
+ ]
num_layers = len(ip_procs)
# Get base weight from IPAdapter instance
- base_weight = float(getattr(self.stream.ipadapter, 'scale', 1.0))
+ base_weight = float(getattr(self.stream.ipadapter, "scale", 1.0))
weights = build_layer_weights(num_layers, base_weight, weight_type)
# If None, keep uniform base scale; else set per-layer vector
if weights is not None:
@@ -1388,7 +1371,7 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None:
else:
self.stream.ipadapter.set_scale(base_weight)
# Keep our tracking attribute in sync
- setattr(self.stream.ipadapter, 'scale', base_weight)
+ setattr(self.stream.ipadapter, "scale", base_weight)
except Exception:
# Do not add fallback mechanisms
raise
@@ -1396,191 +1379,207 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None:
def _get_ipadapter_pipeline(self):
"""
Get the IPAdapter pipeline from the pipeline structure (following ControlNet pattern).
-
+
Returns:
IPAdapter pipeline object or None if not found
"""
# Check if stream is IPAdapter pipeline directly
- if hasattr(self.stream, 'ipadapter'):
+ if hasattr(self.stream, "ipadapter"):
return self.stream
-
+
# Check if stream has nested stream (ControlNet wrapper)
- if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'ipadapter'):
+ if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "ipadapter"):
return self.stream.stream
-
+
# Check if we have a wrapper reference and can access through it
- if self.wrapper and hasattr(self.wrapper, 'stream'):
- if hasattr(self.wrapper.stream, 'ipadapter'):
+ if self.wrapper and hasattr(self.wrapper, "stream"):
+ if hasattr(self.wrapper.stream, "ipadapter"):
return self.wrapper.stream
- elif hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'ipadapter'):
+ elif hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "ipadapter"):
return self.wrapper.stream.stream
-
+
return None
def _get_current_ipadapter_config(self) -> Optional[Dict[str, Any]]:
"""
Get current IPAdapter configuration by introspecting the IPAdapter instance.
-
+
Returns:
Current IPAdapter configuration dict or None if no IPAdapter
"""
# Get config from IPAdapter instance
- if hasattr(self.stream, 'ipadapter') and self.stream.ipadapter is not None:
+ if hasattr(self.stream, "ipadapter") and self.stream.ipadapter is not None:
ipadapter = self.stream.ipadapter
-
+
config = {
- 'scale': getattr(ipadapter, 'scale', 1.0),
- 'weight_type': getattr(ipadapter, 'weight_type', None),
- 'enabled': getattr(ipadapter, 'enabled', True), # Check actual enabled state
+ "scale": getattr(ipadapter, "scale", 1.0),
+ "weight_type": getattr(ipadapter, "weight_type", None),
+ "enabled": getattr(ipadapter, "enabled", True), # Check actual enabled state
}
-
+
# Add static initialization fields
- if hasattr(self.stream, '_ipadapter_module'):
+ if hasattr(self.stream, "_ipadapter_module"):
module_config = self.stream._ipadapter_module.config
- config.update({
- 'style_image_key': module_config.style_image_key,
- 'num_image_tokens': module_config.num_image_tokens,
- 'type': module_config.type.value,
- })
-
+ config.update(
+ {
+ "style_image_key": module_config.style_image_key,
+ "num_image_tokens": module_config.num_image_tokens,
+ "type": module_config.type.value,
+ }
+ )
+
# Check if style image is set
ipadapter_pipeline = self._get_ipadapter_pipeline()
- if ipadapter_pipeline and hasattr(ipadapter_pipeline, 'style_image') and ipadapter_pipeline.style_image:
- config['has_style_image'] = True
+ if ipadapter_pipeline and hasattr(ipadapter_pipeline, "style_image") and ipadapter_pipeline.style_image:
+ config["has_style_image"] = True
else:
- config['has_style_image'] = False
-
+ config["has_style_image"] = False
+
return config
-
+
# No IPAdapter instance found
return None
def _get_current_hook_config(self, hook_type: str) -> List[Dict[str, Any]]:
"""
Get current hook configuration by introspecting the hook module state.
-
+
Args:
hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.)
-
+
Returns:
List of processor configurations or empty list if no module
"""
# Get the hook module
module_attr_name = f"_{hook_type}_module"
hook_module = getattr(self.stream, module_attr_name, None)
-
+
if not hook_module:
return []
-
+
# Get processors from the module
- processors = getattr(hook_module, 'processors', [])
-
+ processors = getattr(hook_module, "processors", [])
+
config = []
for i, processor in enumerate(processors):
proc_config = {
- 'type': getattr(processor, '__class__').__name__,
- 'order': getattr(processor, 'order', i),
- 'enabled': getattr(processor, 'enabled', True),
+ "type": getattr(processor, "__class__").__name__,
+ "order": getattr(processor, "order", i),
+ "enabled": getattr(processor, "enabled", True),
}
-
+
# Try to get processor parameters
- if hasattr(processor, 'params'):
- proc_config['params'] = dict(processor.params)
-
+ if hasattr(processor, "params"):
+ proc_config["params"] = dict(processor.params)
+
config.append(proc_config)
-
+
return config
def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any]]) -> None:
"""
Update hook configuration by modifying existing processors in-place instead of recreating them.
-
+
Args:
hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.)
desired_config: List of processor configurations
"""
logger.info(f"_update_hook_config: Updating {hook_type} with {len(desired_config)} processors")
-
+
# Get or create the hook module
module_attr_name = f"_{hook_type}_module"
hook_module = getattr(self.stream, module_attr_name, None)
-
+
if not hook_module:
logger.info(f"_update_hook_config: No existing {hook_type} module, creating new one")
# Create the appropriate hook module
try:
if hook_type in ["image_preprocessing", "image_postprocessing"]:
- from streamdiffusion.modules.image_processing_module import ImagePreprocessingModule, ImagePostprocessingModule
+ from streamdiffusion.modules.image_processing_module import (
+ ImagePostprocessingModule,
+ ImagePreprocessingModule,
+ )
+
if hook_type == "image_preprocessing":
hook_module = ImagePreprocessingModule()
else:
hook_module = ImagePostprocessingModule()
elif hook_type in ["latent_preprocessing", "latent_postprocessing"]:
- from streamdiffusion.modules.latent_processing_module import LatentPreprocessingModule, LatentPostprocessingModule
+ from streamdiffusion.modules.latent_processing_module import (
+ LatentPostprocessingModule,
+ LatentPreprocessingModule,
+ )
+
if hook_type == "latent_preprocessing":
hook_module = LatentPreprocessingModule()
else:
hook_module = LatentPostprocessingModule()
else:
raise ValueError(f"Unknown hook type: {hook_type}")
-
+
# Install the module
hook_module.install(self.stream)
setattr(self.stream, module_attr_name, hook_module)
logger.info(f"_update_hook_config: Created and installed {hook_type} module")
-
+
except Exception as e:
logger.error(f"_update_hook_config: Failed to create {hook_type} module: {e}")
return
-
- logger.info(f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors")
-
+
+ logger.info(
+ f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors"
+ )
+
# Modify existing processors in-place instead of clearing and recreating
for i, proc_config in enumerate(desired_config):
- processor_type = proc_config.get('type', 'unknown')
- enabled = proc_config.get('enabled', True)
- params = proc_config.get('params', {})
-
+ processor_type = proc_config.get("type", "unknown")
+ enabled = proc_config.get("enabled", True)
+ params = proc_config.get("params", {})
+
logger.info(f"_update_hook_config: Processing config {i}: type={processor_type}, enabled={enabled}")
-
+
if i < len(hook_module.processors):
# Modify existing processor
existing_processor = hook_module.processors[i]
-
+
# Get the current processor type from registry name if available, otherwise use class name
- current_type = existing_processor.params.get('_registry_name') if hasattr(existing_processor, 'params') else None
+ current_type = (
+ existing_processor.params.get("_registry_name") if hasattr(existing_processor, "params") else None
+ )
if not current_type:
current_type = existing_processor.__class__.__name__
-
- logger.info(f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}")
-
+
+ logger.info(
+ f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}"
+ )
+
# If processor type changed, replace it
if current_type.lower() != processor_type.lower():
logger.info(f"_update_hook_config: Type changed, replacing processor {i}")
try:
from streamdiffusion.preprocessing.processors import get_preprocessor
-
+
# Determine normalization context from hook type
- if 'latent' in hook_type:
- normalization_context = 'latent'
+ if "latent" in hook_type:
+ normalization_context = "latent"
else:
# Image preprocessing/postprocessing uses 'pipeline' context
- normalization_context = 'pipeline'
-
+ normalization_context = "pipeline"
+
new_processor = get_preprocessor(
- processor_type,
- pipeline_ref=getattr(self, 'stream', None),
- normalization_context=normalization_context
+ processor_type,
+ pipeline_ref=getattr(self, "stream", None),
+ normalization_context=normalization_context,
)
-
+
# Copy attributes from old processor
- setattr(new_processor, 'order', getattr(existing_processor, 'order', i))
- setattr(new_processor, 'enabled', enabled)
-
+ setattr(new_processor, "order", getattr(existing_processor, "order", i))
+ setattr(new_processor, "enabled", enabled)
+
# Set parameters
- if hasattr(new_processor, 'params'):
+ if hasattr(new_processor, "params"):
new_processor.params.update(params)
-
+
hook_module.processors[i] = new_processor
logger.info(f"_update_hook_config: Successfully replaced processor {i} with {processor_type}")
except Exception as e:
@@ -1588,15 +1587,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any
else:
# Same type, just update attributes
logger.info(f"_update_hook_config: Same type, updating attributes for processor {i}")
- setattr(existing_processor, 'enabled', enabled)
-
+ setattr(existing_processor, "enabled", enabled)
+
# Update parameters
- if hasattr(existing_processor, 'params'):
+ if hasattr(existing_processor, "params"):
existing_processor.params.update(params)
for param_name, param_value in params.items():
if hasattr(existing_processor, param_name):
setattr(existing_processor, param_name, param_value)
-
+
logger.info(f"_update_hook_config: Updated processor {i} enabled={enabled}, params={params}")
else:
# Add new processor
@@ -1606,12 +1605,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any
logger.info(f"_update_hook_config: Successfully added processor {i}: {processor_type}")
except Exception as e:
logger.error(f"_update_hook_config: Failed to add processor {i}: {e}")
-
+
# Remove extra processors if config is shorter
while len(hook_module.processors) > len(desired_config):
removed_idx = len(hook_module.processors) - 1
removed_processor = hook_module.processors.pop()
- logger.info(f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}")
-
- logger.info(f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors")
+ logger.info(
+ f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}"
+ )
+ logger.info(
+ f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors"
+ )
diff --git a/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py b/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py
index 355ad6fc1..2a7eee954 100644
--- a/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py
+++ b/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py
@@ -9,18 +9,19 @@
python -m streamdiffusion.tools.compile_depth_anything_tensorrt --model_size small --resolution 518
"""
-import torch
import logging
-import os
from pathlib import Path
-from typing import Optional
+
import fire
+import torch
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
try:
import tensorrt as trt
+
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
@@ -28,6 +29,7 @@
try:
import onnx
+
ONNX_AVAILABLE = True
except ImportError:
ONNX_AVAILABLE = False
@@ -50,10 +52,7 @@
def export_depth_anything_to_onnx(
- onnx_path: Path,
- model_size: str = "small",
- resolution: int = 518,
- device: str = "cuda"
+ onnx_path: Path, model_size: str = "small", resolution: int = 518, device: str = "cuda"
) -> bool:
"""Export Depth Anything model to ONNX format"""
try:
@@ -102,6 +101,7 @@ def export_depth_anything_to_onnx(
except Exception as e:
logger.error(f"Failed to export ONNX: {e}")
import traceback
+
traceback.print_exc()
return False
@@ -126,7 +126,7 @@ def build_tensorrt_engine(
parser = trt.OnnxParser(network, trt_logger)
# Parse ONNX
- with open(onnx_path, 'rb') as f:
+ with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
for i in range(parser.num_errors):
logger.error(f"ONNX parse error: {parser.get_error(i)}")
@@ -142,10 +142,12 @@ def build_tensorrt_engine(
# Set optimization profile for fixed resolution
profile = builder.create_optimization_profile()
- profile.set_shape("input",
- (1, 3, resolution, resolution), # min
- (1, 3, resolution, resolution), # opt
- (1, 3, resolution, resolution)) # max
+ profile.set_shape(
+ "input",
+ (1, 3, resolution, resolution), # min
+ (1, 3, resolution, resolution), # opt
+ (1, 3, resolution, resolution),
+ ) # max
config.add_optimization_profile(profile)
# Build engine
@@ -158,7 +160,7 @@ def build_tensorrt_engine(
# Save engine
engine_path.parent.mkdir(parents=True, exist_ok=True)
- with open(engine_path, 'wb') as f:
+ with open(engine_path, "wb") as f:
f.write(serialized_engine)
logger.info(f"TensorRT engine saved: {engine_path}")
@@ -167,6 +169,7 @@ def build_tensorrt_engine(
except Exception as e:
logger.error(f"Failed to build TensorRT engine: {e}")
import traceback
+
traceback.print_exc()
return False
@@ -206,7 +209,7 @@ def compile_depth_anything(
# Check if engine already exists
if engine_path.exists():
logger.info(f"Engine already exists: {engine_path}")
- overwrite = input("Overwrite? (y/N): ").lower().strip() == 'y'
+ overwrite = input("Overwrite? (y/N): ").lower().strip() == "y"
if not overwrite:
return
@@ -228,9 +231,9 @@ def compile_depth_anything(
logger.info("Removed intermediate ONNX file")
logger.info(f"\nSuccess! Engine saved to: {engine_path}")
- logger.info(f"\nTo use in config:")
- logger.info(f' preprocessor: "depth_tensorrt"')
- logger.info(f' preprocessor_params:')
+ logger.info("\nTo use in config:")
+ logger.info(' preprocessor: "depth_tensorrt"')
+ logger.info(" preprocessor_params:")
logger.info(f' engine_path: "{engine_path}"')
diff --git a/src/streamdiffusion/tools/compile_raft_tensorrt.py b/src/streamdiffusion/tools/compile_raft_tensorrt.py
index 8734987ef..d3faef953 100644
--- a/src/streamdiffusion/tools/compile_raft_tensorrt.py
+++ b/src/streamdiffusion/tools/compile_raft_tensorrt.py
@@ -1,21 +1,24 @@
-import torch
import logging
from pathlib import Path
-from typing import Optional
+
import fire
+import torch
+
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
try:
import tensorrt as trt
+
TENSORRT_AVAILABLE = True
except ImportError:
TENSORRT_AVAILABLE = False
logger.error("TensorRT not available. Please install it first.")
try:
- from torchvision.models.optical_flow import raft_small, Raft_Small_Weights
+ from torchvision.models.optical_flow import Raft_Small_Weights, raft_small
+
TORCHVISION_AVAILABLE = True
except ImportError:
TORCHVISION_AVAILABLE = False
@@ -28,11 +31,11 @@ def export_raft_to_onnx(
min_width: int = 512,
max_height: int = 512,
max_width: int = 512,
- device: str = "cuda"
+ device: str = "cuda",
) -> bool:
"""
Export RAFT model to ONNX format
-
+
Args:
onnx_path: Path to save the ONNX model
min_height: Minimum input height for the model
@@ -40,41 +43,41 @@ def export_raft_to_onnx(
max_height: Maximum input height for the model
max_width: Maximum input width for the model
device: Device to use for export
-
+
Returns:
True if successful, False otherwise
"""
if not TORCHVISION_AVAILABLE:
logger.error("torchvision is required but not installed")
return False
-
+
logger.info(f"Exporting RAFT model to ONNX: {onnx_path}")
logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}")
-
+
try:
# Load RAFT model
logger.info("Loading RAFT Small model...")
raft_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=True)
raft_model = raft_model.to(device=device)
raft_model.eval()
-
+
# Create dummy inputs using max resolution for export
dummy_frame1 = torch.randn(1, 3, max_height, max_width).to(device)
dummy_frame2 = torch.randn(1, 3, max_height, max_width).to(device)
-
+
# Apply RAFT preprocessing if available
weights = Raft_Small_Weights.DEFAULT
- if hasattr(weights, 'transforms') and weights.transforms is not None:
+ if hasattr(weights, "transforms") and weights.transforms is not None:
transforms = weights.transforms()
dummy_frame1, dummy_frame2 = transforms(dummy_frame1, dummy_frame2)
-
+
# Make batch, height, and width dimensions dynamic
dynamic_axes = {
"frame1": {0: "batch_size", 2: "height", 3: "width"},
"frame2": {0: "batch_size", 2: "height", 3: "width"},
"flow": {0: "batch_size", 2: "height", 3: "width"},
}
-
+
logger.info("Exporting to ONNX...")
with torch.no_grad():
torch.onnx.export(
@@ -82,22 +85,23 @@ def export_raft_to_onnx(
(dummy_frame1, dummy_frame2),
str(onnx_path),
verbose=False,
- input_names=['frame1', 'frame2'],
- output_names=['flow'],
+ input_names=["frame1", "frame2"],
+ output_names=["flow"],
opset_version=17,
export_params=True,
dynamic_axes=dynamic_axes,
)
-
+
del raft_model
torch.cuda.empty_cache()
-
+
logger.info(f"Successfully exported ONNX model to {onnx_path}")
return True
-
+
except Exception as e:
logger.error(f"Failed to export ONNX model: {e}")
import traceback
+
traceback.print_exc()
return False
@@ -110,11 +114,11 @@ def build_tensorrt_engine(
max_height: int = 512,
max_width: int = 512,
fp16: bool = True,
- workspace_size_gb: int = 4
+ workspace_size_gb: int = 4,
) -> bool:
"""
Build TensorRT engine from ONNX model
-
+
Args:
onnx_path: Path to the ONNX model
engine_path: Path to save the TensorRT engine
@@ -124,74 +128,74 @@ def build_tensorrt_engine(
max_width: Maximum input width for optimization
fp16: Enable FP16 precision mode
workspace_size_gb: Maximum workspace size in GB
-
+
Returns:
True if successful, False otherwise
"""
if not TENSORRT_AVAILABLE:
logger.error("TensorRT is required but not installed")
return False
-
+
if not onnx_path.exists():
logger.error(f"ONNX model not found: {onnx_path}")
return False
-
+
logger.info(f"Building TensorRT engine from ONNX model: {onnx_path}")
logger.info(f"Output path: {engine_path}")
logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}")
logger.info(f"FP16 mode: {fp16}")
logger.info("This may take several minutes...")
-
+
try:
builder = trt.Builder(trt.Logger(trt.Logger.INFO))
network = builder.create_network() # EXPLICIT_BATCH deprecated/ignored in TRT 10.x
parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING))
-
+
logger.info("Parsing ONNX model...")
- with open(onnx_path, 'rb') as model:
+ with open(onnx_path, "rb") as model:
if not parser.parse(model.read()):
logger.error("Failed to parse ONNX model")
for error in range(parser.num_errors):
logger.error(f"Parser error: {parser.get_error(error)}")
return False
-
+
logger.info("Configuring TensorRT builder...")
config = builder.create_builder_config()
-
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size_gb * (1 << 30))
-
+
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
logger.info("FP16 mode enabled")
-
+
# Calculate optimal resolution (middle point)
opt_height = (min_height + max_height) // 2
opt_width = (min_width + max_width) // 2
-
+
profile = builder.create_optimization_profile()
min_shape = (1, 3, min_height, min_width)
opt_shape = (1, 3, opt_height, opt_width)
max_shape = (1, 3, max_height, max_width)
-
+
profile.set_shape("frame1", min_shape, opt_shape, max_shape)
profile.set_shape("frame2", min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
-
+
logger.info("Building TensorRT engine... (this will take a while)")
engine = builder.build_serialized_network(network, config)
-
+
if engine is None:
logger.error("Failed to build TensorRT engine")
return False
-
+
logger.info(f"Saving engine to {engine_path}")
engine_path.parent.mkdir(parents=True, exist_ok=True)
- with open(engine_path, 'wb') as f:
+ with open(engine_path, "wb") as f:
f.write(engine)
-
+
logger.info(f"Successfully built and saved TensorRT engine: {engine_path}")
- logger.info(f"Engine size: {engine_path.stat().st_size / (1024*1024):.2f} MB")
-
+ logger.info(f"Engine size: {engine_path.stat().st_size / (1024 * 1024):.2f} MB")
+
# Delete ONNX file after successful engine creation
try:
if onnx_path.exists():
@@ -199,12 +203,13 @@ def build_tensorrt_engine(
logger.info(f"Deleted ONNX file: {onnx_path}")
except Exception as e:
logger.warning(f"Failed to delete ONNX file: {e}")
-
+
return True
-
+
except Exception as e:
logger.error(f"Failed to build TensorRT engine: {e}")
import traceback
+
traceback.print_exc()
return False
@@ -216,11 +221,11 @@ def compile_raft(
device: str = "cuda",
fp16: bool = True,
workspace_size_gb: int = 4,
- force_rebuild: bool = False
+ force_rebuild: bool = False,
):
"""
Main function to compile RAFT model to TensorRT engine
-
+
Args:
min_resolution: Minimum input resolution as "HxW" (e.g., "512x512") (default: "512x512")
max_resolution: Maximum input resolution as "HxW" (e.g., "1024x1024") (default: "512x512")
@@ -234,46 +239,46 @@ def compile_raft(
logger.error("TensorRT is not available. Please install it first using:")
logger.error(" python -m streamdiffusion.tools.install-tensorrt")
return
-
+
if not TORCHVISION_AVAILABLE:
logger.error("torchvision is not available. Please install it first using:")
logger.error(" pip install torchvision")
return
-
+
# Parse resolution strings
try:
- min_height, min_width = map(int, min_resolution.split('x'))
+ min_height, min_width = map(int, min_resolution.split("x"))
except:
logger.error(f"Invalid min_resolution format: {min_resolution}. Expected format: HxW (e.g., 512x512)")
return
-
+
try:
- max_height, max_width = map(int, max_resolution.split('x'))
+ max_height, max_width = map(int, max_resolution.split("x"))
except:
logger.error(f"Invalid max_resolution format: {max_resolution}. Expected format: HxW (e.g., 1024x1024)")
return
-
+
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
-
+
# Add resolution suffix to filenames
onnx_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.onnx"
engine_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.engine"
-
- logger.info("="*80)
+
+ logger.info("=" * 80)
logger.info("RAFT TensorRT Compilation")
- logger.info("="*80)
+ logger.info("=" * 80)
logger.info(f"Output directory: {output_path.absolute()}")
logger.info(f"Resolution range: {min_resolution} - {max_resolution}")
logger.info(f"ONNX path: {onnx_path}")
logger.info(f"Engine path: {engine_path}")
- logger.info("="*80)
-
+ logger.info("=" * 80)
+
if engine_path.exists() and not force_rebuild:
logger.info(f"TensorRT engine already exists: {engine_path}")
logger.info("Use --force_rebuild to rebuild it")
return
-
+
if not onnx_path.exists() or force_rebuild:
logger.info("\n[Step 1/2] Exporting RAFT to ONNX...")
if not export_raft_to_onnx(onnx_path, min_height, min_width, max_height, max_width, device):
@@ -281,21 +286,22 @@ def compile_raft(
return
else:
logger.info(f"\n[Step 1/2] ONNX model already exists: {onnx_path}")
-
+
logger.info("\n[Step 2/2] Building TensorRT engine...")
- if not build_tensorrt_engine(onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb):
+ if not build_tensorrt_engine(
+ onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb
+ ):
logger.error("Failed to build TensorRT engine")
return
-
- logger.info("\n" + "="*80)
+
+ logger.info("\n" + "=" * 80)
logger.info("β Compilation completed successfully!")
- logger.info("="*80)
+ logger.info("=" * 80)
logger.info(f"Engine path: {engine_path.absolute()}")
logger.info("\nYou can now use this engine in TemporalNetTensorRTPreprocessor:")
logger.info(f' engine_path="{engine_path.absolute()}"')
- logger.info("="*80)
+ logger.info("=" * 80)
if __name__ == "__main__":
fire.Fire(compile_raft)
-
diff --git a/src/streamdiffusion/utils/__init__.py b/src/streamdiffusion/utils/__init__.py
index 00ff7cf7d..b40413d24 100644
--- a/src/streamdiffusion/utils/__init__.py
+++ b/src/streamdiffusion/utils/__init__.py
@@ -1,5 +1,6 @@
from .reporting import report_error
+
__all__ = [
"report_error",
-]
\ No newline at end of file
+]
diff --git a/src/streamdiffusion/utils/reporting.py b/src/streamdiffusion/utils/reporting.py
index 44838d9c8..25e650c6f 100644
--- a/src/streamdiffusion/utils/reporting.py
+++ b/src/streamdiffusion/utils/reporting.py
@@ -25,5 +25,3 @@ def report_error(
stacklevel=stacklevel,
extra={"report_error": True},
)
-
-
diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py
index 36abe02b0..cecfec94f 100644
--- a/src/streamdiffusion/wrapper.py
+++ b/src/streamdiffusion/wrapper.py
@@ -131,6 +131,10 @@ def __init__(
static_shapes: bool = False,
fp8_allow_fp16_fallback: bool = False,
builder_optimization_level: Optional[int] = None,
+ # CUDA IPC output (SDβTD zero-copy GPU transport via cuda-link)
+ use_cuda_ipc_output: bool = False,
+ cuda_ipc_shm_name: Optional[str] = None,
+ cuda_ipc_num_slots: int = 2,
):
"""
Initializes the StreamDiffusionWrapper.
@@ -314,6 +318,10 @@ def __init__(
self._output_pin_buf: Optional[torch.Tensor] = None # pinned CPU buffer for async D2H output
self._output_gpu_buf: Optional[torch.Tensor] = None # persistent GPU fp32 staging (avoids per-frame alloc)
self._d2h_event: Optional[torch.cuda.Event] = None # event for fine-grained D2H sync
+ self.use_cuda_ipc_output = use_cuda_ipc_output
+ self._cuda_ipc_shm_name = cuda_ipc_shm_name
+ self._cuda_ipc_num_slots = cuda_ipc_num_slots
+ self._cuda_ipc_exporter = None # lazy-init on first frame via _lazy_init_ipc_exporter
self.batch_size = len(t_index_list) * frame_buffer_size if use_denoising_batch else frame_buffer_size
self.min_batch_size = min_batch_size
self.max_batch_size = max_batch_size
@@ -907,6 +915,15 @@ def postprocess_image(
Union[Image.Image, List[Image.Image]]
The postprocessed image.
"""
+ # CUDA IPC fast-path: export to TD via zero-copy GPU IPC (cuda-link CUDAIPCExporter).
+ # Skips D2H, CPU repack, and CPU SHM write. Returns None to let the TD-side
+ # _send_output_frame early-exit (it already guards on output_image is None).
+ if self.use_cuda_ipc_output and self._cuda_ipc_shm_name:
+ bgra = self._ipc_pack_rgba(image_tensor)
+ exporter = self._lazy_init_ipc_exporter(bgra.shape[0], bgra.shape[1])
+ exporter.export_frame(bgra.data_ptr(), bgra.numel())
+ return None
+
# Fast paths for non-PIL outputs (avoid unnecessary conversions)
if output_type == "latent":
return image_tensor
@@ -947,6 +964,44 @@ def postprocess_image(
else:
return postprocess_image(image_tensor.cpu(), output_type=output_type)[0]
+ def _ipc_pack_rgba(self, image_tensor: torch.Tensor) -> torch.Tensor:
+ """Convert pipeline output to HWC uint8 BGRA on GPU for cuda-link wire contract."""
+ denorm = self._denormalize_on_gpu(image_tensor) # NCHW [0,1]
+ if denorm.dim() == 4:
+ denorm = denorm[0] # CHW [0,1]
+ rgb_u8 = (denorm * 255).clamp(0, 255).to(torch.uint8) # CHW uint8
+ rgb_hwc = rgb_u8.permute(1, 2, 0).contiguous() # HWC RGB
+ alpha = torch.full_like(rgb_hwc[..., :1], 255)
+ return torch.cat([rgb_hwc[..., 2:3], rgb_hwc[..., 1:2], rgb_hwc[..., 0:1], alpha], dim=-1).contiguous()
+
+ def _lazy_init_ipc_exporter(self, height: int, width: int):
+ """Initialize CUDAIPCExporter on first frame (lazy to defer CUDA IPC SHM creation)."""
+ if self._cuda_ipc_exporter is not None:
+ return self._cuda_ipc_exporter
+ from streamdiffusion._compat.cuda_ipc import CUDAIPCExporter
+
+ exporter = CUDAIPCExporter(
+ shm_name=self._cuda_ipc_shm_name,
+ height=height,
+ width=width,
+ channels=4,
+ dtype="uint8",
+ num_slots=self._cuda_ipc_num_slots,
+ debug=False,
+ )
+ exporter.initialize()
+ self._cuda_ipc_exporter = exporter
+ return exporter
+
+ def cleanup_cuda_ipc(self) -> None:
+ """Tear down the CUDA IPC exporter and release its SHM + GPU resources."""
+ if self._cuda_ipc_exporter is not None:
+ try:
+ self._cuda_ipc_exporter.cleanup()
+ except Exception:
+ pass
+ self._cuda_ipc_exporter = None
+
def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""
Denormalize image tensor on GPU for efficiency.
@@ -2150,7 +2205,7 @@ def _load_model(
max_batch_size=self.max_batch_size,
min_batch_size=self.min_batch_size,
cuda_stream=cuda_stream,
- use_cuda_graph=True,
+ use_cuda_graph=False, # TRT's genericReformat uses legacy stream during execute_async_v3 β incompatible with graph capture (901)
unet=None,
model_path=cfg["model_id"],
opt_image_height=self.height,
diff --git a/tests/quality/README.md b/tests/quality/README.md
new file mode 100644
index 000000000..2dd95422c
--- /dev/null
+++ b/tests/quality/README.md
@@ -0,0 +1,77 @@
+# Quality Regression Harness
+
+Compares FP8-TRT output against FP16-TRT golden reference images using SSIM + LPIPS.
+
+## First-time setup
+
+```powershell
+# Install dev deps (lpips, scikit-image)
+pip install -r requirements-dev.txt
+
+# Build FP16-TRT engines for both fixture models (if not already cached)
+python -m streamdiffusion.acceleration.tensorrt.build --model stabilityai/sd-turbo
+python -m streamdiffusion.acceleration.tensorrt.build --model stabilityai/sdxl-turbo
+
+# Generate goldens + update manifest
+python tests/quality/regenerate_golden.py --update-manifest
+
+# Build FP8-TRT engines
+python -m streamdiffusion.acceleration.tensorrt.build --fp8 --model stabilityai/sd-turbo
+python -m streamdiffusion.acceleration.tensorrt.build --fp8 --model stabilityai/sdxl-turbo
+
+# Seed thresholds from FP8 baseline (run once, check in thresholds.yaml)
+python tests/quality/run_compare.py --baseline
+```
+
+## Running the harness
+
+```powershell
+# Full comparison (both fixtures)
+python tests/quality/run_compare.py
+
+# Single fixture
+python tests/quality/run_compare.py --fixture sdxl_turbo_img2img_plain
+```
+
+## Exit codes
+
+| Code | Meaning |
+|------|---------|
+| 0 | All fixtures pass SSIM/LPIPS thresholds |
+| 1 | One or more fixtures fail thresholds |
+| 2 | Manifest version mismatch β results would be meaningless |
+| 3 | Golden PNG missing β run `regenerate_golden.py --update-manifest` first |
+
+## Refreshing goldens after a dep bump
+
+```powershell
+pip install -r requirements.txt -r requirements-dev.txt # upgrade deps
+python tests/quality/regenerate_golden.py --update-manifest # regenerate + re-hash
+python tests/quality/run_compare.py --baseline # re-seed thresholds
+```
+
+## File layout
+
+```
+tests/quality/
+βββ run_compare.py # orchestrator β runs FP8, computes metrics, checks thresholds
+βββ regenerate_golden.py # generates FP16-TRT goldens and updates manifest hashes
+βββ manifest.json # pinned dep versions + golden sha256 (abort on mismatch)
+βββ thresholds.yaml # SSIM/LPIPS floors per fixture (seed with --baseline)
+βββ fixtures/ # fixture parameter files
+β βββ sd_turbo_img2img_plain.json
+β βββ sdxl_turbo_img2img_plain.json
+βββ goldens/ # FP16-TRT reference PNGs (generated, then checked in)
+β βββ sd_turbo_img2img_plain.png
+β βββ sdxl_turbo_img2img_plain.png
+βββ outputs/ # generated on run β gitignored
+ βββ *_fp8.png # FP8 output for each fixture
+ βββ *_comparison.png # side-by-side comparison
+ βββ report.json # SSIM/LPIPS results
+```
+
+## Phase 3 gate
+
+This harness is the prerequisite for Phase 3 (feature-aware calibration + Q/DQ exclusion
+narrowing). Do not narrow exclusions until both fixtures pass at stable thresholds across
+at least two independent runs.
diff --git a/tests/quality/fixtures/sd_turbo_img2img_plain.json b/tests/quality/fixtures/sd_turbo_img2img_plain.json
new file mode 100644
index 000000000..a83b29d69
--- /dev/null
+++ b/tests/quality/fixtures/sd_turbo_img2img_plain.json
@@ -0,0 +1,15 @@
+{
+ "model_id": "stabilityai/sd-turbo",
+ "prompt": "a portrait of a young woman with long brown hair, soft studio lighting, photorealistic",
+ "negative_prompt": "",
+ "seed": 2,
+ "t_index_list": [32, 45],
+ "width": 512,
+ "height": 512,
+ "cfg_type": "none",
+ "guidance_scale": 1.0,
+ "num_inference_steps": 50,
+ "delta": 1.0,
+ "warmup": 1,
+ "use_denoising_batch": true
+}
diff --git a/tests/quality/fixtures/sdxl_turbo_img2img_plain.json b/tests/quality/fixtures/sdxl_turbo_img2img_plain.json
new file mode 100644
index 000000000..e4dc05c73
--- /dev/null
+++ b/tests/quality/fixtures/sdxl_turbo_img2img_plain.json
@@ -0,0 +1,15 @@
+{
+ "model_id": "stabilityai/sdxl-turbo",
+ "prompt": "a portrait of a young woman with long brown hair, soft studio lighting, photorealistic",
+ "negative_prompt": "",
+ "seed": 2,
+ "t_index_list": [32, 45],
+ "width": 512,
+ "height": 512,
+ "cfg_type": "none",
+ "guidance_scale": 1.0,
+ "num_inference_steps": 50,
+ "delta": 1.0,
+ "warmup": 1,
+ "use_denoising_batch": true
+}
diff --git a/tests/quality/goldens/.gitkeep b/tests/quality/goldens/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/quality/manifest.json b/tests/quality/manifest.json
new file mode 100644
index 000000000..47ca14292
--- /dev/null
+++ b/tests/quality/manifest.json
@@ -0,0 +1,29 @@
+{
+ "_note": "Version pins for the quality regression harness. Run regenerate_golden.py --update-manifest after any dep bump to refresh hashes. run_compare.py aborts if installed versions diverge.",
+ "versions": {
+ "torch": "2.8.0+cu128",
+ "tensorrt": "10.16.1.11",
+ "nvidia_modelopt": "0.43.0",
+ "diffusers": "0.38.0"
+ },
+ "fixtures": {
+ "sd_turbo_img2img_plain": {
+ "model_id": "stabilityai/sd-turbo",
+ "seed": 2,
+ "t_index_list": [
+ 32,
+ 45
+ ],
+ "golden_sha256": "9056eadcbfa8637c9ad3aaaa09766ccf69926fc87ac7730219e82908e33c8d97"
+ },
+ "sdxl_turbo_img2img_plain": {
+ "model_id": "stabilityai/sdxl-turbo",
+ "seed": 2,
+ "t_index_list": [
+ 32,
+ 45
+ ],
+ "golden_sha256": "8e6b61fde877ed4752e297c33364ff7bf616376fb3be68b55f4df311e58bdbc3"
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/quality/regenerate_golden.py b/tests/quality/regenerate_golden.py
new file mode 100644
index 000000000..7c86fd449
--- /dev/null
+++ b/tests/quality/regenerate_golden.py
@@ -0,0 +1,156 @@
+"""Generate FP16-TRT golden reference images for the quality regression harness.
+
+Usage:
+ python tests/quality/regenerate_golden.py
+ python tests/quality/regenerate_golden.py --update-manifest
+ python tests/quality/regenerate_golden.py --fixture sdxl_turbo_img2img_plain
+
+Goldens are FP16-TRT engine outputs (not bare PyTorch FP16) so they catch
+TRT-specific regressions that matter in production.
+
+--update-manifest Also updates tests/quality/manifest.json with sha256 hashes
+ of the generated goldens and current installed dep versions.
+"""
+
+import argparse
+import hashlib
+import json
+import logging
+import os
+import sys
+
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+
+from streamdiffusion import StreamDiffusionWrapper
+
+
+logger = logging.getLogger("quality.regenerate")
+
+TESTS_QUALITY_DIR = os.path.dirname(os.path.abspath(__file__))
+REPO_ROOT = os.path.join(TESTS_QUALITY_DIR, "..", "..")
+INPUT_IMAGE = os.path.join(REPO_ROOT, "images", "inputs", "input.png")
+GOLDENS_DIR = os.path.join(TESTS_QUALITY_DIR, "goldens")
+FIXTURES_DIR = os.path.join(TESTS_QUALITY_DIR, "fixtures")
+MANIFEST_PATH = os.path.join(TESTS_QUALITY_DIR, "manifest.json")
+THRESHOLDS_PATH = os.path.join(TESTS_QUALITY_DIR, "thresholds.yaml")
+
+
+def _sha256(path: str) -> str:
+ h = hashlib.sha256()
+ with open(path, "rb") as f:
+ for chunk in iter(lambda: f.read(65536), b""):
+ h.update(chunk)
+ return h.hexdigest()
+
+
+def _current_versions() -> dict:
+ import importlib.metadata
+
+ versions = {}
+ for pkg, key in [
+ ("torch", "torch"),
+ ("tensorrt", "tensorrt"),
+ ("nvidia-modelopt", "nvidia_modelopt"),
+ ("diffusers", "diffusers"),
+ ]:
+ try:
+ versions[key] = importlib.metadata.version(pkg)
+ except Exception:
+ # TRT on Windows is often installed via NVIDIA's custom mechanism
+ # with no .dist-info β fall back to the package's __version__.
+ if pkg == "tensorrt":
+ try:
+ import tensorrt as _trt
+
+ versions[key] = _trt.__version__
+ except Exception:
+ versions[key] = "unknown"
+ else:
+ versions[key] = "unknown"
+ return versions
+
+
+def run_fixture(fixture_name: str, fixture: dict) -> str:
+ """Run FP16-TRT inference for one fixture and return the saved golden path."""
+ golden_path = os.path.join(GOLDENS_DIR, f"{fixture_name}.png")
+ os.makedirs(GOLDENS_DIR, exist_ok=True)
+
+ logger.info(f"[{fixture_name}] Running FP16-TRT inference β {golden_path}")
+
+ stream = StreamDiffusionWrapper(
+ model_id_or_path=fixture["model_id"],
+ t_index_list=fixture["t_index_list"],
+ frame_buffer_size=1,
+ width=fixture["width"],
+ height=fixture["height"],
+ warmup=fixture.get("warmup", 1),
+ acceleration="tensorrt",
+ mode="img2img",
+ use_denoising_batch=fixture.get("use_denoising_batch", True),
+ cfg_type=fixture.get("cfg_type", "none"),
+ seed=fixture["seed"],
+ )
+ stream.prepare(
+ prompt=fixture["prompt"],
+ negative_prompt=fixture.get("negative_prompt", ""),
+ num_inference_steps=fixture.get("num_inference_steps", 50),
+ guidance_scale=fixture.get("guidance_scale", 1.0),
+ delta=fixture.get("delta", 1.0),
+ )
+
+ image_tensor = stream.preprocess_image(INPUT_IMAGE)
+ for _ in range(stream.batch_size - 1):
+ stream(image=image_tensor)
+ output_image = stream(image=image_tensor)
+ output_image.save(golden_path)
+
+ sha = _sha256(golden_path)
+ logger.info(f"[{fixture_name}] Golden saved: {golden_path} sha256={sha[:16]}...")
+ return sha
+
+
+def main():
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
+ parser.add_argument(
+ "--update-manifest", action="store_true", help="Update manifest.json with current versions + golden hashes"
+ )
+ parser.add_argument("--fixture", default=None, help="Run only this fixture (default: all)")
+ args = parser.parse_args()
+
+ fixture_files = sorted(f for f in os.listdir(FIXTURES_DIR) if f.endswith(".json"))
+ if args.fixture:
+ fixture_files = [f"{args.fixture}.json"]
+
+ results = {}
+ for fname in fixture_files:
+ name = fname[:-5]
+ with open(os.path.join(FIXTURES_DIR, fname)) as fp:
+ fixture = json.load(fp)
+ sha = run_fixture(name, fixture)
+ results[name] = sha
+
+ if args.update_manifest:
+ with open(MANIFEST_PATH) as fp:
+ manifest = json.load(fp)
+ manifest["versions"] = _current_versions()
+ for name, sha in results.items():
+ if name not in manifest["fixtures"]:
+ manifest["fixtures"][name] = {}
+ manifest["fixtures"][name]["golden_sha256"] = sha
+ manifest["fixtures"][name].pop("_note", None)
+ with open(MANIFEST_PATH, "w") as fp:
+ json.dump(manifest, fp, indent=4)
+ logger.info(f"Manifest updated: {MANIFEST_PATH}")
+
+ # Seed thresholds from FP8 baseline if engines already exist,
+ # otherwise leave placeholders and print instructions.
+ print(
+ "\nThreshold seeding: run_compare.py --baseline to measure FP8 vs these goldens "
+ "and seed thresholds.yaml with baseline - 0.02 (SSIM) / + 0.05 (LPIPS)."
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/quality/run_compare.py b/tests/quality/run_compare.py
new file mode 100644
index 000000000..cdaba96c8
--- /dev/null
+++ b/tests/quality/run_compare.py
@@ -0,0 +1,252 @@
+"""Quality regression harness: compare FP8-TRT output against FP16-TRT goldens.
+
+Exit codes:
+ 0 All fixtures pass SSIM/LPIPS thresholds
+ 1 One or more fixtures fail thresholds
+ 2 Manifest version mismatch (abort β results would be meaningless)
+ 3 Golden PNG missing β run regenerate_golden.py --update-manifest first
+
+Usage:
+ python tests/quality/run_compare.py
+ python tests/quality/run_compare.py --fixture sdxl_turbo_img2img_plain
+ python tests/quality/run_compare.py --baseline # seed thresholds.yaml
+ python tests/quality/run_compare.py --skip-manifest-check # bypass version pins
+"""
+
+import argparse
+import hashlib
+import json
+import logging
+import os
+import sys
+
+import yaml
+
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
+
+logger = logging.getLogger("quality.run_compare")
+
+TESTS_QUALITY_DIR = os.path.dirname(os.path.abspath(__file__))
+REPO_ROOT = os.path.join(TESTS_QUALITY_DIR, "..", "..")
+INPUT_IMAGE = os.path.join(REPO_ROOT, "images", "inputs", "input.png")
+GOLDENS_DIR = os.path.join(TESTS_QUALITY_DIR, "goldens")
+FIXTURES_DIR = os.path.join(TESTS_QUALITY_DIR, "fixtures")
+MANIFEST_PATH = os.path.join(TESTS_QUALITY_DIR, "manifest.json")
+THRESHOLDS_PATH = os.path.join(TESTS_QUALITY_DIR, "thresholds.yaml")
+OUTPUTS_DIR = os.path.join(TESTS_QUALITY_DIR, "outputs")
+
+
+def _sha256(path: str) -> str:
+ h = hashlib.sha256()
+ with open(path, "rb") as f:
+ for chunk in iter(lambda: f.read(65536), b""):
+ h.update(chunk)
+ return h.hexdigest()
+
+
+def _installed_version(pkg: str) -> str:
+ import importlib.metadata
+
+ try:
+ return importlib.metadata.version(pkg)
+ except Exception:
+ if pkg == "tensorrt":
+ try:
+ import tensorrt as _trt
+
+ return _trt.__version__
+ except Exception:
+ pass
+ return "unknown"
+
+
+def _check_manifest(manifest: dict) -> list[str]:
+ """Return list of version mismatch messages (empty = all match)."""
+ pkg_map = {
+ "torch": "torch",
+ "tensorrt": "tensorrt",
+ "nvidia_modelopt": "nvidia-modelopt",
+ "diffusers": "diffusers",
+ }
+ mismatches = []
+ for key, pkg in pkg_map.items():
+ pinned = manifest["versions"].get(key, "unknown")
+ installed = _installed_version(pkg)
+ if pinned != "unknown" and installed != "unknown" and pinned != installed:
+ mismatches.append(f" {key}: manifest={pinned} installed={installed}")
+ return mismatches
+
+
+def _run_fp8_inference(fixture_name: str, fixture: dict, output_path: str) -> None:
+ from streamdiffusion import StreamDiffusionWrapper
+
+ stream = StreamDiffusionWrapper(
+ model_id_or_path=fixture["model_id"],
+ t_index_list=fixture["t_index_list"],
+ frame_buffer_size=1,
+ width=fixture["width"],
+ height=fixture["height"],
+ warmup=fixture.get("warmup", 1),
+ acceleration="tensorrt",
+ mode="img2img",
+ use_denoising_batch=fixture.get("use_denoising_batch", True),
+ cfg_type=fixture.get("cfg_type", "none"),
+ seed=fixture["seed"],
+ fp8=True,
+ )
+ stream.prepare(
+ prompt=fixture["prompt"],
+ negative_prompt=fixture.get("negative_prompt", ""),
+ num_inference_steps=fixture.get("num_inference_steps", 50),
+ guidance_scale=fixture.get("guidance_scale", 1.0),
+ delta=fixture.get("delta", 1.0),
+ )
+ image_tensor = stream.preprocess_image(INPUT_IMAGE)
+ for _ in range(stream.batch_size - 1):
+ stream(image=image_tensor)
+ output_image = stream(image=image_tensor)
+ output_image.save(output_path)
+
+
+def _compute_metrics(golden_path: str, output_path: str) -> dict:
+ import numpy as np
+ from PIL import Image
+ from skimage.metrics import structural_similarity as ssim_fn
+
+ golden = np.array(Image.open(golden_path).convert("RGB"), dtype=np.float32) / 255.0
+ output = np.array(Image.open(output_path).convert("RGB"), dtype=np.float32) / 255.0
+
+ ssim_val = float(ssim_fn(golden, output, data_range=1.0, channel_axis=2))
+
+ try:
+ import lpips
+ import torch
+
+ loss_fn = lpips.LPIPS(net="alex", verbose=False)
+ g_t = torch.from_numpy(golden).permute(2, 0, 1).unsqueeze(0) * 2 - 1
+ o_t = torch.from_numpy(output).permute(2, 0, 1).unsqueeze(0) * 2 - 1
+ with torch.no_grad():
+ lpips_val = float(loss_fn(g_t, o_t).item())
+ except Exception as e:
+ logger.warning(f"LPIPS computation failed: {e}. Using placeholder 0.0.")
+ lpips_val = 0.0
+
+ return {"ssim": ssim_val, "lpips": lpips_val}
+
+
+def _make_comparison_image(golden_path: str, output_path: str, comparison_path: str) -> None:
+ from PIL import Image, ImageDraw
+
+ golden = Image.open(golden_path).convert("RGB")
+ output = Image.open(output_path).convert("RGB")
+ w, h = golden.width, golden.height
+ canvas = Image.new("RGB", (w * 2 + 10, h + 24), (30, 30, 30))
+ canvas.paste(golden, (0, 24))
+ canvas.paste(output, (w + 10, 24))
+ draw = ImageDraw.Draw(canvas)
+ draw.text((4, 4), "FP16-TRT golden", fill=(200, 200, 200))
+ draw.text((w + 14, 4), "FP8-TRT output", fill=(200, 200, 200))
+ canvas.save(comparison_path)
+
+
+def main():
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s")
+ parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
+ parser.add_argument("--fixture", default=None, help="Run only this fixture (default: all)")
+ parser.add_argument("--baseline", action="store_true", help="Seed thresholds.yaml from this run's SSIM/LPIPS")
+ parser.add_argument("--skip-manifest-check", action="store_true", help="Bypass version pin check")
+ args = parser.parse_args()
+
+ with open(MANIFEST_PATH) as fp:
+ manifest = json.load(fp)
+ with open(THRESHOLDS_PATH) as fp:
+ thresholds = yaml.safe_load(fp)
+
+ if not args.skip_manifest_check:
+ mismatches = _check_manifest(manifest)
+ if mismatches:
+ print("ERROR: Manifest version mismatch β results would be meaningless.\n")
+ print("\n".join(mismatches))
+ print("\nRun regenerate_golden.py --update-manifest to refresh the manifest.")
+ sys.exit(2)
+
+ fixture_files = sorted(f for f in os.listdir(FIXTURES_DIR) if f.endswith(".json"))
+ if args.fixture:
+ fixture_files = [f"{args.fixture}.json"]
+
+ os.makedirs(OUTPUTS_DIR, exist_ok=True)
+ results = {}
+ all_pass = True
+
+ for fname in fixture_files:
+ name = fname[:-5]
+ with open(os.path.join(FIXTURES_DIR, fname)) as fp:
+ fixture = json.load(fp)
+
+ golden_path = os.path.join(GOLDENS_DIR, f"{name}.png")
+ if not os.path.exists(golden_path):
+ print(f"ERROR: Golden not found for {name}: {golden_path}")
+ print("Run: python tests/quality/regenerate_golden.py --update-manifest")
+ sys.exit(3)
+
+ golden_sha = manifest["fixtures"].get(name, {}).get("golden_sha256")
+ if golden_sha is not None:
+ actual_sha = _sha256(golden_path)
+ if actual_sha != golden_sha:
+ print(f"WARNING: Golden sha256 mismatch for {name}. File may be corrupted or outdated.")
+
+ output_path = os.path.join(OUTPUTS_DIR, f"{name}_fp8.png")
+ logger.info(f"[{name}] Running FP8-TRT inference...")
+ _run_fp8_inference(name, fixture, output_path)
+
+ logger.info(f"[{name}] Computing metrics...")
+ metrics = _compute_metrics(golden_path, output_path)
+ results[name] = metrics
+
+ comparison_path = os.path.join(OUTPUTS_DIR, f"{name}_comparison.png")
+ _make_comparison_image(golden_path, output_path, comparison_path)
+
+ thresh = thresholds.get("fixtures", {}).get(name, {})
+ ssim_min = thresh.get("ssim_min", 0.0)
+ lpips_max = thresh.get("lpips_max", 1.0)
+ passed = metrics["ssim"] >= ssim_min and metrics["lpips"] <= lpips_max
+ status = "PASS" if passed else "FAIL"
+ if not passed:
+ all_pass = False
+
+ print(
+ f"[{name}] {status} SSIM={metrics['ssim']:.4f} (min={ssim_min}) "
+ f"LPIPS={metrics['lpips']:.4f} (max={lpips_max})"
+ )
+ logger.info(f"[{name}] Comparison: {comparison_path}")
+
+ report_path = os.path.join(OUTPUTS_DIR, "report.json")
+ with open(report_path, "w") as fp:
+ json.dump(results, fp, indent=2)
+ logger.info(f"Report: {report_path}")
+
+ if args.baseline:
+ for name, metrics in results.items():
+ if name not in thresholds.get("fixtures", {}):
+ thresholds.setdefault("fixtures", {})[name] = {}
+ thresholds["fixtures"][name]["ssim_min"] = round(metrics["ssim"] - 0.02, 4)
+ thresholds["fixtures"][name]["lpips_max"] = round(metrics["lpips"] + 0.05, 4)
+ thresholds["fixtures"][name].pop("_note", None)
+ with open(THRESHOLDS_PATH, "w") as fp:
+ yaml.dump(thresholds, fp, default_flow_style=False, sort_keys=False)
+ print(f"\nThresholds seeded β {THRESHOLDS_PATH}")
+
+ n = len(results)
+ n_pass = sum(
+ 1
+ for name, m in results.items()
+ if m["ssim"] >= thresholds.get("fixtures", {}).get(name, {}).get("ssim_min", 0.0)
+ and m["lpips"] <= thresholds.get("fixtures", {}).get(name, {}).get("lpips_max", 1.0)
+ )
+ print(f"\n{n_pass}/{n} fixtures pass thresholds.")
+ sys.exit(0 if all_pass else 1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/quality/test_fp8_calib_tile.py b/tests/quality/test_fp8_calib_tile.py
new file mode 100644
index 000000000..abe577cb3
--- /dev/null
+++ b/tests/quality/test_fp8_calib_tile.py
@@ -0,0 +1,81 @@
+"""Regression: per-input-aware calibration tiling for FP8 quantize.
+
+Reproduces the kvo_cache_in dim-0=2 vs synthesized dim-0=1 split mismatch
+hit by SDXL-Turbo + use_cached_attn + cfg_type=self configs.
+"""
+
+import math
+
+import numpy as np
+import onnx
+from onnx import TensorProto, helper
+
+
+def _make_min_onnx(path):
+ """Minimal 2-input ONNX: symbolic-batch 'sample' + static-dim0=2 'kvo_cache_in_0'."""
+ sample = helper.make_tensor_value_info("sample", TensorProto.FLOAT, ["2B", 4, 64, 64])
+ kvo = helper.make_tensor_value_info("kvo_cache_in_0", TensorProto.FLOAT, [2, 4, "2B", 64, 64])
+ out = helper.make_tensor_value_info("out", TensorProto.FLOAT, ["2B", 4, 64, 64])
+ ident = helper.make_node("Identity", inputs=["sample"], outputs=["out"])
+ g = helper.make_graph([ident], "min", [sample, kvo], [out])
+ onnx.save(helper.make_model(g, opset_imports=[helper.make_opsetid("", 17)]), path)
+
+
+def test_per_input_tile_preserves_static_dim0(tmp_path):
+ """Per-input tile: sample stays at n_itr rows, kvo stays at 2Γn_itr rows."""
+ onnx_path = str(tmp_path / "min.onnx")
+ _make_min_onnx(onnx_path)
+
+ calib = {
+ "sample": np.zeros((5, 4, 64, 64), dtype=np.float32),
+ "kvo_cache_in_0": np.zeros((10, 4, 5, 64, 64), dtype=np.float32),
+ }
+
+ from streamdiffusion.acceleration.tensorrt.fp8_quantize import _read_onnx_input_specs
+
+ specs = _read_onnx_input_specs(onnx_path)
+
+ # Reproduce the fixed logic
+ resolved_dim0 = {name: max(1, (specs[name][1][0] or 1)) for name in calib}
+ n_itr = max(arr.shape[0] // resolved_dim0[name] for name, arr in calib.items())
+ n_itr = max(1, n_itr)
+ out = {}
+ for k, arr in calib.items():
+ target = n_itr * resolved_dim0[k]
+ if arr.shape[0] != target:
+ repeats = math.ceil(target / max(1, arr.shape[0]))
+ arr = np.tile(arr, (repeats,) + (1,) * (arr.ndim - 1))[:target]
+ out[k] = arr
+
+ # n_itr=5, resolved_dim0(sample)=1 β 5 rows; resolved_dim0(kvo)=2 β 10 rows
+ assert out["sample"].shape == (5, 4, 64, 64)
+ assert out["kvo_cache_in_0"].shape == (10, 4, 5, 64, 64)
+
+ # Verify modelopt split math: n_itr chunks each of shape (resolved_dim0, ...)
+ sample_chunks = np.array_split(out["sample"], n_itr, axis=0)
+ kvo_chunks = np.array_split(out["kvo_cache_in_0"], n_itr, axis=0)
+ assert sample_chunks[0].shape[0] == 1
+ assert kvo_chunks[0].shape[0] == 2 # static dim 0 must be preserved
+
+
+def test_naive_max_rows_tile_would_break(tmp_path):
+ """Confirms the OLD naΓ―ve tile produces the 'Got 1 Expected 2' symptom."""
+ onnx_path = str(tmp_path / "min.onnx")
+ _make_min_onnx(onnx_path)
+
+ calib = {
+ "sample": np.zeros((5, 4, 64, 64), dtype=np.float32),
+ "kvo_cache_in_0": np.zeros((10, 4, 5, 64, 64), dtype=np.float32),
+ }
+
+ # Reproduce the buggy logic
+ _max_rows = max(a.shape[0] for a in calib.values())
+ for k, a in list(calib.items()):
+ if a.shape[0] < _max_rows:
+ calib[k] = np.tile(a, (math.ceil(_max_rows / a.shape[0]),) + (1,) * (a.ndim - 1))[:_max_rows]
+
+ # modelopt: n_itr = sample.shape[0] / symbolic_dim0(1) = 10
+ # splits kvo into 10 chunks β each has shape[0]=1 β ORT rejects (expected 2)
+ n_itr_bad = calib["sample"].shape[0] # 10 (doubled by naΓ―ve tile)
+ kvo_chunk = np.array_split(calib["kvo_cache_in_0"], n_itr_bad, axis=0)[0]
+ assert kvo_chunk.shape[0] == 1 # this is the "Got 1 Expected 2" symptom
diff --git a/tests/quality/thresholds.yaml b/tests/quality/thresholds.yaml
new file mode 100644
index 000000000..d4889aced
--- /dev/null
+++ b/tests/quality/thresholds.yaml
@@ -0,0 +1,7 @@
+fixtures:
+ sd_turbo_img2img_plain:
+ ssim_min: 0.9649
+ lpips_max: 0.0586
+ sdxl_turbo_img2img_plain:
+ ssim_min: 0.9604
+ lpips_max: 0.0609
diff --git a/utils/viewer.py b/utils/viewer.py
index dd6f6cad9..2bd90984b 100644
--- a/utils/viewer.py
+++ b/utils/viewer.py
@@ -3,11 +3,13 @@
import threading
import time
import tkinter as tk
-from multiprocessing import Queue
-from typing import List
+from multiprocessing import Queue
+
from PIL import Image, ImageTk
+
from streamdiffusion.image_utils import postprocess_image
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -28,9 +30,8 @@ def update_image(image_data: Image.Image, label: tk.Label) -> None:
label.configure(image=tk_image, width=width, height=height)
label.image = tk_image # keep a reference
-def _receive_images(
- queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label
-) -> None:
+
+def _receive_images(queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label) -> None:
"""
Continuously receive images from a queue and update the labels.
@@ -85,9 +86,7 @@ def on_closing():
root.quit() # stop event loop
return
- thread = threading.Thread(
- target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True
- )
+ thread = threading.Thread(target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True)
thread.start()
try:
@@ -95,4 +94,3 @@ def on_closing():
root.mainloop()
except KeyboardInterrupt:
return
-