diff --git a/.gitignore b/.gitignore index 3641c6c5a..b38375f3b 100644 --- a/.gitignore +++ b/.gitignore @@ -224,6 +224,12 @@ controlnet_test_* demo/realtime-img2img/uploads/ .cgw.conf +# Local-only git workflow tooling (cgw scripts, hooks, example — never committed) +scripts/git/ +hooks/cc-block-dangerous-git.sh +.githooks/ +cgw.conf.example + # Local Claude / session state (per-user, never committed) .claude/ @@ -261,3 +267,6 @@ SESSION_LOG.md # Profiling/audit CSV exports (Nsight summaries, kernel stats — generated artifacts) audit_reports/ + +# Quality harness run outputs (generated; goldens/ is committed, outputs/ is not) +tests/quality/outputs/ diff --git a/demo/realtime-img2img/app_config.py b/demo/realtime-img2img/app_config.py index 9c21e58ce..252a992a5 100644 --- a/demo/realtime-img2img/app_config.py +++ b/demo/realtime-img2img/app_config.py @@ -1,47 +1,52 @@ """ Application configuration and settings for realtime-img2img """ -import yaml + import logging from pathlib import Path +import yaml + + def load_controlnet_registry(): """Load ControlNet registry from config file""" try: registry_path = Path(__file__).parent / "controlnet_registry.yaml" - with open(registry_path, 'r') as f: + with open(registry_path, "r") as f: config_data = yaml.safe_load(f) - + # Extract the available_controlnets section - return config_data.get('available_controlnets', {}) + return config_data.get("available_controlnets", {}) except Exception as e: logging.exception(f"load_controlnet_registry: Failed to load ControlNet registry: {e}") # Fallback to empty registry return {} + def load_default_settings(): """Load default settings from YAML config file""" try: registry_path = Path(__file__).parent / "controlnet_registry.yaml" - with open(registry_path, 'r') as f: + with open(registry_path, "r") as f: config_data = yaml.safe_load(f) - - return config_data.get('defaults', {}) + + return config_data.get("defaults", {}) except Exception as e: logging.exception(f"load_default_settings: Failed to load default settings: {e}") # Fallback to hardcoded defaults return { - 'guidance_scale': 1.1, - 'delta': 0.7, - 'num_inference_steps': 50, - 'seed': 2, - 't_index_list': [35, 45], - 'ipadapter_scale': 1.0, - 'normalize_prompt_weights': True, - 'normalize_seed_weights': True, - 'prompt': "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece" + "guidance_scale": 1.1, + "delta": 0.7, + "num_inference_steps": 50, + "seed": 2, + "t_index_list": [35, 45], + "ipadapter_scale": 1.0, + "normalize_prompt_weights": True, + "normalize_seed_weights": True, + "prompt": "Portrait of The Joker halloween costume, face painting, with , glare pose, detailed, intricate, full of colour, cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details, unreal engine 5 cinematic, masterpiece", } + # Load configuration at module level AVAILABLE_CONTROLNETS = load_controlnet_registry() DEFAULT_SETTINGS = load_default_settings() diff --git a/demo/realtime-img2img/config.py b/demo/realtime-img2img/config.py index 56d774048..48e31e18f 100644 --- a/demo/realtime-img2img/config.py +++ b/demo/realtime-img2img/config.py @@ -1,6 +1,6 @@ -from typing import NamedTuple import argparse import os +from typing import NamedTuple class Args(NamedTuple): @@ -45,9 +45,7 @@ def pretty_print(self): parser.add_argument("--host", type=str, default=default_host, help="Host address") parser.add_argument("--port", type=int, default=default_port, help="Port number") parser.add_argument("--reload", action="store_true", help="Reload code on change") -parser.add_argument( - "--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img" -) +parser.add_argument("--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img") parser.add_argument( "--max-queue-size", dest="max_queue_size", diff --git a/demo/realtime-img2img/connection_manager.py b/demo/realtime-img2img/connection_manager.py index ae2072198..e35c09df7 100644 --- a/demo/realtime-img2img/connection_manager.py +++ b/demo/realtime-img2img/connection_manager.py @@ -1,10 +1,12 @@ +import asyncio +import logging +from types import SimpleNamespace from typing import Dict, Union from uuid import UUID -import asyncio + from fastapi import WebSocket from starlette.websockets import WebSocketState -import logging -from types import SimpleNamespace + Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]] @@ -20,9 +22,7 @@ def __init__(self): self.active_connections: Connections = {} self.latest_data: Dict[UUID, SimpleNamespace] = {} # Store latest parameters for HTTP streaming - async def connect( - self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0 - ): + async def connect(self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0): await websocket.accept() user_count = self.get_user_count() print(f"User count: {user_count}") @@ -61,7 +61,7 @@ async def get_latest_data(self, user_id: UUID) -> SimpleNamespace: return await queue.get() except asyncio.QueueEmpty: return None - + def get_latest_data_sync(self, user_id: UUID) -> SimpleNamespace: """Get the latest data without consuming it from the queue (for HTTP streaming)""" return self.latest_data.get(user_id) diff --git a/demo/realtime-img2img/img2img.py b/demo/realtime-img2img/img2img.py index 06067b342..a1e6ada8c 100644 --- a/demo/realtime-img2img/img2img.py +++ b/demo/realtime-img2img/img2img.py @@ -1,6 +1,7 @@ -import sys -import os import logging +import os +import sys + sys.path.append( os.path.join( @@ -12,10 +13,11 @@ # Config system functions are now used only in main.py + import torch -from pydantic import BaseModel, Field from PIL import Image -from typing import Optional +from pydantic import BaseModel, Field + # Default values for pipeline parameters default_negative_prompt = "black and white, blurry, low resolution, pixelated, pixel art, low quality, low fidelity" @@ -78,26 +80,22 @@ class InputParams(BaseModel): "768x512 (3:2)", "896x640 (7:5)", "1024x768 (4:3)", - "1024x576 (16:9)" - ] - ) - width: int = Field( - 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width" - ) - height: int = Field( - 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height" + "1024x576 (16:9)", + ], ) + width: int = Field(512, min=2, max=15, title="Width", disabled=True, hide=True, id="width") + height: int = Field(512, min=2, max=15, title="Height", disabled=True, hide=True, id="height") -#TODO update naming convention to reflect the controlnet agnostic nature of the config system (pipeline_config instead of controlnet_config for example) + # TODO update naming convention to reflect the controlnet agnostic nature of the config system (pipeline_config instead of controlnet_config for example) def __init__(self, wrapper, config): """ Initialize Pipeline with pre-created wrapper and config. - + Args: wrapper: Pre-created StreamDiffusionWrapper instance config: Configuration dictionary used to create the wrapper """ - + # IPAdapter state tracking for optimization self._last_ipadapter_source_type = None self._last_ipadapter_source_data = None @@ -106,16 +104,16 @@ def __init__(self, wrapper, config): self.stream = wrapper self.config = config self.use_config = True - + # Extract pipeline configuration from config - self.pipeline_mode = self.config.get('mode', 'img2img') - self.has_controlnet = 'controlnets' in self.config and len(self.config['controlnets']) > 0 - self.has_ipadapter = 'ipadapters' in self.config and len(self.config['ipadapters']) > 0 - + self.pipeline_mode = self.config.get("mode", "img2img") + self.has_controlnet = "controlnets" in self.config and len(self.config["controlnets"]) > 0 + self.has_ipadapter = "ipadapters" in self.config and len(self.config["ipadapters"]) > 0 + # Store config values for later use - self.negative_prompt = self.config.get('negative_prompt', default_negative_prompt) - self.guidance_scale = self.config.get('guidance_scale', 1.2) - self.num_inference_steps = self.config.get('num_inference_steps', 50) + self.negative_prompt = self.config.get("negative_prompt", default_negative_prompt) + self.guidance_scale = self.config.get("guidance_scale", 1.2) + self.num_inference_steps = self.config.get("num_inference_steps", 50) # Update input_mode based on pipeline mode self.info = self.Info() @@ -129,23 +127,23 @@ def __init__(self, wrapper, config): self.guidance_scale = 1.1 self.num_inference_steps = 50 self.negative_prompt = default_negative_prompt - + # Store output type for frame conversion - always force "pt" for optimal performance self.output_type = "pt" def predict(self, params: "Pipeline.InputParams") -> Image.Image: # Get input manager if available (passed from websocket handler) - input_manager = getattr(params, 'input_manager', None) - + input_manager = getattr(params, "input_manager", None) + # Handle different modes if self.pipeline_mode == "txt2img": # Text-to-image mode - + # Handle ControlNet updates if enabled if self.has_controlnet: try: stream_state = self.stream.get_stream_state() - current_cfg = stream_state.get('controlnet_config', []) + current_cfg = stream_state.get("controlnet_config", []) except Exception: current_cfg = [] if current_cfg: @@ -154,11 +152,11 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image: control_image = self._get_controlnet_input(input_manager, i, params.image) if control_image is not None: self.stream.update_control_image(index=i, image=control_image) - + # Handle IPAdapter updates if enabled if self.has_ipadapter: self._update_ipadapter_style_image(input_manager) - + # Generate output based on what's enabled if self.has_controlnet and not self.has_ipadapter: # ControlNet only: use base input for generation @@ -176,12 +174,12 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image: output_image = self.stream() else: # Image-to-image mode: use original logic - + # Handle ControlNet updates if enabled if self.has_controlnet: try: stream_state = self.stream.get_stream_state() - current_cfg = stream_state.get('controlnet_config', []) + current_cfg = stream_state.get("controlnet_config", []) except Exception: current_cfg = [] if current_cfg: @@ -190,11 +188,11 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image: control_image = self._get_controlnet_input(input_manager, i, params.image) if control_image is not None: self.stream.update_control_image(index=i, image=control_image) - + # Handle IPAdapter updates if enabled if self.has_ipadapter: self._update_ipadapter_style_image(input_manager) - + # Generate output based on what's enabled if self.has_controlnet or self.has_ipadapter: # ControlNet and/or IPAdapter: use base input for img2img @@ -216,150 +214,153 @@ def predict(self, params: "Pipeline.InputParams") -> Image.Image: def _get_controlnet_input(self, input_manager, index: int, fallback_image): """ Get input image for a specific ControlNet index. - + Args: input_manager: InputSourceManager instance (can be None) index: ControlNet index fallback_image: Fallback image if no specific source is configured - + Returns: Input image for the ControlNet or fallback """ if input_manager: - frame = input_manager.get_frame('controlnet', index) + frame = input_manager.get_frame("controlnet", index) if frame is not None: return frame - + # Fallback to main image input return fallback_image - + def _get_ipadapter_input(self, input_manager): """ Get input image for IPAdapter. - + Args: input_manager: InputSourceManager instance (can be None) - + Returns: Input image for IPAdapter or None """ if input_manager: - return input_manager.get_frame('ipadapter') + return input_manager.get_frame("ipadapter") return None - + def _update_ipadapter_style_image(self, input_manager): """ Update IPAdapter style image from InputSourceManager. Only updates when source actually changes to avoid unnecessary processing. - + Args: input_manager: InputSourceManager instance (can be None) """ if not input_manager or not self.has_ipadapter: return - + try: # Get current source info to check if it changed - source_info = input_manager.get_source_info('ipadapter') - current_source_type = source_info.get('source_type') - current_source_data = source_info.get('source_data') - is_stream = source_info.get('is_stream', False) - + source_info = input_manager.get_source_info("ipadapter") + current_source_type = source_info.get("source_type") + current_source_data = source_info.get("source_data") + is_stream = source_info.get("is_stream", False) + # Check if source changed (for static images, only update when source changes) source_changed = ( - current_source_type != self._last_ipadapter_source_type or - current_source_data != self._last_ipadapter_source_data + current_source_type != self._last_ipadapter_source_type + or current_source_data != self._last_ipadapter_source_data ) - + # For streaming sources (webcam/video), always get fresh frame # For static sources (uploaded image), only update when source changes should_update = is_stream or source_changed - + if not should_update: return # No update needed - static source unchanged - + # Get IPAdapter style image from input source manager - ipadapter_frame = input_manager.get_frame('ipadapter') - + ipadapter_frame = input_manager.get_frame("ipadapter") + if ipadapter_frame is not None: import torch - + # Use tensor directly - update_style_image expects torch tensor if isinstance(ipadapter_frame, torch.Tensor): try: # Update IPAdapter with tensor and stream configuration self.stream.update_style_image(ipadapter_frame, is_stream=is_stream) - self.stream.update_stream_params(ipadapter_config={'is_stream': is_stream}) - + self.stream.update_stream_params(ipadapter_config={"is_stream": is_stream}) + # Force prompt re-encoding to apply new style image embeddings # This is critical because IPAdapter embedding hook only runs during prompt encoding try: state = self.stream.get_stream_state() - current_prompts = state.get('prompt_list', []) + current_prompts = state.get("prompt_list", []) if current_prompts: self.stream.update_prompt(current_prompts, prompt_interpolation_method="slerp") except Exception as e: - logging.exception(f"_update_ipadapter_style_image: Failed to force prompt re-encoding: {e}") - - + logging.exception( + f"_update_ipadapter_style_image: Failed to force prompt re-encoding: {e}" + ) + # Update tracking variables only on successful update self._last_ipadapter_source_type = current_source_type self._last_ipadapter_source_data = current_source_data - + except Exception as e: logging.exception(f"_update_ipadapter_style_image: Failed to update IPAdapter: {e}") else: - logging.warning("_update_ipadapter_style_image: IPAdapter frame is not a tensor, skipping style image update") + logging.warning( + "_update_ipadapter_style_image: IPAdapter frame is not a tensor, skipping style image update" + ) except Exception as e: logging.exception(f"_update_ipadapter_style_image: Error updating IPAdapter style image: {e}") - + def _get_base_input(self, input_manager, fallback_image): """ Get input image for base pipeline. - + Args: input_manager: InputSourceManager instance (can be None) fallback_image: Fallback image if no specific source is configured - + Returns: Input image for base pipeline or fallback """ if input_manager: - frame = input_manager.get_frame('base') + frame = input_manager.get_frame("base") if frame is not None: return frame - + # Fallback to main image input return fallback_image def update_ipadapter_config(self, scale: float = None, style_image: Image.Image = None) -> bool: """ Update IPAdapter configuration in real-time using direct methods - + Args: scale: New IPAdapter scale value (optional) style_image: New style image (PIL Image, optional) - + Returns: bool: True if successful, False otherwise """ if not self.has_ipadapter: return False - + if scale is None and style_image is None: return False # Nothing to update - + try: # Update scale via unified config system (no direct method needed) if scale is not None: - self.stream.update_stream_params(ipadapter_config={'scale': scale}) - + self.stream.update_stream_params(ipadapter_config={"scale": scale}) + # Update style image via direct method if style_image is not None: self.stream.update_style_image(style_image) - + return True - except Exception as e: + except Exception: return False def update_ipadapter_scale(self, scale: float) -> bool: @@ -374,21 +375,21 @@ def update_ipadapter_weight_type(self, weight_type: str) -> bool: """Update IPAdapter weight type in real-time""" if not self.has_ipadapter: return False - + try: # Use unified updater on wrapper - if hasattr(self.stream, 'update_stream_params'): - self.stream.update_stream_params(ipadapter_config={ 'weight_type': weight_type }) + if hasattr(self.stream, "update_stream_params"): + self.stream.update_stream_params(ipadapter_config={"weight_type": weight_type}) return True # Should not reach here in normal operation return False - except Exception as e: + except Exception: return False def get_ipadapter_info(self) -> dict: """ Get current IPAdapter information - + Returns: dict: IPAdapter information including scale, model info, etc. """ @@ -397,35 +398,37 @@ def get_ipadapter_info(self) -> dict: "scale": 1.0, "weight_type": "linear", "model_path": None, - "style_image_set": False + "style_image_set": False, } - - if self.has_ipadapter and self.config and 'ipadapters' in self.config: + + if self.has_ipadapter and self.config and "ipadapters" in self.config: # Get info from first IPAdapter config - if len(self.config['ipadapters']) > 0: - ipadapter_config = self.config['ipadapters'][0] - info["scale"] = ipadapter_config.get('scale', 1.0) - info["weight_type"] = ipadapter_config.get('weight_type', 'linear') - info["model_path"] = ipadapter_config.get('ipadapter_model_path') - info["style_image_set"] = 'style_image' in ipadapter_config - + if len(self.config["ipadapters"]) > 0: + ipadapter_config = self.config["ipadapters"][0] + info["scale"] = ipadapter_config.get("scale", 1.0) + info["weight_type"] = ipadapter_config.get("weight_type", "linear") + info["model_path"] = ipadapter_config.get("ipadapter_model_path") + info["style_image_set"] = "style_image" in ipadapter_config + # Get current runtime state from wrapper's public API try: - if hasattr(self.stream, 'get_stream_state'): + if hasattr(self.stream, "get_stream_state"): stream_state = self.stream.get_stream_state() - ipadapter_runtime_config = stream_state.get('ipadapter_config', {}) + ipadapter_runtime_config = stream_state.get("ipadapter_config", {}) if ipadapter_runtime_config: - info["scale"] = ipadapter_runtime_config.get('scale', info.get("scale", 1.0)) - info["weight_type"] = ipadapter_runtime_config.get('weight_type', info.get("weight_type", 'linear')) + info["scale"] = ipadapter_runtime_config.get("scale", info.get("scale", 1.0)) + info["weight_type"] = ipadapter_runtime_config.get( + "weight_type", info.get("weight_type", "linear") + ) except Exception: pass # Use defaults from config if wrapper method fails - + return info def update_stream_params(self, **kwargs): """ Update streaming parameters using the consolidated API - + Args: **kwargs: All parameters supported by StreamDiffusionWrapper.update_stream_params() including controlnet_config, guidance_scale, delta, etc. diff --git a/demo/realtime-img2img/input_control.py b/demo/realtime-img2img/input_control.py index c4f41359c..be1e3fe03 100644 --- a/demo/realtime-img2img/input_control.py +++ b/demo/realtime-img2img/input_control.py @@ -1,48 +1,48 @@ -from abc import ABC, abstractmethod -from typing import Dict, Any, Callable, Optional import asyncio +import logging import threading import time -import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional class InputControl(ABC): """Generic interface for input controls that can modify parameters""" - + def __init__(self, parameter_name: str, min_value: float = 0.0, max_value: float = 1.0): self.parameter_name = parameter_name self.min_value = min_value self.max_value = max_value self.is_active = False self.update_callback: Optional[Callable[[str, float], None]] = None - + @abstractmethod async def start(self) -> None: """Start the input control""" pass - + @abstractmethod async def stop(self) -> None: """Stop the input control""" pass - + @abstractmethod def get_current_value(self) -> float: """Get the current normalized value (0.0 to 1.0)""" pass - + def set_update_callback(self, callback: Callable[[str, float], None]) -> None: """Set callback for parameter updates""" self.update_callback = callback - + def normalize_value(self, raw_value: float) -> float: """Normalize raw input value to 0.0-1.0 range""" return max(0.0, min(1.0, raw_value)) - + def scale_to_parameter(self, normalized_value: float) -> float: """Scale normalized value to parameter range""" return self.min_value + (normalized_value * (self.max_value - self.min_value)) - + def _trigger_update(self, normalized_value: float) -> None: """Trigger parameter update if callback is set""" if self.update_callback: @@ -52,9 +52,16 @@ def _trigger_update(self, normalized_value: float) -> None: class GamepadInput(InputControl): """Gamepad input control for parameter modification""" - - def __init__(self, parameter_name: str, min_value: float = 0.0, max_value: float = 1.0, - gamepad_index: int = 0, axis_index: int = 0, deadzone: float = 0.1): + + def __init__( + self, + parameter_name: str, + min_value: float = 0.0, + max_value: float = 1.0, + gamepad_index: int = 0, + axis_index: int = 0, + deadzone: float = 0.1, + ): super().__init__(parameter_name, min_value, max_value) self.gamepad_index = gamepad_index self.axis_index = axis_index @@ -62,78 +69,78 @@ def __init__(self, parameter_name: str, min_value: float = 0.0, max_value: float self.current_value = 0.0 self._stop_event = threading.Event() self._thread = None - + async def start(self) -> None: """Start gamepad monitoring""" if self.is_active: return - + self.is_active = True self._stop_event.clear() self._thread = threading.Thread(target=self._monitor_gamepad, daemon=True) self._thread.start() logging.info(f"GamepadInput: Started monitoring gamepad {self.gamepad_index}, axis {self.axis_index}") - + async def stop(self) -> None: """Stop gamepad monitoring""" if not self.is_active: return - + self.is_active = False self._stop_event.set() - + if self._thread and self._thread.is_alive(): self._thread.join(timeout=1.0) - + logging.info(f"GamepadInput: Stopped monitoring gamepad {self.gamepad_index}, axis {self.axis_index}") - + def get_current_value(self) -> float: """Get current normalized value""" return self.current_value - + def _monitor_gamepad(self) -> None: """Monitor gamepad input in background thread""" try: import pygame - + # Initialize pygame for gamepad support pygame.init() pygame.joystick.init() - + # Check if gamepad is available if pygame.joystick.get_count() <= self.gamepad_index: logging.error(f"GamepadInput: Gamepad {self.gamepad_index} not found") return - + # Initialize the gamepad joystick = pygame.joystick.Joystick(self.gamepad_index) joystick.init() - + logging.info(f"GamepadInput: Connected to {joystick.get_name()}") - + # Monitor gamepad input while not self._stop_event.is_set(): pygame.event.pump() - + # Get axis value if self.axis_index < joystick.get_numaxes(): raw_value = joystick.get_axis(self.axis_index) - + # Apply deadzone if abs(raw_value) < self.deadzone: raw_value = 0.0 - + # Convert from [-1, 1] to [0, 1] range normalized_value = (raw_value + 1.0) / 2.0 - + # Update current value self.current_value = normalized_value - + # Trigger update callback self._trigger_update(normalized_value) - + time.sleep(0.016) # ~60 FPS polling - + except ImportError: logging.error("GamepadInput: pygame not installed. Install with: pip install pygame") except Exception as e: @@ -150,48 +157,48 @@ def _monitor_gamepad(self) -> None: class InputManager: """Manages multiple input controls""" - + def __init__(self): self.inputs: Dict[str, InputControl] = {} self.parameter_update_callback: Optional[Callable[[str, float], None]] = None - + def add_input(self, input_id: str, input_control: InputControl) -> None: """Add an input control""" input_control.set_update_callback(self._handle_parameter_update) self.inputs[input_id] = input_control logging.info(f"InputManager: Added input control {input_id} for parameter {input_control.parameter_name}") - + def remove_input(self, input_id: str) -> None: """Remove an input control""" if input_id in self.inputs: asyncio.create_task(self.inputs[input_id].stop()) del self.inputs[input_id] logging.info(f"InputManager: Removed input control {input_id}") - + async def start_input(self, input_id: str) -> None: """Start a specific input control""" if input_id in self.inputs: await self.inputs[input_id].start() - + async def stop_input(self, input_id: str) -> None: """Stop a specific input control""" if input_id in self.inputs: await self.inputs[input_id].stop() - + async def start_all(self) -> None: """Start all input controls""" for input_control in self.inputs.values(): await input_control.start() - + async def stop_all(self) -> None: """Stop all input controls""" for input_control in self.inputs.values(): await input_control.stop() - + def set_parameter_update_callback(self, callback: Callable[[str, float], None]) -> None: """Set callback for parameter updates from any input""" self.parameter_update_callback = callback - + def get_input_status(self) -> Dict[str, Dict[str, Any]]: """Get status of all inputs""" status = {} @@ -201,11 +208,11 @@ def get_input_status(self) -> Dict[str, Dict[str, Any]]: "is_active": input_control.is_active, "current_value": input_control.get_current_value(), "min_value": input_control.min_value, - "max_value": input_control.max_value + "max_value": input_control.max_value, } return status - + def _handle_parameter_update(self, parameter_name: str, value: float) -> None: """Handle parameter update from input controls""" if self.parameter_update_callback: - self.parameter_update_callback(parameter_name, value) \ No newline at end of file + self.parameter_update_callback(parameter_name, value) diff --git a/demo/realtime-img2img/input_sources.py b/demo/realtime-img2img/input_sources.py index d539845a7..b50a55b2a 100644 --- a/demo/realtime-img2img/input_sources.py +++ b/demo/realtime-img2img/input_sources.py @@ -8,19 +8,19 @@ import logging from enum import Enum -from typing import Dict, Optional, Union, Any from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np import torch from PIL import Image -import cv2 -import numpy as np - from util import bytes_to_pt from utils.video_utils import VideoFrameExtractor class InputSourceType(Enum): """Types of input sources available.""" + WEBCAM = "webcam" UPLOADED_IMAGE = "uploaded_image" UPLOADED_VIDEO = "uploaded_video" @@ -29,15 +29,15 @@ class InputSourceType(Enum): class InputSource: """ Represents an input source for a component. - + Handles different types of inputs (webcam, image, video) and provides a unified interface to get the current frame as a tensor. """ - + def __init__(self, source_type: InputSourceType, source_data: Any = None): """ Initialize an input source. - + Args: source_type: Type of input source source_data: Data for the source (PIL Image, video path, or None for webcam) @@ -48,11 +48,11 @@ def __init__(self, source_type: InputSourceType, source_data: Any = None): self._current_frame = None self._video_extractor = None self._logger = logging.getLogger(f"InputSource.{source_type.value}") - + # Initialize video extractor if needed if source_type == InputSourceType.UPLOADED_VIDEO and source_data: self._init_video_extractor() - + def _init_video_extractor(self): """Initialize video extractor for video input sources.""" if self.source_data and Path(self.source_data).exists(): @@ -64,11 +64,11 @@ def _init_video_extractor(self): self._video_extractor = None else: self._logger.error(f"Video file not found: {self.source_data}") - + def get_frame(self) -> Optional[torch.Tensor]: """ Get the current frame as a PyTorch tensor. - + Returns: torch.Tensor: Current frame with shape (C, H, W), values in [0, 1], dtype float32 None: If no frame is available @@ -77,7 +77,7 @@ def get_frame(self) -> Optional[torch.Tensor]: if self.source_type == InputSourceType.WEBCAM: # For webcam, return cached frame (will be updated externally) return self._current_frame - + elif self.source_type == InputSourceType.UPLOADED_IMAGE: # For static image, convert to tensor if not already done if self._current_frame is None and self.source_data: @@ -91,35 +91,35 @@ def get_frame(self) -> Optional[torch.Tensor]: elif isinstance(self.source_data, bytes): # Convert bytes to tensor using existing utility self._current_frame = bytes_to_pt(self.source_data) - + return self._current_frame - + elif self.source_type == InputSourceType.UPLOADED_VIDEO: # For video, get next frame return self._get_video_frame() - + except Exception as e: self._logger.error(f"Error getting frame from {self.source_type.value}: {e}") - + return None - + def _get_video_frame(self) -> Optional[torch.Tensor]: """Get the next frame from video source.""" if not self._video_extractor: return None - + return self._video_extractor.get_frame() - + def update_webcam_frame(self, frame_data: Union[bytes, torch.Tensor]): """ Update the current frame for webcam sources. - + Args: frame_data: Frame data as bytes or tensor """ if self.source_type != InputSourceType.WEBCAM: return - + try: if isinstance(frame_data, bytes): self._current_frame = bytes_to_pt(frame_data) @@ -127,7 +127,7 @@ def update_webcam_frame(self, frame_data: Union[bytes, torch.Tensor]): self._current_frame = frame_data except Exception as e: self._logger.error(f"Error updating webcam frame: {e}") - + def cleanup(self): """Clean up resources.""" if self._video_extractor: @@ -138,211 +138,224 @@ def cleanup(self): class InputSourceManager: """ Manages input sources for different components in the pipeline. - + Provides a centralized way to set and get input sources for: - ControlNet instances (indexed) - - IPAdapter + - IPAdapter - Base pipeline """ - + def __init__(self): """Initialize the input source manager.""" self.sources = { - 'controlnet': {}, # {index: InputSource} - 'ipadapter': None, # Single InputSource - 'base': None # Single InputSource for main pipeline + "controlnet": {}, # {index: InputSource} + "ipadapter": None, # Single InputSource + "base": None, # Single InputSource for main pipeline } self._logger = logging.getLogger("InputSourceManager") - + # Default to webcam for base pipeline - self.sources['base'] = InputSource(InputSourceType.WEBCAM) - + self.sources["base"] = InputSource(InputSourceType.WEBCAM) + # Default IPAdapter to uploaded_image with default image self._init_default_ipadapter_source() - + def set_source(self, component: str, source: InputSource, index: Optional[int] = None): """ Set input source for a component. - + Args: component: Component name ('controlnet', 'ipadapter', 'base') source: InputSource instance index: Index for ControlNet instances (required for 'controlnet') """ try: - if component == 'controlnet': + if component == "controlnet": if index is None: raise ValueError("Index is required for ControlNet components") - + # Clean up existing source if any - if index in self.sources['controlnet']: - self.sources['controlnet'][index].cleanup() - - self.sources['controlnet'][index] = source + if index in self.sources["controlnet"]: + self.sources["controlnet"][index].cleanup() + + self.sources["controlnet"][index] = source self._logger.info(f"Set ControlNet {index} input source to {source.source_type.value}") - - elif component in ['ipadapter', 'base']: + + elif component in ["ipadapter", "base"]: # Clean up existing source if any if self.sources[component]: self.sources[component].cleanup() - + self.sources[component] = source self._logger.info(f"Set {component} input source to {source.source_type.value}") - + else: raise ValueError(f"Unknown component: {component}") - + except Exception as e: self._logger.error(f"Error setting source for {component}: {e}") - + def get_frame(self, component: str, index: Optional[int] = None) -> Optional[torch.Tensor]: """ Get current frame for a component. - + Args: component: Component name ('controlnet', 'ipadapter', 'base') index: Index for ControlNet instances (required for 'controlnet') - + Returns: torch.Tensor: Current frame or None if not available """ try: - if component == 'controlnet': + if component == "controlnet": if index is None: raise ValueError("Index is required for ControlNet components") - + # Ensure ControlNet is initialized with default webcam source self._ensure_controlnet_initialized(index) - source = self.sources['controlnet'][index] - + source = self.sources["controlnet"][index] + frame = source.get_frame() if frame is not None: return frame - + # If webcam source has no frame yet, fallback to base pipeline input self._logger.debug(f"ControlNet {index} webcam has no frame yet, falling back to base") return self._get_fallback_frame() - - elif component in ['ipadapter', 'base']: + + elif component in ["ipadapter", "base"]: source = self.sources[component] if source: frame = source.get_frame() if frame is not None: return frame - + # Fallback to base pipeline input if not base itself - if component != 'base': + if component != "base": self._logger.debug(f"{component} has no input, falling back to base") return self._get_fallback_frame() - + except Exception as e: self._logger.error(f"Error getting frame for {component}: {e}") - + return None - + def _get_fallback_frame(self) -> Optional[torch.Tensor]: """Get frame from base pipeline as fallback.""" - base_source = self.sources['base'] + base_source = self.sources["base"] if base_source: return base_source.get_frame() return None - + def update_webcam_frame(self, frame_data: Union[bytes, torch.Tensor]): """ Update webcam frame for all webcam sources. - + Args: frame_data: Frame data as bytes or tensor """ # Update base pipeline if it's webcam - if (self.sources['base'] and - self.sources['base'].source_type == InputSourceType.WEBCAM): - self.sources['base'].update_webcam_frame(frame_data) - + if self.sources["base"] and self.sources["base"].source_type == InputSourceType.WEBCAM: + self.sources["base"].update_webcam_frame(frame_data) + # Update ControlNet webcam sources - for source in self.sources['controlnet'].values(): + for source in self.sources["controlnet"].values(): if source.source_type == InputSourceType.WEBCAM: source.update_webcam_frame(frame_data) - + # Update IPAdapter if it's webcam - if (self.sources['ipadapter'] and - self.sources['ipadapter'].source_type == InputSourceType.WEBCAM): - self.sources['ipadapter'].update_webcam_frame(frame_data) - + if self.sources["ipadapter"] and self.sources["ipadapter"].source_type == InputSourceType.WEBCAM: + self.sources["ipadapter"].update_webcam_frame(frame_data) + def _ensure_controlnet_initialized(self, index: int): """ Ensure a ControlNet has a default webcam source if not already set. - + Args: index: ControlNet index """ - if index not in self.sources['controlnet']: - self.sources['controlnet'][index] = InputSource(InputSourceType.WEBCAM) + if index not in self.sources["controlnet"]: + self.sources["controlnet"][index] = InputSource(InputSourceType.WEBCAM) self._logger.info(f"_ensure_controlnet_initialized: Initialized ControlNet {index} with webcam source") def get_source_info(self, component: str, index: Optional[int] = None) -> Dict[str, Any]: """ Get information about a component's input source. - + Returns: Dictionary with source type and metadata """ try: - if component == 'controlnet': + if component == "controlnet": if index is None: - return {'source_type': 'error', 'source_data': 'index_required', 'is_stream': False, 'has_data': False} - + return { + "source_type": "error", + "source_data": "index_required", + "is_stream": False, + "has_data": False, + } + # Ensure ControlNet is initialized with default webcam source self._ensure_controlnet_initialized(index) - source = self.sources['controlnet'][index] - - elif component in ['ipadapter', 'base']: + source = self.sources["controlnet"][index] + + elif component in ["ipadapter", "base"]: source = self.sources[component] if not source: - return {'source_type': 'none', 'source_data': None, 'is_stream': False, 'has_data': False} + return {"source_type": "none", "source_data": None, "is_stream": False, "has_data": False} else: - return {'source_type': 'unknown', 'source_data': None, 'is_stream': False, 'has_data': False} - + return {"source_type": "unknown", "source_data": None, "is_stream": False, "has_data": False} + return { - 'source_type': source.source_type.value, - 'source_data': source.source_data, - 'is_stream': source.is_stream, - 'has_data': source.source_data is not None + "source_type": source.source_type.value, + "source_data": source.source_data, + "is_stream": source.is_stream, + "has_data": source.source_data is not None, } - + except Exception as e: self._logger.error(f"Error getting source info for {component}: {e}") - return {'source_type': 'error', 'source_data': None, 'is_stream': False, 'has_data': False, 'error': str(e)} - + return { + "source_type": "error", + "source_data": None, + "is_stream": False, + "has_data": False, + "error": str(e), + } + def _init_default_ipadapter_source(self): """Initialize IPAdapter with default image source.""" try: import os + from PIL import Image - + # Try to load default image default_image_path = os.path.join(os.path.dirname(__file__), "..", "..", "images", "inputs", "input.png") if os.path.exists(default_image_path): default_image = Image.open(default_image_path).convert("RGB") - self.sources['ipadapter'] = InputSource(InputSourceType.UPLOADED_IMAGE, default_image) + self.sources["ipadapter"] = InputSource(InputSourceType.UPLOADED_IMAGE, default_image) self._logger.info("_init_default_ipadapter_source: Initialized IPAdapter with default image") else: - self._logger.warning("_init_default_ipadapter_source: Default image not found, IPAdapter will have no source") + self._logger.warning( + "_init_default_ipadapter_source: Default image not found, IPAdapter will have no source" + ) except Exception as e: self._logger.error(f"_init_default_ipadapter_source: Error loading default image: {e}") - + def load_config_style_image(self, style_image_path: str, base_config_path: str = None): """ Load IPAdapter style image from config file path. - + Args: style_image_path: Path to style image (can be relative) base_config_path: Base path for resolving relative paths """ try: import os + from PIL import Image - + # Handle relative paths if not os.path.isabs(style_image_path): if base_config_path: @@ -355,17 +368,19 @@ def load_config_style_image(self, style_image_path: str, base_config_path: str = if not os.path.exists(style_image_path): self._logger.warning(f"load_config_style_image: Style image not found: {style_image_path}") return - + if os.path.exists(style_image_path): style_image = Image.open(style_image_path).convert("RGB") input_source = InputSource(InputSourceType.UPLOADED_IMAGE, style_image) - self.set_source('ipadapter', input_source) - self._logger.info(f"load_config_style_image: Loaded IPAdapter style image from config: {style_image_path}") + self.set_source("ipadapter", input_source) + self._logger.info( + f"load_config_style_image: Loaded IPAdapter style image from config: {style_image_path}" + ) else: self._logger.warning(f"load_config_style_image: IPAdapter style image not found: {style_image_path}") except Exception as e: self._logger.exception(f"load_config_style_image: Error loading config style image: {e}") - + def reset_to_defaults(self): """ Reset all input sources to their default states. @@ -374,30 +389,30 @@ def reset_to_defaults(self): try: # Clean up existing sources first self.cleanup() - + # Reset to default states self.sources = { - 'controlnet': {}, # Empty - ControlNets will use fallback to base - 'ipadapter': None, # Will be re-initialized - 'base': None # Will be re-initialized + "controlnet": {}, # Empty - ControlNets will use fallback to base + "ipadapter": None, # Will be re-initialized + "base": None, # Will be re-initialized } - + # Re-initialize defaults - self.sources['base'] = InputSource(InputSourceType.WEBCAM) + self.sources["base"] = InputSource(InputSourceType.WEBCAM) self._init_default_ipadapter_source() - + self._logger.info("reset_to_defaults: Reset all input sources to defaults") - + except Exception as e: self._logger.error(f"reset_to_defaults: Error resetting input sources: {e}") - + def cleanup(self): """Clean up all sources.""" - for source in self.sources['controlnet'].values(): + for source in self.sources["controlnet"].values(): source.cleanup() - - if self.sources['ipadapter']: - self.sources['ipadapter'].cleanup() - - if self.sources['base']: - self.sources['base'].cleanup() + + if self.sources["ipadapter"]: + self.sources["ipadapter"].cleanup() + + if self.sources["base"]: + self.sources["base"].cleanup() diff --git a/demo/realtime-img2img/main.py b/demo/realtime-img2img/main.py index b1c609461..aa5fe5873 100644 --- a/demo/realtime-img2img/main.py +++ b/demo/realtime-img2img/main.py @@ -1,29 +1,16 @@ -from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect, UploadFile, File, Response -from fastapi.responses import StreamingResponse, JSONResponse -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from fastapi import Request - -import markdown2 - import logging -import uuid -import time -from types import SimpleNamespace -import asyncio -import os -import time import mimetypes -import torch -import tempfile -from pathlib import Path -import yaml +import time -from config import config, Args -from util import pil_to_frame, pt_to_frame, bytes_to_pil, bytes_to_pt -from connection_manager import ConnectionManager, ServerFullException +import torch +from config import Args, config +from connection_manager import ConnectionManager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles from img2img import Pipeline -from input_control import InputManager, GamepadInput +from input_control import InputManager + # fix mime error on windows mimetypes.add_type("application/javascript", ".js") @@ -33,87 +20,79 @@ # Import configuration from separate file to avoid circular imports from app_config import AVAILABLE_CONTROLNETS, DEFAULT_SETTINGS + # Configure logging def setup_logging(log_level: str = "INFO"): """Setup logging configuration for the application""" # Convert string to logging level numeric_level = getattr(logging, log_level.upper(), logging.INFO) - + # Configure root logger logging.basicConfig( - level=numeric_level, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' + level=numeric_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) - + # Set up logger for streamdiffusion modules - streamdiffusion_logger = logging.getLogger('streamdiffusion') + streamdiffusion_logger = logging.getLogger("streamdiffusion") streamdiffusion_logger.setLevel(numeric_level) - + # Set up logger for this application - app_logger = logging.getLogger('realtime_img2img') + app_logger = logging.getLogger("realtime_img2img") app_logger.setLevel(numeric_level) - + return app_logger + # Initialize logger logger = setup_logging(config.log_level) # Suppress uvicorn INFO messages if config.quiet: - uvicorn_logger = logging.getLogger('uvicorn') + uvicorn_logger = logging.getLogger("uvicorn") uvicorn_logger.setLevel(logging.WARNING) - uvicorn_access_logger = logging.getLogger('uvicorn.access') + uvicorn_access_logger = logging.getLogger("uvicorn.access") uvicorn_access_logger.setLevel(logging.WARNING) class AppState: """Centralized application state management - SINGLE SOURCE OF TRUTH""" - + def __init__(self): # Pipeline state self.pipeline_lifecycle = "stopped" # stopped, starting, running, error self.pipeline_active = False - - # Configuration state + + # Configuration state self.uploaded_config = None # Raw uploaded config - self.runtime_config = None # Runtime modifications to config + self.runtime_config = None # Runtime modifications to config self.config_needs_reload = False - + # Resolution state self.current_resolution = {"width": 512, "height": 512} - + # Parameter state (consolidates scattered vars from frontend) self.pipeline_params = {} - + # ControlNet state - AUTHORITATIVE SOURCE - self.controlnet_info = { - "enabled": False, - "controlnets": [] - } - - # IPAdapter state - AUTHORITATIVE SOURCE - self.ipadapter_info = { - "enabled": False, - "has_style_image": False, - "scale": 1.0, - "weight_type": "linear" - } - + self.controlnet_info = {"enabled": False, "controlnets": []} + + # IPAdapter state - AUTHORITATIVE SOURCE + self.ipadapter_info = {"enabled": False, "has_style_image": False, "scale": 1.0, "weight_type": "linear"} + # Pipeline hooks state - AUTHORITATIVE SOURCE self.pipeline_hooks = { "image_preprocessing": {"enabled": False, "processors": []}, "image_postprocessing": {"enabled": False, "processors": []}, "latent_preprocessing": {"enabled": False, "processors": []}, - "latent_postprocessing": {"enabled": False, "processors": []} + "latent_postprocessing": {"enabled": False, "processors": []}, } - + # Blending configurations self.prompt_blending = None self.seed_blending = None self.normalize_prompt_weights = True self.normalize_seed_weights = True - + # Core pipeline parameters self.guidance_scale = 1.1 self.delta = 0.7 @@ -122,98 +101,92 @@ def __init__(self): self.t_index_list = [35, 45] self.negative_prompt = "" self.skip_diffusion = False - + # UI state self.fps = 0 self.queue_size = 0 self.model_id = "" self.page_content = "" - + # Input source state self.input_sources = { - 'controlnet': {}, # {index: source_info} - 'ipadapter': None, - 'base': None + "controlnet": {}, # {index: source_info} + "ipadapter": None, + "base": None, } - + # Debug mode state self.debug_mode = False self.debug_pending_frame = False # True when a frame step is requested - + def populate_from_config(self, config_data): """Populate AppState from uploaded config - SINGLE SOURCE OF TRUTH""" if not config_data: return - + logger.info("populate_from_config: Populating AppState from config as single source of truth") - + # Store the complete uploaded config to preserve ALL parameters self.uploaded_config = config_data - + # Core parameters - self.guidance_scale = config_data.get('guidance_scale', self.guidance_scale) - self.delta = config_data.get('delta', self.delta) - self.num_inference_steps = config_data.get('num_inference_steps', self.num_inference_steps) - self.seed = config_data.get('seed', self.seed) - self.t_index_list = config_data.get('t_index_list', self.t_index_list) - self.negative_prompt = config_data.get('negative_prompt', self.negative_prompt) - self.skip_diffusion = config_data.get('skip_diffusion', self.skip_diffusion) - self.model_id = config_data.get('model_id_or_path', self.model_id) - + self.guidance_scale = config_data.get("guidance_scale", self.guidance_scale) + self.delta = config_data.get("delta", self.delta) + self.num_inference_steps = config_data.get("num_inference_steps", self.num_inference_steps) + self.seed = config_data.get("seed", self.seed) + self.t_index_list = config_data.get("t_index_list", self.t_index_list) + self.negative_prompt = config_data.get("negative_prompt", self.negative_prompt) + self.skip_diffusion = config_data.get("skip_diffusion", self.skip_diffusion) + self.model_id = config_data.get("model_id_or_path", self.model_id) + # Resolution parameters - if 'width' in config_data or 'height' in config_data: + if "width" in config_data or "height" in config_data: self.current_resolution = { - "width": config_data.get('width', self.current_resolution["width"]), - "height": config_data.get('height', self.current_resolution["height"]) + "width": config_data.get("width", self.current_resolution["width"]), + "height": config_data.get("height", self.current_resolution["height"]), } - + # Normalization settings - self.normalize_prompt_weights = config_data.get('normalize_weights', self.normalize_prompt_weights) - self.normalize_seed_weights = config_data.get('normalize_weights', self.normalize_seed_weights) - + self.normalize_prompt_weights = config_data.get("normalize_weights", self.normalize_prompt_weights) + self.normalize_seed_weights = config_data.get("normalize_weights", self.normalize_seed_weights) + # ControlNet configuration - if 'controlnets' in config_data: - self.controlnet_info = { - "enabled": True, - "controlnets": [] - } - for i, controlnet in enumerate(config_data['controlnets']): + if "controlnets" in config_data: + self.controlnet_info = {"enabled": True, "controlnets": []} + for i, controlnet in enumerate(config_data["controlnets"]): processed = dict(controlnet) - processed['index'] = i - processed['name'] = controlnet.get('model_id', '') - processed['strength'] = controlnet.get('conditioning_scale', 0.0) + processed["index"] = i + processed["name"] = controlnet.get("model_id", "") + processed["strength"] = controlnet.get("conditioning_scale", 0.0) self.controlnet_info["controlnets"].append(processed) else: self.controlnet_info = {"enabled": False, "controlnets": []} - + # IPAdapter configuration - if config_data.get('use_ipadapter', False): + if config_data.get("use_ipadapter", False): self.ipadapter_info["enabled"] = True - ipadapters = config_data.get('ipadapters', []) + ipadapters = config_data.get("ipadapters", []) if ipadapters: first = ipadapters[0] - self.ipadapter_info["scale"] = first.get('scale', 1.0) - self.ipadapter_info["weight_type"] = first.get('weight_type', 'linear') + self.ipadapter_info["scale"] = first.get("scale", 1.0) + self.ipadapter_info["weight_type"] = first.get("weight_type", "linear") # Store required model paths - self.ipadapter_info["ipadapter_model_path"] = first.get('ipadapter_model_path') - self.ipadapter_info["image_encoder_path"] = first.get('image_encoder_path') - self.ipadapter_info["type"] = first.get('type', 'regular') - self.ipadapter_info["insightface_model_name"] = first.get('insightface_model_name') - if first.get('style_image'): + self.ipadapter_info["ipadapter_model_path"] = first.get("ipadapter_model_path") + self.ipadapter_info["image_encoder_path"] = first.get("image_encoder_path") + self.ipadapter_info["type"] = first.get("type", "regular") + self.ipadapter_info["insightface_model_name"] = first.get("insightface_model_name") + if first.get("style_image"): self.ipadapter_info["has_style_image"] = True - self.ipadapter_info["style_image_path"] = first['style_image'] + self.ipadapter_info["style_image_path"] = first["style_image"] else: self.ipadapter_info = {"enabled": False, "has_style_image": False, "scale": 1.0, "weight_type": "linear"} - + # Pipeline hooks configuration for hook_type in self.pipeline_hooks.keys(): if hook_type in config_data: hook_config = config_data[hook_type] if isinstance(hook_config, dict): - self.pipeline_hooks[hook_type] = { - "enabled": hook_config.get("enabled", False), - "processors": [] - } + self.pipeline_hooks[hook_type] = {"enabled": hook_config.get("enabled", False), "processors": []} # Process processors with proper indexing for index, processor in enumerate(hook_config.get("processors", [])): if isinstance(processor, dict): @@ -223,74 +196,74 @@ def populate_from_config(self, config_data): "type": processor.get("type", "unknown"), "enabled": processor.get("enabled", False), "order": processor.get("order", index + 1), - "params": processor.get("params", {}) + "params": processor.get("params", {}), } self.pipeline_hooks[hook_type]["processors"].append(processed_processor) else: self.pipeline_hooks[hook_type] = {"enabled": False, "processors": []} - + # Blending configurations self.prompt_blending = self._normalize_prompt_config(config_data) self.seed_blending = self._normalize_seed_config(config_data) - + logger.info("populate_from_config: AppState populated successfully from config") def _normalize_prompt_config(self, config_data): """Normalize prompt configuration to always return a list format""" if not config_data: return None - + # Check for explicit prompt_blending first - if 'prompt_blending' in config_data: - prompt_blending = config_data['prompt_blending'] - if isinstance(prompt_blending, dict) and 'prompt_list' in prompt_blending: - prompt_list = prompt_blending['prompt_list'] + if "prompt_blending" in config_data: + prompt_blending = config_data["prompt_blending"] + if isinstance(prompt_blending, dict) and "prompt_list" in prompt_blending: + prompt_list = prompt_blending["prompt_list"] if isinstance(prompt_list, list) and len(prompt_list) > 0: return prompt_list elif isinstance(prompt_blending, list) and len(prompt_blending) > 0: return prompt_blending - - # Check for direct prompt_list key - if 'prompt_list' in config_data: - prompt_list = config_data['prompt_list'] + + # Check for direct prompt_list key + if "prompt_list" in config_data: + prompt_list = config_data["prompt_list"] if isinstance(prompt_list, list) and len(prompt_list) > 0: return prompt_list - + # Check for simple prompt key and convert to list format - if 'prompt' in config_data: - prompt = config_data['prompt'] + if "prompt" in config_data: + prompt = config_data["prompt"] if prompt and isinstance(prompt, str): return [(prompt, 1.0)] - + return None def _normalize_seed_config(self, config_data): """Normalize seed configuration to always return a list format""" if not config_data: return None - + # Check for explicit seed_blending first - if 'seed_blending' in config_data: - seed_blending = config_data['seed_blending'] - if isinstance(seed_blending, dict) and 'seed_list' in seed_blending: - seed_list = seed_blending['seed_list'] + if "seed_blending" in config_data: + seed_blending = config_data["seed_blending"] + if isinstance(seed_blending, dict) and "seed_list" in seed_blending: + seed_list = seed_blending["seed_list"] if isinstance(seed_list, list) and len(seed_list) > 0: return seed_list elif isinstance(seed_blending, list) and len(seed_blending) > 0: return seed_blending - - # Check for direct seed_list key - if 'seed_list' in config_data: - seed_list = config_data['seed_list'] + + # Check for direct seed_list key + if "seed_list" in config_data: + seed_list = config_data["seed_list"] if isinstance(seed_list, list) and len(seed_list) > 0: return seed_list - + # Check for simple seed key and convert to list format - if 'seed' in config_data: - seed = config_data['seed'] + if "seed" in config_data: + seed = config_data["seed"] if seed is not None and isinstance(seed, (int, float)): return [(int(seed), 1.0)] - + return None def get_complete_state(self): @@ -299,13 +272,10 @@ def get_complete_state(self): # Pipeline state "pipeline_active": self.pipeline_active, "pipeline_lifecycle": self.pipeline_lifecycle, - # Configuration "config_needs_reload": self.config_needs_reload, - # Resolution "current_resolution": self.current_resolution, - # Parameters "pipeline_params": self.pipeline_params, "controlnet": self.controlnet_info, @@ -314,7 +284,6 @@ def get_complete_state(self): "seed_blending": self.seed_blending, "normalize_prompt_weights": self.normalize_prompt_weights, "normalize_seed_weights": self.normalize_seed_weights, - # Core parameters "guidance_scale": self.guidance_scale, "delta": self.delta, @@ -323,27 +292,23 @@ def get_complete_state(self): "t_index_list": self.t_index_list, "negative_prompt": self.negative_prompt, "skip_diffusion": self.skip_diffusion, - # UI state "fps": self.fps, "queue_size": self.queue_size, "model_id": self.model_id, "page_content": self.page_content, - # Input sources "input_sources": self.input_sources, - # Debug mode state "debug_mode": self.debug_mode, "debug_pending_frame": self.debug_pending_frame, - # Pipeline hooks - AUTHORITATIVE SOURCE "image_preprocessing": self.pipeline_hooks["image_preprocessing"], "image_postprocessing": self.pipeline_hooks["image_postprocessing"], "latent_preprocessing": self.pipeline_hooks["latent_preprocessing"], "latent_postprocessing": self.pipeline_hooks["latent_postprocessing"], } - + def update_controlnet_strength(self, index: int, strength: float): """Update ControlNet strength in AppState - SINGLE SOURCE OF TRUTH""" if index < len(self.controlnet_info["controlnets"]): @@ -352,32 +317,32 @@ def update_controlnet_strength(self, index: int, strength: float): logger.debug(f"update_controlnet_strength: Updated ControlNet {index} strength to {strength}") else: logger.warning(f"update_controlnet_strength: ControlNet index {index} out of range") - + def add_controlnet(self, controlnet_config: dict): """Add ControlNet to AppState - SINGLE SOURCE OF TRUTH""" index = len(self.controlnet_info["controlnets"]) processed = dict(controlnet_config) - processed['index'] = index - processed['name'] = controlnet_config.get('model_id', '') - processed['strength'] = controlnet_config.get('conditioning_scale', 0.0) - + processed["index"] = index + processed["name"] = controlnet_config.get("model_id", "") + processed["strength"] = controlnet_config.get("conditioning_scale", 0.0) + self.controlnet_info["controlnets"].append(processed) self.controlnet_info["enabled"] = True logger.debug(f"add_controlnet: Added ControlNet at index {index}") - + def remove_controlnet(self, index: int): """Remove ControlNet from AppState - SINGLE SOURCE OF TRUTH""" if index < len(self.controlnet_info["controlnets"]): removed = self.controlnet_info["controlnets"].pop(index) # Re-index remaining controlnets for i, controlnet in enumerate(self.controlnet_info["controlnets"]): - controlnet['index'] = i + controlnet["index"] = i if not self.controlnet_info["controlnets"]: self.controlnet_info["enabled"] = False logger.debug(f"remove_controlnet: Removed ControlNet at index {index}") else: logger.warning(f"remove_controlnet: ControlNet index {index} out of range") - + def update_hook_processor(self, hook_type: str, processor_index: int, updates: dict): """Update pipeline hook processor in AppState - SINGLE SOURCE OF TRUTH""" if hook_type in self.pipeline_hooks: @@ -386,10 +351,12 @@ def update_hook_processor(self, hook_type: str, processor_index: int, updates: d processors[processor_index].update(updates) logger.debug(f"update_hook_processor: Updated {hook_type} processor {processor_index}") else: - logger.warning(f"update_hook_processor: Processor index {processor_index} out of range for {hook_type}") + logger.warning( + f"update_hook_processor: Processor index {processor_index} out of range for {hook_type}" + ) else: logger.warning(f"update_hook_processor: Unknown hook type {hook_type}") - + def add_hook_processor(self, hook_type: str, processor_config: dict): """Add pipeline hook processor to AppState - SINGLE SOURCE OF TRUTH""" if hook_type in self.pipeline_hooks: @@ -400,14 +367,14 @@ def add_hook_processor(self, hook_type: str, processor_config: dict): "type": processor_config.get("type", "unknown"), "enabled": processor_config.get("enabled", True), "order": processor_config.get("order", index + 1), - "params": processor_config.get("params", {}) + "params": processor_config.get("params", {}), } self.pipeline_hooks[hook_type]["processors"].append(processed) self.pipeline_hooks[hook_type]["enabled"] = True logger.debug(f"add_hook_processor: Added {hook_type} processor at index {index}") else: logger.warning(f"add_hook_processor: Unknown hook type {hook_type}") - + def remove_hook_processor(self, hook_type: str, processor_index: int): """Remove pipeline hook processor from AppState - SINGLE SOURCE OF TRUTH""" if hook_type in self.pipeline_hooks: @@ -416,191 +383,204 @@ def remove_hook_processor(self, hook_type: str, processor_index: int): removed = processors.pop(processor_index) # Re-index remaining processors for i, processor in enumerate(processors): - processor['index'] = i + processor["index"] = i if not processors: self.pipeline_hooks[hook_type]["enabled"] = False logger.debug(f"remove_hook_processor: Removed {hook_type} processor at index {processor_index}") else: - logger.warning(f"remove_hook_processor: Processor index {processor_index} out of range for {hook_type}") + logger.warning( + f"remove_hook_processor: Processor index {processor_index} out of range for {hook_type}" + ) else: logger.warning(f"remove_hook_processor: Unknown hook type {hook_type}") def update_parameter(self, parameter_name: str, value: float): """Update a single parameter in AppState - UNIFIED PARAMETER UPDATE""" logger.debug(f"update_parameter: Updating {parameter_name} = {value}") - + # Core pipeline parameters - if parameter_name == 'guidance_scale': + if parameter_name == "guidance_scale": self.guidance_scale = float(value) - elif parameter_name == 'delta': + elif parameter_name == "delta": self.delta = float(value) - elif parameter_name == 'num_inference_steps': + elif parameter_name == "num_inference_steps": self.num_inference_steps = int(value) - elif parameter_name == 'seed': + elif parameter_name == "seed": self.seed = int(value) - elif parameter_name == 'negative_prompt': + elif parameter_name == "negative_prompt": self.negative_prompt = str(value) - elif parameter_name == 'skip_diffusion': + elif parameter_name == "skip_diffusion": self.skip_diffusion = bool(value) - elif parameter_name == 't_index_list': + elif parameter_name == "t_index_list": if isinstance(value, list): self.t_index_list = value else: logger.warning(f"update_parameter: t_index_list must be a list, got {type(value)}") - + # IPAdapter parameters - elif parameter_name == 'ipadapter_scale': + elif parameter_name == "ipadapter_scale": self.ipadapter_info["scale"] = float(value) - elif parameter_name == 'ipadapter_weight_type': + elif parameter_name == "ipadapter_weight_type": # Convert numeric value to weight type string - weight_types = ["linear", "ease in", "ease out", "ease in-out", "reverse in-out", - "weak input", "weak output", "weak middle", "strong middle", - "style transfer", "composition", "strong style transfer", - "style and composition", "style transfer precise", "composition precise"] + weight_types = [ + "linear", + "ease in", + "ease out", + "ease in-out", + "reverse in-out", + "weak input", + "weak output", + "weak middle", + "strong middle", + "style transfer", + "composition", + "strong style transfer", + "style and composition", + "style transfer precise", + "composition precise", + ] index = int(value) % len(weight_types) self.ipadapter_info["weight_type"] = weight_types[index] - + # ControlNet strength parameters - elif parameter_name.startswith('controlnet_') and parameter_name.endswith('_strength'): + elif parameter_name.startswith("controlnet_") and parameter_name.endswith("_strength"): import re - match = re.match(r'controlnet_(\d+)_strength', parameter_name) + + match = re.match(r"controlnet_(\d+)_strength", parameter_name) if match: index = int(match.group(1)) self.update_controlnet_strength(index, float(value)) - + # ControlNet preprocessor parameters - elif parameter_name.startswith('controlnet_') and '_preprocessor_' in parameter_name: + elif parameter_name.startswith("controlnet_") and "_preprocessor_" in parameter_name: import re - match = re.match(r'controlnet_(\d+)_preprocessor_(.+)', parameter_name) + + match = re.match(r"controlnet_(\d+)_preprocessor_(.+)", parameter_name) if match: controlnet_index = int(match.group(1)) param_name = match.group(2) if controlnet_index < len(self.controlnet_info["controlnets"]): controlnet = self.controlnet_info["controlnets"][controlnet_index] - if 'preprocessor_params' not in controlnet: - controlnet['preprocessor_params'] = {} - controlnet['preprocessor_params'][param_name] = value - + if "preprocessor_params" not in controlnet: + controlnet["preprocessor_params"] = {} + controlnet["preprocessor_params"][param_name] = value + # Prompt blending weights - elif parameter_name.startswith('prompt_weight_'): + elif parameter_name.startswith("prompt_weight_"): import re - match = re.match(r'prompt_weight_(\d+)', parameter_name) + + match = re.match(r"prompt_weight_(\d+)", parameter_name) if match: index = int(match.group(1)) if self.prompt_blending and index < len(self.prompt_blending): # Update weight in prompt blending list prompt_text = self.prompt_blending[index][0] self.prompt_blending[index] = (prompt_text, float(value)) - + # Seed blending weights - elif parameter_name.startswith('seed_weight_'): + elif parameter_name.startswith("seed_weight_"): import re - match = re.match(r'seed_weight_(\d+)', parameter_name) + + match = re.match(r"seed_weight_(\d+)", parameter_name) if match: index = int(match.group(1)) if self.seed_blending and index < len(self.seed_blending): # Update weight in seed blending list seed_value = self.seed_blending[index][0] self.seed_blending[index] = (seed_value, float(value)) - + else: logger.warning(f"update_parameter: Unknown parameter {parameter_name}") return - + logger.debug(f"update_parameter: Successfully updated {parameter_name} in AppState") def generate_pipeline_config(self): """Generate pipeline configuration from AppState - PRESERVES ALL ORIGINAL CONFIG""" - logger.info("generate_pipeline_config: Generating pipeline config from AppState, preserving all original config") - + logger.info( + "generate_pipeline_config: Generating pipeline config from AppState, preserving all original config" + ) + # Start with complete original config to preserve ALL parameters config = {} if self.uploaded_config: config = dict(self.uploaded_config) - + # Only override runtime-changeable parameters from AppState - config.update({ - 'guidance_scale': self.guidance_scale, - 'delta': self.delta, - 'num_inference_steps': self.num_inference_steps, - 'seed': self.seed, - 't_index_list': self.t_index_list, - 'negative_prompt': self.negative_prompt, - 'skip_diffusion': self.skip_diffusion, - 'width': self.current_resolution["width"], - 'height': self.current_resolution["height"], - 'output_type': 'pt', # Force optimal tensor performance - }) - + config.update( + { + "guidance_scale": self.guidance_scale, + "delta": self.delta, + "num_inference_steps": self.num_inference_steps, + "seed": self.seed, + "t_index_list": self.t_index_list, + "negative_prompt": self.negative_prompt, + "skip_diffusion": self.skip_diffusion, + "width": self.current_resolution["width"], + "height": self.current_resolution["height"], + "output_type": "pt", # Force optimal tensor performance + } + ) + # Update ControlNet configurations with current AppState values if self.controlnet_info["enabled"] and self.controlnet_info["controlnets"]: - config['controlnets'] = [] + config["controlnets"] = [] for controlnet in self.controlnet_info["controlnets"]: cn_config = dict(controlnet) # Ensure conditioning_scale reflects current strength - cn_config['conditioning_scale'] = controlnet.get('strength', controlnet.get('conditioning_scale', 0.0)) - config['controlnets'].append(cn_config) - elif 'controlnets' in config: + cn_config["conditioning_scale"] = controlnet.get("strength", controlnet.get("conditioning_scale", 0.0)) + config["controlnets"].append(cn_config) + elif "controlnets" in config: # Remove controlnets if disabled - del config['controlnets'] - + del config["controlnets"] + # Update IPAdapter configurations with current AppState values if self.ipadapter_info["enabled"]: - config['use_ipadapter'] = True + config["use_ipadapter"] = True # Preserve original ipadapters config but update runtime values - if 'ipadapters' in config and config['ipadapters']: + if "ipadapters" in config and config["ipadapters"]: # Update existing config with current values - config['ipadapters'][0].update({ - 'scale': self.ipadapter_info["scale"], - 'weight_type': self.ipadapter_info["weight_type"] - }) + config["ipadapters"][0].update( + {"scale": self.ipadapter_info["scale"], "weight_type": self.ipadapter_info["weight_type"]} + ) # Add style image if available if self.ipadapter_info.get("has_style_image") and self.ipadapter_info.get("style_image_path"): - config['ipadapters'][0]['style_image'] = self.ipadapter_info["style_image_path"] - elif 'use_ipadapter' in config: + config["ipadapters"][0]["style_image"] = self.ipadapter_info["style_image_path"] + elif "use_ipadapter" in config: # Disable IPAdapter if not enabled in AppState - config['use_ipadapter'] = False - + config["use_ipadapter"] = False + # Update pipeline hooks with current AppState values for hook_type, hook_config in self.pipeline_hooks.items(): if hook_config["enabled"] and hook_config["processors"]: - config[hook_type] = { - "enabled": True, - "processors": [] - } + config[hook_type] = {"enabled": True, "processors": []} for processor in hook_config["processors"]: proc_config = { "type": processor["type"], "enabled": processor["enabled"], "order": processor["order"], - "params": processor["params"] + "params": processor["params"], } config[hook_type]["processors"].append(proc_config) elif hook_type in config: # Disable hook if not enabled in AppState config[hook_type] = {"enabled": False, "processors": []} - + # Update blending configurations with current AppState values if self.prompt_blending: - config['prompt_blending'] = { - 'prompt_list': self.prompt_blending, - 'interpolation_method': 'slerp' - } - config['normalize_weights'] = self.normalize_prompt_weights - elif 'prompt_blending' in config: - del config['prompt_blending'] - + config["prompt_blending"] = {"prompt_list": self.prompt_blending, "interpolation_method": "slerp"} + config["normalize_weights"] = self.normalize_prompt_weights + elif "prompt_blending" in config: + del config["prompt_blending"] + if self.seed_blending: - config['seed_blending'] = { - 'seed_list': self.seed_blending, - 'interpolation_method': 'linear' - } + config["seed_blending"] = {"seed_list": self.seed_blending, "interpolation_method": "linear"} # Note: seed normalization uses same normalize_weights key if not self.prompt_blending: # Only set if not already set by prompt blending - config['normalize_weights'] = self.normalize_seed_weights - elif 'seed_blending' in config: - del config['seed_blending'] - + config["normalize_weights"] = self.normalize_seed_weights + elif "seed_blending" in config: + del config["seed_blending"] + logger.info("generate_pipeline_config: Generated pipeline config preserving all original parameters") return config @@ -612,7 +592,6 @@ def update_state(self, updates): logger.debug(f"AppState update_state: Updated {key} = {value}") else: logger.warning(f"AppState update_state: Unknown state key: {key}") - class App: @@ -623,46 +602,50 @@ def __init__(self, config: Args): self.conn_manager = ConnectionManager() self.fps_counter = [] self.last_fps_update = time.time() - + # Centralized state management self.app_state = AppState() - + # Initialize input manager for controller support self.input_manager = InputManager() # Initialize input source manager for modular input routing from input_sources import InputSourceManager + self.input_source_manager = InputSourceManager() - + # Preemptively initialize input sources to avoid config upload delay self._preload_input_sources() - + self.init_app() def _preload_input_sources(self): """Preemptively initialize input sources and preprocessors to avoid delays during config upload""" try: - logger.info("_preload_input_sources: Preemptively initializing input sources and preprocessors to avoid config upload delay") - + logger.info( + "_preload_input_sources: Preemptively initializing input sources and preprocessors to avoid config upload delay" + ) + # Preload base input source - self.input_source_manager.get_source_info('base') - + self.input_source_manager.get_source_info("base") + # Preload IPAdapter input source - self.input_source_manager.get_source_info('ipadapter') - + self.input_source_manager.get_source_info("ipadapter") + # Preload potential ControlNet input sources (up to 5) for i in range(5): - self.input_source_manager.get_source_info('controlnet', index=i) - + self.input_source_manager.get_source_info("controlnet", index=i) + # Preload preprocessors to trigger controlnet_aux imports # This is what causes the delay - the first time a preprocessor is accessed, # all the controlnet_aux modules get imported logger.info("_preload_input_sources: Triggering preprocessor imports...") try: - from streamdiffusion.preprocessing.processors import list_preprocessors, get_preprocessor_class + from streamdiffusion.preprocessing.processors import get_preprocessor_class, list_preprocessors + # List all available preprocessors - this triggers the lazy imports available = list_preprocessors() logger.info(f"_preload_input_sources: Found {len(available)} preprocessors, loading metadata...") - + # Access at least one preprocessor class to ensure all imports complete if available: for processor_name in available[:3]: # Load first 3 to trigger most imports @@ -670,15 +653,15 @@ def _preload_input_sources(self): _ = get_preprocessor_class(processor_name) except Exception as e: logger.debug(f"_preload_input_sources: Could not load {processor_name}: {e}") - + logger.info("_preload_input_sources: Preprocessor imports completed") except Exception as prep_error: logger.warning(f"_preload_input_sources: Could not preload preprocessors: {prep_error}") - + logger.info("_preload_input_sources: Input sources and preprocessors preloaded successfully") except Exception as e: logger.error(f"_preload_input_sources: Error during preload: {e}") - + def cleanup(self): """Cleanup resources when app is shutting down""" logger.info("App cleanup: Starting application cleanup...") @@ -687,7 +670,7 @@ def cleanup(self): self._cleanup_pipeline(self.pipeline) self.pipeline = None self.app_state.pipeline_lifecycle = "stopped" - if hasattr(self, 'input_source_manager'): + if hasattr(self, "input_source_manager"): self.input_source_manager.cleanup() self._cleanup_temp_files() logger.info("App cleanup: Completed application cleanup") @@ -696,15 +679,17 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N """Handle parameter updates from input controls - UNIFIED THROUGH APPSTATE""" try: logger.debug(f"_handle_input_parameter_update: Updating {parameter_name} = {value} via AppState") - + # Update AppState as single source of truth self.app_state.update_parameter(parameter_name, value) - + # Sync to pipeline if active (for real-time updates) - if self.pipeline and hasattr(self.pipeline, 'stream'): + if self.pipeline and hasattr(self.pipeline, "stream"): self._sync_appstate_to_pipeline() else: - logger.debug(f"_handle_input_parameter_update: No active pipeline, parameter stored in AppState for next pipeline creation") + logger.debug( + "_handle_input_parameter_update: No active pipeline, parameter stored in AppState for next pipeline creation" + ) except Exception as e: logger.exception(f"_handle_input_parameter_update: Failed to update {parameter_name}: {e}") @@ -712,36 +697,36 @@ def _handle_input_parameter_update(self, parameter_name: str, value: float) -> N def _update_resolution(self, width: int, height: int) -> None: """Update resolution by recreating pipeline with new dimensions""" logger.info(f"_update_resolution: Updating resolution to {width}x{height}") - + # Update AppState first self.app_state.current_resolution = {"width": width, "height": height} - + # If no pipeline exists, just update state (will be used when pipeline is created) if not self.pipeline: logger.info("_update_resolution: No pipeline exists, resolution will apply on next pipeline creation") return - + # Set pipeline lifecycle state self.app_state.pipeline_lifecycle = "restarting" - + # Store reference to old pipeline for cleanup old_pipeline = self.pipeline - + # Clear current pipeline reference before cleanup self.pipeline = None - + # Cleanup old pipeline and free VRAM if old_pipeline: self._cleanup_pipeline(old_pipeline) old_pipeline = None - + # Create new pipeline with new resolution # No state restoration needed - _create_pipeline() uses AppState as single source of truth try: self.pipeline = self._create_pipeline() self.app_state.pipeline_lifecycle = "running" logger.info(f"_update_resolution: Pipeline successfully recreated with resolution {width}x{height}") - + except Exception as e: self.app_state.pipeline_lifecycle = "error" logger.error(f"_update_resolution: Failed to recreate pipeline: {e}") @@ -750,9 +735,9 @@ def _update_resolution(self, width: int, height: int) -> None: def _sync_appstate_to_pipeline(self): """Sync AppState parameters to active pipeline for real-time updates""" try: - if not self.pipeline or not hasattr(self.pipeline, 'stream'): + if not self.pipeline or not hasattr(self.pipeline, "stream"): return - + # Core parameters self.pipeline.update_stream_params( guidance_scale=self.app_state.guidance_scale, @@ -760,64 +745,57 @@ def _sync_appstate_to_pipeline(self): num_inference_steps=self.app_state.num_inference_steps, seed=self.app_state.seed, negative_prompt=self.app_state.negative_prompt, - t_index_list=self.app_state.t_index_list + t_index_list=self.app_state.t_index_list, ) - + # IPAdapter parameters if self.app_state.ipadapter_info["enabled"]: - self.pipeline.update_stream_params(ipadapter_config={ - 'scale': self.app_state.ipadapter_info["scale"] - }) - if hasattr(self.pipeline, 'update_ipadapter_weight_type'): + self.pipeline.update_stream_params(ipadapter_config={"scale": self.app_state.ipadapter_info["scale"]}) + if hasattr(self.pipeline, "update_ipadapter_weight_type"): self.pipeline.update_ipadapter_weight_type(self.app_state.ipadapter_info["weight_type"]) - + # ControlNet parameters if self.app_state.controlnet_info["enabled"] and self.app_state.controlnet_info["controlnets"]: controlnet_config = [] for cn in self.app_state.controlnet_info["controlnets"]: config_entry = dict(cn) - config_entry['conditioning_scale'] = cn['strength'] + config_entry["conditioning_scale"] = cn["strength"] controlnet_config.append(config_entry) self.pipeline.update_stream_params(controlnet_config=controlnet_config) - + # Prompt blending if self.app_state.prompt_blending: self.pipeline.update_stream_params(prompt_list=self.app_state.prompt_blending) - + # Seed blending if self.app_state.seed_blending: self.pipeline.update_stream_params(seed_list=self.app_state.seed_blending) - + logger.debug("_sync_appstate_to_pipeline: Successfully synced AppState to pipeline") - + except Exception as e: logger.exception(f"_sync_appstate_to_pipeline: Failed to sync AppState to pipeline: {e}") - - - - def _get_controlnet_pipeline(self): """Get the ControlNet pipeline from the main pipeline structure""" if not self.pipeline: return None - + stream = self.pipeline.stream - + # Module-aware: module installs expose preprocessors on stream - if hasattr(stream, 'preprocessors'): + if hasattr(stream, "preprocessors"): return stream - + # Check if stream has nested stream (IPAdapter wrapper) - if hasattr(stream, 'stream') and hasattr(stream.stream, 'preprocessors'): + if hasattr(stream, "stream") and hasattr(stream.stream, "preprocessors"): return stream.stream - + # New module path on stream - if hasattr(stream, '_controlnet_module'): + if hasattr(stream, "_controlnet_module"): return stream._controlnet_module return None - def init_app(self): # Enhanced CORS for API-only development mode if self.args.api_only: @@ -844,74 +822,95 @@ def init_app(self): # Register route modules self._register_routes() - + def _register_routes(self): """Register all route modules with dependency injection""" - from routes import parameters, controlnet, ipadapter, inference, pipeline_hooks, websocket, input_sources, debug - from routes.common.dependencies import get_app_instance as shared_get_app_instance, get_pipeline_class as shared_get_pipeline_class, get_default_settings as shared_get_default_settings, get_available_controlnets as shared_get_available_controlnets - + from routes import ( + controlnet, + debug, + inference, + input_sources, + ipadapter, + parameters, + pipeline_hooks, + websocket, + ) + from routes.common.dependencies import get_app_instance as shared_get_app_instance + from routes.common.dependencies import get_available_controlnets as shared_get_available_controlnets + from routes.common.dependencies import get_default_settings as shared_get_default_settings + from routes.common.dependencies import get_pipeline_class as shared_get_pipeline_class + # Create dependency overrides to inject app instance and other dependencies def get_app_instance(): return self - + def get_pipeline_class(): return Pipeline - + def get_default_settings(): return DEFAULT_SETTINGS - + def get_available_controlnets(): return AVAILABLE_CONTROLNETS - + # Include routers and set up dependency overrides on the main app - for router_module in [parameters, controlnet, ipadapter, inference, pipeline_hooks, websocket, input_sources, debug]: + for router_module in [ + parameters, + controlnet, + ipadapter, + inference, + pipeline_hooks, + websocket, + input_sources, + debug, + ]: # Include the router self.app.include_router(router_module.router) - + # Set up dependency overrides on the main app (not individual routers) self.app.dependency_overrides[shared_get_app_instance] = get_app_instance self.app.dependency_overrides[shared_get_pipeline_class] = get_pipeline_class self.app.dependency_overrides[shared_get_default_settings] = get_default_settings self.app.dependency_overrides[shared_get_available_controlnets] = get_available_controlnets - + # Set up static files if not in API-only mode if not self.args.api_only: self.app.mount("/", StaticFiles(directory="frontend/public", html=True), name="public") - def _create_pipeline(self): """Create pipeline using AppState as single source of truth""" logger.info("_create_pipeline: Creating pipeline using AppState as single source of truth") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 - + # Generate pipeline config from AppState - SINGLE SOURCE OF TRUTH pipeline_config = self.app_state.generate_pipeline_config() - + # Load config style images into InputSourceManager before creating pipeline self._load_config_style_images() - + # Create wrapper using the unified config - THIS IS NOW THE SINGLE PLACE WHERE WRAPPER IS CREATED from src.streamdiffusion.config import create_wrapper_from_config - + # Create wrapper using the unified config wrapper = create_wrapper_from_config(pipeline_config) - + # Update args with config values before passing to Pipeline from config import Args + args_dict = self.args._asdict() - if 'acceleration' in pipeline_config: - args_dict['acceleration'] = pipeline_config['acceleration'] - if 'engine_dir' in pipeline_config: - args_dict['engine_dir'] = pipeline_config['engine_dir'] - if 'use_safety_checker' in pipeline_config: - args_dict['safety_checker'] = pipeline_config['use_safety_checker'] - + if "acceleration" in pipeline_config: + args_dict["acceleration"] = pipeline_config["acceleration"] + if "engine_dir" in pipeline_config: + args_dict["engine_dir"] = pipeline_config["engine_dir"] + if "use_safety_checker" in pipeline_config: + args_dict["safety_checker"] = pipeline_config["use_safety_checker"] + updated_args = Args(**args_dict) - + # Create Pipeline instance with the pre-created wrapper and config pipeline = Pipeline(wrapper=wrapper, config=pipeline_config) - + logger.info("_create_pipeline: Pipeline created successfully with pre-created wrapper") return pipeline @@ -919,24 +918,25 @@ def _load_config_style_images(self): """Load style images from config into InputSourceManager""" if not self.app_state.uploaded_config: return - + try: # Load IPAdapter style images from config - ipadapters = self.app_state.uploaded_config.get('ipadapters', []) + ipadapters = self.app_state.uploaded_config.get("ipadapters", []) if ipadapters: first_ipadapter = ipadapters[0] - style_image_path = first_ipadapter.get('style_image') + style_image_path = first_ipadapter.get("style_image") if style_image_path: # Use the config file path as base for relative paths - base_config_path = getattr(self.args, 'controlnet_config', None) + base_config_path = getattr(self.args, "controlnet_config", None) self.input_source_manager.load_config_style_image(style_image_path, base_config_path) except Exception as e: logging.exception(f"_load_config_style_images: Error loading config style images: {e}") def _cleanup_temp_files(self): """Clean up any temporary config files""" - if hasattr(self, '_temp_config_files'): + if hasattr(self, "_temp_config_files"): import os + for temp_path in self._temp_config_files: try: if os.path.exists(temp_path): @@ -945,26 +945,24 @@ def _cleanup_temp_files(self): pass self._temp_config_files.clear() - def _calculate_aspect_ratio(self, width: int, height: int) -> str: """Calculate and return aspect ratio as a string""" - import math - + def gcd(a, b): while b: a, b = b, a % b return a - + ratio_gcd = gcd(width, height) - return f"{width//ratio_gcd}:{height//ratio_gcd}" + return f"{width // ratio_gcd}:{height // ratio_gcd}" def _cleanup_pipeline(self, pipeline): """Properly cleanup a pipeline and free VRAM""" if pipeline is None: return - + try: - if hasattr(pipeline, 'cleanup'): + if hasattr(pipeline, "cleanup"): pipeline.cleanup() del pipeline torch.cuda.empty_cache() @@ -984,4 +982,4 @@ def _cleanup_pipeline(self, pipeline): reload=config.reload, ssl_certfile=config.ssl_certfile, ssl_keyfile=config.ssl_keyfile, - ) \ No newline at end of file + ) diff --git a/demo/realtime-img2img/routes/__init__.py b/demo/realtime-img2img/routes/__init__.py index 713519c91..cd671cfb6 100644 --- a/demo/realtime-img2img/routes/__init__.py +++ b/demo/realtime-img2img/routes/__init__.py @@ -1,4 +1,3 @@ """ Routes package for realtime-img2img API endpoints """ - diff --git a/demo/realtime-img2img/routes/common/api_utils.py b/demo/realtime-img2img/routes/common/api_utils.py index a1ade9fff..0a5425774 100644 --- a/demo/realtime-img2img/routes/common/api_utils.py +++ b/demo/realtime-img2img/routes/common/api_utils.py @@ -3,46 +3,43 @@ """ import logging +from typing import Any, Dict, Optional + from fastapi import HTTPException, Request from fastapi.responses import JSONResponse -from typing import Any, Dict, Optional async def handle_api_request( - request: Request, - operation_name: str, - required_params: list = None, - pipeline_required: bool = True + request: Request, operation_name: str, required_params: list = None, pipeline_required: bool = True ) -> Dict[str, Any]: """ Standard request handler for API endpoints - + Args: request: FastAPI request object operation_name: Name of the operation for logging required_params: List of required parameter names pipeline_required: Whether an active pipeline is required - + Returns: Parsed JSON data from request - + Raises: HTTPException: For validation errors """ try: data = await request.json() - + # Check required parameters if required_params: missing_params = [param for param in required_params if param not in data] if missing_params: raise HTTPException( - status_code=400, - detail=f"Missing required parameters: {', '.join(missing_params)}" + status_code=400, detail=f"Missing required parameters: {', '.join(missing_params)}" ) - + return data - + except Exception as e: logging.exception(f"{operation_name}: Failed to parse request: {e}") raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}") @@ -51,18 +48,15 @@ async def handle_api_request( def create_success_response(message: str, **extra_data) -> JSONResponse: """ Create a standardized success response - + Args: message: Success message **extra_data: Additional data to include in response - + Returns: JSONResponse with standardized format """ - response_data = { - "status": "success", - "message": message - } + response_data = {"status": "success", "message": message} response_data.update(extra_data) return JSONResponse(response_data) @@ -70,84 +64,71 @@ def create_success_response(message: str, **extra_data) -> JSONResponse: def handle_api_error(error: Exception, operation_name: str, status_code: int = 500) -> HTTPException: """ Standard error handler for API endpoints - + Args: error: The caught exception operation_name: Name of the operation for logging status_code: HTTP status code to return - + Returns: HTTPException with standardized error message """ logging.error(f"{operation_name}: Failed: {error}") - return HTTPException( - status_code=status_code, - detail=f"Failed to {operation_name.lower()}: {str(error)}" - ) + return HTTPException(status_code=status_code, detail=f"Failed to {operation_name.lower()}: {str(error)}") def validate_pipeline(pipeline: Any, operation_name: str) -> None: """ Validate that pipeline exists and is initialized - + Args: pipeline: Pipeline object to validate operation_name: Name of the operation for error messages - + Raises: HTTPException: If pipeline is not valid """ logging.info(f"validate_pipeline: {operation_name} - pipeline is: {pipeline is not None}") if not pipeline: logging.error(f"validate_pipeline: {operation_name} - Pipeline is not initialized") - raise HTTPException( - status_code=400, - detail="Pipeline is not initialized" - ) + raise HTTPException(status_code=400, detail="Pipeline is not initialized") def validate_feature_enabled(pipeline: Any, feature_name: str, feature_check_attr: str) -> None: """ Validate that a specific feature is enabled in the pipeline - + Args: pipeline: Pipeline object feature_name: Human-readable feature name (e.g., "ControlNet", "IPAdapter") feature_check_attr: Attribute name to check (e.g., "has_controlnet", "has_ipadapter") - + Raises: HTTPException: If feature is not enabled """ if not getattr(pipeline, feature_check_attr, False): - raise HTTPException( - status_code=400, - detail=f"{feature_name} is not enabled" - ) + raise HTTPException(status_code=400, detail=f"{feature_name} is not enabled") def validate_config_mode(pipeline: Any, config_check: Optional[str] = None) -> None: """ Validate that pipeline is using config mode - + Args: pipeline: Pipeline object config_check: Optional specific config key to check for existence - + Raises: HTTPException: If not in config mode or config key missing """ - logging.info(f"validate_config_mode: use_config={getattr(pipeline, 'use_config', None)}, config exists={getattr(pipeline, 'config', None) is not None}") + logging.info( + f"validate_config_mode: use_config={getattr(pipeline, 'use_config', None)}, config exists={getattr(pipeline, 'config', None) is not None}" + ) if not (pipeline.use_config and pipeline.config): - logging.error(f"validate_config_mode: Pipeline is not using configuration mode") - raise HTTPException( - status_code=400, - detail="Pipeline is not using configuration mode" - ) - + logging.error("validate_config_mode: Pipeline is not using configuration mode") + raise HTTPException(status_code=400, detail="Pipeline is not using configuration mode") + if config_check and config_check not in pipeline.config: logging.error(f"validate_config_mode: Configuration key '{config_check}' not found in pipeline config") logging.info(f"validate_config_mode: Available config keys: {list(pipeline.config.keys())}") - raise HTTPException( - status_code=400, - detail=f"Configuration missing required section: {config_check}" - ) \ No newline at end of file + raise HTTPException(status_code=400, detail=f"Configuration missing required section: {config_check}") diff --git a/demo/realtime-img2img/routes/controlnet.py b/demo/realtime-img2img/routes/controlnet.py index e25cbcbfc..9571d613f 100644 --- a/demo/realtime-img2img/routes/controlnet.py +++ b/demo/realtime-img2img/routes/controlnet.py @@ -1,19 +1,24 @@ """ ControlNet-related endpoints for realtime-img2img """ -from fastapi import APIRouter, Request, HTTPException, Depends, UploadFile, File -from fastapi.responses import JSONResponse + +import copy import logging + import yaml -import tempfile -from pathlib import Path -import copy +from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile +from fastapi.responses import JSONResponse -from .common.api_utils import handle_api_request, create_success_response, handle_api_error, validate_feature_enabled, validate_config_mode +from .common.api_utils import ( + create_success_response, + handle_api_error, +) from .common.dependencies import get_app_instance, get_available_controlnets + router = APIRouter(prefix="/api", tags=["controlnet"]) + def _ensure_runtime_controlnet_config(app_instance): """Ensure runtime controlnet config is initialized from uploaded config or create minimal config""" if app_instance.app_state.runtime_config is None: @@ -22,45 +27,45 @@ def _ensure_runtime_controlnet_config(app_instance): app_instance.app_state.runtime_config = copy.deepcopy(app_instance.app_state.uploaded_config) else: # Create minimal config if no YAML exists - app_instance.app_state.runtime_config = {'controlnets': []} - + app_instance.app_state.runtime_config = {"controlnets": []} + # Ensure controlnets key exists in runtime config - if 'controlnets' not in app_instance.app_state.runtime_config: - app_instance.app_state.runtime_config['controlnets'] = [] + if "controlnets" not in app_instance.app_state.runtime_config: + app_instance.app_state.runtime_config["controlnets"] = [] @router.post("/controlnet/upload-config") async def upload_controlnet_config(file: UploadFile = File(...), app_instance=Depends(get_app_instance)): """Upload and load a new ControlNet YAML configuration""" try: - if not file.filename.endswith(('.yaml', '.yml')): + if not file.filename.endswith((".yaml", ".yml")): raise HTTPException(status_code=400, detail="File must be a YAML file") - + # Save uploaded file temporarily content = await file.read() - + # Parse YAML content try: - config_data = yaml.safe_load(content.decode('utf-8')) + config_data = yaml.safe_load(content.decode("utf-8")) except yaml.YAMLError as e: raise HTTPException(status_code=400, detail=f"Invalid YAML format: {str(e)}") - + # YAML is source of truth - completely replace any runtime modifications app_instance.app_state.uploaded_config = config_data app_instance.app_state.runtime_config = None app_instance.app_state.config_needs_reload = True - + # SINGLE SOURCE OF TRUTH: Populate AppState from config app_instance.app_state.populate_from_config(config_data) - + # RESET ALL INPUT SOURCES TO DEFAULTS WHEN NEW CONFIG IS UPLOADED - if hasattr(app_instance, 'input_source_manager'): + if hasattr(app_instance, "input_source_manager"): try: app_instance.input_source_manager.reset_to_defaults() logging.info("upload_controlnet_config: Reset all input sources to defaults") except Exception as e: logging.exception(f"upload_controlnet_config: Failed to reset input sources: {e}") - + # FORCE DESTROY ACTIVE PIPELINE TO MAKE CONFIG THE SOURCE OF TRUTH if app_instance.pipeline: logging.info("upload_controlnet_config: Destroying active pipeline to force config as source of truth") @@ -69,43 +74,43 @@ async def upload_controlnet_config(file: UploadFile = File(...), app_instance=De app_instance.pipeline = None app_instance._cleanup_pipeline(old_pipeline) app_instance.app_state.pipeline_lifecycle = "stopped" - + # Get config prompt if available - config_prompt = config_data.get('prompt', None) + config_prompt = config_data.get("prompt", None) # Get negative prompt if available - config_negative_prompt = config_data.get('negative_prompt', None) - + config_negative_prompt = config_data.get("negative_prompt", None) + # Get t_index_list from config if available from app_config import DEFAULT_SETTINGS - t_index_list = config_data.get('t_index_list', DEFAULT_SETTINGS.get('t_index_list', [35, 45])) - + + t_index_list = config_data.get("t_index_list", DEFAULT_SETTINGS.get("t_index_list", [35, 45])) + # Get acceleration from config if available - config_acceleration = config_data.get('acceleration', app_instance.args.acceleration) - + config_acceleration = config_data.get("acceleration", app_instance.args.acceleration) + # Get width and height from config if available - config_width = config_data.get('width', None) - config_height = config_data.get('height', None) - + config_width = config_data.get("width", None) + config_height = config_data.get("height", None) + # Update resolution if width/height are specified in config if config_width is not None and config_height is not None: try: # Validate resolution if config_width % 64 != 0 or config_height % 64 != 0: raise HTTPException(status_code=400, detail="Resolution must be multiples of 64") - + if not (384 <= config_width <= 1024) or not (384 <= config_height <= 1024): raise HTTPException(status_code=400, detail="Resolution must be between 384 and 1024") - - app_instance.app_state.current_resolution = { - "width": int(config_width), - "height": int(config_height) - } - - logging.info(f"upload_controlnet_config: Updated resolution from config to {config_width}x{config_height}") - + + app_instance.app_state.current_resolution = {"width": int(config_width), "height": int(config_height)} + + logging.info( + f"upload_controlnet_config: Updated resolution from config to {config_width}x{config_height}" + ) + except (ValueError, TypeError): raise HTTPException(status_code=400, detail="Invalid width/height values in config") - + # Build current resolution string current_resolution = None if config_width and config_height: @@ -114,13 +119,13 @@ async def upload_controlnet_config(file: UploadFile = File(...), app_instance=De aspect_ratio = app_instance._calculate_aspect_ratio(config_width, config_height) if aspect_ratio: current_resolution += f" ({aspect_ratio})" - + # Build config_values for other parameters that frontend may expect config_values = {} for key in [ - 'use_taesd', - 'cfg_type', - 'safety_checker', + "use_taesd", + "cfg_type", + "safety_checker", ]: if key in config_data: config_values[key] = config_data[key] @@ -155,46 +160,57 @@ async def upload_controlnet_config(file: UploadFile = File(...), app_instance=De "latent_preprocessing": app_instance.app_state.pipeline_hooks["latent_preprocessing"], "latent_postprocessing": app_instance.app_state.pipeline_hooks["latent_postprocessing"], } - + return JSONResponse(response_data) - + except Exception as e: logging.exception(f"upload_controlnet_config: Failed to upload configuration: {e}") raise HTTPException(status_code=500, detail=f"Failed to upload configuration: {str(e)}") + @router.get("/controlnet/info") async def get_controlnet_info(app_instance=Depends(get_app_instance)): """Get current ControlNet configuration info - SINGLE SOURCE OF TRUTH""" return JSONResponse({"controlnet": app_instance.app_state.controlnet_info}) + @router.get("/blending/current") async def get_current_blending_config(app_instance=Depends(get_app_instance)): """Get current prompt and seed blending configurations""" try: - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream') and hasattr(app_instance.pipeline.stream, 'get_stream_state'): + if ( + app_instance.pipeline + and hasattr(app_instance.pipeline, "stream") + and hasattr(app_instance.pipeline.stream, "get_stream_state") + ): state = app_instance.pipeline.stream.get_stream_state(include_caches=False) - return JSONResponse({ - "prompt_blending": state.get("prompt_list", []), - "seed_blending": state.get("seed_list", []), - "normalize_prompt_weights": state.get("normalize_prompt_weights", True), - "normalize_seed_weights": state.get("normalize_seed_weights", True), - "has_config": app_instance.app_state.uploaded_config is not None, - "pipeline_active": True - }) + return JSONResponse( + { + "prompt_blending": state.get("prompt_list", []), + "seed_blending": state.get("seed_list", []), + "normalize_prompt_weights": state.get("normalize_prompt_weights", True), + "normalize_seed_weights": state.get("normalize_seed_weights", True), + "has_config": app_instance.app_state.uploaded_config is not None, + "pipeline_active": True, + } + ) # Fallback to AppState when pipeline not initialized - SINGLE SOURCE OF TRUTH - return JSONResponse({ - "prompt_blending": app_instance.app_state.prompt_blending, - "seed_blending": app_instance.app_state.seed_blending, - "normalize_prompt_weights": app_instance.app_state.normalize_prompt_weights, - "normalize_seed_weights": app_instance.app_state.normalize_seed_weights, - "has_config": app_instance.app_state.uploaded_config is not None, - "pipeline_active": False - }) - + return JSONResponse( + { + "prompt_blending": app_instance.app_state.prompt_blending, + "seed_blending": app_instance.app_state.seed_blending, + "normalize_prompt_weights": app_instance.app_state.normalize_prompt_weights, + "normalize_seed_weights": app_instance.app_state.normalize_seed_weights, + "has_config": app_instance.app_state.uploaded_config is not None, + "pipeline_active": False, + } + ) + except Exception as e: raise handle_api_error(e, "get_current_blending_config") + @router.post("/controlnet/update-strength") async def update_controlnet_strength(request: Request, app_instance=Depends(get_app_instance)): """Update ControlNet strength in real-time""" @@ -202,13 +218,13 @@ async def update_controlnet_strength(request: Request, app_instance=Depends(get_ data = await request.json() controlnet_index = data.get("index") strength = data.get("strength") - + if controlnet_index is None or strength is None: raise HTTPException(status_code=400, detail="Missing index or strength parameter") - + # Update ControlNet strength in AppState - SINGLE SOURCE OF TRUTH app_instance.app_state.update_controlnet_strength(controlnet_index, float(strength)) - + # Update pipeline if active if app_instance.pipeline: try: @@ -216,114 +232,122 @@ async def update_controlnet_strength(request: Request, app_instance=Depends(get_ controlnet_config = [] for cn in app_instance.app_state.controlnet_info["controlnets"]: config_entry = dict(cn) - config_entry['conditioning_scale'] = cn['strength'] # Map strength back to conditioning_scale + config_entry["conditioning_scale"] = cn["strength"] # Map strength back to conditioning_scale controlnet_config.append(config_entry) app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config) except Exception as e: logging.exception(f"update_controlnet_strength: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - + return create_success_response(f"Updated ControlNet {controlnet_index} strength to {strength}") - + except Exception as e: raise handle_api_error(e, "update_controlnet_strength") + @router.get("/controlnet/available") -async def get_available_controlnets_endpoint(app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets)): +async def get_available_controlnets_endpoint( + app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets) +): """Get list of available ControlNets that can be added""" try: # Debug the dependency injection - + # Detect current model architecture to filter appropriate ControlNets model_type = "sd15" # Default fallback - + # Try to determine model type from pipeline config or uploaded config - if app_instance.pipeline and hasattr(app_instance.pipeline, 'config') and app_instance.pipeline.config: - model_id = app_instance.pipeline.config.get('model_id', '') - if 'sdxl' in model_id.lower() or 'xl' in model_id.lower(): + if app_instance.pipeline and hasattr(app_instance.pipeline, "config") and app_instance.pipeline.config: + model_id = app_instance.pipeline.config.get("model_id", "") + if "sdxl" in model_id.lower() or "xl" in model_id.lower(): model_type = "sdxl" elif app_instance.app_state.uploaded_config: # If no pipeline yet, try to get model type from uploaded config - model_id = app_instance.app_state.uploaded_config.get('model_id_or_path', '') - if 'sdxl' in model_id.lower() or 'xl' in model_id.lower(): + model_id = app_instance.app_state.uploaded_config.get("model_id_or_path", "") + if "sdxl" in model_id.lower() or "xl" in model_id.lower(): model_type = "sdxl" - + # Handle case where available_controlnets dependency returns None if available_controlnets is None: logging.warning("get_available_controlnets: available_controlnets dependency returned None") available = [] else: available = available_controlnets.get(model_type, []) - + # Filter out already active ControlNets current_controlnets = [] # Check runtime config first, then fall back to uploaded config - if app_instance.app_state.runtime_config and 'controlnets' in app_instance.app_state.runtime_config: - current_controlnets = [cn.get('model_id', '') for cn in app_instance.app_state.runtime_config['controlnets']] - elif app_instance.app_state.uploaded_config and 'controlnets' in app_instance.app_state.uploaded_config: - current_controlnets = [cn.get('model_id', '') for cn in app_instance.app_state.uploaded_config['controlnets']] - + if app_instance.app_state.runtime_config and "controlnets" in app_instance.app_state.runtime_config: + current_controlnets = [ + cn.get("model_id", "") for cn in app_instance.app_state.runtime_config["controlnets"] + ] + elif app_instance.app_state.uploaded_config and "controlnets" in app_instance.app_state.uploaded_config: + current_controlnets = [ + cn.get("model_id", "") for cn in app_instance.app_state.uploaded_config["controlnets"] + ] + filtered_available = [] for cn in available: - if cn['model_id'] not in current_controlnets: + if cn["model_id"] not in current_controlnets: filtered_available.append(cn) - - return JSONResponse({ - "status": "success", - "available_controlnets": filtered_available, - "model_type": model_type - }) - + + return JSONResponse( + {"status": "success", "available_controlnets": filtered_available, "model_type": model_type} + ) + except Exception as e: raise handle_api_error(e, "get_available_controlnets_endpoint") + @router.post("/controlnet/add") -async def add_controlnet(request: Request, app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets)): +async def add_controlnet( + request: Request, app_instance=Depends(get_app_instance), available_controlnets=Depends(get_available_controlnets) +): """Add a ControlNet from the predefined list""" try: data = await request.json() controlnet_id = data.get("controlnet_id") conditioning_scale = data.get("conditioning_scale", None) - + if not controlnet_id: raise HTTPException(status_code=400, detail="Missing controlnet_id parameter") - + # Find the ControlNet definition controlnet_def = None for model_type, controlnets in available_controlnets.items(): for cn in controlnets: - if cn['id'] == controlnet_id: + if cn["id"] == controlnet_id: controlnet_def = cn break if controlnet_def: break - + if not controlnet_def: raise HTTPException(status_code=400, detail=f"ControlNet {controlnet_id} not found in registry") - + # Use provided scale or default if conditioning_scale is None: - conditioning_scale = controlnet_def['default_scale'] - + conditioning_scale = controlnet_def["default_scale"] + # Initialize runtime config from YAML if not already done _ensure_runtime_controlnet_config(app_instance) - + # Create new ControlNet entry new_controlnet = { - 'model_id': controlnet_def['model_id'], - 'conditioning_scale': conditioning_scale, - 'preprocessor': controlnet_def['default_preprocessor'], - 'preprocessor_params': controlnet_def.get('preprocessor_params', {}), - 'enabled': True + "model_id": controlnet_def["model_id"], + "conditioning_scale": conditioning_scale, + "preprocessor": controlnet_def["default_preprocessor"], + "preprocessor_params": controlnet_def.get("preprocessor_params", {}), + "enabled": True, } - + # Add to runtime config (not YAML) - app_instance.app_state.runtime_config['controlnets'].append(new_controlnet) - + app_instance.app_state.runtime_config["controlnets"].append(new_controlnet) + # Add to AppState - SINGLE SOURCE OF TRUTH app_instance.app_state.add_controlnet(new_controlnet) - + # Update pipeline if active if app_instance.pipeline: try: @@ -331,79 +355,84 @@ async def add_controlnet(request: Request, app_instance=Depends(get_app_instance controlnet_config = [] for cn in app_instance.app_state.controlnet_info["controlnets"]: config_entry = dict(cn) - config_entry['conditioning_scale'] = cn['strength'] # Map strength back to conditioning_scale + config_entry["conditioning_scale"] = cn["strength"] # Map strength back to conditioning_scale controlnet_config.append(config_entry) app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config) except Exception as e: logging.exception(f"add_controlnet: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - - + # Return updated ControlNet info immediately - SINGLE SOURCE OF TRUTH - added_index = len(app_instance.app_state.runtime_config['controlnets']) - 1 - - return JSONResponse({ - "status": "success", - "message": f"Added {controlnet_def['name']}", - "controlnet_index": added_index, - "controlnet_info": app_instance.app_state.controlnet_info - }) - + added_index = len(app_instance.app_state.runtime_config["controlnets"]) - 1 + + return JSONResponse( + { + "status": "success", + "message": f"Added {controlnet_def['name']}", + "controlnet_index": added_index, + "controlnet_info": app_instance.app_state.controlnet_info, + } + ) + except Exception as e: raise handle_api_error(e, "add_controlnet") + @router.get("/controlnet/status") async def get_controlnet_status(app_instance=Depends(get_app_instance)): """Get the status of ControlNet configuration""" try: controlnet_pipeline = app_instance._get_controlnet_pipeline() - + if not controlnet_pipeline: - return JSONResponse({ - "status": "no_pipeline", - "message": "No ControlNet pipeline available", - "controlnet_count": 0 - }) - + return JSONResponse( + {"status": "no_pipeline", "message": "No ControlNet pipeline available", "controlnet_count": 0} + ) + # Use AppState - SINGLE SOURCE OF TRUTH controlnet_count = len(app_instance.app_state.controlnet_info["controlnets"]) - - return JSONResponse({ - "status": "ready", - "controlnet_count": controlnet_count, - "message": f"{controlnet_count} ControlNet(s) configured" if controlnet_count > 0 else "No ControlNets configured" - }) - + + return JSONResponse( + { + "status": "ready", + "controlnet_count": controlnet_count, + "message": f"{controlnet_count} ControlNet(s) configured" + if controlnet_count > 0 + else "No ControlNets configured", + } + ) + except Exception as e: raise handle_api_error(e, "get_controlnet_status") + @router.post("/controlnet/remove") async def remove_controlnet(request: Request, app_instance=Depends(get_app_instance)): """Remove a ControlNet by index""" try: data = await request.json() index = data.get("index") - + if index is None: raise HTTPException(status_code=400, detail="Missing index parameter") - + # Initialize runtime config from YAML if not already done _ensure_runtime_controlnet_config(app_instance) - - if 'controlnets' not in app_instance.app_state.runtime_config: + + if "controlnets" not in app_instance.app_state.runtime_config: raise HTTPException(status_code=400, detail="No ControlNet configuration found") - - controlnets = app_instance.app_state.runtime_config['controlnets'] - + + controlnets = app_instance.app_state.runtime_config["controlnets"] + if index < 0 or index >= len(controlnets): raise HTTPException(status_code=400, detail=f"ControlNet index {index} out of range") - + removed_controlnet = controlnets.pop(index) - + # Remove from AppState - SINGLE SOURCE OF TRUTH app_instance.app_state.remove_controlnet(index) - + # Update pipeline if active if app_instance.pipeline: try: @@ -411,65 +440,66 @@ async def remove_controlnet(request: Request, app_instance=Depends(get_app_insta controlnet_config = [] for cn in app_instance.app_state.controlnet_info["controlnets"]: config_entry = dict(cn) - config_entry['conditioning_scale'] = cn['strength'] # Map strength back to conditioning_scale + config_entry["conditioning_scale"] = cn["strength"] # Map strength back to conditioning_scale controlnet_config.append(config_entry) app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config) except Exception as e: logging.exception(f"remove_controlnet: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - - + # Return updated ControlNet info immediately - SINGLE SOURCE OF TRUTH - return create_success_response(f"Removed ControlNet at index {index}", controlnet_info=app_instance.app_state.controlnet_info) - + return create_success_response( + f"Removed ControlNet at index {index}", controlnet_info=app_instance.app_state.controlnet_info + ) + except Exception as e: raise handle_api_error(e, "remove_controlnet") + # Preprocessor endpoints (closely related to ControlNet) @router.get("/preprocessors/info") async def get_preprocessors_info(app_instance=Depends(get_app_instance)): """Get preprocessor information using metadata from preprocessor classes""" try: # Use the same processor registry as pipeline hooks - from streamdiffusion.preprocessing.processors import list_preprocessors, get_preprocessor_class - + from streamdiffusion.preprocessing.processors import get_preprocessor_class, list_preprocessors + available_processors = list_preprocessors() processors_info = {} - + for processor_name in available_processors: try: processor_class = get_preprocessor_class(processor_name) - if hasattr(processor_class, 'get_preprocessor_metadata'): + if hasattr(processor_class, "get_preprocessor_metadata"): metadata = processor_class.get_preprocessor_metadata() processors_info[processor_name] = { "name": metadata.get("name", processor_name), "description": metadata.get("description", ""), - "parameters": metadata.get("parameters", {}) + "parameters": metadata.get("parameters", {}), } else: processors_info[processor_name] = { "name": processor_name, "description": f"{processor_name} processor", - "parameters": {} + "parameters": {}, } except Exception as e: logging.warning(f"get_preprocessors_info: Failed to load metadata for {processor_name}: {e}") processors_info[processor_name] = { "name": processor_name, "description": f"{processor_name} processor", - "parameters": {} + "parameters": {}, } - - return JSONResponse({ - "status": "success", - "available": list(processors_info.keys()), - "preprocessors": processors_info - }) - + + return JSONResponse( + {"status": "success", "available": list(processors_info.keys()), "preprocessors": processors_info} + ) + except Exception as e: raise handle_api_error(e, "get_preprocessors_info") + @router.post("/preprocessors/switch") async def switch_preprocessor(request: Request, app_instance=Depends(get_app_instance)): """Switch preprocessor for a specific ControlNet""" @@ -478,50 +508,59 @@ async def switch_preprocessor(request: Request, app_instance=Depends(get_app_ins # Support both parameter naming conventions for compatibility controlnet_index = data.get("controlnet_index") or data.get("processor_index") preprocessor_name = data.get("preprocessor") or data.get("processor") - + if controlnet_index is None or not preprocessor_name: - raise HTTPException(status_code=400, detail="Missing controlnet_index/processor_index or preprocessor/processor parameter") - + raise HTTPException( + status_code=400, detail="Missing controlnet_index/processor_index or preprocessor/processor parameter" + ) + # Validate AppState has ControlNet configuration (pipeline not required) - if not app_instance.app_state.controlnet_info["enabled"] or not app_instance.app_state.controlnet_info["controlnets"]: - raise HTTPException(status_code=400, detail="No ControlNet configuration available. Please upload a config first.") - + if ( + not app_instance.app_state.controlnet_info["enabled"] + or not app_instance.app_state.controlnet_info["controlnets"] + ): + raise HTTPException( + status_code=400, detail="No ControlNet configuration available. Please upload a config first." + ) + # Update AppState - SINGLE SOURCE OF TRUTH (works before pipeline creation) if controlnet_index >= len(app_instance.app_state.controlnet_info["controlnets"]): raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range") - + # Update the preprocessor in AppState controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index] - old_preprocessor = controlnet.get('preprocessor', 'unknown') - controlnet['preprocessor'] = preprocessor_name - controlnet['preprocessor_params'] = {} # Reset parameters when switching - + old_preprocessor = controlnet.get("preprocessor", "unknown") + controlnet["preprocessor"] = preprocessor_name + controlnet["preprocessor_params"] = {} # Reset parameters when switching + # Update runtime config to keep in sync - if app_instance.app_state.runtime_config and 'controlnets' in app_instance.app_state.runtime_config: - if controlnet_index < len(app_instance.app_state.runtime_config['controlnets']): - app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor'] = preprocessor_name - app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor_params'] = {} - + if app_instance.app_state.runtime_config and "controlnets" in app_instance.app_state.runtime_config: + if controlnet_index < len(app_instance.app_state.runtime_config["controlnets"]): + app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor"] = ( + preprocessor_name + ) + app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor_params"] = {} + # Update pipeline if active if app_instance.pipeline: try: controlnet_config = [] for cn in app_instance.app_state.controlnet_info["controlnets"]: config_entry = dict(cn) - config_entry['conditioning_scale'] = cn['strength'] + config_entry["conditioning_scale"] = cn["strength"] controlnet_config.append(config_entry) app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config) except Exception as e: logging.exception(f"switch_preprocessor: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - - + return create_success_response(f"Switched ControlNet {controlnet_index} preprocessor to {preprocessor_name}") - + except Exception as e: raise handle_api_error(e, "switch_preprocessor") + @router.post("/preprocessors/update-params") async def update_preprocessor_params(request: Request, app_instance=Depends(get_app_instance)): """Update preprocessor parameters for a specific ControlNet""" @@ -532,107 +571,128 @@ async def update_preprocessor_params(request: Request, app_instance=Depends(get_ except Exception as json_error: logging.error(f"update_preprocessor_params: JSON parsing failed: {json_error}") raise HTTPException(status_code=400, detail=f"Invalid JSON: {json_error}") - + controlnet_index = data.get("controlnet_index") params = data.get("params", {}) - + if controlnet_index is None: - logging.error(f"update_preprocessor_params: Missing controlnet_index parameter") + logging.error("update_preprocessor_params: Missing controlnet_index parameter") raise HTTPException(status_code=400, detail="Missing controlnet_index parameter") - + # Validate AppState has ControlNet configuration (pipeline not required) - if not app_instance.app_state.controlnet_info["enabled"] or not app_instance.app_state.controlnet_info["controlnets"]: - logging.error(f"update_preprocessor_params: No ControlNet configuration available in AppState") - raise HTTPException(status_code=400, detail="No ControlNet configuration available. Please upload a config first.") - + if ( + not app_instance.app_state.controlnet_info["enabled"] + or not app_instance.app_state.controlnet_info["controlnets"] + ): + logging.error("update_preprocessor_params: No ControlNet configuration available in AppState") + raise HTTPException( + status_code=400, detail="No ControlNet configuration available. Please upload a config first." + ) + # Update AppState - SINGLE SOURCE OF TRUTH (works before pipeline creation) if controlnet_index >= len(app_instance.app_state.controlnet_info["controlnets"]): - logging.error(f"update_preprocessor_params: ControlNet index {controlnet_index} out of range (max: {len(app_instance.app_state.controlnet_info['controlnets'])-1})") + logging.error( + f"update_preprocessor_params: ControlNet index {controlnet_index} out of range (max: {len(app_instance.app_state.controlnet_info['controlnets']) - 1})" + ) raise HTTPException(status_code=400, detail=f"ControlNet index {controlnet_index} out of range") - + # Update preprocessor parameters in AppState controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index] - if 'preprocessor_params' not in controlnet: - controlnet['preprocessor_params'] = {} - controlnet['preprocessor_params'].update(params) - + if "preprocessor_params" not in controlnet: + controlnet["preprocessor_params"] = {} + controlnet["preprocessor_params"].update(params) + # Update runtime config to keep in sync - if app_instance.app_state.runtime_config and 'controlnets' in app_instance.app_state.runtime_config: - if controlnet_index < len(app_instance.app_state.runtime_config['controlnets']): - if 'preprocessor_params' not in app_instance.app_state.runtime_config['controlnets'][controlnet_index]: - app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor_params'] = {} - app_instance.app_state.runtime_config['controlnets'][controlnet_index]['preprocessor_params'].update(params) - + if app_instance.app_state.runtime_config and "controlnets" in app_instance.app_state.runtime_config: + if controlnet_index < len(app_instance.app_state.runtime_config["controlnets"]): + if "preprocessor_params" not in app_instance.app_state.runtime_config["controlnets"][controlnet_index]: + app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor_params"] = {} + app_instance.app_state.runtime_config["controlnets"][controlnet_index]["preprocessor_params"].update( + params + ) + # Update pipeline if active if app_instance.pipeline: try: controlnet_config = [] for cn in app_instance.app_state.controlnet_info["controlnets"]: config_entry = dict(cn) - config_entry['conditioning_scale'] = cn['strength'] + config_entry["conditioning_scale"] = cn["strength"] controlnet_config.append(config_entry) app_instance.pipeline.update_stream_params(controlnet_config=controlnet_config) except Exception as e: logging.exception(f"update_preprocessor_params: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - - logging.debug(f"update_preprocessor_params: Updated ControlNet {controlnet_index} preprocessor params: {params}") - - return create_success_response(f"Updated ControlNet {controlnet_index} preprocessor parameters", updated_params=params) - + + logging.debug( + f"update_preprocessor_params: Updated ControlNet {controlnet_index} preprocessor params: {params}" + ) + + return create_success_response( + f"Updated ControlNet {controlnet_index} preprocessor parameters", updated_params=params + ) + except Exception as e: logging.exception(f"update_preprocessor_params: Exception occurred: {str(e)}") raise handle_api_error(e, "update_preprocessor_params") + @router.get("/preprocessors/current-params/{controlnet_index}") async def get_current_preprocessor_params(controlnet_index: int, app_instance=Depends(get_app_instance)): """Get current parameter values for a specific ControlNet preprocessor""" try: # First try to get from uploaded config if no pipeline if not app_instance.pipeline and app_instance.app_state.uploaded_config: - controlnets = app_instance.app_state.uploaded_config.get('controlnets', []) + controlnets = app_instance.app_state.uploaded_config.get("controlnets", []) if controlnet_index < len(controlnets): controlnet = controlnets[controlnet_index] - return JSONResponse({ - "status": "success", - "controlnet_index": controlnet_index, - "preprocessor": controlnet.get('preprocessor', 'unknown'), - "parameters": controlnet.get('preprocessor_params', {}), - "note": "From uploaded config" - }) - + return JSONResponse( + { + "status": "success", + "controlnet_index": controlnet_index, + "preprocessor": controlnet.get("preprocessor", "unknown"), + "parameters": controlnet.get("preprocessor_params", {}), + "note": "From uploaded config", + } + ) + # Return empty/default response if no config available if not app_instance.pipeline: - return JSONResponse({ - "status": "success", - "controlnet_index": controlnet_index, - "preprocessor": "unknown", - "parameters": {}, - "note": "Pipeline not initialized - no config available" - }) - + return JSONResponse( + { + "status": "success", + "controlnet_index": controlnet_index, + "preprocessor": "unknown", + "parameters": {}, + "note": "Pipeline not initialized - no config available", + } + ) + # Use AppState - SINGLE SOURCE OF TRUTH if controlnet_index >= len(app_instance.app_state.controlnet_info["controlnets"]): - return JSONResponse({ + return JSONResponse( + { + "status": "success", + "controlnet_index": controlnet_index, + "preprocessor": "unknown", + "parameters": {}, + "note": f"ControlNet index {controlnet_index} out of range", + } + ) + + controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index] + preprocessor = controlnet.get("preprocessor", "unknown") + preprocessor_params = controlnet.get("preprocessor_params", {}) + + return JSONResponse( + { "status": "success", "controlnet_index": controlnet_index, - "preprocessor": "unknown", - "parameters": {}, - "note": f"ControlNet index {controlnet_index} out of range" - }) - - controlnet = app_instance.app_state.controlnet_info["controlnets"][controlnet_index] - preprocessor = controlnet.get('preprocessor', 'unknown') - preprocessor_params = controlnet.get('preprocessor_params', {}) - - return JSONResponse({ - "status": "success", - "controlnet_index": controlnet_index, - "preprocessor": preprocessor, - "parameters": preprocessor_params - }) - + "preprocessor": preprocessor, + "parameters": preprocessor_params, + } + ) + except Exception as e: raise handle_api_error(e, "get_current_preprocessor_params") - diff --git a/demo/realtime-img2img/routes/debug.py b/demo/realtime-img2img/routes/debug.py index 5015ce378..4fcd9e3b7 100644 --- a/demo/realtime-img2img/routes/debug.py +++ b/demo/realtime-img2img/routes/debug.py @@ -1,75 +1,82 @@ """ Debug mode API endpoints for realtime-img2img """ -from fastapi import APIRouter, HTTPException, Depends -from pydantic import BaseModel + import logging +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel + from .common.dependencies import get_app_instance + router = APIRouter(prefix="/api/debug", tags=["debug"]) + class DebugResponse(BaseModel): success: bool message: str debug_mode: bool debug_pending_frame: bool = False + @router.post("/enable", response_model=DebugResponse) async def enable_debug_mode(app_instance=Depends(get_app_instance)): """Enable debug mode - pauses automatic frame processing""" try: app_instance.app_state.debug_mode = True app_instance.app_state.debug_pending_frame = False - + logging.info("enable_debug_mode: Debug mode enabled") - + return DebugResponse( success=True, message="Debug mode enabled. Frame processing is now paused.", debug_mode=True, - debug_pending_frame=False + debug_pending_frame=False, ) except Exception as e: logging.exception(f"enable_debug_mode: Failed to enable debug mode: {e}") raise HTTPException(status_code=500, detail=f"Failed to enable debug mode: {str(e)}") + @router.post("/disable", response_model=DebugResponse) async def disable_debug_mode(app_instance=Depends(get_app_instance)): """Disable debug mode - resumes automatic frame processing""" try: app_instance.app_state.debug_mode = False app_instance.app_state.debug_pending_frame = False - + logging.info("disable_debug_mode: Debug mode disabled") - + return DebugResponse( success=True, message="Debug mode disabled. Automatic frame processing resumed.", debug_mode=False, - debug_pending_frame=False + debug_pending_frame=False, ) except Exception as e: logging.exception(f"disable_debug_mode: Failed to disable debug mode: {e}") raise HTTPException(status_code=500, detail=f"Failed to disable debug mode: {str(e)}") + @router.post("/step", response_model=DebugResponse) async def step_frame(app_instance=Depends(get_app_instance)): """Process exactly one frame when in debug mode""" try: if not app_instance.app_state.debug_mode: raise HTTPException(status_code=400, detail="Debug mode is not enabled") - + # Set pending frame flag to allow one frame to be processed app_instance.app_state.debug_pending_frame = True - + logging.info("step_frame: Frame step requested") - + return DebugResponse( success=True, message="Frame step requested. Next frame will be processed.", debug_mode=True, - debug_pending_frame=True + debug_pending_frame=True, ) except HTTPException: raise @@ -77,6 +84,7 @@ async def step_frame(app_instance=Depends(get_app_instance)): logging.exception(f"step_frame: Failed to step frame: {e}") raise HTTPException(status_code=500, detail=f"Failed to step frame: {str(e)}") + @router.get("/status", response_model=DebugResponse) async def get_debug_status(app_instance=Depends(get_app_instance)): """Get current debug mode status""" @@ -85,7 +93,7 @@ async def get_debug_status(app_instance=Depends(get_app_instance)): success=True, message="Debug status retrieved", debug_mode=app_instance.app_state.debug_mode, - debug_pending_frame=app_instance.app_state.debug_pending_frame + debug_pending_frame=app_instance.app_state.debug_pending_frame, ) except Exception as e: logging.exception(f"get_debug_status: Failed to get debug status: {e}") diff --git a/demo/realtime-img2img/routes/inference.py b/demo/realtime-img2img/routes/inference.py index f9bce553e..f1c81dd4a 100644 --- a/demo/realtime-img2img/routes/inference.py +++ b/demo/realtime-img2img/routes/inference.py @@ -1,23 +1,27 @@ """ Inference and system status endpoints for realtime-img2img """ -from fastapi import APIRouter, Request, HTTPException, Depends -from fastapi.responses import JSONResponse, StreamingResponse + import logging import uuid + import markdown2 +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from .common.dependencies import get_app_instance, get_default_settings, get_pipeline_class -from .common.api_utils import handle_api_request, create_success_response, handle_api_error -from .common.dependencies import get_app_instance, get_pipeline_class, get_default_settings router = APIRouter(prefix="/api", tags=["inference"]) + @router.get("/queue") async def get_queue_size(app_instance=Depends(get_app_instance)): """Get current queue size""" queue_size = app_instance.conn_manager.get_user_count() return JSONResponse({"queue_size": queue_size}) + @router.get("/stream/{user_id}") async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_app_instance)): """Main streaming endpoint for inference""" @@ -29,27 +33,38 @@ async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_ app_instance.pipeline = app_instance._create_pipeline() app_instance.app_state.pipeline_lifecycle = "running" logging.info("stream: Pipeline created successfully") - + # Recreate pipeline if config changed (but not resolution - that's handled separately) - elif app_instance.app_state.config_needs_reload or (app_instance.app_state.uploaded_config and not (app_instance.pipeline.use_config and app_instance.pipeline.config and 'controlnets' in app_instance.pipeline.config)) or (app_instance.app_state.uploaded_config and not app_instance.pipeline.use_config): + elif ( + app_instance.app_state.config_needs_reload + or ( + app_instance.app_state.uploaded_config + and not ( + app_instance.pipeline.use_config + and app_instance.pipeline.config + and "controlnets" in app_instance.pipeline.config + ) + ) + or (app_instance.app_state.uploaded_config and not app_instance.pipeline.use_config) + ): if app_instance.app_state.config_needs_reload: logging.info("stream: Recreating pipeline with new ControlNet config...") else: logging.info("stream: Upgrading to ControlNet pipeline...") - + app_instance.app_state.pipeline_lifecycle = "restarting" - + # Properly cleanup the old pipeline before creating new one old_pipeline = app_instance.pipeline app_instance.pipeline = None - + if old_pipeline: app_instance._cleanup_pipeline(old_pipeline) old_pipeline = None - + # Create new pipeline app_instance.pipeline = app_instance._create_pipeline() - + app_instance.app_state.config_needs_reload = False app_instance.app_state.pipeline_lifecycle = "running" logging.info("stream: Pipeline recreated successfully") @@ -59,25 +74,31 @@ async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_ # Check for acceleration changes (requires pipeline recreation) acceleration_changed = False - if hasattr(app_instance, 'new_acceleration') and app_instance.new_acceleration != app_instance.args.acceleration: - logging.info(f"stream: Acceleration change detected: {app_instance.args.acceleration} -> {app_instance.new_acceleration}") - + if ( + hasattr(app_instance, "new_acceleration") + and app_instance.new_acceleration != app_instance.args.acceleration + ): + logging.info( + f"stream: Acceleration change detected: {app_instance.args.acceleration} -> {app_instance.new_acceleration}" + ) + # Create new Args object with updated acceleration (NamedTuple is immutable) from config import Args + args_dict = app_instance.args._asdict() - args_dict['acceleration'] = app_instance.new_acceleration + args_dict["acceleration"] = app_instance.new_acceleration app_instance.args = Args(**args_dict) - delattr(app_instance, 'new_acceleration') - + delattr(app_instance, "new_acceleration") + # Recreate pipeline with new acceleration old_pipeline = app_instance.pipeline app_instance.pipeline = None if old_pipeline: app_instance._cleanup_pipeline(old_pipeline) - + app_instance.pipeline = app_instance._create_pipeline() acceleration_changed = True - logging.info(f"stream: Pipeline recreated with new acceleration") + logging.info("stream: Pipeline recreated with new acceleration") # IPAdapter style images are now handled dynamically in pipeline.predict() # No static application needed here @@ -91,6 +112,7 @@ async def stream(user_id: uuid.UUID, request: Request, app_instance=Depends(get_ # Generate and stream frames using pipeline.predict() in a loop (like original) try: + async def generate_frames(): try: while True: @@ -105,219 +127,241 @@ async def generate_frames(): else: # Wait in debug mode without requesting frames import asyncio + await asyncio.sleep(0.1) # Small delay to prevent busy waiting continue else: # Normal mode - request new frame automatically await app_instance.conn_manager.send_json(user_id, {"status": "send_frame"}) - + # Get the latest parameters from the WebSocket connection manager # This consumes data from the queue after requesting a new frame # Get latest data from the queue (blocks until new data arrives) params = await app_instance.conn_manager.get_latest_data(user_id) if params is None: continue - + # Attach InputSourceManager to params for modular input routing - if hasattr(app_instance, 'input_source_manager'): + if hasattr(app_instance, "input_source_manager"): params.input_manager = app_instance.input_source_manager - + # Generate frame using pipeline.predict() image = app_instance.pipeline.predict(params) if image is None: logging.error("stream: predict returned None image; skipping frame") continue - + # Update FPS counter import time + current_time = time.time() - if hasattr(app_instance, 'last_frame_time'): + if hasattr(app_instance, "last_frame_time"): frame_time = current_time - app_instance.last_frame_time app_instance.fps_counter.append(frame_time) if len(app_instance.fps_counter) > 30: # Keep last 30 frames app_instance.fps_counter.pop(0) app_instance.last_frame_time = current_time - + # Convert image to frame format for streaming # Use appropriate frame conversion based on output type if app_instance.pipeline.output_type == "pt": from util import pt_to_frame + frame = pt_to_frame(image) else: from util import pil_to_frame + frame = pil_to_frame(image) yield frame - + except Exception as e: logging.exception(f"stream: Error in frame generation: {e}") return StreamingResponse( generate_frames(), media_type="multipart/x-mixed-replace; boundary=frame", - headers={"Cache-Control": "no-cache, no-store, must-revalidate"} + headers={"Cache-Control": "no-cache, no-store, must-revalidate"}, ) - + except Exception as e: raise e - + except Exception as e: logging.exception(f"stream: Error in streaming endpoint: {e}") raise HTTPException(status_code=500, detail=f"Streaming failed: {str(e)}") + @router.get("/state") -async def get_app_state(app_instance=Depends(get_app_instance), pipeline_class=Depends(get_pipeline_class), default_settings=Depends(get_default_settings)): +async def get_app_state( + app_instance=Depends(get_app_instance), + pipeline_class=Depends(get_pipeline_class), + default_settings=Depends(get_default_settings), +): """Get complete application state - replaces /api/settings with centralized state management""" try: # Update app_state with current dynamic values - SINGLE SOURCE OF TRUTH app_instance.app_state.pipeline_active = app_instance.pipeline is not None - + # Update FPS from fps_counter if len(app_instance.fps_counter) > 0: avg_frame_time = sum(app_instance.fps_counter) / len(app_instance.fps_counter) app_instance.app_state.fps = round(1.0 / avg_frame_time if avg_frame_time > 0 else 0, 1) else: app_instance.app_state.fps = 0 - + # Update queue size app_instance.app_state.queue_size = app_instance.conn_manager.get_user_count() - + # Update pipeline parameters schema app_instance.app_state.pipeline_params = pipeline_class.InputParams.schema() - + # Update page content - if app_instance.pipeline and hasattr(app_instance.pipeline, 'info'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "info"): info = app_instance.pipeline.info else: info = pipeline_class.Info() - + if info.page_content: app_instance.app_state.page_content = markdown2.markdown(info.page_content) - + # Get complete state state_data = app_instance.app_state.get_complete_state() - + # Add additional fields expected by frontend for backward compatibility - state_data.update({ - "info": pipeline_class.Info.schema(), - "input_params": app_instance.app_state.pipeline_params, - "max_queue_size": app_instance.args.max_queue_size, - "acceleration": app_instance.args.acceleration, - # Add config prompt for backward compatibility - "config_prompt": app_instance.app_state.uploaded_config.get('prompt') if app_instance.app_state.uploaded_config else None, - # Add resolution in expected format - "resolution": f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}", - }) - + state_data.update( + { + "info": pipeline_class.Info.schema(), + "input_params": app_instance.app_state.pipeline_params, + "max_queue_size": app_instance.args.max_queue_size, + "acceleration": app_instance.args.acceleration, + # Add config prompt for backward compatibility + "config_prompt": app_instance.app_state.uploaded_config.get("prompt") + if app_instance.app_state.uploaded_config + else None, + # Add resolution in expected format + "resolution": f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}", + } + ) + return JSONResponse(state_data) - + except Exception as e: logging.error(f"get_app_state: Error getting application state: {e}") raise HTTPException(status_code=500, detail=f"Failed to get application state: {str(e)}") + @router.get("/settings") -async def settings(app_instance=Depends(get_app_instance), pipeline_class=Depends(get_pipeline_class), default_settings=Depends(get_default_settings)): +async def settings( + app_instance=Depends(get_app_instance), + pipeline_class=Depends(get_pipeline_class), + default_settings=Depends(get_default_settings), +): """Get pipeline settings and configuration info""" # Use Pipeline class directly for schema info (doesn't require instance) info_schema = pipeline_class.Info.schema() - + # Get info from pipeline instance if available to get correct input_mode - if app_instance.pipeline and hasattr(app_instance.pipeline, 'info'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "info"): info = app_instance.pipeline.info else: info = pipeline_class.Info() - + page_content = "" if info.page_content: page_content = markdown2.markdown(info.page_content) input_params = pipeline_class.InputParams.schema() - + # Add ControlNet information - SINGLE SOURCE OF TRUTH controlnet_info = app_instance.app_state.controlnet_info - + # Add IPAdapter information - SINGLE SOURCE OF TRUTH ipadapter_info = app_instance.app_state.ipadapter_info - + # Include config prompt if available, otherwise use default config_prompt = None - if app_instance.app_state.uploaded_config and 'prompt' in app_instance.app_state.uploaded_config: - config_prompt = app_instance.app_state.uploaded_config['prompt'] + if app_instance.app_state.uploaded_config and "prompt" in app_instance.app_state.uploaded_config: + config_prompt = app_instance.app_state.uploaded_config["prompt"] elif not config_prompt: - config_prompt = default_settings.get('prompt') - + config_prompt = default_settings.get("prompt") + # Get current t_index_list from pipeline or config current_t_index_list = None - if app_instance.pipeline and hasattr(app_instance.pipeline.stream, 't_list'): + if app_instance.pipeline and hasattr(app_instance.pipeline.stream, "t_list"): current_t_index_list = app_instance.pipeline.stream.t_list - elif app_instance.app_state.uploaded_config and 't_index_list' in app_instance.app_state.uploaded_config: - current_t_index_list = app_instance.app_state.uploaded_config['t_index_list'] + elif app_instance.app_state.uploaded_config and "t_index_list" in app_instance.app_state.uploaded_config: + current_t_index_list = app_instance.app_state.uploaded_config["t_index_list"] else: # Default values - current_t_index_list = default_settings.get('t_index_list', [35, 45]) - + current_t_index_list = default_settings.get("t_index_list", [35, 45]) + # Get current acceleration setting current_acceleration = app_instance.args.acceleration - + # Get current resolution - current_resolution = f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}" + current_resolution = ( + f"{app_instance.app_state.current_resolution['width']}x{app_instance.app_state.current_resolution['height']}" + ) # Add aspect ratio for display - aspect_ratio = app_instance._calculate_aspect_ratio(app_instance.app_state.current_resolution['width'], app_instance.app_state.current_resolution['height']) + aspect_ratio = app_instance._calculate_aspect_ratio( + app_instance.app_state.current_resolution["width"], app_instance.app_state.current_resolution["height"] + ) if aspect_ratio: current_resolution += f" ({aspect_ratio})" - if app_instance.app_state.uploaded_config and 'acceleration' in app_instance.app_state.uploaded_config: - current_acceleration = app_instance.app_state.uploaded_config['acceleration'] - + if app_instance.app_state.uploaded_config and "acceleration" in app_instance.app_state.uploaded_config: + current_acceleration = app_instance.app_state.uploaded_config["acceleration"] + # Get current streaming parameters (default values or from pipeline if available) - current_guidance_scale = default_settings.get('guidance_scale', 1.1) - current_delta = default_settings.get('delta', 0.7) - current_num_inference_steps = default_settings.get('num_inference_steps', 50) - current_seed = default_settings.get('seed', 2) - - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + current_guidance_scale = default_settings.get("guidance_scale", 1.1) + current_delta = default_settings.get("delta", 0.7) + current_num_inference_steps = default_settings.get("num_inference_steps", 50) + current_seed = default_settings.get("seed", 2) + + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): try: state = app_instance.pipeline.stream.get_stream_state() - current_guidance_scale = state.get('guidance_scale', current_guidance_scale) - current_delta = state.get('delta', current_delta) - current_num_inference_steps = state.get('num_inference_steps', current_num_inference_steps) - current_seed = state.get('seed', current_seed) + current_guidance_scale = state.get("guidance_scale", current_guidance_scale) + current_delta = state.get("delta", current_delta) + current_num_inference_steps = state.get("num_inference_steps", current_num_inference_steps) + current_seed = state.get("seed", current_seed) except Exception as e: logging.warning(f"settings: Failed to get current stream parameters: {e}") - + # Get negative prompt if available - current_negative_prompt = default_settings.get('negative_prompt', '') - if app_instance.app_state.uploaded_config and 'negative_prompt' in app_instance.app_state.uploaded_config: - current_negative_prompt = app_instance.app_state.uploaded_config['negative_prompt'] - elif app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + current_negative_prompt = default_settings.get("negative_prompt", "") + if app_instance.app_state.uploaded_config and "negative_prompt" in app_instance.app_state.uploaded_config: + current_negative_prompt = app_instance.app_state.uploaded_config["negative_prompt"] + elif app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): try: state = app_instance.pipeline.stream.get_stream_state() - current_negative_prompt = state.get('negative_prompt', current_negative_prompt) + current_negative_prompt = state.get("negative_prompt", current_negative_prompt) except Exception: pass - + # Get prompt and seed blending configuration - SINGLE SOURCE OF TRUTH prompt_blending_config = app_instance.app_state.prompt_blending seed_blending_config = app_instance.app_state.seed_blending - + # Get normalization settings - SINGLE SOURCE OF TRUTH normalize_prompt_weights = app_instance.app_state.normalize_prompt_weights normalize_seed_weights = app_instance.app_state.normalize_seed_weights - + # Get current skip_diffusion setting - SINGLE SOURCE OF TRUTH current_skip_diffusion = app_instance.app_state.skip_diffusion - + # Determine current model id for UI badge - SINGLE SOURCE OF TRUTH model_id_for_ui = app_instance.app_state.model_id - + # Check if pipeline is active pipeline_active = app_instance.pipeline is not None - + # Build config_values for other parameters that frontend may expect config_values = {} if app_instance.app_state.uploaded_config: for key in [ - 'use_taesd', - 'cfg_type', - 'safety_checker', + "use_taesd", + "cfg_type", + "safety_checker", ]: if key in app_instance.app_state.uploaded_config: config_values[key] = app_instance.app_state.uploaded_config[key] @@ -347,9 +391,10 @@ async def settings(app_instance=Depends(get_app_instance), pipeline_class=Depend "model_id": model_id_for_ui, "config_values": config_values, } - + return JSONResponse(response_data) + @router.get("/fps") async def get_fps(app_instance=Depends(get_app_instance)): """Get current FPS""" @@ -358,8 +403,5 @@ async def get_fps(app_instance=Depends(get_app_instance)): fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0 else: fps = 0 - - return JSONResponse({"fps": round(fps, 1)}) - - + return JSONResponse({"fps": round(fps, 1)}) diff --git a/demo/realtime-img2img/routes/input_sources.py b/demo/realtime-img2img/routes/input_sources.py index 0901c56ca..1029f2398 100644 --- a/demo/realtime-img2img/routes/input_sources.py +++ b/demo/realtime-img2img/routes/input_sources.py @@ -1,19 +1,21 @@ """ Input Source Management API endpoints for realtime-img2img """ -from fastapi import APIRouter, Request, HTTPException, Depends, UploadFile, File -from fastapi.responses import JSONResponse + +import io import logging -from pathlib import Path -from typing import Optional, Any, Dict import uuid +from pathlib import Path +from typing import Optional + +from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile +from input_sources import InputSource, InputSourceManager, InputSourceType from PIL import Image -import io +from utils.video_utils import is_supported_video_format, validate_video_file -from .common.api_utils import handle_api_request, create_success_response, handle_api_error +from .common.api_utils import create_success_response, handle_api_error, handle_api_request from .common.dependencies import get_app_instance -from input_sources import InputSource, InputSourceType, InputSourceManager -from utils.video_utils import validate_video_file, is_supported_video_format + router = APIRouter(prefix="/api/input-sources", tags=["input-sources"]) @@ -22,7 +24,7 @@ def _get_input_source_manager(app_instance) -> InputSourceManager: """Get or create the input source manager for the app instance.""" - if not hasattr(app_instance, 'input_source_manager'): + if not hasattr(app_instance, "input_source_manager"): app_instance.input_source_manager = InputSourceManager() return app_instance.input_source_manager @@ -31,7 +33,7 @@ def _get_input_source_manager(app_instance) -> InputSourceManager: async def set_input_source(request: Request, app_instance=Depends(get_app_instance)): """ Set input source for a component. - + Body: { component: str, # 'controlnet', 'ipadapter', 'base' index?: int, # Required for controlnet @@ -40,61 +42,60 @@ async def set_input_source(request: Request, app_instance=Depends(get_app_instan } """ try: - data = await handle_api_request(request, "set_input_source", - required_params=['component', 'source_type'], - pipeline_required=False) - - component = data['component'] - source_type_str = data['source_type'] - index = data.get('index') - source_data = data.get('source_data') - + data = await handle_api_request( + request, "set_input_source", required_params=["component", "source_type"], pipeline_required=False + ) + + component = data["component"] + source_type_str = data["source_type"] + index = data.get("index") + source_data = data.get("source_data") + # Validate component - if component not in ['controlnet', 'ipadapter', 'base']: + if component not in ["controlnet", "ipadapter", "base"]: raise HTTPException(status_code=400, detail=f"Invalid component: {component}") - + # Validate source type try: source_type = InputSourceType(source_type_str) except ValueError: raise HTTPException(status_code=400, detail=f"Invalid source type: {source_type_str}") - + # Validate index for controlnet - if component == 'controlnet' and index is None: + if component == "controlnet" and index is None: raise HTTPException(status_code=400, detail="Index is required for ControlNet components") - + # Get input source manager manager = _get_input_source_manager(app_instance) - + # Create input source input_source = InputSource(source_type, source_data) - + # Set the source manager.set_source(component, input_source, index) - + logger.info(f"set_input_source: Set {component} input source to {source_type_str}") - - return create_success_response({ - 'message': f'Input source set for {component}', - 'component': component, - 'source_type': source_type_str, - 'index': index - }) - + + return create_success_response( + { + "message": f"Input source set for {component}", + "component": component, + "source_type": source_type_str, + "index": index, + } + ) + except Exception as e: return handle_api_error(e, "set_input_source") @router.post("/upload-image/{component}") async def upload_component_image( - component: str, - file: UploadFile = File(...), - index: Optional[int] = None, - app_instance=Depends(get_app_instance) + component: str, file: UploadFile = File(...), index: Optional[int] = None, app_instance=Depends(get_app_instance) ): """ Upload image for specific component. - + Args: component: Component name ('controlnet', 'ipadapter', 'base') file: Image file to upload @@ -102,69 +103,72 @@ async def upload_component_image( """ try: # Validate component - if component not in ['controlnet', 'ipadapter', 'base']: + if component not in ["controlnet", "ipadapter", "base"]: raise HTTPException(status_code=400, detail=f"Invalid component: {component}") - + # Validate index for controlnet - if component == 'controlnet' and index is None: + if component == "controlnet" and index is None: raise HTTPException(status_code=400, detail="Index is required for ControlNet components") - + # Validate file type - if not file.content_type.startswith('image/'): + if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File must be an image") - + # Generate unique filename file_id = str(uuid.uuid4()) - file_extension = Path(file.filename).suffix if file.filename else '.jpg' - filename = f"{component}_{index}_{file_id}{file_extension}" if index is not None else f"{component}_{file_id}{file_extension}" - + file_extension = Path(file.filename).suffix if file.filename else ".jpg" + filename = ( + f"{component}_{index}_{file_id}{file_extension}" + if index is not None + else f"{component}_{file_id}{file_extension}" + ) + # Save file uploads_dir = Path("uploads/images") uploads_dir.mkdir(parents=True, exist_ok=True) file_path = uploads_dir / filename - + content = await file.read() - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: f.write(content) - + # Create PIL Image for input source try: image = Image.open(io.BytesIO(content)) - image = image.convert('RGB') # Ensure RGB format + image = image.convert("RGB") # Ensure RGB format except Exception as e: # Clean up file if image processing fails file_path.unlink(missing_ok=True) raise HTTPException(status_code=400, detail=f"Invalid image file: {str(e)}") - + # Get input source manager and set source manager = _get_input_source_manager(app_instance) input_source = InputSource(InputSourceType.UPLOADED_IMAGE, image) manager.set_source(component, input_source, index) - + logger.info(f"upload_component_image: Uploaded image for {component} (index: {index})") - - return create_success_response({ - 'message': f'Image uploaded for {component}', - 'component': component, - 'index': index, - 'filename': filename, - 'file_path': str(file_path) - }) - + + return create_success_response( + { + "message": f"Image uploaded for {component}", + "component": component, + "index": index, + "filename": filename, + "file_path": str(file_path), + } + ) + except Exception as e: return handle_api_error(e, "upload_component_image") @router.post("/upload-video/{component}") async def upload_component_video( - component: str, - file: UploadFile = File(...), - index: Optional[int] = None, - app_instance=Depends(get_app_instance) + component: str, file: UploadFile = File(...), index: Optional[int] = None, app_instance=Depends(get_app_instance) ): """ Upload video for specific component. - + Args: component: Component name ('controlnet', 'ipadapter', 'base') file: Video file to upload @@ -172,91 +176,95 @@ async def upload_component_video( """ try: # Validate component - if component not in ['controlnet', 'ipadapter', 'base']: + if component not in ["controlnet", "ipadapter", "base"]: raise HTTPException(status_code=400, detail=f"Invalid component: {component}") - + # Validate index for controlnet - if component == 'controlnet' and index is None: + if component == "controlnet" and index is None: raise HTTPException(status_code=400, detail="Index is required for ControlNet components") - + # Validate file type if not file.filename or not is_supported_video_format(file.filename): raise HTTPException(status_code=400, detail="File must be a supported video format") - + # Generate unique filename file_id = str(uuid.uuid4()) file_extension = Path(file.filename).suffix - filename = f"{component}_{index}_{file_id}{file_extension}" if index is not None else f"{component}_{file_id}{file_extension}" - + filename = ( + f"{component}_{index}_{file_id}{file_extension}" + if index is not None + else f"{component}_{file_id}{file_extension}" + ) + # Save file uploads_dir = Path("uploads/videos") uploads_dir.mkdir(parents=True, exist_ok=True) file_path = uploads_dir / filename - + content = await file.read() - with open(file_path, 'wb') as f: + with open(file_path, "wb") as f: f.write(content) - + # Validate video file is_valid, error_msg = validate_video_file(str(file_path)) if not is_valid: # Clean up file if validation fails file_path.unlink(missing_ok=True) raise HTTPException(status_code=400, detail=f"Invalid video file: {error_msg}") - + # Get input source manager and set source manager = _get_input_source_manager(app_instance) input_source = InputSource(InputSourceType.UPLOADED_VIDEO, str(file_path)) manager.set_source(component, input_source, index) - + logger.info(f"upload_component_video: Uploaded video for {component} (index: {index})") - - return create_success_response({ - 'message': f'Video uploaded for {component}', - 'component': component, - 'index': index, - 'filename': filename, - 'file_path': str(file_path) - }) - + + return create_success_response( + { + "message": f"Video uploaded for {component}", + "component": component, + "index": index, + "filename": filename, + "file_path": str(file_path), + } + ) + except Exception as e: return handle_api_error(e, "upload_component_video") @router.get("/info/{component}") async def get_component_source_info( - component: str, - index: Optional[int] = None, - app_instance=Depends(get_app_instance) + component: str, index: Optional[int] = None, app_instance=Depends(get_app_instance) ): """ Get information about a component's input source. - + Args: component: Component name ('controlnet', 'ipadapter', 'base') index: Index for ControlNet (required for controlnet component) """ try: # Validate component - if component not in ['controlnet', 'ipadapter', 'base']: + if component not in ["controlnet", "ipadapter", "base"]: raise HTTPException(status_code=400, detail=f"Invalid component: {component}") - + # Validate index for controlnet - if component == 'controlnet' and index is None: + if component == "controlnet" and index is None: raise HTTPException(status_code=400, detail="Index is required for ControlNet components") - + # Get input source manager manager = _get_input_source_manager(app_instance) - + # Get source info source_info = manager.get_source_info(component, index) - + # Make sure source_data is JSON serializable (remove PIL Image objects) - if source_info and 'source_data' in source_info: - source_data = source_info['source_data'] + if source_info and "source_data" in source_info: + source_data = source_info["source_data"] # If source_data is a PIL Image, just indicate it's present rather than trying to serialize it - if hasattr(source_data, '__class__') and source_data.__class__.__name__ == 'Image': - source_info['source_data'] = 'image_present' + if hasattr(source_data, "__class__") and source_data.__class__.__name__ == "Image": + source_info["source_data"] = "image_present" elif isinstance(source_data, str): # Keep strings (file paths) as-is pass @@ -264,16 +272,13 @@ async def get_component_source_info( # For other non-serializable objects, convert to string representation try: import json + json.dumps(source_data) # Test if it's serializable except (TypeError, ValueError): - source_info['source_data'] = str(type(source_data).__name__) - - return create_success_response({ - 'component': component, - 'index': index, - 'source_info': source_info - }) - + source_info["source_data"] = str(type(source_data).__name__) + + return create_success_response({"component": component, "index": index, "source_info": source_info}) + except Exception as e: return handle_api_error(e, "get_component_source_info") @@ -284,64 +289,60 @@ async def list_all_source_info(app_instance=Depends(get_app_instance)): try: # Get input source manager manager = _get_input_source_manager(app_instance) - + # Collect all source information all_sources = { - 'base': manager.get_source_info('base'), - 'ipadapter': manager.get_source_info('ipadapter'), - 'controlnets': {} + "base": manager.get_source_info("base"), + "ipadapter": manager.get_source_info("ipadapter"), + "controlnets": {}, } - + # Get all controlnet sources - for index, source in manager.sources['controlnet'].items(): - all_sources['controlnets'][index] = manager.get_source_info('controlnet', index) - - return create_success_response({ - 'sources': all_sources - }) - + for index, source in manager.sources["controlnet"].items(): + all_sources["controlnets"][index] = manager.get_source_info("controlnet", index) + + return create_success_response({"sources": all_sources}) + except Exception as e: return handle_api_error(e, "list_all_source_info") @router.post("/reset/{component}") -async def reset_component_source( - component: str, - index: Optional[int] = None, - app_instance=Depends(get_app_instance) -): +async def reset_component_source(component: str, index: Optional[int] = None, app_instance=Depends(get_app_instance)): """ Reset a component's input source to webcam (default). - + Args: component: Component name ('controlnet', 'ipadapter', 'base') index: Index for ControlNet (required for controlnet component) """ try: # Validate component - if component not in ['controlnet', 'ipadapter', 'base']: + if component not in ["controlnet", "ipadapter", "base"]: raise HTTPException(status_code=400, detail=f"Invalid component: {component}") - + # Validate index for controlnet - if component == 'controlnet' and index is None: + if component == "controlnet" and index is None: raise HTTPException(status_code=400, detail="Index is required for ControlNet components") - + # Get input source manager manager = _get_input_source_manager(app_instance) - + # Create webcam input source webcam_source = InputSource(InputSourceType.WEBCAM) manager.set_source(component, webcam_source, index) - + logger.info(f"reset_component_source: Reset {component} to webcam (index: {index})") - - return create_success_response({ - 'message': f'Input source reset to webcam for {component}', - 'component': component, - 'index': index, - 'source_type': 'webcam' - }) - + + return create_success_response( + { + "message": f"Input source reset to webcam for {component}", + "component": component, + "index": index, + "source_type": "webcam", + } + ) + except Exception as e: return handle_api_error(e, "reset_component_source") @@ -355,20 +356,22 @@ async def reset_all_input_sources(app_instance=Depends(get_app_instance)): try: # Get input source manager manager = _get_input_source_manager(app_instance) - + # Reset all sources to defaults manager.reset_to_defaults() - + logger.info("reset_all_input_sources: Reset all input sources to defaults") - - return create_success_response({ - 'message': 'All input sources reset to defaults', - 'defaults': { - 'base': 'webcam', - 'ipadapter': 'uploaded_image (default image)', - 'controlnet': 'fallback to base pipeline' + + return create_success_response( + { + "message": "All input sources reset to defaults", + "defaults": { + "base": "webcam", + "ipadapter": "uploaded_image (default image)", + "controlnet": "fallback to base pipeline", + }, } - }) - + ) + except Exception as e: return handle_api_error(e, "reset_all_input_sources") diff --git a/demo/realtime-img2img/routes/ipadapter.py b/demo/realtime-img2img/routes/ipadapter.py index 59a58b42a..edf7a4519 100644 --- a/demo/realtime-img2img/routes/ipadapter.py +++ b/demo/realtime-img2img/routes/ipadapter.py @@ -1,107 +1,122 @@ """ IPAdapter-related endpoints for realtime-img2img """ -from fastapi import APIRouter, Request, HTTPException, Depends -from fastapi.responses import JSONResponse, Response + import logging import os -from .common.api_utils import handle_api_request, create_success_response, handle_api_error, validate_pipeline, validate_feature_enabled, validate_config_mode +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import Response + +from .common.api_utils import ( + create_success_response, + handle_api_error, + handle_api_request, + validate_config_mode, +) from .common.dependencies import get_app_instance + router = APIRouter(prefix="/api", tags=["ipadapter"]) # Legacy upload endpoint removed - use /api/input-sources/upload-image/ipadapter instead # Legacy get uploaded image endpoint removed - use InputSourceManager instead + @router.get("/default-image") async def get_default_image(): """Get the default image (input.png)""" try: default_image_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "images", "inputs", "input.png") - + if not os.path.exists(default_image_path): raise HTTPException(status_code=404, detail="Default image not found") - + # Read and return the default image file with open(default_image_path, "rb") as image_file: image_content = image_file.read() - - return Response(content=image_content, media_type="image/png", headers={"Cache-Control": "public, max-age=3600"}) - + + return Response( + content=image_content, media_type="image/png", headers={"Cache-Control": "public, max-age=3600"} + ) + except Exception as e: raise handle_api_error(e, "get_default_image") + @router.post("/ipadapter/update-scale") async def update_ipadapter_scale(request: Request, app_instance=Depends(get_app_instance)): """Update IPAdapter scale/strength in real-time""" try: data = await handle_api_request(request, "update_ipadapter_scale", ["scale"]) scale = data.get("scale") - + # Validate AppState has IPAdapter configuration (pipeline not required) if not app_instance.app_state.ipadapter_info["enabled"]: - raise HTTPException(status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first.") - + raise HTTPException( + status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first." + ) + # Update AppState as single source of truth (works before pipeline creation) app_instance.app_state.update_parameter("ipadapter_scale", float(scale)) - + # Sync to pipeline if active - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - + return create_success_response(f"Updated IPAdapter scale to {scale}") - + except Exception as e: raise handle_api_error(e, "update_ipadapter_scale") + @router.post("/ipadapter/update-weight-type") async def update_ipadapter_weight_type(request: Request, app_instance=Depends(get_app_instance)): """Update IPAdapter weight type in real-time""" try: data = await handle_api_request(request, "update_ipadapter_weight_type", ["weight_type"]) weight_type = data.get("weight_type") - + # Validate AppState has IPAdapter configuration (pipeline not required) if not app_instance.app_state.ipadapter_info["enabled"]: - raise HTTPException(status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first.") - + raise HTTPException( + status_code=400, detail="IPAdapter is not enabled. Please upload a config with IPAdapter first." + ) + # Update AppState as single source of truth (works before pipeline creation) app_instance.app_state.ipadapter_info["weight_type"] = weight_type - + # Sync to pipeline if active - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - + return create_success_response(f"Updated IPAdapter weight type to {weight_type}") - + except Exception as e: raise handle_api_error(e, "update_ipadapter_weight_type") + @router.post("/ipadapter/update-enabled") async def update_ipadapter_enabled(request: Request, app_instance=Depends(get_app_instance)): """Enable or disable IPAdapter in real-time""" try: data = await handle_api_request(request, "update_ipadapter_enabled", ["enabled"]) enabled = data.get("enabled") - + # Update AppState as single source of truth (works before pipeline creation) app_instance.app_state.ipadapter_info["enabled"] = bool(enabled) logging.info(f"update_ipadapter_enabled: Updated AppState ipadapter enabled to {enabled}") - + # Sync to pipeline if active if app_instance.pipeline: validate_config_mode(app_instance.pipeline, "ipadapters") - + # Update IPAdapter enabled state in the pipeline - app_instance.pipeline.stream.update_stream_params( - ipadapter_config={'enabled': bool(enabled)} - ) - logging.info(f"update_ipadapter_enabled: Synced to active pipeline") - + app_instance.pipeline.stream.update_stream_params(ipadapter_config={"enabled": bool(enabled)}) + logging.info("update_ipadapter_enabled: Synced to active pipeline") + return create_success_response(f"IPAdapter {'enabled' if enabled else 'disabled'} successfully") - + except Exception as e: raise handle_api_error(e, "update_ipadapter_enabled") - diff --git a/demo/realtime-img2img/routes/parameters.py b/demo/realtime-img2img/routes/parameters.py index 77f875592..27f4b0d06 100644 --- a/demo/realtime-img2img/routes/parameters.py +++ b/demo/realtime-img2img/routes/parameters.py @@ -1,15 +1,19 @@ """ Parameter update endpoints for realtime-img2img """ -from fastapi import APIRouter, Request, HTTPException, Depends -from fastapi.responses import JSONResponse + import logging -from .common.api_utils import handle_api_request, create_success_response, handle_api_error +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse + +from .common.api_utils import create_success_response, handle_api_error, handle_api_request from .common.dependencies import get_app_instance + router = APIRouter(prefix="/api", tags=["parameters"]) + @router.post("/params") async def update_params(request: Request, app_instance=Depends(get_app_instance)): """Update multiple streaming parameters in a single unified call""" @@ -17,7 +21,7 @@ async def update_params(request: Request, app_instance=Depends(get_app_instance) data = await request.json() logging.info(f"update_params: Received data: {data}") logging.info(f"update_params: Pipeline exists: {app_instance.pipeline is not None}") - + # Allow updating resolution even when pipeline is not initialized. # We save the new values so they take effect on the next stream start. if "resolution" in data: @@ -25,48 +29,44 @@ async def update_params(request: Request, app_instance=Depends(get_app_instance) logging.info("update_params: No pipeline exists, updating resolution directly") else: logging.info("update_params: Pipeline exists, resolution update may be handled differently") - + if "resolution" in data: resolution = data["resolution"] if isinstance(resolution, dict) and "width" in resolution and "height" in resolution: width, height = int(resolution["width"]), int(resolution["height"]) - + # Call the proper pipeline recreation method app_instance._update_resolution(width, height) - + message = f"Resolution updated to {width}x{height} and pipeline recreated successfully" logging.info(f"update_params: {message}") - return JSONResponse({ - "status": "success", - "message": message - }) + return JSONResponse({"status": "success", "message": message}) elif isinstance(resolution, str): # Handle string format like "512x768 (2:3)" or "512x768" - resolution_part = resolution.split(' ')[0] + resolution_part = resolution.split(" ")[0] logging.info(f"update_params: Parsing resolution string: {resolution} -> {resolution_part}") try: - width, height = map(int, resolution_part.split('x')) + width, height = map(int, resolution_part.split("x")) logging.info(f"update_params: Parsed width={width}, height={height}") - + # Call the proper pipeline recreation method app_instance._update_resolution(width, height) - + message = f"Resolution updated to {width}x{height} and pipeline recreated successfully" logging.info(f"update_params: {message}") - return JSONResponse({ - "status": "success", - "message": message - }) + return JSONResponse({"status": "success", "message": message}) except ValueError: raise HTTPException(status_code=400, detail="Invalid resolution format") else: - raise HTTPException(status_code=400, detail="Resolution must be {width: int, height: int} or 'widthxheight' string") + raise HTTPException( + status_code=400, detail="Resolution must be {width: int, height: int} or 'widthxheight' string" + ) # No pipeline validation needed - AppState updates work before pipeline creation - + # Update parameters that exist in the data params = {} - + if "guidance_scale" in data: params["guidance_scale"] = float(data["guidance_scale"]) if "delta" in data: @@ -86,65 +86,60 @@ async def update_params(request: Request, app_instance=Depends(get_app_instance) # Update AppState as single source of truth (works before pipeline creation) for param_name, param_value in params.items(): app_instance.app_state.update_parameter(param_name, param_value) - + # Sync to pipeline if active (for real-time updates) - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - - return JSONResponse({ - "status": "success", - "message": f"Updated parameters: {list(params.keys())}", - "updated_params": params - }) + + return JSONResponse( + { + "status": "success", + "message": f"Updated parameters: {list(params.keys())}", + "updated_params": params, + } + ) else: - return JSONResponse({ - "status": "success", - "message": "No valid parameters provided to update" - }) - + return JSONResponse({"status": "success", "message": "No valid parameters provided to update"}) + except Exception as e: logging.exception(f"update_params: Failed to update parameters: {e}") raise HTTPException(status_code=500, detail=f"Failed to update parameters: {str(e)}") + async def _update_single_parameter( - request: Request, - app_instance, - parameter_name: str, - value_converter: callable, - operation_name: str + request: Request, app_instance, parameter_name: str, value_converter: callable, operation_name: str ): """Generic function to update a single parameter""" try: data = await handle_api_request(request, operation_name, [parameter_name]) # No pipeline validation needed - AppState updates work before pipeline creation - + value = value_converter(data[parameter_name]) - + # Update AppState as single source of truth (works before pipeline creation) app_instance.app_state.update_parameter(parameter_name, value) - + # Sync to pipeline if active (for real-time updates) - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - + return create_success_response(f"Updated {parameter_name} to {value}", **{parameter_name: value}) - + except Exception as e: raise handle_api_error(e, operation_name) + @router.post("/update-guidance-scale") async def update_guidance_scale(request: Request, app_instance=Depends(get_app_instance)): """Update guidance scale parameter""" - return await _update_single_parameter( - request, app_instance, "guidance_scale", float, "update_guidance_scale" - ) + return await _update_single_parameter(request, app_instance, "guidance_scale", float, "update_guidance_scale") + @router.post("/update-delta") async def update_delta(request: Request, app_instance=Depends(get_app_instance)): """Update delta parameter""" - return await _update_single_parameter( - request, app_instance, "delta", float, "update_delta" - ) + return await _update_single_parameter(request, app_instance, "delta", float, "update_delta") + @router.post("/update-num-inference-steps") async def update_num_inference_steps(request: Request, app_instance=Depends(get_app_instance)): @@ -153,32 +148,32 @@ async def update_num_inference_steps(request: Request, app_instance=Depends(get_ request, app_instance, "num_inference_steps", int, "update_num_inference_steps" ) + @router.post("/update-seed") async def update_seed(request: Request, app_instance=Depends(get_app_instance)): """Update seed parameter""" - return await _update_single_parameter( - request, app_instance, "seed", int, "update_seed" - ) + return await _update_single_parameter(request, app_instance, "seed", int, "update_seed") + @router.post("/blending") async def update_blending(request: Request, app_instance=Depends(get_app_instance)): """Update prompt and/or seed blending configuration in real-time""" try: data = await request.json() - + # No pipeline validation needed - AppState updates work before pipeline creation - + params = {} updated_types = [] - + # Handle prompt blending if "prompt_list" in data: prompt_list = data["prompt_list"] interpolation_method = data.get("prompt_interpolation_method", "slerp") - + if not isinstance(prompt_list, list): raise HTTPException(status_code=400, detail="prompt_list must be a list") - + # Validate and convert format prompt_tuples = [] for item in prompt_list: @@ -187,8 +182,11 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc elif isinstance(item, dict) and "prompt" in item and "weight" in item: prompt_tuples.append((str(item["prompt"]), float(item["weight"]))) else: - raise HTTPException(status_code=400, detail="Each prompt item must be [prompt, weight] or {prompt: str, weight: float}") - + raise HTTPException( + status_code=400, + detail="Each prompt item must be [prompt, weight] or {prompt: str, weight: float}", + ) + params["prompt_list"] = prompt_tuples params["prompt_interpolation_method"] = interpolation_method updated_types.append("prompt") @@ -197,10 +195,10 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc if "seed_list" in data: seed_list = data["seed_list"] interpolation_method = data.get("seed_interpolation_method", "linear") - + if not isinstance(seed_list, list): raise HTTPException(status_code=400, detail="seed_list must be a list") - + # Validate and convert format seed_tuples = [] for item in seed_list: @@ -209,8 +207,10 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc elif isinstance(item, dict) and "seed" in item and "weight" in item: seed_tuples.append((int(item["seed"]), float(item["weight"]))) else: - raise HTTPException(status_code=400, detail="Each seed item must be [seed, weight] or {seed: int, weight: float}") - + raise HTTPException( + status_code=400, detail="Each seed item must be [seed, weight] or {seed: int, weight: float}" + ) + params["seed_list"] = seed_tuples params["seed_interpolation_method"] = interpolation_method updated_types.append("seed") @@ -223,63 +223,64 @@ async def update_blending(request: Request, app_instance=Depends(get_app_instanc app_instance.app_state.prompt_blending = params["prompt_list"] if "seed_list" in params: app_instance.app_state.seed_blending = params["seed_list"] - + # Sync to pipeline if active - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - + return create_success_response(f"Updated {' and '.join(updated_types)} blending", updated_types=updated_types) - + except Exception as e: raise handle_api_error(e, "update_blending") + @router.post("/blending/update-prompt-weight") async def update_prompt_weight(request: Request, app_instance=Depends(get_app_instance)): """Update a specific prompt weight in the current blending configuration""" try: data = await request.json() - index = data.get('index') - weight = data.get('weight') - + index = data.get("index") + weight = data.get("weight") + if index is None or weight is None: raise HTTPException(status_code=400, detail="Missing index or weight parameter") - + # No pipeline validation needed - AppState updates work before pipeline creation - + # Update AppState as single source of truth app_instance.app_state.update_parameter(f"prompt_weight_{index}", float(weight)) - + # Sync to pipeline if active - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - + return create_success_response(f"Updated prompt weight {index} to {weight}") - + except Exception as e: raise handle_api_error(e, "update_prompt_weight") -@router.post("/blending/update-seed-weight") + +@router.post("/blending/update-seed-weight") async def update_seed_weight(request: Request, app_instance=Depends(get_app_instance)): """Update a specific seed weight in the current blending configuration""" try: data = await request.json() - index = data.get('index') - weight = data.get('weight') - + index = data.get("index") + weight = data.get("weight") + if index is None or weight is None: raise HTTPException(status_code=400, detail="Missing index or weight parameter") - + # No pipeline validation needed - AppState updates work before pipeline creation - + # Update AppState as single source of truth app_instance.app_state.update_parameter(f"seed_weight_{index}", float(weight)) - + # Sync to pipeline if active - if app_instance.pipeline and hasattr(app_instance.pipeline, 'stream'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "stream"): app_instance._sync_appstate_to_pipeline() - + return create_success_response(f"Updated seed weight {index} to {weight}") - + except Exception as e: raise handle_api_error(e, "update_seed_weight") - diff --git a/demo/realtime-img2img/routes/pipeline_hooks.py b/demo/realtime-img2img/routes/pipeline_hooks.py index 2fd8f226f..0c7184ef6 100644 --- a/demo/realtime-img2img/routes/pipeline_hooks.py +++ b/demo/realtime-img2img/routes/pipeline_hooks.py @@ -1,14 +1,16 @@ - """ Pipeline hooks endpoints for realtime-img2img """ -from fastapi import APIRouter, Request, HTTPException, Depends -from fastapi.responses import JSONResponse + import logging -from .common.api_utils import handle_api_request, create_success_response, handle_api_error +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse + +from .common.api_utils import create_success_response, handle_api_error from .common.dependencies import get_app_instance + router = APIRouter(prefix="/api", tags=["pipeline-hooks"]) @@ -22,12 +24,13 @@ async def get_pipeline_hooks_info_config(app_instance=Depends(get_app_instance)) "image_preprocessing": app_instance.app_state.pipeline_hooks["image_preprocessing"], "image_postprocessing": app_instance.app_state.pipeline_hooks["image_postprocessing"], "latent_preprocessing": app_instance.app_state.pipeline_hooks["latent_preprocessing"], - "latent_postprocessing": app_instance.app_state.pipeline_hooks["latent_postprocessing"] + "latent_postprocessing": app_instance.app_state.pipeline_hooks["latent_postprocessing"], } return JSONResponse(hooks_info) except Exception as e: raise handle_api_error(e, "get_pipeline_hooks_info_config") + # Individual hook type endpoints that frontend expects @router.get("/pipeline-hooks/image_preprocessing/info-config") async def get_image_preprocessing_info_config(app_instance=Depends(get_app_instance)): @@ -35,83 +38,95 @@ async def get_image_preprocessing_info_config(app_instance=Depends(get_app_insta try: hook_info = app_instance.app_state.pipeline_hooks["image_preprocessing"] return JSONResponse({"image_preprocessing": hook_info}) - except Exception as e: + except Exception: return JSONResponse({"image_preprocessing": None}) + @router.get("/pipeline-hooks/image_postprocessing/info-config") async def get_image_postprocessing_info_config(app_instance=Depends(get_app_instance)): """Get image postprocessing hook configuration info - SINGLE SOURCE OF TRUTH""" try: hook_info = app_instance.app_state.pipeline_hooks["image_postprocessing"] return JSONResponse({"image_postprocessing": hook_info}) - except Exception as e: + except Exception: return JSONResponse({"image_postprocessing": None}) + @router.get("/pipeline-hooks/latent_preprocessing/info-config") async def get_latent_preprocessing_info_config(app_instance=Depends(get_app_instance)): """Get latent preprocessing hook configuration info - SINGLE SOURCE OF TRUTH""" try: hook_info = app_instance.app_state.pipeline_hooks["latent_preprocessing"] return JSONResponse({"latent_preprocessing": hook_info}) - except Exception as e: + except Exception: return JSONResponse({"latent_preprocessing": None}) + @router.get("/pipeline-hooks/latent_postprocessing/info-config") async def get_latent_postprocessing_info_config(app_instance=Depends(get_app_instance)): """Get latent postprocessing hook configuration info - SINGLE SOURCE OF TRUTH""" try: hook_info = app_instance.app_state.pipeline_hooks["latent_postprocessing"] return JSONResponse({"latent_postprocessing": hook_info}) - except Exception as e: + except Exception: return JSONResponse({"latent_postprocessing": None}) + @router.get("/pipeline-hooks/{hook_type}/info") async def get_hook_processors_info(hook_type: str, app_instance=Depends(get_app_instance)): """Get available processors for a specific hook type""" try: - if hook_type not in ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"]: + if hook_type not in [ + "image_preprocessing", + "image_postprocessing", + "latent_preprocessing", + "latent_postprocessing", + ]: raise HTTPException(status_code=400, detail=f"Invalid hook type: {hook_type}") - + # Use the same processor registry as ControlNet - from streamdiffusion.preprocessing.processors import list_preprocessors, get_preprocessor_class - + from streamdiffusion.preprocessing.processors import get_preprocessor_class, list_preprocessors + available_processors = list_preprocessors() processors_info = {} - + for processor_name in available_processors: try: processor_class = get_preprocessor_class(processor_name) - if hasattr(processor_class, 'get_preprocessor_metadata'): + if hasattr(processor_class, "get_preprocessor_metadata"): metadata = processor_class.get_preprocessor_metadata() processors_info[processor_name] = { "name": metadata.get("name", processor_name), "description": metadata.get("description", ""), - "parameters": metadata.get("parameters", {}) + "parameters": metadata.get("parameters", {}), } else: processors_info[processor_name] = { "name": processor_name, "description": f"{processor_name} processor", - "parameters": {} + "parameters": {}, } except Exception as e: logging.warning(f"get_hook_processors_info: Failed to load metadata for {processor_name}: {e}") processors_info[processor_name] = { "name": processor_name, "description": f"{processor_name} processor", - "parameters": {} + "parameters": {}, } - - return JSONResponse({ - "status": "success", - "hook_type": hook_type, - "available": list(processors_info.keys()), - "preprocessors": processors_info - }) - + + return JSONResponse( + { + "status": "success", + "hook_type": hook_type, + "available": list(processors_info.keys()), + "preprocessors": processors_info, + } + ) + except Exception as e: raise handle_api_error(e, "get_hook_processors_info") + @router.post("/pipeline-hooks/{hook_type}/add") async def add_hook_processor(hook_type: str, request: Request, app_instance=Depends(get_app_instance)): """Add a new processor to a hook""" @@ -119,27 +134,28 @@ async def add_hook_processor(hook_type: str, request: Request, app_instance=Depe data = await request.json() processor_type = data.get("processor_type") processor_params = data.get("processor_params", {}) - + if not processor_type: raise HTTPException(status_code=400, detail="Missing processor_type parameter") - + # No pipeline validation needed - AppState updates work before pipeline creation - - if hook_type not in ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"]: + + if hook_type not in [ + "image_preprocessing", + "image_postprocessing", + "latent_preprocessing", + "latent_postprocessing", + ]: raise HTTPException(status_code=400, detail=f"Invalid hook type: {hook_type}") - + logging.debug(f"add_hook_processor: Adding {processor_type} to {hook_type}") - + # Create processor config - new_processor = { - "type": processor_type, - "params": processor_params, - "enabled": True - } - + new_processor = {"type": processor_type, "params": processor_params, "enabled": True} + # Add to AppState - SINGLE SOURCE OF TRUTH app_instance.app_state.add_hook_processor(hook_type, new_processor) - + # Update pipeline if active if app_instance.pipeline: try: @@ -148,7 +164,7 @@ async def add_hook_processor(hook_type: str, request: Request, app_instance=Depe config_entry = { "type": processor["type"], "params": processor["params"], - "enabled": processor["enabled"] + "enabled": processor["enabled"], } hook_config.append(config_entry) update_kwargs = {f"{hook_type}_config": hook_config} @@ -157,25 +173,26 @@ async def add_hook_processor(hook_type: str, request: Request, app_instance=Depe logging.exception(f"add_hook_processor: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - + logging.info(f"add_hook_processor: Successfully added {processor_type} to {hook_type}") - + return create_success_response(f"Added {processor_type} processor to {hook_type}") - + except Exception as e: raise handle_api_error(e, "add_hook_processor") + @router.delete("/pipeline-hooks/{hook_type}/remove/{processor_index}") async def remove_hook_processor(hook_type: str, processor_index: int, app_instance=Depends(get_app_instance)): """Remove a processor from a hook""" try: # No pipeline validation needed - AppState updates work before pipeline creation - + logging.debug(f"remove_hook_processor: Removing processor {processor_index} from {hook_type}") - + # Remove from AppState - SINGLE SOURCE OF TRUTH app_instance.app_state.remove_hook_processor(hook_type, processor_index) - + # Update pipeline if active if app_instance.pipeline: try: @@ -184,7 +201,7 @@ async def remove_hook_processor(hook_type: str, processor_index: int, app_instan config_entry = { "type": processor["type"], "params": processor["params"], - "enabled": processor["enabled"] + "enabled": processor["enabled"], } hook_config.append(config_entry) update_kwargs = {f"{hook_type}_config": hook_config} @@ -193,14 +210,15 @@ async def remove_hook_processor(hook_type: str, processor_index: int, app_instan logging.exception(f"remove_hook_processor: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - + logging.info(f"remove_hook_processor: Successfully removed processor {processor_index} from {hook_type}") - + return create_success_response(f"Removed processor {processor_index} from {hook_type}") - + except Exception as e: raise handle_api_error(e, "remove_hook_processor") + @router.post("/pipeline-hooks/{hook_type}/toggle") async def toggle_hook_processor(hook_type: str, request: Request, app_instance=Depends(get_app_instance)): """Toggle a processor enabled/disabled""" @@ -208,17 +226,19 @@ async def toggle_hook_processor(hook_type: str, request: Request, app_instance=D data = await request.json() processor_index = data.get("processor_index") enabled = data.get("enabled") - + if processor_index is None or enabled is None: raise HTTPException(status_code=400, detail="Missing processor_index or enabled parameter") - + # No pipeline validation needed - AppState updates work before pipeline creation - - logging.debug(f"toggle_hook_processor: Toggling processor {processor_index} in {hook_type} to {'enabled' if enabled else 'disabled'}") - + + logging.debug( + f"toggle_hook_processor: Toggling processor {processor_index} in {hook_type} to {'enabled' if enabled else 'disabled'}" + ) + # Update AppState - SINGLE SOURCE OF TRUTH app_instance.app_state.update_hook_processor(hook_type, processor_index, {"enabled": bool(enabled)}) - + # Update pipeline if active if app_instance.pipeline: try: @@ -227,7 +247,7 @@ async def toggle_hook_processor(hook_type: str, request: Request, app_instance=D config_entry = { "type": processor["type"], "params": processor["params"], - "enabled": processor["enabled"] + "enabled": processor["enabled"], } hook_config.append(config_entry) update_kwargs = {f"{hook_type}_config": hook_config} @@ -236,14 +256,17 @@ async def toggle_hook_processor(hook_type: str, request: Request, app_instance=D logging.exception(f"toggle_hook_processor: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - + logging.info(f"toggle_hook_processor: Successfully toggled processor {processor_index} in {hook_type}") - - return create_success_response(f"Processor {processor_index} in {hook_type} {'enabled' if enabled else 'disabled'}") - + + return create_success_response( + f"Processor {processor_index} in {hook_type} {'enabled' if enabled else 'disabled'}" + ) + except Exception as e: raise handle_api_error(e, "toggle_hook_processor") + @router.post("/pipeline-hooks/{hook_type}/switch") async def switch_hook_processor(hook_type: str, request: Request, app_instance=Depends(get_app_instance)): """Switch a processor to a different type""" @@ -252,45 +275,53 @@ async def switch_hook_processor(hook_type: str, request: Request, app_instance=D processor_index = data.get("processor_index") # Support both parameter naming conventions for compatibility new_processor_type = data.get("processor_type") or data.get("processor") - + if processor_index is None or not new_processor_type: - raise HTTPException(status_code=400, detail="Missing processor_index or processor_type/processor parameter") - + raise HTTPException( + status_code=400, detail="Missing processor_index or processor_type/processor parameter" + ) + # Handle config-only mode when no pipeline is active if not app_instance.pipeline: if not app_instance.app_state.uploaded_config: raise HTTPException(status_code=400, detail="No pipeline active and no uploaded config available") - - logging.info(f"switch_hook_processor: Updating config for {hook_type} processor {processor_index} to {new_processor_type}") - + + logging.info( + f"switch_hook_processor: Updating config for {hook_type} processor {processor_index} to {new_processor_type}" + ) + # Update the uploaded config directly hook_config = app_instance.app_state.uploaded_config.get(hook_type, {"enabled": False, "processors": []}) if processor_index >= len(hook_config.get("processors", [])): - raise HTTPException(status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}") - + raise HTTPException( + status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}" + ) + # Update processor type in config hook_config["processors"][processor_index]["type"] = new_processor_type hook_config["processors"][processor_index]["params"] = {} app_instance.app_state.uploaded_config[hook_type] = hook_config - + else: # No pipeline validation needed - AppState updates work before pipeline creation - - logging.debug(f"switch_hook_processor: Switching processor {processor_index} in {hook_type} to {new_processor_type}") - + + logging.debug( + f"switch_hook_processor: Switching processor {processor_index} in {hook_type} to {new_processor_type}" + ) + # Update AppState - SINGLE SOURCE OF TRUTH processors = app_instance.app_state.pipeline_hooks[hook_type]["processors"] - + if processor_index >= len(processors): - raise HTTPException(status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}") - + raise HTTPException( + status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}" + ) + # Update the processor type and reset params in AppState - app_instance.app_state.update_hook_processor(hook_type, processor_index, { - "type": new_processor_type, - "name": new_processor_type, - "params": {} - }) - + app_instance.app_state.update_hook_processor( + hook_type, processor_index, {"type": new_processor_type, "name": new_processor_type, "params": {}} + ) + # Update pipeline if active if app_instance.pipeline: try: @@ -299,7 +330,7 @@ async def switch_hook_processor(hook_type: str, request: Request, app_instance=D config_entry = { "type": processor["type"], "params": processor["params"], - "enabled": processor["enabled"] + "enabled": processor["enabled"], } hook_config.append(config_entry) update_kwargs = {f"{hook_type}_config": hook_config} @@ -308,14 +339,17 @@ async def switch_hook_processor(hook_type: str, request: Request, app_instance=D logging.exception(f"switch_hook_processor: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - - logging.info(f"switch_hook_processor: Successfully switched processor {processor_index} in {hook_type} to {new_processor_type}") - + + logging.info( + f"switch_hook_processor: Successfully switched processor {processor_index} in {hook_type} to {new_processor_type}" + ) + return create_success_response(f"Switched processor {processor_index} in {hook_type} to {new_processor_type}") - + except Exception as e: raise handle_api_error(e, "switch_hook_processor") + @router.post("/pipeline-hooks/{hook_type}/update-params") async def update_hook_processor_params(hook_type: str, request: Request, app_instance=Depends(get_app_instance)): """Update parameters for a specific processor""" @@ -323,48 +357,58 @@ async def update_hook_processor_params(hook_type: str, request: Request, app_ins logging.info(f"update_hook_processor_params: ===== STARTING {hook_type} REQUEST =====") data = await request.json() logging.info(f"update_hook_processor_params: Received data: {data}") - + processor_index = data.get("processor_index") processor_params = data.get("processor_params", {}) - logging.info(f"update_hook_processor_params: processor_index={processor_index}, processor_params={processor_params}") - + logging.info( + f"update_hook_processor_params: processor_index={processor_index}, processor_params={processor_params}" + ) + if processor_index is None: - logging.error(f"update_hook_processor_params: Missing processor_index parameter") + logging.error("update_hook_processor_params: Missing processor_index parameter") raise HTTPException(status_code=400, detail="Missing processor_index parameter") - + # No pipeline validation needed - AppState updates work before pipeline creation - + logging.debug(f"update_hook_processor_params: Updating params for processor {processor_index} in {hook_type}") - + # Check if processors exist in AppState processors = app_instance.app_state.pipeline_hooks[hook_type]["processors"] if not processors: logging.error(f"update_hook_processor_params: Hook type {hook_type} not found or empty") - raise HTTPException(status_code=400, detail=f"No processors configured for {hook_type}. Add a processor first using the 'Add {hook_type.replace('_', ' ').title()} Processor' button.") - + raise HTTPException( + status_code=400, + detail=f"No processors configured for {hook_type}. Add a processor first using the 'Add {hook_type.replace('_', ' ').title()} Processor' button.", + ) + if processor_index >= len(processors): - logging.error(f"update_hook_processor_params: Processor index {processor_index} out of range for {hook_type} (max: {len(processors)-1})") - raise HTTPException(status_code=400, detail=f"Processor index {processor_index} not found. Only {len(processors)} processors are configured for {hook_type}.") - + logging.error( + f"update_hook_processor_params: Processor index {processor_index} out of range for {hook_type} (max: {len(processors) - 1})" + ) + raise HTTPException( + status_code=400, + detail=f"Processor index {processor_index} not found. Only {len(processors)} processors are configured for {hook_type}.", + ) + # Update the processor parameters in AppState - SINGLE SOURCE OF TRUTH logging.info(f"update_hook_processor_params: Current processor config: {processors[processor_index]}") - + # Handle 'enabled' field separately as it's a top-level processor field, not a parameter updates = {} - if 'enabled' in processor_params: - enabled_value = processor_params.pop('enabled') # Remove from params dict - updates['enabled'] = bool(enabled_value) + if "enabled" in processor_params: + enabled_value = processor_params.pop("enabled") # Remove from params dict + updates["enabled"] = bool(enabled_value) logging.info(f"update_hook_processor_params: Updated enabled field to: {enabled_value}") - + # Update remaining parameters in the params field if processor_params: # Only update if there are remaining params - current_params = processors[processor_index].get('params', {}) + current_params = processors[processor_index].get("params", {}) current_params.update(processor_params) - updates['params'] = current_params - + updates["params"] = current_params + # Apply updates to AppState app_instance.app_state.update_hook_processor(hook_type, processor_index, updates) - + # Update pipeline if active if app_instance.pipeline: try: @@ -373,29 +417,36 @@ async def update_hook_processor_params(hook_type: str, request: Request, app_ins config_entry = { "type": processor["type"], "params": processor["params"], - "enabled": processor["enabled"] + "enabled": processor["enabled"], } hook_config.append(config_entry) update_kwargs = {f"{hook_type}_config": hook_config} logging.info(f"update_hook_processor_params: Calling update_stream_params with: {update_kwargs}") app_instance.pipeline.update_stream_params(**update_kwargs) - logging.info(f"update_hook_processor_params: update_stream_params completed successfully") + logging.info("update_hook_processor_params: update_stream_params completed successfully") except Exception as e: logging.exception(f"update_hook_processor_params: Failed to update pipeline: {e}") # Mark for reload as fallback app_instance.app_state.config_needs_reload = True - - logging.info(f"update_hook_processor_params: Successfully updated params for processor {processor_index} in {hook_type}") - - return create_success_response(f"Updated parameters for processor {processor_index} in {hook_type}", updated_params=processor_params) - + + logging.info( + f"update_hook_processor_params: Successfully updated params for processor {processor_index} in {hook_type}" + ) + + return create_success_response( + f"Updated parameters for processor {processor_index} in {hook_type}", updated_params=processor_params + ) + except Exception as e: logging.exception(f"update_hook_processor_params: Exception occurred: {str(e)}") logging.error(f"update_hook_processor_params: Exception type: {type(e).__name__}") raise handle_api_error(e, "update_hook_processor_params") + @router.get("/pipeline-hooks/{hook_type}/current-params/{processor_index}") -async def get_current_hook_processor_params(hook_type: str, processor_index: int, app_instance=Depends(get_app_instance)): +async def get_current_hook_processor_params( + hook_type: str, processor_index: int, app_instance=Depends(get_app_instance) +): """Get current parameters for a specific processor""" try: # First try to get from uploaded config if no pipeline @@ -404,44 +455,50 @@ async def get_current_hook_processor_params(hook_type: str, processor_index: int processors = hook_config.get("processors", []) if processor_index < len(processors): processor = processors[processor_index] - return JSONResponse({ + return JSONResponse( + { + "status": "success", + "hook_type": hook_type, + "processor_index": processor_index, + "processor_type": processor.get("type", "unknown"), + "parameters": processor.get("params", {}), + "enabled": processor.get("enabled", True), + "note": "From uploaded config", + } + ) + + # Return empty if no config available + if not app_instance.pipeline: + return JSONResponse( + { "status": "success", "hook_type": hook_type, "processor_index": processor_index, - "processor_type": processor.get('type', 'unknown'), - "parameters": processor.get('params', {}), - "enabled": processor.get('enabled', True), - "note": "From uploaded config" - }) - - # Return empty if no config available - if not app_instance.pipeline: - return JSONResponse({ - "status": "success", - "hook_type": hook_type, - "processor_index": processor_index, - "processor_type": "unknown", - "parameters": {}, - "enabled": False, - "note": "Pipeline not initialized - no config available" - }) - + "processor_type": "unknown", + "parameters": {}, + "enabled": False, + "note": "Pipeline not initialized - no config available", + } + ) + # Use AppState - SINGLE SOURCE OF TRUTH processors = app_instance.app_state.pipeline_hooks[hook_type]["processors"] - + if processor_index >= len(processors): raise HTTPException(status_code=400, detail=f"Invalid processor index {processor_index} for {hook_type}") - + processor = processors[processor_index] - - return JSONResponse({ - "status": "success", - "hook_type": hook_type, - "processor_index": processor_index, - "processor_type": processor.get('type', 'unknown'), - "parameters": processor.get('params', {}), - "enabled": processor.get('enabled', True) - }) - + + return JSONResponse( + { + "status": "success", + "hook_type": hook_type, + "processor_index": processor_index, + "processor_type": processor.get("type", "unknown"), + "parameters": processor.get("params", {}), + "enabled": processor.get("enabled", True), + } + ) + except Exception as e: - raise handle_api_error(e, "get_current_hook_processor_params") \ No newline at end of file + raise handle_api_error(e, "get_current_hook_processor_params") diff --git a/demo/realtime-img2img/routes/websocket.py b/demo/realtime-img2img/routes/websocket.py index 48f37a2d7..2a2c87a0a 100644 --- a/demo/realtime-img2img/routes/websocket.py +++ b/demo/realtime-img2img/routes/websocket.py @@ -1,33 +1,40 @@ """ WebSocket endpoints for realtime-img2img """ -from fastapi import APIRouter, WebSocket, HTTPException, Depends + import logging -import uuid import time +import uuid from types import SimpleNamespace -from util import bytes_to_pt from connection_manager import ServerFullException -from .common.dependencies import get_app_instance, get_pipeline_class +from fastapi import APIRouter, Depends, HTTPException, WebSocket from input_sources import InputSourceManager +from util import bytes_to_pt + +from .common.dependencies import get_app_instance, get_pipeline_class + router = APIRouter(prefix="/api", tags=["websocket"]) def _get_input_source_manager(app_instance) -> InputSourceManager: """Get or create the input source manager for the app instance.""" - if not hasattr(app_instance, 'input_source_manager'): + if not hasattr(app_instance, "input_source_manager"): app_instance.input_source_manager = InputSourceManager() return app_instance.input_source_manager + @router.websocket("/ws/{user_id}") -async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket, app_instance=Depends(get_app_instance), pipeline_class=Depends(get_pipeline_class)): +async def websocket_endpoint( + user_id: uuid.UUID, + websocket: WebSocket, + app_instance=Depends(get_app_instance), + pipeline_class=Depends(get_pipeline_class), +): """Main WebSocket endpoint for real-time communication""" try: - await app_instance.conn_manager.connect( - user_id, websocket, app_instance.args.max_queue_size - ) + await app_instance.conn_manager.connect(user_id, websocket, app_instance.args.max_queue_size) await handle_websocket_data(user_id, app_instance, pipeline_class) except ServerFullException as e: logging.exception(f"websocket_endpoint: Server Full: {e}") @@ -35,6 +42,7 @@ async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket, app_insta await app_instance.conn_manager.disconnect(user_id) logging.info(f"websocket_endpoint: User disconnected: {user_id}") + async def handle_websocket_data(user_id: uuid.UUID, app_instance, pipeline_class): """Handle WebSocket data flow for a specific user""" if not app_instance.conn_manager.check_user(user_id): @@ -42,10 +50,7 @@ async def handle_websocket_data(user_id: uuid.UUID, app_instance, pipeline_class last_time = time.time() try: while True: - if ( - app_instance.args.timeout > 0 - and time.time() - last_time > app_instance.args.timeout - ): + if app_instance.args.timeout > 0 and time.time() - last_time > app_instance.args.timeout: await app_instance.conn_manager.send_json( user_id, { @@ -62,45 +67,45 @@ async def handle_websocket_data(user_id: uuid.UUID, app_instance, pipeline_class params = await app_instance.conn_manager.receive_json(user_id) params = pipeline_class.InputParams(**params) params = SimpleNamespace(**params.dict()) - + # Check if we need image data based on pipeline need_image = True - if app_instance.pipeline and hasattr(app_instance.pipeline, 'pipeline_mode'): + if app_instance.pipeline and hasattr(app_instance.pipeline, "pipeline_mode"): # Need image for img2img OR for txt2img with ControlNets - has_controlnets = app_instance.pipeline.use_config and app_instance.pipeline.config and 'controlnets' in app_instance.pipeline.config + has_controlnets = ( + app_instance.pipeline.use_config + and app_instance.pipeline.config + and "controlnets" in app_instance.pipeline.config + ) need_image = app_instance.pipeline.pipeline_mode == "img2img" or has_controlnets - elif app_instance.app_state.uploaded_config and 'mode' in app_instance.app_state.uploaded_config: + elif app_instance.app_state.uploaded_config and "mode" in app_instance.app_state.uploaded_config: # Need image for img2img OR for txt2img with ControlNets - has_controlnets = 'controlnets' in app_instance.app_state.uploaded_config - need_image = app_instance.app_state.uploaded_config['mode'] == "img2img" or has_controlnets - + has_controlnets = "controlnets" in app_instance.app_state.uploaded_config + need_image = app_instance.app_state.uploaded_config["mode"] == "img2img" or has_controlnets + # Get input source manager input_manager = _get_input_source_manager(app_instance) - + if need_image: # Receive main webcam stream (fallback) image_data = await app_instance.conn_manager.receive_bytes(user_id) if len(image_data) == 0: - await app_instance.conn_manager.send_json( - user_id, {"status": "send_frame"} - ) + await app_instance.conn_manager.send_json(user_id, {"status": "send_frame"}) continue - + # Update webcam frame in input manager for all webcam sources input_manager.update_webcam_frame(image_data) - + # Always use direct bytes-to-tensor conversion for efficiency params.image = bytes_to_pt(image_data) else: params.image = None - + # Store the input manager reference in params for later use by img2img.py params.input_manager = input_manager - + await app_instance.conn_manager.update_data(user_id, params) except Exception as e: logging.exception(f"handle_websocket_data: Websocket Error: {e}, {user_id} ") await app_instance.conn_manager.disconnect(user_id) - - diff --git a/demo/realtime-img2img/util.py b/demo/realtime-img2img/util.py index 0ef21421b..b311c4faa 100644 --- a/demo/realtime-img2img/util.py +++ b/demo/realtime-img2img/util.py @@ -1,11 +1,10 @@ +import io from importlib import import_module from types import ModuleType -from typing import Dict, Any -from pydantic import BaseModel as PydanticBaseModel, Field -from PIL import Image -import io + import torch -from torchvision.io import encode_jpeg, decode_jpeg +from PIL import Image +from torchvision.io import decode_jpeg, encode_jpeg def get_pipeline_class(pipeline_name: str) -> ModuleType: @@ -30,22 +29,22 @@ def bytes_to_pil(image_bytes: bytes) -> Image.Image: def bytes_to_pt(image_bytes: bytes) -> torch.Tensor: """ Convert JPEG/PNG bytes directly to PyTorch tensor using torchvision - + Args: image_bytes: Raw image bytes (JPEG/PNG format) - + Returns: torch.Tensor: Image tensor with shape (C, H, W), values in [0, 1], dtype float32 """ # Convert bytes to tensor for torchvision byte_tensor = torch.frombuffer(image_bytes, dtype=torch.uint8) - + # Decode JPEG/PNG directly to tensor (C, H, W) format, uint8 [0, 255] image_tensor = decode_jpeg(byte_tensor) - + # Convert to float32 and normalize to [0, 1] image_tensor = image_tensor.float() / 255.0 - + return image_tensor @@ -65,24 +64,24 @@ def pil_to_frame(image: Image.Image) -> bytes: def pt_to_frame(tensor: torch.Tensor) -> bytes: """ Convert PyTorch tensor directly to JPEG frame bytes using torchvision - + Args: tensor: PyTorch tensor with shape (C, H, W) or (1, C, H, W), values in [0, 1] - + Returns: bytes: JPEG frame data for streaming """ # Handle batch dimension - take first image if batched if tensor.dim() == 4: tensor = tensor[0] - + # Convert to uint8 format (0-255) and ensure correct shape (C, H, W) tensor_uint8 = (tensor * 255).clamp(0, 255).to(torch.uint8) - + # Encode directly to JPEG bytes using torchvision jpeg_bytes = encode_jpeg(tensor_uint8, quality=90) frame_data = jpeg_bytes.cpu().numpy().tobytes() - + return ( b"--frame\r\n" + b"Content-Type: image/jpeg\r\n" diff --git a/demo/realtime-img2img/utils/video_utils.py b/demo/realtime-img2img/utils/video_utils.py index ad092415f..4ba9d4924 100644 --- a/demo/realtime-img2img/utils/video_utils.py +++ b/demo/realtime-img2img/utils/video_utils.py @@ -6,25 +6,26 @@ """ import logging +from pathlib import Path +from typing import Optional, Tuple + import cv2 import numpy as np import torch -from pathlib import Path -from typing import Optional, Tuple class VideoFrameExtractor: """ Extracts frames from video files for use as input sources. - + Handles video playback, looping, and frame extraction with automatic conversion to PyTorch tensors. """ - + def __init__(self, video_path: str): """ Initialize the video frame extractor. - + Args: video_path: Path to the video file """ @@ -34,34 +35,35 @@ def __init__(self, video_path: str): self.frame_count = 0 self.current_frame_idx = 0 self._logger = logging.getLogger(f"VideoFrameExtractor.{self.video_path.name}") - + self._initialize_capture() - + def _initialize_capture(self): """Initialize the video capture object.""" if not self.video_path.exists(): self._logger.error(f"Video file not found: {self.video_path}") return - + self.cap = cv2.VideoCapture(str(self.video_path)) - + if not self.cap.isOpened(): self._logger.error(f"Failed to open video file: {self.video_path}") return - + # Get video properties self.fps = self.cap.get(cv2.CAP_PROP_FPS) self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - self._logger.info(f"Initialized video: {self.video_path.name}, " - f"FPS: {self.fps:.2f}, Frames: {self.frame_count}") - + + self._logger.info( + f"Initialized video: {self.video_path.name}, FPS: {self.fps:.2f}, Frames: {self.frame_count}" + ) + def get_frame(self) -> Optional[torch.Tensor]: """ Extract the current frame and advance to the next frame. - + Automatically loops back to the beginning when reaching the end. - + Returns: torch.Tensor: Frame as tensor with shape (C, H, W), values in [0, 1], dtype float32 None: If frame extraction fails @@ -69,101 +71,101 @@ def get_frame(self) -> Optional[torch.Tensor]: if not self.cap or not self.cap.isOpened(): self._logger.error("Video capture not initialized or closed") return None - + ret, frame = self.cap.read() - + if not ret: # End of video, loop back to beginning self._logger.debug("End of video reached, looping back to start") self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) self.current_frame_idx = 0 ret, frame = self.cap.read() - + if not ret: self._logger.error("Failed to read frame even after reset") return None - + self.current_frame_idx += 1 - + # Convert frame to tensor return self._frame_to_tensor(frame) - + def get_frame_at_time(self, timestamp: float) -> Optional[torch.Tensor]: """ Get frame at a specific timestamp. - + Args: timestamp: Time in seconds - + Returns: torch.Tensor: Frame at the specified time or None if failed """ if not self.cap or not self.cap.isOpened(): return None - + # Convert timestamp to frame number frame_number = int(timestamp * self.fps) frame_number = max(0, min(frame_number, self.frame_count - 1)) - + # Seek to frame self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) self.current_frame_idx = frame_number - + ret, frame = self.cap.read() if ret: return self._frame_to_tensor(frame) - + return None - + def _frame_to_tensor(self, frame: np.ndarray) -> torch.Tensor: """ Convert OpenCV frame to PyTorch tensor. - + Args: frame: OpenCV frame in BGR format - + Returns: torch.Tensor: Frame tensor in RGB format with shape (C, H, W) """ # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - + # Convert to tensor and normalize frame_tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0 - + return frame_tensor - + def get_video_info(self) -> dict: """ Get information about the video. - + Returns: Dictionary with video metadata """ if not self.cap: return {} - + width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) duration = self.frame_count / self.fps if self.fps > 0 else 0 - + return { - 'path': str(self.video_path), - 'fps': self.fps, - 'frame_count': self.frame_count, - 'width': width, - 'height': height, - 'duration': duration, - 'current_frame': self.current_frame_idx + "path": str(self.video_path), + "fps": self.fps, + "frame_count": self.frame_count, + "width": width, + "height": height, + "duration": duration, + "current_frame": self.current_frame_idx, } - + def reset(self): """Reset video to beginning.""" if self.cap: self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) self.current_frame_idx = 0 self._logger.debug("Video reset to beginning") - + def cleanup(self): """Release video capture resources.""" if self.cap: @@ -175,25 +177,25 @@ def cleanup(self): def get_video_thumbnail(video_path: str, timestamp: float = 0.0) -> Optional[torch.Tensor]: """ Get a thumbnail frame from a video file. - + Args: video_path: Path to the video file timestamp: Time in seconds to extract thumbnail from - + Returns: torch.Tensor: Thumbnail frame or None if failed """ try: extractor = VideoFrameExtractor(video_path) - + if timestamp > 0: thumbnail = extractor.get_frame_at_time(timestamp) else: thumbnail = extractor.get_frame() - + extractor.cleanup() return thumbnail - + except Exception as e: logging.getLogger("video_utils").error(f"Failed to get thumbnail: {e}") return None @@ -202,52 +204,63 @@ def get_video_thumbnail(video_path: str, timestamp: float = 0.0) -> Optional[tor def validate_video_file(video_path: str) -> Tuple[bool, str]: """ Validate if a file is a readable video. - + Args: video_path: Path to the video file - + Returns: Tuple of (is_valid, error_message) """ try: path = Path(video_path) - + if not path.exists(): return False, "Video file does not exist" - + # Try to open with OpenCV cap = cv2.VideoCapture(str(path)) - + if not cap.isOpened(): return False, "Cannot open video file" - + # Try to read first frame ret, frame = cap.read() cap.release() - + if not ret: return False, "Cannot read frames from video" - + return True, "Video file is valid" - + except Exception as e: return False, f"Video validation error: {str(e)}" # Supported video formats SUPPORTED_VIDEO_FORMATS = { - '.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.wmv', - '.m4v', '.3gp', '.ogv', '.ts', '.m2ts', '.mts' + ".mp4", + ".avi", + ".mov", + ".mkv", + ".webm", + ".flv", + ".wmv", + ".m4v", + ".3gp", + ".ogv", + ".ts", + ".m2ts", + ".mts", } def is_supported_video_format(filename: str) -> bool: """ Check if a file has a supported video format. - + Args: filename: Name or path of the file - + Returns: bool: True if format is supported """ diff --git a/demo/realtime-txt2img/config.py b/demo/realtime-txt2img/config.py index 354941485..f7bb12403 100644 --- a/demo/realtime-txt2img/config.py +++ b/demo/realtime-txt2img/config.py @@ -1,8 +1,9 @@ +import os from dataclasses import dataclass, field from typing import List, Literal import torch -import os + SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "False") == "True" diff --git a/demo/vid2vid/app.py b/demo/vid2vid/app.py index d03da15e2..c7f4a37ff 100644 --- a/demo/vid2vid/app.py +++ b/demo/vid2vid/app.py @@ -1,18 +1,18 @@ -import gradio as gr - import os import sys -from typing import Literal, Dict, Optional +from typing import Dict, Literal, Optional -import fire +import gradio as gr import torch from torchvision.io import read_video, write_video from tqdm import tqdm + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from streamdiffusion import StreamDiffusionWrapper + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -28,7 +28,6 @@ def main( enable_similar_image_filter: bool = True, seed: int = 2, ): - """ Process for generating images based on a prompt using a specified model. @@ -106,9 +105,5 @@ def main( return output -demo = gr.Interface( - main, - gr.Video(sources=['upload', 'webcam']), - "playable_video" -) +demo = gr.Interface(main, gr.Video(sources=["upload", "webcam"]), "playable_video") demo.launch() diff --git a/docs/pr_audit_2026-05-16.md b/docs/pr_audit_2026-05-16.md new file mode 100644 index 000000000..4fcde6c3c --- /dev/null +++ b/docs/pr_audit_2026-05-16.md @@ -0,0 +1,44 @@ +# StreamDiffusion PR Audit — 2026-05-16 + +Cross-check of local dev branch vs dotsimulate/StreamDiffusion upstream to determine which open PRs are effectively landed. + +## PR status + +| PR | Title | GitHub state | Actual landed state | +|---|---|---|---| +| **#4** | perf: Inference performance & pipeline correctness | OPEN | **Effectively landed** — cherry-picked as `ef50c0d` into `dotsimulate/feat/trt10.16-fp8-perf`, then merged via PR #12 | +| **#5** | feat: IP-Adapter auto-res, VRAM offload & dep updates | OPEN | **Effectively landed** — cherry-pick `18cf3f9`, then PR #12 | +| **#6** | feat: FP8 quantization & TensorRT build infra | OPEN | **Effectively landed** — cherry-pick `85ca135` (rewritten as ONNX-level quantization), then PR #12 | +| **#7** | perf: Tier 1 hot-path elimination & TE stutter fix | OPEN | **Effectively landed** — cherry-pick `c0819aa`, then PR #12 | +| **#8** | perf: TRT engine builder — static shapes, CUDA graphs | OPEN | **Effectively landed** — cherry-pick `b4be27d`, then PR #12 | +| **#9** | chore(installer): overhaul (superseded by #10) | CLOSED | n/a | +| **#10** | chore(installer): overhaul install scripts | OPEN | **Effectively landed** — cherry-pick `ea4d3f1`, then PR #12 | +| **#11** | fix(trt): missing build_all_tactics param | MERGED | Landed cleanly — `4d0cb2b` on SDTD_031_dev, 2026-04-22 | +| **#12** | feat: TRT 10.16.1.11 + FP8 + CUDA hot-path | MERGED | Landed with conflict resolution — `70b0523` (2026-05-10) + cleanup `13b6651` (2026-05-13) | + +## Key refs + +- Squashed PR #12 source: `9610624` on `origin/dotsimulate/feat/trt10.16-fp8-perf` +- PR #12 merge into SDTD_031_dev: `70b0523` (2026-05-10) +- Formatter-noise cleanup: `13b6651` (2026-05-13) +- PR #11 merge: `4d0cb2b` (2026-04-22) +- Cherry-pick anchors (Apr 9): `ef50c0d`, `18cf3f9`, `85ca135`, `c0819aa`, `b4be27d`, `ea4d3f1` + +## Dev repo state + +- Checked-out branch: `dotsimulate/feat/trt10.16-fp8-perf` at `9610624` +- Local `forkni/feat-fp8-torch-calibration-v2`: **35 commits ahead** of origin — unsquashed Phase 1 history; content likely fully covered by the squashed `9610624` (squash was created 8 min after last commit) +- `origin/SDTD_031_dev` is stale at `4d0cb2b` — run `git fetch origin` to get `13b6651` + +## Recommended next steps + +1. `git fetch origin` in dev repo to update remote-tracking refs +2. Close PRs #4, #5, #6, #7, #8, #10 on GitHub — comment referencing `70b0523` as superseding merge +3. Before closing #6, confirm ONNX-level FP8 path (PR #12) is intended replacement for torch-mode approach in pr2 branch +4. Verify local `forkni/feat-fp8-torch-calibration-v2` 35 ahead commits are fully covered by squash: `git diff forkni/feat-fp8-torch-calibration-v2 origin/dotsimulate/feat/trt10.16-fp8-perf --stat` +5. FP8 Phase 2: open new branch off SDTD_031_dev at `13b6651` +6. StreamDiffusion-installer: commit `8c8020a` in installer repo — Phase 1 blocker resolved, safe to revisit push (needs maintainer access, 403 on direct push) + +## Process lesson + +PR #12 merge required reverting 27 formatter-noise files and re-adding 4 real params afterward. Future PRs: run a pre-PR pass with `git diff --stat` to catch quote flips/whitespace/import reorder before submitting. diff --git a/examples/benchmark/multi.py b/examples/benchmark/multi.py index bfa971cf2..443835b99 100644 --- a/examples/benchmark/multi.py +++ b/examples/benchmark/multi.py @@ -3,7 +3,7 @@ import sys import time from multiprocessing import Process, Queue -from typing import List, Literal, Optional, Dict +from typing import Dict, List, Literal, Optional import fire import PIL.Image @@ -13,6 +13,7 @@ from streamdiffusion.image_utils import postprocess_image + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from streamdiffusion import StreamDiffusionWrapper diff --git a/examples/benchmark/single.py b/examples/benchmark/single.py index 5e55fb633..133abc13b 100644 --- a/examples/benchmark/single.py +++ b/examples/benchmark/single.py @@ -1,7 +1,7 @@ import io import os import sys -from typing import List, Literal, Optional, Dict +from typing import Dict, List, Literal, Optional import fire import PIL.Image @@ -101,9 +101,7 @@ def run( delta=0.5, ) - downloaded_image = download_image("https://github.com/ddpn08.png").resize( - (width, height) - ) + downloaded_image = download_image("https://github.com/ddpn08.png").resize((width, height)) # warmup for _ in range(warmup): diff --git a/examples/config/config_ipadapter_stream_test.py b/examples/config/config_ipadapter_stream_test.py index a66c01e19..b8141e033 100644 --- a/examples/config/config_ipadapter_stream_test.py +++ b/examples/config/config_ipadapter_stream_test.py @@ -13,205 +13,217 @@ - Tests the IPAdapter stream behavior fix """ -import cv2 -import torch -import numpy as np -from PIL import Image import argparse -from pathlib import Path -import sys -import time +import json import os import shutil -import json -from collections import deque +import sys +import time +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image def tensor_to_opencv(tensor: torch.Tensor, target_width: int, target_height: int) -> np.ndarray: """ Convert a PyTorch tensor (output_type='pt') to OpenCV BGR format for video writing. Uses efficient tensor operations similar to the realtime-img2img demo. - + Args: tensor: Tensor in range [0,1] with shape [B, C, H, W] or [C, H, W] target_width: Target width for output target_height: Target height for output - + Returns: BGR numpy array ready for OpenCV """ # Handle batch dimension - take first image if batched if tensor.dim() == 4: tensor = tensor[0] - + # Convert to uint8 format (0-255) and ensure correct shape (C, H, W) tensor_uint8 = (tensor * 255).clamp(0, 255).to(torch.uint8) - + # Convert from [C, H, W] to [H, W, C] format if tensor_uint8.dim() == 3: image_np = tensor_uint8.permute(1, 2, 0).cpu().numpy() else: raise ValueError(f"tensor_to_opencv: Unexpected tensor shape: {tensor_uint8.shape}") - + # Convert RGB to BGR for OpenCV image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) - + # Resize if needed if image_bgr.shape[:2] != (target_height, target_width): image_bgr = cv2.resize(image_bgr, (target_width, target_height)) - + return image_bgr def process_video_ipadapter_stream(config_path, input_video, static_image, output_dir, engine_only=False): """Process video using IPAdapter as primary driving force with static base image""" print(f"process_video_ipadapter_stream: Loading config from {config_path}") - + # Import here to avoid loading at module level sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) - from streamdiffusion import load_config, create_wrapper_from_config - + from streamdiffusion import create_wrapper_from_config, load_config + # Load configuration config = load_config(config_path) - + # Force tensor output for better performance - config['output_type'] = 'pt' - + config["output_type"] = "pt" + # Get width and height from config (with defaults) - width = config.get('width', 512) - height = config.get('height', 512) - + width = config.get("width", 512) + height = config.get("height", 512) + print(f"process_video_ipadapter_stream: Using dimensions: {width}x{height}") - print(f"process_video_ipadapter_stream: Using output_type='pt' for better performance") - + print("process_video_ipadapter_stream: Using output_type='pt' for better performance") + # Create output directory output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - + # Copy config, input video, and static image to output directory config_copy_path = output_dir / f"config_{Path(config_path).name}" shutil.copy2(config_path, config_copy_path) print(f"process_video_ipadapter_stream: Copied config to {config_copy_path}") - + input_copy_path = output_dir / f"input_{Path(input_video).name}" shutil.copy2(input_video, input_copy_path) print(f"process_video_ipadapter_stream: Copied input video to {input_copy_path}") - + static_copy_path = output_dir / f"static_{Path(static_image).name}" shutil.copy2(static_image, static_copy_path) print(f"process_video_ipadapter_stream: Copied static image to {static_copy_path}") - + # Create wrapper using the built-in function wrapper = create_wrapper_from_config(config) - + if engine_only: print("Engine-only mode: TensorRT engines have been built (if needed). Exiting.") return None - + # Load and prepare static image static_img = Image.open(static_image) static_img = static_img.resize((width, height), Image.Resampling.LANCZOS) print(f"process_video_ipadapter_stream: Loaded static image: {static_image}") - + # Open input video cap = cv2.VideoCapture(str(input_video)) if not cap.isOpened(): raise ValueError(f"Could not open input video: {input_video}") - + # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + print(f"process_video_ipadapter_stream: Input video - {frame_count} frames at {fps} FPS") - + # Setup output video writer (3-panel display: input, static, generated) output_video_path = output_dir / "output_video.mp4" - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width * 3, height)) - + # Performance tracking frame_times = [] total_start_time = time.time() - + print("process_video_ipadapter_stream: Starting IPAdapter stream processing...") print("process_video_ipadapter_stream: Using static image as base input, video frames for ControlNet + IPAdapter") - + frame_idx = 0 while True: ret, frame = cap.read() if not ret: break - + frame_start_time = time.time() - + # Resize frame frame_resized = cv2.resize(frame, (width, height)) - + # Convert frame to PIL frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame_rgb) - + # Update ControlNet control images (structural guidance from video frames) - if hasattr(wrapper.stream, '_controlnet_module') and wrapper.stream._controlnet_module: + if hasattr(wrapper.stream, "_controlnet_module") and wrapper.stream._controlnet_module: controlnet_count = len(wrapper.stream._controlnet_module.controlnets) - print(f"process_video_ipadapter_stream: Updating control image for {controlnet_count} ControlNet(s) on frame {frame_idx}") + print( + f"process_video_ipadapter_stream: Updating control image for {controlnet_count} ControlNet(s) on frame {frame_idx}" + ) for i in range(controlnet_count): wrapper.update_control_image(i, frame_pil) else: print(f"process_video_ipadapter_stream: No ControlNet module found for frame {frame_idx}") - + # Update IPAdapter style image (style/content guidance from video frames) # This is the key part - using video frames as IPAdapter style images with is_stream=True - if hasattr(wrapper.stream, '_ipadapter_module') and wrapper.stream._ipadapter_module: - print(f"process_video_ipadapter_stream: Updating IPAdapter style image on frame {frame_idx} (is_stream=True)") + if hasattr(wrapper.stream, "_ipadapter_module") and wrapper.stream._ipadapter_module: + print( + f"process_video_ipadapter_stream: Updating IPAdapter style image on frame {frame_idx} (is_stream=True)" + ) # Update style image with is_stream=True for pipelined processing wrapper.update_style_image(frame_pil, is_stream=True) else: print(f"process_video_ipadapter_stream: No IPAdapter module found for frame {frame_idx}") - + # Process with static image as base input (this is the key difference) # The static image provides the base structure, while ControlNet and IPAdapter # provide the dynamic guidance from the video frames output_tensor = wrapper(static_img) - + # Convert tensor output to OpenCV BGR format output_bgr = tensor_to_opencv(output_tensor, width, height) - + # Convert static image to display format static_array = np.array(static_img) static_bgr = cv2.cvtColor(static_array, cv2.COLOR_RGB2BGR) - + # Create 3-panel display: Input Video | Static Base | Generated Output combined = np.hstack([frame_resized, static_bgr, output_bgr]) - + # Add labels cv2.putText(combined, "Input Video", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) cv2.putText(combined, "Static Base", (width + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - cv2.putText(combined, "Generated", (width * 2 + 10, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - + cv2.putText(combined, "Generated", (width * 2 + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + # Add frame info - cv2.putText(combined, f"Frame: {frame_idx}/{frame_count}", (10, height - 20), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) - + cv2.putText( + combined, + f"Frame: {frame_idx}/{frame_count}", + (10, height - 20), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + # Write frame out.write(combined) - + # Track performance frame_time = time.time() - frame_start_time frame_times.append(frame_time) - + frame_idx += 1 if frame_idx % 10 == 0: avg_fps = len(frame_times) / sum(frame_times) if frame_times else 0 - print(f"process_video_ipadapter_stream: Processed {frame_idx}/{frame_count} frames (Avg FPS: {avg_fps:.2f})") - + print( + f"process_video_ipadapter_stream: Processed {frame_idx}/{frame_count} frames (Avg FPS: {avg_fps:.2f})" + ) + total_time = time.time() - total_start_time - + # Cleanup cap.release() out.release() - + # Calculate performance metrics if frame_times: avg_frame_time = sum(frame_times) / len(frame_times) @@ -222,7 +234,7 @@ def process_video_ipadapter_stream(config_path, input_video, static_image, outpu min_fps = 1.0 / max_frame_time else: avg_frame_time = avg_fps = min_frame_time = max_frame_time = max_fps = min_fps = 0 - + # Performance metrics metrics = { "input_video": str(input_video), @@ -238,75 +250,94 @@ def process_video_ipadapter_stream(config_path, input_video, static_image, outpu "avg_frame_time_seconds": avg_frame_time, "min_frame_time_seconds": min_frame_time, "max_frame_time_seconds": max_frame_time, - "model_id": config['model_id'], - "acceleration": config.get('acceleration', 'none'), - "frame_buffer_size": config.get('frame_buffer_size', 1), - "num_inference_steps": config.get('num_inference_steps', 50), - "guidance_scale": config.get('guidance_scale', 1.1), - "controlnets": [cn['model_id'] for cn in config.get('controlnets', [])], - "ipadapter_configs": [ip['ipadapter_model_path'] for ip in config.get('ipadapter_config', [])], + "model_id": config["model_id"], + "acceleration": config.get("acceleration", "none"), + "frame_buffer_size": config.get("frame_buffer_size", 1), + "num_inference_steps": config.get("num_inference_steps", 50), + "guidance_scale": config.get("guidance_scale", 1.1), + "controlnets": [cn["model_id"] for cn in config.get("controlnets", [])], + "ipadapter_configs": [ip["ipadapter_model_path"] for ip in config.get("ipadapter_config", [])], "test_type": "ipadapter_stream_test", "is_stream_enabled": True, "output_type": "pt", - "description": "IPAdapter as primary driving force with static base image using tensor output for performance" + "description": "IPAdapter as primary driving force with static base image using tensor output for performance", } - + # Save metrics metrics_path = output_dir / "performance_metrics.json" - with open(metrics_path, 'w') as f: + with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2) - - print(f"process_video_ipadapter_stream: Processing completed!") + + print("process_video_ipadapter_stream: Processing completed!") print(f"process_video_ipadapter_stream: Output video saved to: {output_video_path}") print(f"process_video_ipadapter_stream: Performance metrics saved to: {metrics_path}") print(f"process_video_ipadapter_stream: Average FPS: {avg_fps:.2f}") print(f"process_video_ipadapter_stream: Total time: {total_time:.2f} seconds") - print(f"process_video_ipadapter_stream: Test completed - IPAdapter stream behavior verified") - + print("process_video_ipadapter_stream: Test completed - IPAdapter stream behavior verified") + return metrics def main(): - parser = argparse.ArgumentParser(description="IPAdapter Stream Test Demo - Tests IPAdapter as primary driving force") - - parser.add_argument("--config", type=str, required=True, - help="Path to configuration file (must include both ControlNet and IPAdapter configs)") - parser.add_argument("--input-video", type=str, required=True, - help="Path to input video file (used for both ControlNet and IPAdapter guidance)") - parser.add_argument("--static-image", type=str, required=True, - help="Path to static image file (used as base input to StreamDiffusion)") - parser.add_argument("--output-dir", type=str, default="output", - help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.") - parser.add_argument("--engine-only", action="store_true", - help="Only build TensorRT engines and exit (no video processing)") - + parser = argparse.ArgumentParser( + description="IPAdapter Stream Test Demo - Tests IPAdapter as primary driving force" + ) + + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to configuration file (must include both ControlNet and IPAdapter configs)", + ) + parser.add_argument( + "--input-video", + type=str, + required=True, + help="Path to input video file (used for both ControlNet and IPAdapter guidance)", + ) + parser.add_argument( + "--static-image", + type=str, + required=True, + help="Path to static image file (used as base input to StreamDiffusion)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="output", + help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.", + ) + parser.add_argument( + "--engine-only", action="store_true", help="Only build TensorRT engines and exit (no video processing)" + ) + args = parser.parse_args() - + # Create timestamped subdirectory within the specified parent directory timestamp = time.strftime("%Y%m%d_%H%M%S") input_name = Path(args.input_video).stem static_name = Path(args.static_image).stem config_name = Path(args.config).stem subdir_name = f"ipadapter_stream_test_{config_name}_{input_name}_{static_name}_{timestamp}" - + # Combine parent directory with generated subdirectory name final_output_dir = Path(args.output_dir) / subdir_name args.output_dir = str(final_output_dir) print(f"main: Using output directory: {args.output_dir}") - + # Validate input files if not Path(args.config).exists(): print(f"main: Error - Config file not found: {args.config}") return 1 - + if not Path(args.input_video).exists(): print(f"main: Error - Input video not found: {args.input_video}") return 1 - + if not Path(args.static_image).exists(): print(f"main: Error - Static image not found: {args.static_image}") return 1 - + print("IPAdapter Stream Test Demo") print("=" * 50) print(f"main: Config: {args.config}") @@ -321,14 +352,10 @@ def main(): print("- is_stream=True → High-throughput pipelined processing") print("- Tests IPAdapter stream behavior fix") print("=" * 50) - + try: metrics = process_video_ipadapter_stream( - args.config, - args.input_video, - args.static_image, - args.output_dir, - engine_only=args.engine_only + args.config, args.input_video, args.static_image, args.output_dir, engine_only=args.engine_only ) if args.engine_only: print("main: Engine-only mode completed successfully!") @@ -337,6 +364,7 @@ def main(): return 0 except Exception as e: import traceback + print(f"main: Error during processing: {e}") print(f"main: Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}") return 1 diff --git a/examples/config/config_video_test.py b/examples/config/config_video_test.py index 7ff0a0208..69a8d39e6 100644 --- a/examples/config/config_video_test.py +++ b/examples/config/config_video_test.py @@ -7,136 +7,136 @@ of the config and input video to an output directory. """ -import cv2 -import torch -import numpy as np -from PIL import Image import argparse -from pathlib import Path -import sys -import time +import json import os import shutil -import json -from collections import deque +import sys +import time +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image def tensor_to_opencv(tensor: torch.Tensor, target_width: int, target_height: int) -> np.ndarray: """ Convert a PyTorch tensor (output_type='pt') to OpenCV BGR format for video writing. Uses efficient tensor operations similar to the realtime-img2img demo. - + Args: tensor: Tensor in range [0,1] with shape [B, C, H, W] or [C, H, W] target_width: Target width for output target_height: Target height for output - + Returns: BGR numpy array ready for OpenCV """ # Handle batch dimension - take first image if batched if tensor.dim() == 4: tensor = tensor[0] - + # Convert to uint8 format (0-255) and ensure correct shape (C, H, W) tensor_uint8 = (tensor * 255).clamp(0, 255).to(torch.uint8) - + # Convert from [C, H, W] to [H, W, C] format if tensor_uint8.dim() == 3: image_np = tensor_uint8.permute(1, 2, 0).cpu().numpy() else: raise ValueError(f"tensor_to_opencv: Unexpected tensor shape: {tensor_uint8.shape}") - + # Convert RGB to BGR for OpenCV image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) - + # Resize if needed if image_bgr.shape[:2] != (target_height, target_width): image_bgr = cv2.resize(image_bgr, (target_width, target_height)) - + return image_bgr def process_video(config_path, input_video, output_dir, engine_only=False): """Process video through ControlNet pipeline""" print(f"process_video: Loading config from {config_path}") - + # Import here to avoid loading at module level sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) - from streamdiffusion import load_config, create_wrapper_from_config - + from streamdiffusion import create_wrapper_from_config, load_config + # Load configuration config = load_config(config_path) - + # Force tensor output for better performance - config['output_type'] = 'pt' - + config["output_type"] = "pt" + # Get width and height from config (with defaults) - width = config.get('width', 512) - height = config.get('height', 512) - + width = config.get("width", 512) + height = config.get("height", 512) + print(f"process_video: Using dimensions: {width}x{height}") - print(f"process_video: Using output_type='pt' for better performance") - + print("process_video: Using output_type='pt' for better performance") + # Create output directory output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - + # Copy config and input video to output directory config_copy_path = output_dir / f"config_{Path(config_path).name}" shutil.copy2(config_path, config_copy_path) print(f"process_video: Copied config to {config_copy_path}") - + input_copy_path = output_dir / f"input_{Path(input_video).name}" shutil.copy2(input_video, input_copy_path) print(f"process_video: Copied input video to {input_copy_path}") - + # Create wrapper using the built-in function (width/height from config) wrapper = create_wrapper_from_config(config) - + if engine_only: print("Engine-only mode: TensorRT engines have been built (if needed). Exiting.") return None - + # Open input video cap = cv2.VideoCapture(str(input_video)) if not cap.isOpened(): raise ValueError(f"Could not open input video: {input_video}") - + # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + print(f"process_video: Input video - {frame_count} frames at {fps} FPS") - + # Setup output video writer output_video_path = output_dir / "output_video.mp4" - fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(str(output_video_path), fourcc, fps, (width + width, height)) - + # Performance tracking frame_times = [] total_start_time = time.time() - + print("process_video: Starting video processing...") - + frame_idx = 0 while True: ret, frame = cap.read() if not ret: break - + frame_start_time = time.time() - + # Resize frame frame_resized = cv2.resize(frame, (width, height)) - + # Convert frame to PIL frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB) frame_pil = Image.fromarray(frame_rgb) - + # Update control image for all configured ControlNets - if hasattr(wrapper.stream, '_controlnet_module') and wrapper.stream._controlnet_module: + if hasattr(wrapper.stream, "_controlnet_module") and wrapper.stream._controlnet_module: controlnet_count = len(wrapper.stream._controlnet_module.controlnets) print(f"process_video: Updating control image for {controlnet_count} ControlNet(s) on frame {frame_idx}") for i in range(controlnet_count): @@ -144,36 +144,35 @@ def process_video(config_path, input_video, output_dir, engine_only=False): else: print(f"process_video: No ControlNet module found for frame {frame_idx}") output_tensor = wrapper(frame_pil) - + # Convert tensor output to OpenCV BGR format output_bgr = tensor_to_opencv(output_tensor, width, height) - + # Create side-by-side display combined = np.hstack([frame_resized, output_bgr]) - + # Add labels cv2.putText(combined, "Input", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - cv2.putText(combined, "Generated", (width + 10, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - + cv2.putText(combined, "Generated", (width + 10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + # Write frame out.write(combined) - + # Track performance frame_time = time.time() - frame_start_time frame_times.append(frame_time) - + frame_idx += 1 if frame_idx % 10 == 0: avg_fps = len(frame_times) / sum(frame_times) if frame_times else 0 print(f"process_video: Processed {frame_idx}/{frame_count} frames (Avg FPS: {avg_fps:.2f})") - + total_time = time.time() - total_start_time - + # Cleanup cap.release() out.release() - + # Calculate performance metrics if frame_times: avg_frame_time = sum(frame_times) / len(frame_times) @@ -184,7 +183,7 @@ def process_video(config_path, input_video, output_dir, engine_only=False): min_fps = 1.0 / max_frame_time else: avg_frame_time = avg_fps = min_frame_time = max_frame_time = max_fps = min_fps = 0 - + # Performance metrics metrics = { "input_video": str(input_video), @@ -199,72 +198,77 @@ def process_video(config_path, input_video, output_dir, engine_only=False): "avg_frame_time_seconds": avg_frame_time, "min_frame_time_seconds": min_frame_time, "max_frame_time_seconds": max_frame_time, - "model_id": config['model_id'], - "acceleration": config.get('acceleration', 'none'), - "frame_buffer_size": config.get('frame_buffer_size', 1), - "num_inference_steps": config.get('num_inference_steps', 50), - "guidance_scale": config.get('guidance_scale', 1.1), + "model_id": config["model_id"], + "acceleration": config.get("acceleration", "none"), + "frame_buffer_size": config.get("frame_buffer_size", 1), + "num_inference_steps": config.get("num_inference_steps", 50), + "guidance_scale": config.get("guidance_scale", 1.1), "output_type": "pt", - "controlnets": [cn['model_id'] for cn in config.get('controlnets', [])], + "controlnets": [cn["model_id"] for cn in config.get("controlnets", [])], "test_type": "controlnet_video_test", - "description": "ControlNet video processing using tensor output for performance" + "description": "ControlNet video processing using tensor output for performance", } - + # Save metrics metrics_path = output_dir / "performance_metrics.json" - with open(metrics_path, 'w') as f: + with open(metrics_path, "w") as f: json.dump(metrics, f, indent=2) - - print(f"process_video: Processing completed!") + + print("process_video: Processing completed!") print(f"process_video: Output video saved to: {output_video_path}") print(f"process_video: Performance metrics saved to: {metrics_path}") print(f"process_video: Average FPS: {avg_fps:.2f}") print(f"process_video: Total time: {total_time:.2f} seconds") - + return metrics + def main(): parser = argparse.ArgumentParser(description="ControlNet Video Test Demo") - + # Get the script directory to make paths relative to it script_dir = Path(__file__).parent default_config = script_dir.parent.parent / "configs" / "controlnet_examples" / "multi_controlnet_example.yaml" - - parser.add_argument("--config", type=str, required=True, - help="Path to ControlNet configuration file") - parser.add_argument("--input-video", type=str, required=True, - help="Path to input video file") - parser.add_argument("--output-dir", type=str, default="output", - help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.") - parser.add_argument("--engine-only", action="store_true", help="Only build TensorRT engines and exit (no video processing)") - + + parser.add_argument("--config", type=str, required=True, help="Path to ControlNet configuration file") + parser.add_argument("--input-video", type=str, required=True, help="Path to input video file") + parser.add_argument( + "--output-dir", + type=str, + default="output", + help="Parent directory for results (default: 'output'). Script will create a timestamped subdirectory inside this.", + ) + parser.add_argument( + "--engine-only", action="store_true", help="Only build TensorRT engines and exit (no video processing)" + ) + args = parser.parse_args() - + # Create timestamped subdirectory within the specified parent directory timestamp = time.strftime("%Y%m%d_%H%M%S") input_name = Path(args.input_video).stem config_name = Path(args.config).stem subdir_name = f"controlnet_test_{config_name}_{input_name}_{timestamp}" - + # Combine parent directory with generated subdirectory name final_output_dir = Path(args.output_dir) / subdir_name args.output_dir = str(final_output_dir) print(f"main: Using output directory: {args.output_dir}") - + # Validate input files if not Path(args.config).exists(): print(f"main: Error - Config file not found: {args.config}") return 1 - + if not Path(args.input_video).exists(): print(f"main: Error - Input video not found: {args.input_video}") return 1 - + print("ControlNet Video Test Demo") print(f"main: Config: {args.config}") print(f"main: Input video: {args.input_video}") print(f"main: Output directory: {args.output_dir}") - + try: metrics = process_video(args.config, args.input_video, args.output_dir, engine_only=args.engine_only) if args.engine_only: @@ -274,10 +278,11 @@ def main(): return 0 except Exception as e: import traceback + print(f"main: Error during processing: {e}") print(f"main: Traceback:\n{''.join(traceback.format_tb(e.__traceback__))}") return 1 if __name__ == "__main__": - exit(main()) \ No newline at end of file + exit(main()) diff --git a/examples/img2img/multi.py b/examples/img2img/multi.py index 8912d1439..af112585b 100644 --- a/examples/img2img/multi.py +++ b/examples/img2img/multi.py @@ -1,7 +1,7 @@ import glob import os import sys -from typing import Literal, Dict, Optional +from typing import Dict, Literal, Optional import fire @@ -10,6 +10,7 @@ from streamdiffusion import StreamDiffusionWrapper + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/examples/img2img/single.py b/examples/img2img/single.py index 4be8a2185..b894bc031 100644 --- a/examples/img2img/single.py +++ b/examples/img2img/single.py @@ -1,6 +1,6 @@ import os import sys -from typing import Literal, Dict, Optional +from typing import Dict, Literal, Optional import fire @@ -9,6 +9,7 @@ from streamdiffusion import StreamDiffusionWrapper + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/examples/optimal-performance/multi.py b/examples/optimal-performance/multi.py index 791d88b19..e154dfcc1 100644 --- a/examples/optimal-performance/multi.py +++ b/examples/optimal-performance/multi.py @@ -3,7 +3,7 @@ import threading import time import tkinter as tk -from multiprocessing import Process, Queue, get_context +from multiprocessing import Queue, get_context from typing import List, Literal import fire @@ -98,9 +98,7 @@ def image_generation_process( return -def _receive_images( - queue: Queue, fps_queue: Queue, labels: List[tk.Label], fps_label: tk.Label -) -> None: +def _receive_images(queue: Queue, fps_queue: Queue, labels: List[tk.Label], fps_label: tk.Label) -> None: """ Continuously receive images from a queue and update the labels. @@ -120,9 +118,7 @@ def _receive_images( if not queue.empty(): [ labels[0].after(0, update_image, image_data, labels) - for image_data in postprocess_image( - queue.get(block=False), output_type="pil" - ) + for image_data in postprocess_image(queue.get(block=False), output_type="pil") ] if not fps_queue.empty(): fps_label.config(text=f"FPS: {fps_queue.get(block=False):.2f}") @@ -153,9 +149,7 @@ def receive_images(queue: Queue, fps_queue: Queue) -> None: fps_label = tk.Label(root, text="FPS: 0") fps_label.grid(rows=2, columnspan=2) - thread = threading.Thread( - target=_receive_images, args=(queue, fps_queue, labels, fps_label), daemon=True - ) + thread = threading.Thread(target=_receive_images, args=(queue, fps_queue, labels, fps_label), daemon=True) thread.start() try: @@ -173,7 +167,7 @@ def main( """ Main function to start the image generation and viewer processes. """ - ctx = get_context('spawn') + ctx = get_context("spawn") queue = ctx.Queue() fps_queue = ctx.Queue() process1 = ctx.Process( @@ -188,5 +182,6 @@ def main( process1.join() process2.join() + if __name__ == "__main__": fire.Fire(main) diff --git a/examples/optimal-performance/single.py b/examples/optimal-performance/single.py index a8020bb86..ec610de72 100644 --- a/examples/optimal-performance/single.py +++ b/examples/optimal-performance/single.py @@ -1,15 +1,17 @@ import os import sys import time -from multiprocessing import Process, Queue, get_context +from multiprocessing import Queue, get_context from typing import Literal import fire + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from utils.viewer import receive_images from streamdiffusion import StreamDiffusionWrapper +from utils.viewer import receive_images + def image_generation_process( queue: Queue, @@ -63,6 +65,7 @@ def image_generation_process( print(f"fps: {fps}") return + def main( prompt: str = "cat with sunglasses and a hat, photoreal, 8K", model_id_or_path: str = "stabilityai/sd-turbo", @@ -71,7 +74,7 @@ def main( """ Main function to start the image generation and viewer processes. """ - ctx = get_context('spawn') + ctx = get_context("spawn") queue = ctx.Queue() fps_queue = ctx.Queue() process1 = ctx.Process( @@ -86,5 +89,6 @@ def main( process1.join() process2.join() + if __name__ == "__main__": fire.Fire(main) diff --git a/examples/screen/main.py b/examples/screen/main.py index 16042ed62..405ff6eec 100644 --- a/examples/screen/main.py +++ b/examples/screen/main.py @@ -1,26 +1,31 @@ import os import sys -import time import threading -from multiprocessing import Process, Queue, get_context +import time +import tkinter as tk +from multiprocessing import Queue, get_context from multiprocessing.connection import Connection -from typing import List, Literal, Dict, Optional -import torch +from typing import Dict, Literal, Optional + +import fire +import mss import PIL.Image +import torch + from streamdiffusion.image_utils import pil2tensor -import mss -import fire -import tkinter as tk + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from utils.viewer import receive_images from streamdiffusion import StreamDiffusionWrapper +from utils.viewer import receive_images + inputs = [] top = 0 left = 0 + def screen( event: threading.Event, height: int = 512, @@ -37,10 +42,12 @@ def screen( img = PIL.Image.frombytes("RGB", img.size, img.bgra, "raw", "BGRX") img.resize((height, width)) inputs.append(pil2tensor(img)) - print('exit : screen') + print("exit : screen") + + def dummy_screen( - width: int, - height: int, + width: int, + height: int, ): root = tk.Tk() root.title("Press Enter to start") @@ -48,17 +55,22 @@ def dummy_screen( root.resizable(False, False) root.attributes("-alpha", 0.8) root.configure(bg="black") + def destroy(event): root.destroy() + root.bind("", destroy) + def update_geometry(event): global top, left top = root.winfo_y() left = root.winfo_x() + root.bind("", update_geometry) root.mainloop() return {"top": top, "left": left, "width": width, "height": height} + def monitor_setting_process( width: int, height: int, @@ -67,6 +79,7 @@ def monitor_setting_process( monitor = dummy_screen(width, height) monitor_sender.send(monitor) + def image_generation_process( queue: Queue, fps_queue: Queue, @@ -88,7 +101,7 @@ def image_generation_process( enable_similar_image_filter: bool, similar_image_filter_threshold: float, similar_image_filter_max_skip_frame: float, - monitor_receiver : Connection, + monitor_receiver: Connection, ) -> None: """ Process for generating images based on a prompt using a specified model. @@ -179,7 +192,7 @@ def image_generation_process( while True: try: - if not close_queue.empty(): # closing check + if not close_queue.empty(): # closing check break if len(inputs) < frame_buffer_size: time.sleep(0.005) @@ -191,9 +204,7 @@ def image_generation_process( sampled_inputs.append(inputs[len(inputs) - index - 1]) input_batch = torch.cat(sampled_inputs) inputs.clear() - output_images = stream.stream( - input_batch.to(device=stream.device, dtype=stream.dtype) - ).cpu() + output_images = stream.stream(input_batch.to(device=stream.device, dtype=stream.dtype)).cpu() if frame_buffer_size == 1: output_images = [output_images] for output_image in output_images: @@ -205,10 +216,11 @@ def image_generation_process( break print("closing image_generation_process...") - event.set() # stop capture thread + event.set() # stop capture thread input_screen.join() print(f"fps: {fps}") + def main( model_id_or_path: str = "KBlueLeaf/kohaku-v2.1", lora_dict: Optional[Dict[str, float]] = None, @@ -231,7 +243,7 @@ def main( """ Main function to start the image generation and viewer processes. """ - ctx = get_context('spawn') + ctx = get_context("spawn") queue = ctx.Queue() fps_queue = ctx.Queue() close_queue = Queue() @@ -262,7 +274,7 @@ def main( similar_image_filter_threshold, similar_image_filter_max_skip_frame, monitor_receiver, - ), + ), ) process1.start() @@ -272,7 +284,7 @@ def main( width, height, monitor_sender, - ), + ), ) monitor_process.start() monitor_process.join() @@ -285,10 +297,10 @@ def main( print("process2 terminated.") close_queue.put(True) print("process1 terminating...") - process1.join(5) # with timeout + process1.join(5) # with timeout if process1.is_alive(): print("process1 still alive. force killing...") - process1.terminate() # force kill... + process1.terminate() # force kill... process1.join() print("process1 terminated.") diff --git a/examples/txt2img/multi.py b/examples/txt2img/multi.py index 0e50e36f8..1d3301966 100644 --- a/examples/txt2img/multi.py +++ b/examples/txt2img/multi.py @@ -1,18 +1,26 @@ import os import sys -from typing import Literal, Dict, Optional +from typing import Dict, Literal, Optional import fire + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from streamdiffusion import StreamDiffusionWrapper + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) def main( - output: str = os.path.join(CURRENT_DIR, "..", "..", "images", "outputs",), + output: str = os.path.join( + CURRENT_DIR, + "..", + "..", + "images", + "outputs", + ), model_id_or_path: str = "KBlueLeaf/kohaku-v2.1", lora_dict: Optional[Dict[str, float]] = None, prompt: str = "1girl with brown dog hair, thick glasses, smiling", @@ -22,7 +30,6 @@ def main( acceleration: Literal["none", "xformers", "tensorrt"] = "xformers", seed: int = 2, ): - """ Process for generating images based on a prompt using a specified model. diff --git a/examples/txt2img/single.py b/examples/txt2img/single.py index 80f5ed238..f3c8763aa 100644 --- a/examples/txt2img/single.py +++ b/examples/txt2img/single.py @@ -1,6 +1,6 @@ import os import sys -from typing import Literal, Dict, Optional +from typing import Dict, Literal, Optional import fire @@ -9,6 +9,7 @@ from streamdiffusion import StreamDiffusionWrapper + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -23,7 +24,6 @@ def main( use_denoising_batch: bool = False, seed: int = 2, ): - """ Process for generating images based on a prompt using a specified model. diff --git a/examples/vid2vid/main.py b/examples/vid2vid/main.py index c4860d64a..a045b29ad 100644 --- a/examples/vid2vid/main.py +++ b/examples/vid2vid/main.py @@ -1,16 +1,18 @@ import os import sys -from typing import Literal, Dict, Optional +from typing import Dict, Literal, Optional import fire import torch from torchvision.io import read_video, write_video from tqdm import tqdm + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from streamdiffusion import StreamDiffusionWrapper + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -26,7 +28,6 @@ def main( enable_similar_image_filter: bool = True, seed: int = 2, ): - """ Process for generating images based on a prompt using a specified model. diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 000000000..90f6c5a19 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,4 @@ +# Development + quality-harness dependencies (not needed for production inference) +lpips>=0.1.4 +scikit-image>=0.21 +PyYAML>=6.0 diff --git a/src/streamdiffusion/__init__.py b/src/streamdiffusion/__init__.py index 88bfdb895..17b226ce3 100644 --- a/src/streamdiffusion/__init__.py +++ b/src/streamdiffusion/__init__.py @@ -1,3 +1,4 @@ +from . import _compat # noqa: F401 — applies kvo_cache patch before any diffusers import from .config import create_wrapper_from_config, load_config, save_config from .pipeline import StreamDiffusion from .preprocessing.processors import list_preprocessors diff --git a/src/streamdiffusion/_compat/__init__.py b/src/streamdiffusion/_compat/__init__.py new file mode 100644 index 000000000..23f44fa02 --- /dev/null +++ b/src/streamdiffusion/_compat/__init__.py @@ -0,0 +1,4 @@ +from .diffusers_kvo_patch import apply as _apply_kvo_patch + + +_apply_kvo_patch() diff --git a/src/streamdiffusion/_compat/diffusers_kvo_patch.py b/src/streamdiffusion/_compat/diffusers_kvo_patch.py new file mode 100644 index 000000000..a45e82b26 --- /dev/null +++ b/src/streamdiffusion/_compat/diffusers_kvo_patch.py @@ -0,0 +1,792 @@ +"""Re-applies varshith15/diffusers@3e3b72f kvo_cache patch onto upstream diffusers. + +Source fork: https://github.com/varshith15/diffusers @ 3e3b72f557e91546894340edabc845e894f00922 +Target: diffusers >= 0.38.0 (upstream HuggingFace) + +The patch threads an optional KV-cache through the UNet2DConditionModel forward pass so +that StreamDiffusion's TRT cached-attention pipeline (unet_unified_export.py) works with +vanilla upstream diffusers. When kvo_cache=None (the default) behaviour is identical to +upstream diffusers — no performance or correctness impact for non-cached paths. + +Called automatically at ``import streamdiffusion`` via _compat/__init__.py. +""" + +from __future__ import annotations + +import inspect +import logging + + +logger = logging.getLogger(__name__) + +_PATCHED = False + +# Sentinel used by _patch_down_block so we can distinguish "caller passed kvo_cache=None +# explicitly" (patched UNet → must return 3-tuple) from "caller never passed kvo_cache" +# (ControlNet / any other diffusers module that uses CrossAttnDownBlock2D → must return the +# original 2-tuple to avoid "too many values to unpack (expected 2)"). +_KVO_NOT_GIVEN = object() + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + + +def apply() -> None: + """Apply kvo_cache patch. Idempotent — safe to call multiple times.""" + global _PATCHED + if _PATCHED: + return + if _is_patched(): + _PATCHED = True + return + + _patch_attn_processor() + _patch_attention_forward() + _patch_basic_transformer_block() + _patch_transformer2d() + _patch_mid_block() + _patch_down_block() + _patch_up_block() + _patch_unet2d() + + _PATCHED = True + logger.debug("diffusers_kvo_patch: kvo_cache patch applied") + + +def _is_patched() -> bool: + from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + + return "kvo_cache" in inspect.signature(UNet2DConditionModel.forward).parameters + + +# --------------------------------------------------------------------------- +# Individual patches +# --------------------------------------------------------------------------- + + +def _patch_attn_processor() -> None: + """AttnProcessor2_0.__call__: accept kvo_cache kwarg, return (hidden_states, kvo_cache).""" + from diffusers.models.attention_processor import AttnProcessor2_0 + + _orig = AttnProcessor2_0.__call__ + + def _call( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + kvo_cache=None, + *args, + **kwargs, + ): + result = _orig( + self, + attn, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + temb=temb, + *args, + **kwargs, + ) + return result, kvo_cache + + AttnProcessor2_0.__call__ = _call + + +def _patch_attention_forward() -> None: + """Attention.forward: accept kvo_cache, route to processor only for self-attn.""" + from diffusers.models.attention_processor import Attention + from diffusers.utils import logging as _dlog + + _dlogger = _dlog.get_logger("diffusers.models.attention_processor") + + def _forward( + self, hidden_states, encoder_hidden_states=None, attention_mask=None, kvo_cache=None, **cross_attention_kwargs + ): + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [ + k for k in cross_attention_kwargs if k not in attn_parameters and k not in quiet_attn_parameters + ] + if unused_kwargs: + _dlogger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by " + f"{self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + if encoder_hidden_states is None: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + kvo_cache=kvo_cache, + **cross_attention_kwargs, + ) + else: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + Attention.forward = _forward + + +def _patch_basic_transformer_block() -> None: + """BasicTransformerBlock.forward: thread kvo_cache through attn1, handle attn2 tuple return.""" + from diffusers.models.attention import BasicTransformerBlock + + def _forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + added_cond_kwargs=None, + kvo_cache=None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + batch_size = hidden_states.shape[0] + + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output, kvo_cache_out = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + kvo_cache=kvo_cache, + **cross_attention_kwargs, + ) + + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + if isinstance(attn_output, tuple): + attn_output = attn_output[0] + hidden_states = attn_output + hidden_states + + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + + if self.norm_type == "ada_norm_zero": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + from diffusers.models.attention import _chunked_feed_forward + + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states, kvo_cache_out + + BasicTransformerBlock.forward = _forward + + +def _patch_transformer2d() -> None: + """Transformer2DModel.forward: thread kvo_cache through transformer_blocks.""" + import torch + from diffusers.models.modeling_outputs import Transformer2DModelOutput + from diffusers.models.transformers.transformer_2d import Transformer2DModel + + def _forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + added_cond_kwargs=None, + class_labels=None, + cross_attention_kwargs=None, + attention_mask=None, + encoder_attention_mask=None, + kvo_cache=None, + return_dict=True, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if self.is_input_continuous: + batch_size, _, height, width = hidden_states.shape + residual = hidden_states + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + hidden_states, encoder_hidden_states, timestep, added_cond_kwargs + ) + + kvo_cache_out = [] + for idx, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + ) + else: + block_cache_in = kvo_cache[idx] if kvo_cache else None + hidden_states, block_cache_out = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + kvo_cache=block_cache_in, + ) + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + + if self.is_input_continuous: + output = self._get_output_for_continuous_inputs( + hidden_states=hidden_states, + residual=residual, + batch_size=batch_size, + height=height, + width=width, + inner_dim=inner_dim, + ) + elif self.is_input_vectorized: + output = self._get_output_for_vectorized_inputs(hidden_states) + elif self.is_input_patches: + output = self._get_output_for_patched_inputs( + hidden_states=hidden_states, + timestep=timestep, + class_labels=class_labels, + embedded_timestep=embedded_timestep, + height=height, + width=width, + ) + + if not return_dict: + return (output, kvo_cache_out) + + return Transformer2DModelOutput(sample=output) + + Transformer2DModel.forward = _forward + + +def _patch_mid_block() -> None: + """UNetMidBlock2DCrossAttn.forward: thread kvo_cache through attention loop. + + Same sentinel trick as _patch_down_block: callers that do not pass kvo_cache explicitly + (e.g. ControlNet) get the original single-return-value (hidden_states); the patched UNet + path — which always passes kvo_cache= explicitly — gets the 2-tuple (hidden, kvo_cache_out). + """ + import torch + from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn + + def _forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + encoder_attention_mask=None, + kvo_cache=_KVO_NOT_GIVEN, + ): + _caller_passed_kvo = kvo_cache is not _KVO_NOT_GIVEN + if not _caller_passed_kvo: + kvo_cache = None + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + kvo_cache_out = [] + for idx, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + block_cache_in = kvo_cache[idx] if kvo_cache else None + hidden_states, block_cache_out = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + kvo_cache=block_cache_in, + return_dict=False, + ) + hidden_states = resnet(hidden_states, temb) + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + + if _caller_passed_kvo: + return hidden_states, kvo_cache_out + return hidden_states + + UNetMidBlock2DCrossAttn.forward = _forward + + +def _patch_down_block() -> None: + """CrossAttnDownBlock2D.forward: thread kvo_cache, return (hidden, output_states, kvo_cache_out). + + Uses _KVO_NOT_GIVEN sentinel so the patched forward stays backward-compatible with callers + (e.g. ControlNet) that never pass kvo_cache: those callers get the original 2-tuple return, + while _patch_unet2d (which always passes kvo_cache= explicitly) gets the 3-tuple. + """ + import torch + from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D + + def _forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + encoder_attention_mask=None, + additional_residuals=None, + kvo_cache=_KVO_NOT_GIVEN, + ): + _caller_passed_kvo = kvo_cache is not _KVO_NOT_GIVEN + if not _caller_passed_kvo: + kvo_cache = None + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + blocks = list(zip(self.resnets, self.attentions)) + kvo_cache_out = [] + + for i, (resnet, attn) in enumerate(blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + block_cache_in = kvo_cache[i] if kvo_cache else None + hidden_states, block_cache_out = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + kvo_cache=block_cache_in, + return_dict=False, + ) + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + output_states = output_states + (hidden_states,) + + if _caller_passed_kvo: + return hidden_states, output_states, kvo_cache_out + return hidden_states, output_states + + CrossAttnDownBlock2D.forward = _forward + + +def _patch_up_block() -> None: + """CrossAttnUpBlock2D.forward: thread kvo_cache, return (hidden, kvo_cache_out).""" + from diffusers.models.unets.unet_2d_blocks import CrossAttnUpBlock2D + + try: + from diffusers.models.unets.unet_2d_blocks import apply_freeu + except ImportError: + apply_freeu = None + import torch + + def _forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + encoder_attention_mask=None, + kvo_cache=None, + ): + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + apply_freeu is not None + and getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + kvo_cache_out = [] + for idx, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + block_cache_in = kvo_cache[idx] if kvo_cache else None + hidden_states, block_cache_out = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + kvo_cache=block_cache_in, + return_dict=False, + ) + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states, kvo_cache_out + + CrossAttnUpBlock2D.forward = _forward + + +def _patch_unet2d() -> None: + """UNet2DConditionModel.forward: add kvo_cache param, wire through all blocks.""" + import torch + from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput + from diffusers.utils import deprecate + + def _forward( + self, + sample, + timestep, + encoder_hidden_states, + class_labels=None, + timestep_cond=None, + attention_mask=None, + cross_attention_kwargs=None, + added_cond_kwargs=None, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + down_intrablock_additional_residuals=None, + encoder_attention_mask=None, + kvo_cache=None, + return_dict=True, + ): + default_overall_up_factor = 2**self.num_upsamplers + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + forward_upsample_size = True + break + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + sample = self.conv_in(sample) + + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = down_intrablock_additional_residuals is not None + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated " + "and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used " + "for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + cache_idx = 0 + kvo_cache_out = [] + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + block_cache_in = kvo_cache[cache_idx] if kvo_cache else None + sample, res_samples, block_cache_out = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + kvo_cache=block_cache_in, + **additional_residuals, + ) + cache_idx += 1 + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + down_block_res_samples = new_down_block_res_samples + + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + block_cache_in = kvo_cache[cache_idx] if kvo_cache else None + sample, block_cache_out = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + kvo_cache=block_cache_in, + ) + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + cache_idx += 1 + else: + sample = self.mid_block(sample, emb) + + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + block_cache_in = kvo_cache[cache_idx] if kvo_cache else None + sample, block_cache_out = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + kvo_cache=block_cache_in, + ) + cache_idx += 1 + if block_cache_out is not None: + kvo_cache_out.append(block_cache_out) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample, kvo_cache_out) + + return UNet2DConditionOutput(sample=sample) + + UNet2DConditionModel.forward = _forward diff --git a/src/streamdiffusion/_hf_tracing_patches.py b/src/streamdiffusion/_hf_tracing_patches.py index d2d611f30..422b6bf65 100644 --- a/src/streamdiffusion/_hf_tracing_patches.py +++ b/src/streamdiffusion/_hf_tracing_patches.py @@ -5,7 +5,10 @@ import torch + _ALREADY = False # idempotence guard + + # --------------------------------------------------------------------------- # # 1. UNet2DConditionModel: guard in_channels % up_factor # --------------------------------------------------------------------------- # @@ -16,11 +19,11 @@ def _patch_unet(): def patched(self, sample, *args, **kwargs): if torch.jit.is_tracing(): - dim = torch.as_tensor(getattr(self.config, "in_channels", self.in_channels)) + dim = torch.as_tensor(getattr(self.config, "in_channels", self.in_channels)) up_factor = torch.as_tensor(getattr(self.config, "default_overall_up_factor", 1)) torch._assert( torch.remainder(dim, up_factor) == 0, - f"in_channels={dim} not divisible by default_overall_up_factor={up_factor}" + f"in_channels={dim} not divisible by default_overall_up_factor={up_factor}", ) return orig_fwd(self, sample, *args, **kwargs) @@ -32,12 +35,13 @@ def patched(self, sample, *args, **kwargs): # --------------------------------------------------------------------------- # def _patch_downsample(): import diffusers.models.downsampling as d + orig_fwd = d.Downsample2D.forward def patched(self, hidden_states, *args, **kwargs): torch._assert( hidden_states.shape[1] == self.channels, - f"[Downsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}" + f"[Downsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}", ) return orig_fwd(self, hidden_states, *args, **kwargs) @@ -49,12 +53,13 @@ def patched(self, hidden_states, *args, **kwargs): # --------------------------------------------------------------------------- # def _patch_upsample(): import diffusers.models.upsampling as u + orig_fwd = u.Upsample2D.forward def patched(self, hidden_states, *args, **kwargs): torch._assert( hidden_states.shape[1] == self.channels, - f"[Upsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}" + f"[Upsample2D] channels mismatch: {hidden_states.shape[1]} vs {self.channels}", ) return orig_fwd(self, hidden_states, *args, **kwargs) diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index dd3a97e0b..f79fff72e 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -286,6 +286,9 @@ def _quant_fn(): # --- FP8 Q/DQ layer count (sanity gate: < 100 means quantization is inactive) --- if fp8 and os.path.exists(engine_path): try: + import json as _json + import re as _re + import tensorrt as trt _rt = trt.Runtime(trt.Logger(trt.Logger.WARNING)) @@ -300,6 +303,28 @@ def _quant_fn(): _build_logger.warning( f"[BUILD] Low Q/DQ count ({_qdq} < 500) — FP8 quantization likely inactive or incomplete" ) + + # Fused-MHA check: count attention layers TRT fused into a single kernel. + # Pattern is empirical — FLUX uses "_gemm_mha_v2"; SDXL on Ada may differ. + # First build logs sample names so the regex can be confirmed or tightened. + _MHA_RE = _re.compile(r"mha|fmha|MultiHead|FlashAttn", _re.IGNORECASE) + try: + _layers = _json.loads(_info).get("Layers", []) + except Exception: + _layers = [] + _total = len(_layers) + _mha_names = [_l.get("Name", "") for _l in _layers if _MHA_RE.search(_l.get("Name", ""))] + _mha_count = len(_mha_names) + stats["mha_fused_kernels"] = _mha_count + stats["total_engine_layers"] = _total + _build_logger.info(f"[BUILD] FP8 engine fused MHA layers: {_mha_count} / {_total} total") + if _mha_count == 0 and _total > 0: + _build_logger.warning( + "[BUILD] No fused MHA layers detected — attention may be running decomposed " + "(slower). Sample layer names (first 5): " + str([_l.get("Name", "") for _l in _layers[:5]]) + ) + else: + _build_logger.info(f"[BUILD] Sample fused-MHA layer names: {_mha_names[:3]}") except Exception as _e: _build_logger.warning(f"[BUILD] FP8 inspector check skipped: {_e}") diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py index 13532a0c4..be5408075 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/__init__.py @@ -1,15 +1,16 @@ from .controlnet_export import SDXLControlNetExportWrapper from .unet_controlnet_export import ControlNetUNetExportWrapper, MultiControlNetUNetExportWrapper from .unet_ipadapter_export import IPAdapterUNetExportWrapper -from .unet_sdxl_export import SDXLExportWrapper, SDXLConditioningHandler +from .unet_sdxl_export import SDXLConditioningHandler, SDXLExportWrapper from .unet_unified_export import UnifiedExportWrapper + __all__ = [ "SDXLControlNetExportWrapper", "ControlNetUNetExportWrapper", - "MultiControlNetUNetExportWrapper", + "MultiControlNetUNetExportWrapper", "IPAdapterUNetExportWrapper", "SDXLExportWrapper", "SDXLConditioningHandler", "UnifiedExportWrapper", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py index 946917b1c..43809a1a7 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/controlnet_export.py @@ -1,23 +1,24 @@ import torch + class SDXLControlNetExportWrapper(torch.nn.Module): """Wrapper for SDXL ControlNet models to handle added_cond_kwargs properly during ONNX export""" - + def __init__(self, controlnet_model): super().__init__() self.controlnet = controlnet_model - + # Get device and dtype from model - if hasattr(controlnet_model, 'device'): + if hasattr(controlnet_model, "device"): self.device = controlnet_model.device else: # Try to infer from first parameter try: self.device = next(controlnet_model.parameters()).device except: - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - if hasattr(controlnet_model, 'dtype'): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if hasattr(controlnet_model, "dtype"): self.dtype = controlnet_model.dtype else: # Try to infer from first parameter @@ -25,15 +26,14 @@ def __init__(self, controlnet_model): self.dtype = next(controlnet_model.parameters()).dtype except: self.dtype = torch.float16 - - def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, text_embeds, time_ids): + + def forward( + self, sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale, text_embeds, time_ids + ): """Forward pass that handles SDXL ControlNet requirements and produces 9 down blocks""" # Use the provided SDXL conditioning - added_cond_kwargs = { - 'text_embeds': text_embeds, - 'time_ids': time_ids - } - + added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids} + # Call the ControlNet with proper arguments including conditioning_scale result = self.controlnet( sample=sample, @@ -42,40 +42,49 @@ def forward(self, sample, timestep, encoder_hidden_states, controlnet_cond, cond controlnet_cond=controlnet_cond, conditioning_scale=conditioning_scale, added_cond_kwargs=added_cond_kwargs, - return_dict=False + return_dict=False, ) - + # Extract down blocks and mid block from result if isinstance(result, tuple) and len(result) >= 2: down_block_res_samples, mid_block_res_sample = result[0], result[1] - elif hasattr(result, 'down_block_res_samples') and hasattr(result, 'mid_block_res_sample'): + elif hasattr(result, "down_block_res_samples") and hasattr(result, "mid_block_res_sample"): down_block_res_samples = result.down_block_res_samples mid_block_res_sample = result.mid_block_res_sample else: raise ValueError(f"Unexpected ControlNet output format: {type(result)}") - + # SDXL ControlNet should have exactly 9 down blocks if len(down_block_res_samples) != 9: raise ValueError(f"SDXL ControlNet expected 9 down blocks, got {len(down_block_res_samples)}") - + # Return 9 down blocks + 1 mid block with explicit names matching UNet pattern # Following the pattern from controlnet_wrapper.py and models.py: # down_block_00: Initial sample (320 channels) - # down_block_01-03: Block 0 residuals (320 channels) + # down_block_01-03: Block 0 residuals (320 channels) # down_block_04-06: Block 1 residuals (640 channels) # down_block_07-08: Block 2 residuals (1280 channels) down_block_00 = down_block_res_samples[0] # Initial: 320 channels, 88x88 down_block_01 = down_block_res_samples[1] # Block0: 320 channels, 88x88 - down_block_02 = down_block_res_samples[2] # Block0: 320 channels, 88x88 + down_block_02 = down_block_res_samples[2] # Block0: 320 channels, 88x88 down_block_03 = down_block_res_samples[3] # Block0: 320 channels, 44x44 down_block_04 = down_block_res_samples[4] # Block1: 640 channels, 44x44 down_block_05 = down_block_res_samples[5] # Block1: 640 channels, 44x44 down_block_06 = down_block_res_samples[6] # Block1: 640 channels, 22x22 down_block_07 = down_block_res_samples[7] # Block2: 1280 channels, 22x22 down_block_08 = down_block_res_samples[8] # Block2: 1280 channels, 22x22 - mid_block = mid_block_res_sample # Mid: 1280 channels, 22x22 - + mid_block = mid_block_res_sample # Mid: 1280 channels, 22x22 + # Return as individual tensors to preserve names in ONNX - return (down_block_00, down_block_01, down_block_02, down_block_03, - down_block_04, down_block_05, down_block_06, down_block_07, - down_block_08, mid_block) \ No newline at end of file + return ( + down_block_00, + down_block_01, + down_block_02, + down_block_03, + down_block_04, + down_block_05, + down_block_06, + down_block_07, + down_block_08, + mid_block, + ) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py index e7a834bc1..fb6b5733c 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_controlnet_export.py @@ -1,30 +1,32 @@ """ControlNet-aware UNet wrapper for ONNX export""" +from typing import Dict, List, Optional + import torch -from typing import List, Optional, Dict, Any from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + from ..models.utils import convert_list_to_structure class ControlNetUNetExportWrapper(torch.nn.Module): """Wrapper that combines UNet with ControlNet inputs for ONNX export""" - + def __init__(self, unet: UNet2DConditionModel, control_input_names: List[str], kvo_cache_structure: List[int]): super().__init__() self.unet = unet self.control_input_names = control_input_names self.kvo_cache_structure = kvo_cache_structure - + self.control_names = [] for name in control_input_names: if "input_control" in name or "output_control" in name or "middle_control" in name: self.control_names.append(name) - + self.num_controlnet_args = len(self.control_names) - + # Detect if this is SDXL based on UNet config self.is_sdxl = self._detect_sdxl_architecture(unet) - + # SDXL ControlNet has different structure than SD1.5 if self.is_sdxl: # SDXL has 1 initial + 3 down blocks producing 9 control tensors total @@ -32,30 +34,30 @@ def __init__(self, unet: UNet2DConditionModel, control_input_names: List[str], k else: # SD1.5 has 12 down blocks self.expected_down_blocks = 12 - + def _detect_sdxl_architecture(self, unet): """Detect if UNet is SDXL based on architecture""" - if hasattr(unet, 'config'): + if hasattr(unet, "config"): config = unet.config # SDXL has 3 down blocks vs SD1.5's 4 down blocks - block_out_channels = getattr(config, 'block_out_channels', None) + block_out_channels = getattr(config, "block_out_channels", None) if block_out_channels and len(block_out_channels) == 3: return True return False - + def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): """Forward pass that organizes control inputs and calls UNet""" - - control_args = args[:self.num_controlnet_args] - kvo_cache = args[self.num_controlnet_args:] - + + control_args = args[: self.num_controlnet_args] + kvo_cache = args[self.num_controlnet_args :] + down_block_controls = [] mid_block_control = None - + if control_args: all_control_tensors = [] middle_tensor = None - + for tensor, name in zip(control_args, self.control_names): if "input_control" in name: if "middle" in name: @@ -64,7 +66,7 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): all_control_tensors.append(tensor) elif "middle_control" in name: middle_tensor = tensor - + if len(all_control_tensors) == self.expected_down_blocks: down_block_controls = all_control_tensors mid_block_control = middle_tensor @@ -73,7 +75,7 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): if len(all_control_tensors) > 0: if len(all_control_tensors) > self.expected_down_blocks: # Too many tensors - take the first expected_down_blocks - down_block_controls = all_control_tensors[:self.expected_down_blocks] + down_block_controls = all_control_tensors[: self.expected_down_blocks] else: # Too few tensors - use what we have down_block_controls = all_control_tensors @@ -82,30 +84,29 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): # No control tensors available - skip ControlNet down_block_controls = None mid_block_control = None - + formatted_kvo_cache = [] if len(kvo_cache) > 0: formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure) unet_kwargs = { - 'sample': sample, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'kvo_cache': formatted_kvo_cache, - 'return_dict': False, + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "kvo_cache": formatted_kvo_cache, + "return_dict": False, } - + # Pass through all additional kwargs (for SDXL models) unet_kwargs.update(kwargs) # Auto-generate SDXL conditioning if missing and UNet requires it - if 'added_cond_kwargs' not in unet_kwargs or unet_kwargs.get('added_cond_kwargs') is None: - if (hasattr(self.unet, 'config') and - getattr(self.unet.config, 'addition_embed_type', None) == 'text_time'): + if "added_cond_kwargs" not in unet_kwargs or unet_kwargs.get("added_cond_kwargs") is None: + if hasattr(self.unet, "config") and getattr(self.unet.config, "addition_embed_type", None) == "text_time": batch_size = sample.shape[0] - unet_kwargs['added_cond_kwargs'] = { - 'text_embeds': torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), + unet_kwargs["added_cond_kwargs"] = { + "text_embeds": torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), } if down_block_controls: @@ -115,30 +116,30 @@ def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): # Control tensors are now generated in the correct order to match UNet's down_block_res_samples # For SDXL: [88x88, 88x88, 88x88, 44x44, 44x44, 44x44, 22x22, 22x22, 22x22] # This directly aligns with UNet's: [initial_sample] + [block0_residuals] + [block1_residuals] + [block2_residuals] - unet_kwargs['down_block_additional_residuals'] = adapted_controls - + unet_kwargs["down_block_additional_residuals"] = adapted_controls + if mid_block_control is not None: # Adapt middle control tensor shape if needed adapted_mid_control = self._adapt_middle_control_tensor(mid_block_control, sample) - unet_kwargs['mid_block_additional_residual'] = adapted_mid_control - + unet_kwargs["mid_block_additional_residual"] = adapted_mid_control + try: res = self.unet(**unet_kwargs) if len(kvo_cache) > 0: return res else: return res[0] - except Exception as e: + except Exception: raise - + def _adapt_control_tensors(self, control_tensors, sample): """Adapt control tensor shapes to match UNet expectations""" if not control_tensors: return control_tensors - + adapted_tensors = [] sample_height, sample_width = sample.shape[-2:] - + # Updated factors to match the corrected control tensor generation # SDXL: 9 tensors [88x88, 88x88, 88x88, 44x44, 44x44, 44x44, 22x22, 22x22, 22x22] # Factors: [1, 1, 1, 2, 2, 2, 4, 4, 4] to match UNet down_block_res_samples structure @@ -146,30 +147,31 @@ def _adapt_control_tensors(self, control_tensors, sample): expected_downsample_factors = [1, 1, 1, 2, 2, 2, 4, 4, 4] # 9 tensors for SDXL else: expected_downsample_factors = [1, 1, 1, 2, 2, 2, 4, 4, 4, 8, 8, 8] # 12 tensors for SD1.5 - + for i, control_tensor in enumerate(control_tensors): if control_tensor is None: adapted_tensors.append(control_tensor) continue - + # Check if tensor needs spatial adaptation if len(control_tensor.shape) >= 4: control_height, control_width = control_tensor.shape[-2:] - + # Use the correct downsampling factor for this tensor index if i < len(expected_downsample_factors): downsample_factor = expected_downsample_factors[i] expected_height = sample_height // downsample_factor expected_width = sample_width // downsample_factor - + if control_height != expected_height or control_width != expected_width: # Use interpolation to adapt size import torch.nn.functional as F + adapted_tensor = F.interpolate( - control_tensor, + control_tensor, size=(expected_height, expected_width), - mode='bilinear', - align_corners=False + mode="bilinear", + align_corners=False, ) adapted_tensors.append(adapted_tensor) else: @@ -179,94 +181,94 @@ def _adapt_control_tensors(self, control_tensors, sample): adapted_tensors.append(control_tensor) else: adapted_tensors.append(control_tensor) - + return adapted_tensors - + def _adapt_middle_control_tensor(self, mid_control, sample): """Adapt middle control tensor shape to match UNet expectations""" if mid_control is None: return mid_control - + # Middle control is typically at the bottleneck, so heavily downsampled if len(mid_control.shape) >= 4 and len(sample.shape) >= 4: sample_height, sample_width = sample.shape[-2:] control_height, control_width = mid_control.shape[-2:] - + # For SDXL: middle block is at 4x downsampling (22x22 from 88x88) # For SD1.5: middle block is at 8x downsampling expected_factor = 4 if self.is_sdxl else 8 expected_height = sample_height // expected_factor expected_width = sample_width // expected_factor - + if control_height != expected_height or control_width != expected_width: import torch.nn.functional as F + adapted_tensor = F.interpolate( - mid_control, - size=(expected_height, expected_width), - mode='bilinear', - align_corners=False + mid_control, size=(expected_height, expected_width), mode="bilinear", align_corners=False ) return adapted_tensor - + return mid_control class MultiControlNetUNetExportWrapper(torch.nn.Module): """Advanced wrapper for multiple ControlNets with different scales""" - - def __init__(self, - unet: UNet2DConditionModel, - control_input_names: List[str], - kvo_cache_structure: List[int], - num_controlnets: int = 1, - conditioning_scales: Optional[List[float]] = None): + + def __init__( + self, + unet: UNet2DConditionModel, + control_input_names: List[str], + kvo_cache_structure: List[int], + num_controlnets: int = 1, + conditioning_scales: Optional[List[float]] = None, + ): super().__init__() self.unet = unet self.control_input_names = control_input_names self.num_controlnets = num_controlnets self.conditioning_scales = conditioning_scales or [1.0] * num_controlnets self.kvo_cache_structure = kvo_cache_structure - + self.control_names = [] for name in control_input_names: if "input_control" in name or "output_control" in name or "middle_control" in name: self.control_names.append(name) - + self.num_controlnet_args = len(self.control_names) self.controlnet_indices = [] controls_per_net = self.num_controlnet_args // num_controlnets - + for cn_idx in range(num_controlnets): start_idx = cn_idx * controls_per_net end_idx = start_idx + controls_per_net self.controlnet_indices.append(list(range(start_idx, end_idx))) - + def forward(self, sample, timestep, encoder_hidden_states, *args): """Forward pass for multiple ControlNets""" - control_args = args[:self.num_controlnet_args] - kvo_cache = args[self.num_controlnet_args:] + control_args = args[: self.num_controlnet_args] + kvo_cache = args[self.num_controlnet_args :] combined_down_controls = None combined_mid_control = None - + for cn_idx, indices in enumerate(self.controlnet_indices): scale = self.conditioning_scales[cn_idx] if scale == 0: continue - + cn_controls = [control_args[i] for i in indices if i < len(control_args)] - + if not cn_controls: continue - + num_down = len(cn_controls) - 1 down_controls = cn_controls[:num_down] mid_control = cn_controls[num_down] if num_down < len(cn_controls) else None - + scaled_down = [ctrl * scale for ctrl in down_controls] scaled_mid = mid_control * scale if mid_control is not None else None - + if combined_down_controls is None: combined_down_controls = scaled_down combined_mid_control = scaled_mid @@ -275,24 +277,24 @@ def forward(self, sample, timestep, encoder_hidden_states, *args): combined_down_controls[i] += scaled_down[i] if scaled_mid is not None and combined_mid_control is not None: combined_mid_control += scaled_mid - + formatted_kvo_cache = [] if len(kvo_cache) > 0: formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure) unet_kwargs = { - 'sample': sample, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'kvo_cache': formatted_kvo_cache, - 'return_dict': False, + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "kvo_cache": formatted_kvo_cache, + "return_dict": False, } - + if combined_down_controls: - unet_kwargs['down_block_additional_residuals'] = list(reversed(combined_down_controls)) + unet_kwargs["down_block_additional_residuals"] = list(reversed(combined_down_controls)) if combined_mid_control is not None: - unet_kwargs['mid_block_additional_residual'] = combined_mid_control - + unet_kwargs["mid_block_additional_residual"] = combined_mid_control + res = self.unet(**unet_kwargs) if len(kvo_cache) > 0: return res @@ -301,11 +303,13 @@ def forward(self, sample, timestep, encoder_hidden_states, *args): return res -def create_controlnet_wrapper(unet: UNet2DConditionModel, - control_input_names: List[str], - kvo_cache_structure: List[int], - num_controlnets: int = 1, - conditioning_scales: Optional[List[float]] = None) -> torch.nn.Module: +def create_controlnet_wrapper( + unet: UNet2DConditionModel, + control_input_names: List[str], + kvo_cache_structure: List[int], + num_controlnets: int = 1, + conditioning_scales: Optional[List[float]] = None, +) -> torch.nn.Module: """Factory function to create appropriate ControlNet wrapper""" if num_controlnets == 1: return ControlNetUNetExportWrapper(unet, control_input_names, kvo_cache_structure) @@ -315,17 +319,18 @@ def create_controlnet_wrapper(unet: UNet2DConditionModel, ) -def organize_control_tensors(control_tensors: List[torch.Tensor], - control_input_names: List[str]) -> Dict[str, List[torch.Tensor]]: +def organize_control_tensors( + control_tensors: List[torch.Tensor], control_input_names: List[str] +) -> Dict[str, List[torch.Tensor]]: """Organize control tensors by type (input, output, middle)""" - organized = {'input': [], 'output': [], 'middle': []} - + organized = {"input": [], "output": [], "middle": []} + for tensor, name in zip(control_tensors, control_input_names): if "input_control" in name: - organized['input'].append(tensor) + organized["input"].append(tensor) elif "output_control" in name: - organized['output'].append(tensor) + organized["output"].append(tensor) elif "middle_control" in name: - organized['middle'].append(tensor) - - return organized \ No newline at end of file + organized["middle"].append(tensor) + + return organized diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py index f7eb18615..93310ad87 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_ipadapter_export.py @@ -1,30 +1,37 @@ +from typing import List + import torch from diffusers import UNet2DConditionModel -from typing import Optional, Dict, Any, List - -from ....model_detection import detect_model, detect_model_from_diffusers_unet from diffusers_ipadapter.ip_adapter.attention_processor import TRTIPAttnProcessor, TRTIPAttnProcessor2_0 +from ....model_detection import detect_model_from_diffusers_unet + class IPAdapterUNetExportWrapper(torch.nn.Module): """ Wrapper that bakes IPAdapter attention processors into the UNet for ONNX export. - + This approach installs IPAdapter attention processors before ONNX export, allowing the specialized attention logic to be compiled into TensorRT. The UNet expects concatenated embeddings (text + image) as encoder_hidden_states. """ - - def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tokens: int = 4, install_processors: bool = True): + + def __init__( + self, + unet: UNet2DConditionModel, + cross_attention_dim: int, + num_tokens: int = 4, + install_processors: bool = True, + ): super().__init__() self.unet = unet self.num_image_tokens = num_tokens # 4 for standard, 16 for plus self.cross_attention_dim = cross_attention_dim # 768 for SD1.5, 2048 for SDXL self.install_processors = install_processors - + # Convert to float32 BEFORE installing processors (to avoid resetting them) self.unet = self.unet.to(dtype=torch.float32) - + # Track installed TRT processors self._ip_trt_processors: List[torch.nn.Module] = [] self.num_ip_layers: int = 0 @@ -36,8 +43,10 @@ def __init__(self, unet: UNet2DConditionModel, cross_attention_dim: int, num_tok # Install IPAdapter processors AFTER dtype conversion self._install_ipadapter_processors() else: - print("IPAdapterUNetExportWrapper: WARNING - UNet will not have IPAdapter functionality without processors!") - + print( + "IPAdapterUNetExportWrapper: WARNING - UNet will not have IPAdapter functionality without processors!" + ) + def _has_ipadapter_processors(self) -> bool: """Check if the UNet already has IPAdapter processors installed""" try: @@ -45,44 +54,48 @@ def _has_ipadapter_processors(self) -> bool: for name, processor in processors.items(): # Check for IPAdapter processor class names processor_class = processor.__class__.__name__ - if 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class: + if "IPAttn" in processor_class or "IPAttnProcessor" in processor_class: return True return False except Exception as e: print(f"IPAdapterUNetExportWrapper: Error checking existing processors: {e}") return False - + def _ensure_processor_dtype_consistency(self): """Ensure existing IPAdapter processors have correct dtype for ONNX export""" if hasattr(torch.nn.functional, "scaled_dot_product_attention"): from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor + IPProcClass = TRTIPAttnProcessor2_0 else: from diffusers.models.attention_processor import AttnProcessor + IPProcClass = TRTIPAttnProcessor try: processors = self.unet.attn_processors updated_processors = {} self._ip_trt_processors = [] ip_layer_index = 0 - + for name, processor in processors.items(): processor_class = processor.__class__.__name__ - if 'TRTIPAttn' in processor_class: + if "TRTIPAttn" in processor_class: # Already TRT processors: ensure dtype and record proc = processor.to(dtype=torch.float32) proc._scale_index = ip_layer_index self._ip_trt_processors.append(proc) ip_layer_index += 1 updated_processors[name] = proc - elif 'IPAttn' in processor_class or 'IPAttnProcessor' in processor_class: + elif "IPAttn" in processor_class or "IPAttnProcessor" in processor_class: # Replace standard processors with TRT variants, preserving weights where applicable - hidden_size = getattr(processor, 'hidden_size', None) - cross_attention_dim = getattr(processor, 'cross_attention_dim', None) - num_tokens = getattr(processor, 'num_tokens', self.num_image_tokens) - proc = IPProcClass(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) + hidden_size = getattr(processor, "hidden_size", None) + cross_attention_dim = getattr(processor, "cross_attention_dim", None) + num_tokens = getattr(processor, "num_tokens", self.num_image_tokens) + proc = IPProcClass( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens + ) # Copy IP projection weights if present - if hasattr(processor, 'to_k_ip') and hasattr(processor, 'to_v_ip') and hasattr(proc, 'to_k_ip'): + if hasattr(processor, "to_k_ip") and hasattr(processor, "to_v_ip") and hasattr(proc, "to_k_ip"): with torch.no_grad(): proc.to_k_ip.weight.copy_(processor.to_k_ip.weight.to(dtype=torch.float32)) proc.to_v_ip.weight.copy_(processor.to_v_ip.weight.to(dtype=torch.float32)) @@ -93,16 +106,17 @@ def _ensure_processor_dtype_consistency(self): updated_processors[name] = proc else: updated_processors[name] = AttnProcessor() - + # Update all processors to ensure consistency self.unet.set_attn_processor(updated_processors) self.num_ip_layers = len(self._ip_trt_processors) - + except Exception as e: print(f"IPAdapterUNetExportWrapper: Error updating processor dtypes: {e}") import traceback + traceback.print_exc() - + def _install_ipadapter_processors(self): """ Install IPAdapter attention processors that will be baked into ONNX. @@ -112,19 +126,23 @@ def _install_ipadapter_processors(self): try: if hasattr(torch.nn.functional, "scaled_dot_product_attention"): from diffusers.models.attention_processor import AttnProcessor2_0 as AttnProcessor + IPProcClass = TRTIPAttnProcessor2_0 else: from diffusers.models.attention_processor import AttnProcessor + IPProcClass = TRTIPAttnProcessor - + # Install attention processors with proper configuration processor_names = list(self.unet.attn_processors.keys()) - + attn_procs = {} ip_layer_index = 0 for name in processor_names: - cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim - + cross_attention_dim = ( + None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + ) + # Determine hidden_size based on processor location hidden_size = None if name.startswith("mid_block"): @@ -138,7 +156,7 @@ def _install_ipadapter_processors(self): else: # Fallback for any unexpected processor names hidden_size = self.unet.config.block_out_channels[0] # Use first block size as fallback - + if cross_attention_dim is None: # Self-attention layers use standard processors attn_procs[name] = AttnProcessor() @@ -154,38 +172,46 @@ def _install_ipadapter_processors(self): self._ip_trt_processors.append(proc) ip_layer_index += 1 attn_procs[name] = proc - + self.unet.set_attn_processor(attn_procs) self.num_ip_layers = len(self._ip_trt_processors) - - except Exception as e: print(f"IPAdapterUNetExportWrapper: ERROR - Could not install IPAdapter processors: {e}") print(f"IPAdapterUNetExportWrapper: Exception type: {type(e).__name__}") print("IPAdapterUNetExportWrapper: IPAdapter functionality will not work without processors!") import traceback + traceback.print_exc() raise e - + def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None: """Assign per-layer scale tensor to installed TRTIPAttn processors.""" if not isinstance(ipadapter_scale, torch.Tensor): import logging - logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}") + + logging.getLogger(__name__).error( + f"IPAdapterUNetExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}" + ) raise TypeError("ipadapter_scale must be a torch.Tensor") if self.num_ip_layers <= 0 or not self._ip_trt_processors: raise RuntimeError("No TRTIPAttn processors installed") if ipadapter_scale.ndim != 1 or ipadapter_scale.shape[0] != self.num_ip_layers: import logging - logging.getLogger(__name__).error(f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)") + + logging.getLogger(__name__).error( + f"IPAdapterUNetExportWrapper: ipadapter_scale has wrong shape {tuple(ipadapter_scale.shape)}, expected=({self.num_ip_layers},)" + ) raise ValueError(f"ipadapter_scale must have shape [{self.num_ip_layers}]") # Ensure float32 for ONNX export stability scale_vec = ipadapter_scale.to(dtype=torch.float32) try: import logging - logging.getLogger(__name__).debug(f"IPAdapterUNetExportWrapper: scale_vec min={scale_vec.min()}, max={scale_vec.max()}") + + logging.getLogger(__name__).debug( + f"IPAdapterUNetExportWrapper: scale_vec min={scale_vec.min()}, max={scale_vec.max()}" + ) except Exception: pass for proc in self._ip_trt_processors: @@ -194,27 +220,27 @@ def set_ipadapter_scale(self, ipadapter_scale: torch.Tensor) -> None: def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torch.Tensor = None): """ Forward pass with concatenated embeddings (text + image). - + The IPAdapter processors installed in the UNet will automatically: 1. Split the concatenated embeddings into text and image parts 2. Process image tokens with separate attention computation 3. Apply scaling and blending between text and image attention - + Args: sample: Latent input tensor - timestep: Timestep tensor + timestep: Timestep tensor encoder_hidden_states: Concatenated embeddings [text_tokens + image_tokens, cross_attention_dim] - + Returns: UNet output (noise prediction) """ # Validate input shapes batch_size, seq_len, embed_dim = encoder_hidden_states.shape - + # Check that we have the expected number of image tokens if embed_dim != self.cross_attention_dim: raise ValueError(f"Embedding dimension {embed_dim} doesn't match expected {self.cross_attention_dim}") - + # Ensure dtype consistency for ONNX export if encoder_hidden_states.dtype != torch.float32: encoder_hidden_states = encoder_hidden_states.to(torch.float32) @@ -223,29 +249,28 @@ def forward(self, sample, timestep, encoder_hidden_states, ipadapter_scale: torc if ipadapter_scale is None: raise RuntimeError("IPAdapterUNetExportWrapper.forward requires ipadapter_scale tensor") self.set_ipadapter_scale(ipadapter_scale) - + # Pass concatenated embeddings to UNet with baked-in IPAdapter processors return self.unet( - sample=sample, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - return_dict=False + sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, return_dict=False ) -def create_ipadapter_wrapper(unet: UNet2DConditionModel, num_tokens: int = 4, install_processors: bool = True) -> IPAdapterUNetExportWrapper: +def create_ipadapter_wrapper( + unet: UNet2DConditionModel, num_tokens: int = 4, install_processors: bool = True +) -> IPAdapterUNetExportWrapper: """ Create an IPAdapter wrapper with automatic architecture detection and baked-in processors. - + Handles both cases: 1. UNet with pre-loaded IPAdapter processors (preserves existing weights) 2. UNet without IPAdapter processors (installs new ones if install_processors=True) - + Args: unet: UNet2DConditionModel to wrap num_tokens: Number of image tokens (4 for standard, 16 for plus) install_processors: Whether to install IPAdapter processors if none exist - + Returns: IPAdapterUNetExportWrapper with baked-in IPAdapter attention processors """ @@ -253,23 +278,21 @@ def create_ipadapter_wrapper(unet: UNet2DConditionModel, num_tokens: int = 4, in try: model_type = detect_model_from_diffusers_unet(unet) cross_attention_dim = unet.config.cross_attention_dim - + # Check if UNet already has IPAdapter processors installed existing_processors = unet.attn_processors - has_ipadapter = any('IPAttn' in proc.__class__.__name__ or 'IPAttnProcessor' in proc.__class__.__name__ - for proc in existing_processors.values()) - + has_ipadapter = any( + "IPAttn" in proc.__class__.__name__ or "IPAttnProcessor" in proc.__class__.__name__ + for proc in existing_processors.values() + ) + # Validate expected dimensions - expected_dims = { - "SD15": 768, - "SDXL": 2048, - "SD21": 1024 - } - + expected_dims = {"SD15": 768, "SDXL": 2048, "SD21": 1024} + expected_dim = expected_dims.get(model_type) - + return IPAdapterUNetExportWrapper(unet, cross_attention_dim, num_tokens, install_processors) - + except Exception as e: print(f"create_ipadapter_wrapper: Error during model detection: {e}") - return IPAdapterUNetExportWrapper(unet, 768, num_tokens, install_processors) \ No newline at end of file + return IPAdapterUNetExportWrapper(unet, 768, num_tokens, install_processors) diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py index fa1f0f890..078b1f913 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_sdxl_export.py @@ -4,13 +4,17 @@ conditioning parameters, and Turbo variants """ +import logging +from typing import Any, Dict + import torch -from typing import Dict, List, Optional, Tuple, Any, Union from diffusers import UNet2DConditionModel + from ....model_detection import ( detect_model, ) -import logging + + logger = logging.getLogger(__name__) # Handle different diffusers versions for CLIPTextModel import @@ -18,7 +22,7 @@ from diffusers.models.transformers.clip_text_model import CLIPTextModel except ImportError: try: - from diffusers.models.clip_text_model import CLIPTextModel + from diffusers.models.clip_text_model import CLIPTextModel except ImportError: try: from transformers import CLIPTextModel @@ -29,79 +33,81 @@ class SDXLExportWrapper(torch.nn.Module): """Wrapper for SDXL UNet to handle optional conditioning in legacy TensorRT""" - + def __init__(self, unet): super().__init__() self.unet = unet self.base_unet = self._get_base_unet(unet) self.supports_added_cond = self._test_added_cond_support() - + def _get_base_unet(self, unet): """Extract the base UNet from wrappers""" # Handle ControlNet wrapper - if hasattr(unet, 'unet_model') and hasattr(unet.unet_model, 'config'): + if hasattr(unet, "unet_model") and hasattr(unet.unet_model, "config"): return unet.unet_model - elif hasattr(unet, 'unet') and hasattr(unet.unet, 'config'): + elif hasattr(unet, "unet") and hasattr(unet.unet, "config"): return unet.unet - elif hasattr(unet, 'config'): + elif hasattr(unet, "config"): return unet else: # Fallback: try to find any attribute that has config for attr_name in dir(unet): - if not attr_name.startswith('_'): + if not attr_name.startswith("_"): attr = getattr(unet, attr_name, None) - if hasattr(attr, 'config') and hasattr(attr.config, 'addition_embed_type'): + if hasattr(attr, "config") and hasattr(attr.config, "addition_embed_type"): return attr return unet - + def _test_added_cond_support(self): """Test if this SDXL model supports added_cond_kwargs""" try: # Create minimal test inputs - sample = torch.randn(1, 4, 8, 8, device='cuda', dtype=torch.float16) - timestep = torch.tensor([0.5], device='cuda', dtype=torch.float32) - encoder_hidden_states = torch.randn(1, 77, 2048, device='cuda', dtype=torch.float16) - + sample = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) + timestep = torch.tensor([0.5], device="cuda", dtype=torch.float32) + encoder_hidden_states = torch.randn(1, 77, 2048, device="cuda", dtype=torch.float16) + # Test with added_cond_kwargs test_added_cond = { - 'text_embeds': torch.randn(1, 1280, device='cuda', dtype=torch.float16), - 'time_ids': torch.randn(1, 6, device='cuda', dtype=torch.float16) + "text_embeds": torch.randn(1, 1280, device="cuda", dtype=torch.float16), + "time_ids": torch.randn(1, 6, device="cuda", dtype=torch.float16), } - + with torch.no_grad(): _ = self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=test_added_cond) - + logger.info("SDXL model supports added_cond_kwargs") return True - + except Exception as e: logger.error(f"SDXL model does not support added_cond_kwargs: {e}") return False - + def forward(self, *args, **kwargs): """Forward pass that handles SDXL conditioning gracefully""" try: # Ensure added_cond_kwargs is never None to prevent TypeError - if 'added_cond_kwargs' in kwargs and kwargs['added_cond_kwargs'] is None: - kwargs['added_cond_kwargs'] = {} - + if "added_cond_kwargs" in kwargs and kwargs["added_cond_kwargs"] is None: + kwargs["added_cond_kwargs"] = {} + # Auto-generate SDXL conditioning if missing and model needs it - if (len(args) >= 3 and 'added_cond_kwargs' not in kwargs and - hasattr(self.base_unet.config, 'addition_embed_type') and - self.base_unet.config.addition_embed_type == 'text_time'): - + if ( + len(args) >= 3 + and "added_cond_kwargs" not in kwargs + and hasattr(self.base_unet.config, "addition_embed_type") + and self.base_unet.config.addition_embed_type == "text_time" + ): sample = args[0] device = sample.device batch_size = sample.shape[0] - + logger.info("Auto-generating required SDXL conditioning...") - kwargs['added_cond_kwargs'] = { - 'text_embeds': torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=device, dtype=sample.dtype) + kwargs["added_cond_kwargs"] = { + "text_embeds": torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=device, dtype=sample.dtype), } - + # If model supports added conditioning and we have the kwargs, use them - if self.supports_added_cond and 'added_cond_kwargs' in kwargs: + if self.supports_added_cond and "added_cond_kwargs" in kwargs: result = self.unet(*args, **kwargs) return result elif len(args) >= 3: @@ -110,7 +116,7 @@ def forward(self, *args, **kwargs): else: # Fallback return self.unet(*args, **kwargs) - + except (TypeError, AttributeError) as e: logger.error(f"[SDXL_WRAPPER] forward: Exception caught: {e}") if "NoneType" in str(e) or "iterable" in str(e) or "text_embeds" in str(e): @@ -120,15 +126,17 @@ def forward(self, *args, **kwargs): sample, timestep, encoder_hidden_states = args[0], args[1], args[2] device = sample.device batch_size = sample.shape[0] - + # Create minimal valid SDXL conditioning minimal_conditioning = { - 'text_embeds': torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=device, dtype=sample.dtype) + "text_embeds": torch.zeros(batch_size, 1280, device=device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=device, dtype=sample.dtype), } - + try: - return self.unet(sample, timestep, encoder_hidden_states, added_cond_kwargs=minimal_conditioning) + return self.unet( + sample, timestep, encoder_hidden_states, added_cond_kwargs=minimal_conditioning + ) except Exception as final_e: logger.info(f"Final fallback to basic call: {final_e}") return self.unet(sample, timestep, encoder_hidden_states) @@ -136,181 +144,180 @@ def forward(self, *args, **kwargs): return self.unet(*args) else: raise e - + + class SDXLConditioningHandler: """Handles SDXL conditioning parameters and dual text encoders""" - + def __init__(self, unet_info: Dict[str, Any]): self.unet_info = unet_info - self.is_sdxl = unet_info['is_sdxl'] - self.has_time_cond = unet_info['has_time_cond'] - self.has_addition_embed = unet_info['has_addition_embed'] - + self.is_sdxl = unet_info["is_sdxl"] + self.has_time_cond = unet_info["has_time_cond"] + self.has_addition_embed = unet_info["has_addition_embed"] + def get_conditioning_spec(self) -> Dict[str, Any]: """Get conditioning specification for ONNX export and TensorRT""" spec = { - 'text_encoder_dim': 768, # CLIP ViT-L - 'context_dim': 768, # Default SD1.5 - 'pooled_embeds': False, - 'time_ids': False, - 'dual_encoders': False + "text_encoder_dim": 768, # CLIP ViT-L + "context_dim": 768, # Default SD1.5 + "pooled_embeds": False, + "time_ids": False, + "dual_encoders": False, } - + if self.is_sdxl: - spec.update({ - 'text_encoder_dim': 768, # CLIP ViT-L - 'text_encoder_2_dim': 1280, # OpenCLIP ViT-bigG - 'context_dim': 2048, # Concatenated 768 + 1280 - 'pooled_embeds': True, # Pooled text embeddings - 'time_ids': self.has_time_cond, # Size/crop conditioning - 'dual_encoders': True - }) - + spec.update( + { + "text_encoder_dim": 768, # CLIP ViT-L + "text_encoder_2_dim": 1280, # OpenCLIP ViT-bigG + "context_dim": 2048, # Concatenated 768 + 1280 + "pooled_embeds": True, # Pooled text embeddings + "time_ids": self.has_time_cond, # Size/crop conditioning + "dual_encoders": True, + } + ) + return spec - - def create_sample_conditioning(self, batch_size: int = 1, device: str = 'cuda') -> Dict[str, torch.Tensor]: + + def create_sample_conditioning(self, batch_size: int = 1, device: str = "cuda") -> Dict[str, torch.Tensor]: """Create sample conditioning tensors for testing/export""" spec = self.get_conditioning_spec() dtype = torch.float16 - + conditioning = { - 'encoder_hidden_states': torch.randn( - batch_size, 77, spec['context_dim'], - device=device, dtype=dtype - ) + "encoder_hidden_states": torch.randn(batch_size, 77, spec["context_dim"], device=device, dtype=dtype) } - - if spec['pooled_embeds']: - conditioning['text_embeds'] = torch.randn( - batch_size, spec['text_encoder_2_dim'], - device=device, dtype=dtype + + if spec["pooled_embeds"]: + conditioning["text_embeds"] = torch.randn( + batch_size, spec["text_encoder_2_dim"], device=device, dtype=dtype ) - - if spec['time_ids']: - conditioning['time_ids'] = torch.randn( - batch_size, 6, # [height, width, crop_h, crop_w, target_height, target_width] - device=device, dtype=dtype + + if spec["time_ids"]: + conditioning["time_ids"] = torch.randn( + batch_size, + 6, # [height, width, crop_h, crop_w, target_height, target_width] + device=device, + dtype=dtype, ) - + return conditioning - + def test_unet_conditioning(self, unet: UNet2DConditionModel) -> Dict[str, bool]: """Test what conditioning the UNet actually supports""" - results = { - 'basic': False, - 'added_cond_kwargs': False, - 'separate_args': False - } - + results = {"basic": False, "added_cond_kwargs": False, "separate_args": False} + try: # Ensure model is on CUDA and in eval mode for testing - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" unet_test = unet.to(device).eval() - + # Create test inputs on the same device sample = torch.randn(1, 4, 8, 8, device=device, dtype=torch.float16) timestep = torch.tensor([0.5], device=device, dtype=torch.float32) conditioning = self.create_sample_conditioning(1, device=device) - + # Test basic call try: with torch.no_grad(): - _ = unet_test(sample, timestep, conditioning['encoder_hidden_states']) - results['basic'] = True + _ = unet_test(sample, timestep, conditioning["encoder_hidden_states"]) + results["basic"] = True except Exception: pass - + # Test added_cond_kwargs (standard SDXL) if self.is_sdxl: try: added_cond = {} - if 'text_embeds' in conditioning: - added_cond['text_embeds'] = conditioning['text_embeds'] - if 'time_ids' in conditioning: - added_cond['time_ids'] = conditioning['time_ids'] - + if "text_embeds" in conditioning: + added_cond["text_embeds"] = conditioning["text_embeds"] + if "time_ids" in conditioning: + added_cond["time_ids"] = conditioning["time_ids"] + with torch.no_grad(): - _ = unet_test(sample, timestep, conditioning['encoder_hidden_states'], - added_cond_kwargs=added_cond) - results['added_cond_kwargs'] = True + _ = unet_test( + sample, timestep, conditioning["encoder_hidden_states"], added_cond_kwargs=added_cond + ) + results["added_cond_kwargs"] = True except Exception: pass - + # Test separate arguments (some implementations) try: - args = [sample, timestep, conditioning['encoder_hidden_states']] - if 'text_embeds' in conditioning: - args.append(conditioning['text_embeds']) - if 'time_ids' in conditioning: - args.append(conditioning['time_ids']) - + args = [sample, timestep, conditioning["encoder_hidden_states"]] + if "text_embeds" in conditioning: + args.append(conditioning["text_embeds"]) + if "time_ids" in conditioning: + args.append(conditioning["time_ids"]) + with torch.no_grad(): _ = unet_test(*args) - results['separate_args'] = True + results["separate_args"] = True except Exception: pass - + except Exception as e: # If testing fails completely, provide safe defaults print(f"⚠️ UNet conditioning test setup failed: {e}") results = { - 'basic': True, # Assume basic call works - 'added_cond_kwargs': self.is_sdxl, # Assume SDXL models support this - 'separate_args': False + "basic": True, # Assume basic call works + "added_cond_kwargs": self.is_sdxl, # Assume SDXL models support this + "separate_args": False, } - + return results def get_onnx_export_spec(self) -> Dict[str, Any]: """Get specification for ONNX export""" spec = self.conditioning_handler.get_conditioning_spec() - + # Add export-specific details - spec.update({ - 'input_names': ['sample', 'timestep', 'encoder_hidden_states'], - 'output_names': ['noise_pred'], - 'dynamic_axes': { - 'sample': {0: 'batch_size'}, - 'timestep': {0: 'batch_size'}, - 'encoder_hidden_states': {0: 'batch_size'}, - 'noise_pred': {0: 'batch_size'} + spec.update( + { + "input_names": ["sample", "timestep", "encoder_hidden_states"], + "output_names": ["noise_pred"], + "dynamic_axes": { + "sample": {0: "batch_size"}, + "timestep": {0: "batch_size"}, + "encoder_hidden_states": {0: "batch_size"}, + "noise_pred": {0: "batch_size"}, + }, } - }) - + ) + # Add SDXL-specific inputs if supported - if self.is_sdxl and self.supported_calls['added_cond_kwargs']: - if spec['pooled_embeds']: - spec['input_names'].append('text_embeds') - spec['dynamic_axes']['text_embeds'] = {0: 'batch_size'} - - if spec['time_ids']: - spec['input_names'].append('time_ids') - spec['dynamic_axes']['time_ids'] = {0: 'batch_size'} - - return spec + if self.is_sdxl and self.supported_calls["added_cond_kwargs"]: + if spec["pooled_embeds"]: + spec["input_names"].append("text_embeds") + spec["dynamic_axes"]["text_embeds"] = {0: "batch_size"} + + if spec["time_ids"]: + spec["input_names"].append("time_ids") + spec["dynamic_axes"]["time_ids"] = {0: "batch_size"} + return spec def get_sdxl_tensorrt_config(model_path: str, unet: UNet2DConditionModel) -> Dict[str, Any]: """Get complete TensorRT configuration for SDXL model""" # Use the new detection function detection_result = detect_model(unet) - + # Create a config dict compatible with SDXLConditioningHandler config = { - 'is_sdxl': detection_result['is_sdxl'], - 'has_time_cond': detection_result['architecture_details']['has_time_conditioning'], - 'has_addition_embed': detection_result['architecture_details']['has_addition_embeds'], - 'model_type': detection_result['model_type'], - 'is_turbo': detection_result['is_turbo'], - 'is_sd3': detection_result['is_sd3'], - 'confidence': detection_result['confidence'], - 'architecture_details': detection_result['architecture_details'], - 'compatibility_info': detection_result['compatibility_info'] + "is_sdxl": detection_result["is_sdxl"], + "has_time_cond": detection_result["architecture_details"]["has_time_conditioning"], + "has_addition_embed": detection_result["architecture_details"]["has_addition_embeds"], + "model_type": detection_result["model_type"], + "is_turbo": detection_result["is_turbo"], + "is_sd3": detection_result["is_sd3"], + "confidence": detection_result["confidence"], + "architecture_details": detection_result["architecture_details"], + "compatibility_info": detection_result["compatibility_info"], } - + # Add conditioning specification conditioning_handler = SDXLConditioningHandler(config) - config['conditioning_spec'] = conditioning_handler.get_conditioning_spec() - - return config \ No newline at end of file + config["conditioning_spec"] = conditioning_handler.get_conditioning_spec() + + return config diff --git a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py index 1c87efbf9..cb4765b82 100644 --- a/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py +++ b/src/streamdiffusion/acceleration/tensorrt/export_wrappers/unet_unified_export.py @@ -1,23 +1,33 @@ +from typing import List, Optional + import torch from diffusers import UNet2DConditionModel -from typing import Optional, List + +from streamdiffusion._compat.diffusers_kvo_patch import apply as _apply_kvo_patch + + +_apply_kvo_patch() # ensure kvo_cache patch is present even if diffusers was imported first + +from ..models.utils import convert_list_to_structure from .unet_controlnet_export import create_controlnet_wrapper from .unet_ipadapter_export import create_ipadapter_wrapper -from ..models.utils import convert_list_to_structure + class UnifiedExportWrapper(torch.nn.Module): """ - Unified wrapper that composes wrappers for conditioning modules. + Unified wrapper that composes wrappers for conditioning modules. """ - - def __init__(self, - unet: UNet2DConditionModel, - use_controlnet: bool = False, - use_ipadapter: bool = False, - control_input_names: Optional[List[str]] = None, - num_tokens: int = 4, - kvo_cache_structure: List[int] = [], - **kwargs): + + def __init__( + self, + unet: UNet2DConditionModel, + use_controlnet: bool = False, + use_ipadapter: bool = False, + control_input_names: Optional[List[str]] = None, + num_tokens: int = 4, + kvo_cache_structure: List[int] = [], + **kwargs, + ): super().__init__() self.use_controlnet = use_controlnet self.use_ipadapter = use_ipadapter @@ -25,23 +35,24 @@ def __init__(self, self.ipadapter_wrapper = None self.unet = unet self.kvo_cache_structure = kvo_cache_structure - + # Apply IPAdapter first (installs processors into UNet) if use_ipadapter: - ipadapter_kwargs = {k: v for k, v in kwargs.items() if k in ['install_processors']} - if 'install_processors' not in ipadapter_kwargs: - ipadapter_kwargs['install_processors'] = True - + ipadapter_kwargs = {k: v for k, v in kwargs.items() if k in ["install_processors"]} + if "install_processors" not in ipadapter_kwargs: + ipadapter_kwargs["install_processors"] = True self.ipadapter_wrapper = create_ipadapter_wrapper(unet, num_tokens=num_tokens, **ipadapter_kwargs) self.unet = self.ipadapter_wrapper.unet - + # Apply ControlNet second (wraps whatever UNet we have) if use_controlnet and control_input_names: - controlnet_kwargs = {k: v for k, v in kwargs.items() if k in ['num_controlnets', 'conditioning_scales']} + controlnet_kwargs = {k: v for k, v in kwargs.items() if k in ["num_controlnets", "conditioning_scales"]} + + self.controlnet_wrapper = create_controlnet_wrapper( + self.unet, control_input_names, kvo_cache_structure, **controlnet_kwargs + ) - self.controlnet_wrapper = create_controlnet_wrapper(self.unet, control_input_names, kvo_cache_structure, **controlnet_kwargs) - def _basic_unet_forward(self, sample, timestep, encoder_hidden_states, *kvo_cache, **kwargs): """Basic UNet forward that passes through all parameters to handle any model type""" formatted_kvo_cache = [] @@ -49,52 +60,57 @@ def _basic_unet_forward(self, sample, timestep, encoder_hidden_states, *kvo_cach formatted_kvo_cache = convert_list_to_structure(kvo_cache, self.kvo_cache_structure) # Auto-generate SDXL conditioning if missing and UNet requires it - if 'added_cond_kwargs' not in kwargs or kwargs.get('added_cond_kwargs') is None: + if "added_cond_kwargs" not in kwargs or kwargs.get("added_cond_kwargs") is None: base_unet = self.unet - if (hasattr(base_unet, 'config') and - getattr(base_unet.config, 'addition_embed_type', None) == 'text_time'): + if hasattr(base_unet, "config") and getattr(base_unet.config, "addition_embed_type", None) == "text_time": batch_size = sample.shape[0] - kwargs['added_cond_kwargs'] = { - 'text_embeds': torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), - 'time_ids': torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), + kwargs["added_cond_kwargs"] = { + "text_embeds": torch.zeros(batch_size, 1280, device=sample.device, dtype=sample.dtype), + "time_ids": torch.zeros(batch_size, 6, device=sample.device, dtype=sample.dtype), } unet_kwargs = { - 'sample': sample, - 'timestep': timestep, - 'encoder_hidden_states': encoder_hidden_states, - 'return_dict': False, - 'kvo_cache': formatted_kvo_cache, - **kwargs # Pass through all additional parameters (SDXL, future model types, etc.) + "sample": sample, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "return_dict": False, + "kvo_cache": formatted_kvo_cache, + **kwargs, # Pass through all additional parameters (SDXL, future model types, etc.) } res = self.unet(**unet_kwargs) if len(kvo_cache) > 0: return res else: return res[0] - - def forward(self, - sample: torch.Tensor, - timestep: torch.Tensor, - encoder_hidden_states: torch.Tensor, - *args, - **kwargs) -> torch.Tensor: + + def forward( + self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: """Forward pass that handles any UNet parameters via **kwargs passthrough""" # Handle IP-Adapter runtime scale vector as a positional argument placed before control tensors if self.use_ipadapter and self.ipadapter_wrapper is not None: # ipadapter_scale is appended as the first extra positional input after the 3 base inputs if len(args) == 0: import logging - logging.getLogger(__name__).error("UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True") + + logging.getLogger(__name__).error( + "UnifiedExportWrapper: ipadapter_scale missing; required when use_ipadapter=True" + ) raise RuntimeError("UnifiedExportWrapper: ipadapter_scale tensor is required when use_ipadapter=True") ipadapter_scale = args[0] if not isinstance(ipadapter_scale, torch.Tensor): import logging - logging.getLogger(__name__).error(f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}") + + logging.getLogger(__name__).error( + f"UnifiedExportWrapper: ipadapter_scale wrong type: {type(ipadapter_scale)}" + ) raise TypeError("ipadapter_scale must be a torch.Tensor") try: import logging - logging.getLogger(__name__).debug(f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}") + + logging.getLogger(__name__).debug( + f"UnifiedExportWrapper: ipadapter_scale shape={tuple(ipadapter_scale.shape)}, dtype={ipadapter_scale.dtype}" + ) except Exception: pass # assign per-layer scale tensors into processors @@ -107,4 +123,4 @@ def forward(self, return self.controlnet_wrapper(sample, timestep, encoder_hidden_states, *args, **kwargs) else: # Basic UNet call with all parameters passed through - return self._basic_unet_forward(sample, timestep, encoder_hidden_states, *args, **kwargs) \ No newline at end of file + return self._basic_unet_forward(sample, timestep, encoder_hidden_states, *args, **kwargs) diff --git a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py index 762c638ac..7cd517f4f 100644 --- a/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py +++ b/src/streamdiffusion/acceleration/tensorrt/fp8_quantize.py @@ -302,6 +302,59 @@ def _read_onnx_input_specs(onnx_path: str) -> Dict[str, tuple]: return result +def _assert_finite_qdq_scales(onnx_path: str) -> None: + """Raise RuntimeError if any FP8 Q/DQ scale in onnx_path is non-finite. + + Root cause: modelopt/onnx/quantization/fp8.py computes + np_fp8_scale = (np_scale * 448.0) / 127.0 + in the source dtype (FP16). An INT8 amax > ~18,500 overflows to +inf; + the resulting Q/DQ scale produces zero output at inference. + + Checks both initializer scales AND Constant-node scales (the latter appear + on some residual-add Q nodes injected by modelopt on SDXL UNet). + On failure: raises RuntimeError listing the first 5 offending node names + and directs the user to extend _DEFAULT_EXCLUDE_PATTERNS and delete the + cached .fp8.onnx so modelopt reruns without those layers. + """ + import onnx as _onnx + from onnx import numpy_helper as _numpy_helper + + model = _onnx.load(onnx_path, load_external_data=True) + graph = model.graph + init_map = {init.name: init for init in graph.initializer} + const_map: dict = {} + for node in graph.node: + if node.op_type == "Constant" and node.output: + attr = {a.name: a for a in node.attribute} + if "value" in attr: + const_map[node.output[0]] = _numpy_helper.to_array(attr["value"].t) + + bad: list = [] + for node in graph.node: + if node.op_type not in ("QuantizeLinear", "DequantizeLinear"): + continue + if len(node.input) < 2: + continue + scale_name = node.input[1] + if scale_name in init_map: + arr = _numpy_helper.to_array(init_map[scale_name]).flatten().astype(np.float64) + elif scale_name in const_map: + arr = const_map[scale_name].flatten().astype(np.float64) + else: + continue + if not np.isfinite(arr).all(): + bad.append(node.name or scale_name) + + if bad: + names = ", ".join(bad[:5]) + ("..." if len(bad) > 5 else "") + raise RuntimeError( + f"[FP8] Non-finite Q/DQ scale in {len(bad)} node(s): {names}. " + f"Add the offending layer substring(s) to _DEFAULT_EXCLUDE_PATTERNS " + f"in fp8_quantize.py, delete the cached .fp8.onnx, and rebuild. " + f"Diagnostic: modelopt/onnx/quantization/fp8.py overflow when INT8 amax > ~18500." + ) + + def quantize_onnx_fp8( onnx_path: str, output_path: str, @@ -376,7 +429,8 @@ def quantize_onnx_fp8( # e.g. SDXL exports `sample` as FP32 even though the unet runs FP16. # modelopt's CalibrationDataProvider asserts strict count match and ORT's # inference probe rejects dtype mismatches, so filter+cast accordingly. - _onnx_inputs = {k: v[0] for k, v in _read_onnx_input_specs(onnx_path).items()} + _specs = _read_onnx_input_specs(onnx_path) # {name: (dtype, dims)} + _onnx_inputs = {k: v[0] for k, v in _specs.items()} _dropped = set(calibration_data.keys()) - set(_onnx_inputs) if _dropped: logger.info(f"[FP8] Dropping calibration keys not exposed by ONNX: {sorted(_dropped)}") @@ -389,6 +443,27 @@ def quantize_onnx_fp8( logger.info(f"[FP8] Casting calibration '{_k}': {calibration_data[_k].dtype} → {_expected}") calibration_data[_k] = calibration_data[_k].astype(_expected) + # Per-input tile: target rows = n_itr × resolved_dim0(name) so every input + # splits into exactly n_itr chunks of shape (resolved_dim0, ...). + # Mirrors modelopt CalibrationDataProvider: symbolic dims → 1, static dims kept. + # Naïve _max_rows tile breaks kvo_cache_in_* (ONNX dim0=2 static) by pumping + # sample to 2×_n_itr rows, causing modelopt to split kvo into (1,...) chunks. + import math as _math + + _resolved_dim0 = {name: max(1, (_specs[name][1][0] or 1)) for name in calibration_data} + _n_itr = max(arr.shape[0] // _resolved_dim0[name] for name, arr in calibration_data.items()) + _n_itr = max(1, _n_itr) + for _k in list(calibration_data.keys()): + _arr = calibration_data[_k] + _target_rows = _n_itr * _resolved_dim0[_k] + if _arr.shape[0] != _target_rows: + _repeats = _math.ceil(_target_rows / max(1, _arr.shape[0])) + calibration_data[_k] = np.tile(_arr, (_repeats,) + (1,) * (_arr.ndim - 1))[:_target_rows] + logger.info( + f"[FP8] Tiled '{_k}' {_arr.shape[0]} → {_target_rows} rows " + f"(n_itr={_n_itr} × resolved_dim0={_resolved_dim0[_k]})" + ) + import inspect as _inspect _params = set(_inspect.signature(modelopt_quantize).parameters.keys()) @@ -421,6 +496,8 @@ def quantize_onnx_fp8( if not os.path.exists(output_path): raise RuntimeError(f"[FP8] modelopt_quantize completed but output not found: {output_path}") + _assert_finite_qdq_scales(output_path) + size_mb = os.path.getsize(output_path) / (1024**2) logger.info(f"[FP8] FP8 ONNX written: {output_path} ({size_mb:.1f} MB)") if size_mb > 5000: diff --git a/src/streamdiffusion/acceleration/tensorrt/models/__init__.py b/src/streamdiffusion/acceleration/tensorrt/models/__init__.py index f0f2c4b93..b6a1bd62d 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/__init__.py @@ -1,13 +1,14 @@ -from .models import Optimizer, BaseModel, CLIP, UNet, VAE, VAEEncoder -from .controlnet_models import ControlNetTRT, ControlNetSDXLTRT +from .controlnet_models import ControlNetSDXLTRT, ControlNetTRT +from .models import CLIP, VAE, BaseModel, Optimizer, UNet, VAEEncoder + __all__ = [ "Optimizer", - "BaseModel", + "BaseModel", "CLIP", "UNet", "VAE", "VAEEncoder", "ControlNetTRT", "ControlNetSDXLTRT", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py index 6179ffc9a..be41d502f 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/attention_processors.py @@ -2,10 +2,10 @@ import torch import torch.nn.functional as F - from diffusers.models.attention_processor import Attention from diffusers.utils import USE_PEFT_BACKEND + class CachedSTAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -28,9 +28,9 @@ def __init__(self): # clone/contiguous path. Set to True by wrapper.py after engine build. self._curr_key_buf: Optional[torch.Tensor] = None self._curr_value_buf: Optional[torch.Tensor] = None - self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key + self._cached_key_tr_buf: Optional[torch.Tensor] = None # transposed cache key self._cached_value_tr_buf: Optional[torch.Tensor] = None # transposed cache value - self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H) + self._kvo_out_buf: Optional[torch.Tensor] = None # (2, 1, B, S, H) self._use_prealloc: bool = False def _ensure_buffers( diff --git a/src/streamdiffusion/acceleration/tensorrt/models/models.py b/src/streamdiffusion/acceleration/tensorrt/models/models.py index a85b2415d..d59e5a416 100644 --- a/src/streamdiffusion/acceleration/tensorrt/models/models.py +++ b/src/streamdiffusion/acceleration/tensorrt/models/models.py @@ -61,9 +61,9 @@ def infer_shapes(self, return_onnx=False): onnx_graph = gs.export_onnx(self.graph) if onnx_graph.ByteSize() > 2147483648: print( - f"⚠️ Model size ({onnx_graph.ByteSize() / (1024**3):.2f} GB) exceeds 2GB - this is normal for SDXL models" + f"[WARN] Model size ({onnx_graph.ByteSize() / (1024**3):.2f} GB) exceeds 2GB - this is normal for SDXL models" ) - print("🔧 ONNX shape inference will be skipped for large models to avoid memory issues") + print("[INFO] ONNX shape inference will be skipped for large models to avoid memory issues") # For large models like SDXL, skip shape inference to avoid memory/size issues # The model will still work with TensorRT's own shape inference during engine building else: diff --git a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py index 165c261e4..7fa98f855 100644 --- a/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py +++ b/src/streamdiffusion/acceleration/tensorrt/runtime_engines/__init__.py @@ -1,12 +1,13 @@ """Runtime TensorRT engine wrappers.""" -from .unet_engine import UNet2DConditionModelEngine, AutoencoderKLEngine -from .controlnet_engine import ControlNetModelEngine from ..engine_manager import EngineManager +from .controlnet_engine import ControlNetModelEngine +from .unet_engine import AutoencoderKLEngine, UNet2DConditionModelEngine + __all__ = [ "UNet2DConditionModelEngine", - "AutoencoderKLEngine", + "AutoencoderKLEngine", "ControlNetModelEngine", "EngineManager", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 001fda3db..4bab80c1a 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -1,91 +1,96 @@ -import os -import sys -import yaml import json -from typing import Dict, List, Optional, Union, Any, Tuple from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import yaml + def load_config(config_path: Union[str, Path]) -> Dict[str, Any]: """Load StreamDiffusion configuration from YAML or JSON file""" config_path = Path(config_path) - + if not config_path.exists(): raise FileNotFoundError(f"load_config: Configuration file not found: {config_path}") - with open(config_path, 'r', encoding='utf-8') as f: - if config_path.suffix.lower() in ['.yaml', '.yml']: + with open(config_path, "r", encoding="utf-8") as f: + if config_path.suffix.lower() in [".yaml", ".yml"]: config_data = yaml.safe_load(f) - elif config_path.suffix.lower() == '.json': + elif config_path.suffix.lower() == ".json": config_data = json.load(f) else: raise ValueError(f"load_config: Unsupported configuration file format: {config_path.suffix}") - + _validate_config(config_data) - + return config_data def save_config(config: Dict[str, Any], config_path: Union[str, Path]) -> None: """Save StreamDiffusion configuration to YAML or JSON file""" config_path = Path(config_path) - + _validate_config(config) config_path.parent.mkdir(parents=True, exist_ok=True) - with open(config_path, 'w', encoding='utf-8') as f: - if config_path.suffix.lower() in ['.yaml', '.yml']: + with open(config_path, "w", encoding="utf-8") as f: + if config_path.suffix.lower() in [".yaml", ".yml"]: yaml.dump(config, f, default_flow_style=False, indent=2) - elif config_path.suffix.lower() == '.json': + elif config_path.suffix.lower() == ".json": json.dump(config, f, indent=2) else: raise ValueError(f"save_config: Unsupported configuration file format: {config_path.suffix}") + def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any: """Create StreamDiffusionWrapper from configuration dictionary - + Prompt Interface: - Legacy: Use 'prompt' field for single prompt - New: Use 'prompt_blending' with 'prompt_list' for multiple weighted prompts - If both are provided, 'prompt_blending' takes precedence and 'prompt' is ignored - negative_prompt: Currently a single string (not list) for all prompt types """ + from streamdiffusion import StreamDiffusionWrapper - import torch final_config = {**config, **overrides} wrapper_params = _extract_wrapper_params(final_config) wrapper = StreamDiffusionWrapper(**wrapper_params) - + prepare_params = _extract_prepare_params(final_config) # Handle prompt configuration with clear precedence - if 'prompt_blending' in final_config: + if "prompt_blending" in final_config: # Use prompt blending (new interface) - ignore legacy 'prompt' field - blend_config = final_config['prompt_blending'] - + blend_config = final_config["prompt_blending"] + # Prepare with prompt blending directly using unified interface - prepare_params_with_blending = {k: v for k, v in prepare_params.items() - if k not in ['prompt_blending', 'seed_blending']} - prepare_params_with_blending['prompt'] = blend_config.get('prompt_list', []) - prepare_params_with_blending['prompt_interpolation_method'] = blend_config.get('interpolation_method', 'slerp') - + prepare_params_with_blending = { + k: v for k, v in prepare_params.items() if k not in ["prompt_blending", "seed_blending"] + } + prepare_params_with_blending["prompt"] = blend_config.get("prompt_list", []) + prepare_params_with_blending["prompt_interpolation_method"] = blend_config.get("interpolation_method", "slerp") + # Add seed blending if configured - if 'seed_blending' in final_config: - seed_blend_config = final_config['seed_blending'] - prepare_params_with_blending['seed_list'] = seed_blend_config.get('seed_list', []) - prepare_params_with_blending['seed_interpolation_method'] = seed_blend_config.get('interpolation_method', 'linear') - + if "seed_blending" in final_config: + seed_blend_config = final_config["seed_blending"] + prepare_params_with_blending["seed_list"] = seed_blend_config.get("seed_list", []) + prepare_params_with_blending["seed_interpolation_method"] = seed_blend_config.get( + "interpolation_method", "linear" + ) + wrapper.prepare(**prepare_params_with_blending) - elif prepare_params.get('prompt'): + elif prepare_params.get("prompt"): # Use legacy single prompt interface - clean_prepare_params = {k: v for k, v in prepare_params.items() - if k not in ['prompt_blending', 'seed_blending']} + clean_prepare_params = { + k: v for k, v in prepare_params.items() if k not in ["prompt_blending", "seed_blending"] + } wrapper.prepare(**clean_prepare_params) # Apply seed blending if configured and not already handled in prepare - if 'seed_blending' in final_config and 'prompt_blending' not in final_config: - seed_blend_config = final_config['seed_blending'] + if "seed_blending" in final_config and "prompt_blending" not in final_config: + seed_blend_config = final_config["seed_blending"] wrapper.update_stream_params( - seed_list=seed_blend_config.get('seed_list', []), - interpolation_method=seed_blend_config.get('interpolation_method', 'linear') + seed_list=seed_blend_config.get("seed_list", []), + interpolation_method=seed_blend_config.get("interpolation_method", "linear"), ) return wrapper @@ -93,256 +98,264 @@ def create_wrapper_from_config(config: Dict[str, Any], **overrides) -> Any: def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: """Extract parameters for StreamDiffusionWrapper.__init__() from config""" - import torch param_map = { - 'model_id_or_path': config.get('model_id', 'stabilityai/sd-turbo'), - 't_index_list': config.get('t_index_list', [0, 16, 32, 45]), - 'lora_dict': config.get('lora_dict'), - 'mode': config.get('mode', 'img2img'), - 'output_type': config.get('output_type', 'pil'), - 'vae_id': config.get('vae_id'), - 'device': config.get('device', 'cuda'), - 'dtype': _parse_dtype(config.get('dtype', 'float16')), - 'frame_buffer_size': config.get('frame_buffer_size', 1), - 'width': config.get('width', 512), - 'height': config.get('height', 512), - 'warmup': config.get('warmup', 10), - 'acceleration': config.get('acceleration', 'tensorrt'), - 'do_add_noise': config.get('do_add_noise', True), - 'device_ids': config.get('device_ids'), - 'use_lcm_lora': config.get('use_lcm_lora'), # Backwards compatibility - 'use_tiny_vae': config.get('use_tiny_vae', True), - 'enable_similar_image_filter': config.get('enable_similar_image_filter', False), - 'similar_image_filter_threshold': config.get('similar_image_filter_threshold', 0.98), - 'similar_image_filter_max_skip_frame': config.get('similar_image_filter_max_skip_frame', 10), - 'similar_filter_sleep_fraction': config.get('similar_filter_sleep_fraction', 0.025), - 'use_denoising_batch': config.get('use_denoising_batch', True), - 'cfg_type': config.get('cfg_type', 'self'), - 'seed': config.get('seed', 2), - 'use_safety_checker': config.get('use_safety_checker', False), - 'skip_diffusion': config.get('skip_diffusion', False), - 'engine_dir': config.get('engine_dir', 'engines'), - 'normalize_prompt_weights': config.get('normalize_prompt_weights', True), - 'normalize_seed_weights': config.get('normalize_seed_weights', True), - 'scheduler': config.get('scheduler', 'lcm'), - 'sampler': config.get('sampler', 'normal'), - 'compile_engines_only': config.get('compile_engines_only', False), - 'static_shapes': config.get('static_shapes', False), - 'fp8': config.get('fp8', False), - 'builder_optimization_level': config.get('builder_optimization_level'), - 'build_engines_if_missing': config.get('build_engines_if_missing', True), - 'fp8_allow_fp16_fallback': config.get('fp8_allow_fp16_fallback', False), + "model_id_or_path": config.get("model_id", "stabilityai/sd-turbo"), + "t_index_list": config.get("t_index_list", [0, 16, 32, 45]), + "lora_dict": config.get("lora_dict"), + "mode": config.get("mode", "img2img"), + "output_type": config.get("output_type", "pil"), + "vae_id": config.get("vae_id"), + "device": config.get("device", "cuda"), + "dtype": _parse_dtype(config.get("dtype", "float16")), + "frame_buffer_size": config.get("frame_buffer_size", 1), + "width": config.get("width", 512), + "height": config.get("height", 512), + "warmup": config.get("warmup", 10), + "acceleration": config.get("acceleration", "tensorrt"), + "do_add_noise": config.get("do_add_noise", True), + "device_ids": config.get("device_ids"), + "use_lcm_lora": config.get("use_lcm_lora"), # Backwards compatibility + "use_tiny_vae": config.get("use_tiny_vae", True), + "enable_similar_image_filter": config.get("enable_similar_image_filter", False), + "similar_image_filter_threshold": config.get("similar_image_filter_threshold", 0.98), + "similar_image_filter_max_skip_frame": config.get("similar_image_filter_max_skip_frame", 10), + "similar_filter_sleep_fraction": config.get("similar_filter_sleep_fraction", 0.025), + "use_denoising_batch": config.get("use_denoising_batch", True), + "cfg_type": config.get("cfg_type", "self"), + "seed": config.get("seed", 2), + "use_safety_checker": config.get("use_safety_checker", False), + "skip_diffusion": config.get("skip_diffusion", False), + "engine_dir": config.get("engine_dir", "engines"), + "normalize_prompt_weights": config.get("normalize_prompt_weights", True), + "normalize_seed_weights": config.get("normalize_seed_weights", True), + "scheduler": config.get("scheduler", "lcm"), + "sampler": config.get("sampler", "normal"), + "compile_engines_only": config.get("compile_engines_only", False), + "static_shapes": config.get("static_shapes", False), + "fp8": config.get("fp8", False), + "builder_optimization_level": config.get("builder_optimization_level"), + "build_engines_if_missing": config.get("build_engines_if_missing", True), + "fp8_allow_fp16_fallback": config.get("fp8_allow_fp16_fallback", False), } - if 'controlnets' in config and config['controlnets']: - param_map['use_controlnet'] = True - param_map['controlnet_config'] = _prepare_controlnet_configs(config) + if "controlnets" in config and config["controlnets"]: + param_map["use_controlnet"] = True + param_map["controlnet_config"] = _prepare_controlnet_configs(config) else: - param_map['use_controlnet'] = config.get('use_controlnet', False) - param_map['controlnet_config'] = config.get('controlnet_config') - + param_map["use_controlnet"] = config.get("use_controlnet", False) + param_map["controlnet_config"] = config.get("controlnet_config") + # Set IPAdapter usage if IPAdapters are configured - if 'ipadapters' in config and config['ipadapters']: - param_map['use_ipadapter'] = True - param_map['ipadapter_config'] = _prepare_ipadapter_configs(config) + if "ipadapters" in config and config["ipadapters"]: + param_map["use_ipadapter"] = True + param_map["ipadapter_config"] = _prepare_ipadapter_configs(config) else: - param_map['use_ipadapter'] = config.get('use_ipadapter', False) - param_map['ipadapter_config'] = config.get('ipadapter_config') - - param_map['use_cached_attn'] = config.get('use_cached_attn', False) - - param_map['cache_maxframes'] = config.get('cache_maxframes', 1) - param_map['cache_interval'] = config.get('cache_interval', 1) - + param_map["use_ipadapter"] = config.get("use_ipadapter", False) + param_map["ipadapter_config"] = config.get("ipadapter_config") + + param_map["use_cached_attn"] = config.get("use_cached_attn", False) + + param_map["cache_maxframes"] = config.get("cache_maxframes", 1) + param_map["cache_interval"] = config.get("cache_interval", 1) + # Pipeline hook configurations (Phase 4: Configuration Integration) hook_configs = _prepare_pipeline_hook_configs(config) param_map.update(hook_configs) - + return {k: v for k, v in param_map.items() if v is not None} def _extract_prepare_params(config: Dict[str, Any]) -> Dict[str, Any]: """Extract parameters for wrapper.prepare() from config""" prepare_params = { - 'prompt': config.get('prompt', ''), - 'negative_prompt': config.get('negative_prompt', ''), - 'num_inference_steps': config.get('num_inference_steps', 50), - 'guidance_scale': config.get('guidance_scale', 1.2), - 'delta': config.get('delta', 1.0), + "prompt": config.get("prompt", ""), + "negative_prompt": config.get("negative_prompt", ""), + "num_inference_steps": config.get("num_inference_steps", 50), + "guidance_scale": config.get("guidance_scale", 1.2), + "delta": config.get("delta", 1.0), } - + # Handle prompt blending configuration - if 'prompt_blending' in config: - blend_config = config['prompt_blending'] - prepare_params['prompt_blending'] = { - 'prompt_list': blend_config.get('prompt_list', []), - 'interpolation_method': blend_config.get('interpolation_method', 'slerp'), - 'enable_caching': blend_config.get('enable_caching', True) + if "prompt_blending" in config: + blend_config = config["prompt_blending"] + prepare_params["prompt_blending"] = { + "prompt_list": blend_config.get("prompt_list", []), + "interpolation_method": blend_config.get("interpolation_method", "slerp"), + "enable_caching": blend_config.get("enable_caching", True), } - + # Handle seed blending configuration - if 'seed_blending' in config: - seed_blend_config = config['seed_blending'] - prepare_params['seed_blending'] = { - 'seed_list': seed_blend_config.get('seed_list', []), - 'interpolation_method': seed_blend_config.get('interpolation_method', 'linear'), - 'enable_caching': seed_blend_config.get('enable_caching', True) + if "seed_blending" in config: + seed_blend_config = config["seed_blending"] + prepare_params["seed_blending"] = { + "seed_list": seed_blend_config.get("seed_list", []), + "interpolation_method": seed_blend_config.get("interpolation_method", "linear"), + "enable_caching": seed_blend_config.get("enable_caching", True), } - + return prepare_params + def _prepare_controlnet_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: """Prepare ControlNet configurations for wrapper""" controlnet_configs = [] - pipeline_type = config.get('pipeline_type', 'sd1.5') - for cn_config in config['controlnets']: + pipeline_type = config.get("pipeline_type", "sd1.5") + for cn_config in config["controlnets"]: controlnet_config = { - 'model_id': cn_config['model_id'], - 'preprocessor': cn_config.get('preprocessor', 'passthrough'), - 'conditioning_scale': cn_config.get('conditioning_scale', 1.0), - 'enabled': cn_config.get('enabled', True), - 'preprocessor_params': cn_config.get('preprocessor_params'), - 'conditioning_channels': cn_config.get('conditioning_channels'), - 'pipeline_type': pipeline_type, - 'control_guidance_start': cn_config.get('control_guidance_start', 0.0), - 'control_guidance_end': cn_config.get('control_guidance_end', 1.0), + "model_id": cn_config["model_id"], + "preprocessor": cn_config.get("preprocessor", "passthrough"), + "conditioning_scale": cn_config.get("conditioning_scale", 1.0), + "enabled": cn_config.get("enabled", True), + "preprocessor_params": cn_config.get("preprocessor_params"), + "conditioning_channels": cn_config.get("conditioning_channels"), + "pipeline_type": pipeline_type, + "control_guidance_start": cn_config.get("control_guidance_start", 0.0), + "control_guidance_end": cn_config.get("control_guidance_end", 1.0), } controlnet_configs.append(controlnet_config) - + return controlnet_configs def _prepare_ipadapter_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: """Prepare IPAdapter configurations for wrapper""" ipadapter_configs = [] - - for ip_config in config['ipadapters']: + + for ip_config in config["ipadapters"]: ipadapter_config = { - 'ipadapter_model_path': ip_config['ipadapter_model_path'], - 'image_encoder_path': ip_config['image_encoder_path'], - 'style_image': ip_config.get('style_image'), - 'scale': ip_config.get('scale', 1.0), - 'enabled': ip_config.get('enabled', True), + "ipadapter_model_path": ip_config["ipadapter_model_path"], + "image_encoder_path": ip_config["image_encoder_path"], + "style_image": ip_config.get("style_image"), + "scale": ip_config.get("scale", 1.0), + "enabled": ip_config.get("enabled", True), # Preserve FaceID options from config for downstream wrapper/module handling - 'type': ip_config.get('type', 'regular'), - 'insightface_model_name': ip_config.get('insightface_model_name'), + "type": ip_config.get("type", "regular"), + "insightface_model_name": ip_config.get("insightface_model_name"), } ipadapter_configs.append(ipadapter_config) - + return ipadapter_configs def _prepare_pipeline_hook_configs(config: Dict[str, Any]) -> Dict[str, Any]: """Prepare pipeline hook configurations for wrapper following ControlNet/IPAdapter pattern""" hook_configs = {} - + # Image preprocessing hooks - if 'image_preprocessing' in config and config['image_preprocessing']: - if config['image_preprocessing'].get('enabled', True): - hook_configs['image_preprocessing_config'] = _prepare_single_hook_config( - config['image_preprocessing'], 'image_preprocessing' + if "image_preprocessing" in config and config["image_preprocessing"]: + if config["image_preprocessing"].get("enabled", True): + hook_configs["image_preprocessing_config"] = _prepare_single_hook_config( + config["image_preprocessing"], "image_preprocessing" ) - - # Image postprocessing hooks - if 'image_postprocessing' in config and config['image_postprocessing']: - if config['image_postprocessing'].get('enabled', True): - hook_configs['image_postprocessing_config'] = _prepare_single_hook_config( - config['image_postprocessing'], 'image_postprocessing' + + # Image postprocessing hooks + if "image_postprocessing" in config and config["image_postprocessing"]: + if config["image_postprocessing"].get("enabled", True): + hook_configs["image_postprocessing_config"] = _prepare_single_hook_config( + config["image_postprocessing"], "image_postprocessing" ) - + # Latent preprocessing hooks - if 'latent_preprocessing' in config and config['latent_preprocessing']: - if config['latent_preprocessing'].get('enabled', True): - hook_configs['latent_preprocessing_config'] = _prepare_single_hook_config( - config['latent_preprocessing'], 'latent_preprocessing' + if "latent_preprocessing" in config and config["latent_preprocessing"]: + if config["latent_preprocessing"].get("enabled", True): + hook_configs["latent_preprocessing_config"] = _prepare_single_hook_config( + config["latent_preprocessing"], "latent_preprocessing" ) - + # Latent postprocessing hooks - if 'latent_postprocessing' in config and config['latent_postprocessing']: - if config['latent_postprocessing'].get('enabled', True): - hook_configs['latent_postprocessing_config'] = _prepare_single_hook_config( - config['latent_postprocessing'], 'latent_postprocessing' + if "latent_postprocessing" in config and config["latent_postprocessing"]: + if config["latent_postprocessing"].get("enabled", True): + hook_configs["latent_postprocessing_config"] = _prepare_single_hook_config( + config["latent_postprocessing"], "latent_postprocessing" ) - + return hook_configs def _prepare_single_hook_config(hook_config: Dict[str, Any], hook_type: str) -> Dict[str, Any]: """Prepare configuration for a single hook type""" return { - 'enabled': hook_config.get('enabled', True), - 'processors': hook_config.get('processors', []), - 'hook_type': hook_type, + "enabled": hook_config.get("enabled", True), + "processors": hook_config.get("processors", []), + "hook_type": hook_type, } def _validate_pipeline_hook_configs(config: Dict[str, Any]) -> None: """Validate pipeline hook configurations following ControlNet/IPAdapter validation pattern""" - hook_types = ['image_preprocessing', 'image_postprocessing', 'latent_preprocessing', 'latent_postprocessing'] - + hook_types = ["image_preprocessing", "image_postprocessing", "latent_preprocessing", "latent_postprocessing"] + for hook_type in hook_types: if hook_type in config: hook_config = config[hook_type] if not isinstance(hook_config, dict): raise ValueError(f"_validate_config: '{hook_type}' must be a dictionary") - + # Validate enabled field - if 'enabled' in hook_config: - enabled = hook_config['enabled'] + if "enabled" in hook_config: + enabled = hook_config["enabled"] if not isinstance(enabled, bool): raise ValueError(f"_validate_config: '{hook_type}.enabled' must be a boolean") - + # Validate processors field - if 'processors' in hook_config: - processors = hook_config['processors'] + if "processors" in hook_config: + processors = hook_config["processors"] if not isinstance(processors, list): raise ValueError(f"_validate_config: '{hook_type}.processors' must be a list") - + for i, processor in enumerate(processors): if not isinstance(processor, dict): raise ValueError(f"_validate_config: '{hook_type}.processors[{i}]' must be a dictionary") - + # Validate processor type (required) - if 'type' not in processor: - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}]' missing required 'type' field") - - if not isinstance(processor['type'], str): + if "type" not in processor: + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}]' missing required 'type' field" + ) + + if not isinstance(processor["type"], str): raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].type' must be a string") - + # Validate enabled field (optional, defaults to True) - if 'enabled' in processor: - enabled = processor['enabled'] + if "enabled" in processor: + enabled = processor["enabled"] if not isinstance(enabled, bool): - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].enabled' must be a boolean") - + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}].enabled' must be a boolean" + ) + # Validate order field (optional) - if 'order' in processor: - order = processor['order'] + if "order" in processor: + order = processor["order"] if not isinstance(order, int): - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].order' must be an integer") - + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}].order' must be an integer" + ) + # Validate params field (optional, coerce None to empty dict) - if 'params' in processor: - if processor['params'] is None: - processor['params'] = {} - elif not isinstance(processor['params'], dict): - raise ValueError(f"_validate_config: '{hook_type}.processors[{i}].params' must be a dictionary") + if "params" in processor: + if processor["params"] is None: + processor["params"] = {} + elif not isinstance(processor["params"], dict): + raise ValueError( + f"_validate_config: '{hook_type}.processors[{i}].params' must be a dictionary" + ) def create_prompt_blending_config( base_config: Dict[str, Any], prompt_list: List[Tuple[str, float]], prompt_interpolation_method: str = "slerp", - enable_caching: bool = True + enable_caching: bool = True, ) -> Dict[str, Any]: """Create a configuration with prompt blending settings""" config = base_config.copy() - - config['prompt_blending'] = { - 'prompt_list': prompt_list, - 'interpolation_method': prompt_interpolation_method, - 'enable_caching': enable_caching + + config["prompt_blending"] = { + "prompt_list": prompt_list, + "interpolation_method": prompt_interpolation_method, + "enable_caching": enable_caching, } - + return config @@ -350,150 +363,152 @@ def create_seed_blending_config( base_config: Dict[str, Any], seed_list: List[Tuple[int, float]], interpolation_method: str = "linear", - enable_caching: bool = True + enable_caching: bool = True, ) -> Dict[str, Any]: """Create a configuration with seed blending settings""" config = base_config.copy() - - config['seed_blending'] = { - 'seed_list': seed_list, - 'interpolation_method': interpolation_method, - 'enable_caching': enable_caching + + config["seed_blending"] = { + "seed_list": seed_list, + "interpolation_method": interpolation_method, + "enable_caching": enable_caching, } - + return config def set_normalize_weights_config( - base_config: Dict[str, Any], - normalize_prompt_weights: bool = True, - normalize_seed_weights: bool = True + base_config: Dict[str, Any], normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True ) -> Dict[str, Any]: """Create a configuration with separate normalize weight settings""" config = base_config.copy() - - config['normalize_prompt_weights'] = normalize_prompt_weights - config['normalize_seed_weights'] = normalize_seed_weights - + + config["normalize_prompt_weights"] = normalize_prompt_weights + config["normalize_seed_weights"] = normalize_seed_weights + return config + def _parse_dtype(dtype_str: str) -> Any: """Parse dtype string to torch dtype""" import torch - + dtype_map = { - 'float16': torch.float16, - 'float32': torch.float32, - 'half': torch.float16, - 'float': torch.float32, + "float16": torch.float16, + "float32": torch.float32, + "half": torch.float16, + "float": torch.float32, } - + if isinstance(dtype_str, str): return dtype_map.get(dtype_str.lower(), torch.float16) return dtype_str # Assume it's already a torch dtype + + def _validate_config(config: Dict[str, Any]) -> None: """Basic validation of configuration dictionary""" if not isinstance(config, dict): raise ValueError("_validate_config: Configuration must be a dictionary") - - if 'model_id' not in config: + + if "model_id" not in config: raise ValueError("_validate_config: Missing required field: model_id") - - if 'controlnets' in config: - if not isinstance(config['controlnets'], list): + + if "controlnets" in config: + if not isinstance(config["controlnets"], list): raise ValueError("_validate_config: 'controlnets' must be a list") - - for i, controlnet in enumerate(config['controlnets']): + + for i, controlnet in enumerate(config["controlnets"]): if not isinstance(controlnet, dict): raise ValueError(f"_validate_config: ControlNet {i} must be a dictionary") - - if 'model_id' not in controlnet: + + if "model_id" not in controlnet: raise ValueError(f"_validate_config: ControlNet {i} missing required 'model_id'") - + # Validate conditioning_channels if present - if 'conditioning_channels' in controlnet: - channels = controlnet['conditioning_channels'] + if "conditioning_channels" in controlnet: + channels = controlnet["conditioning_channels"] if not isinstance(channels, int) or channels <= 0: - raise ValueError(f"_validate_config: ControlNet {i} 'conditioning_channels' must be a positive integer, got {channels}") - + raise ValueError( + f"_validate_config: ControlNet {i} 'conditioning_channels' must be a positive integer, got {channels}" + ) + # Validate ipadapters if present - if 'ipadapters' in config: - if not isinstance(config['ipadapters'], list): + if "ipadapters" in config: + if not isinstance(config["ipadapters"], list): raise ValueError("_validate_config: 'ipadapters' must be a list") - - for i, ipadapter in enumerate(config['ipadapters']): + + for i, ipadapter in enumerate(config["ipadapters"]): if not isinstance(ipadapter, dict): raise ValueError(f"_validate_config: IPAdapter {i} must be a dictionary") - - if 'ipadapter_model_path' not in ipadapter: + + if "ipadapter_model_path" not in ipadapter: raise ValueError(f"_validate_config: IPAdapter {i} missing required 'ipadapter_model_path'") - - if 'image_encoder_path' not in ipadapter: + + if "image_encoder_path" not in ipadapter: raise ValueError(f"_validate_config: IPAdapter {i} missing required 'image_encoder_path'") # Validate prompt blending configuration if present - if 'prompt_blending' in config: - blend_config = config['prompt_blending'] + if "prompt_blending" in config: + blend_config = config["prompt_blending"] if not isinstance(blend_config, dict): raise ValueError("_validate_config: 'prompt_blending' must be a dictionary") - - if 'prompt_list' in blend_config: - prompt_list = blend_config['prompt_list'] + + if "prompt_list" in blend_config: + prompt_list = blend_config["prompt_list"] if not isinstance(prompt_list, list): raise ValueError("_validate_config: 'prompt_list' must be a list") - + for i, prompt_item in enumerate(prompt_list): if not isinstance(prompt_item, (list, tuple)) or len(prompt_item) != 2: raise ValueError(f"_validate_config: Prompt item {i} must be [text, weight] pair") - + text, weight = prompt_item if not isinstance(text, str): raise ValueError(f"_validate_config: Prompt text {i} must be a string") - + if not isinstance(weight, (int, float)) or weight < 0: raise ValueError(f"_validate_config: Prompt weight {i} must be a non-negative number") - - interpolation_method = blend_config.get('interpolation_method', 'slerp') - if interpolation_method not in ['linear', 'slerp']: + + interpolation_method = blend_config.get("interpolation_method", "slerp") + if interpolation_method not in ["linear", "slerp"]: raise ValueError("_validate_config: interpolation_method must be 'linear' or 'slerp'") # Validate seed blending configuration if present - if 'seed_blending' in config: - seed_blend_config = config['seed_blending'] + if "seed_blending" in config: + seed_blend_config = config["seed_blending"] if not isinstance(seed_blend_config, dict): raise ValueError("_validate_config: 'seed_blending' must be a dictionary") - - if 'seed_list' in seed_blend_config: - seed_list = seed_blend_config['seed_list'] + + if "seed_list" in seed_blend_config: + seed_list = seed_blend_config["seed_list"] if not isinstance(seed_list, list): raise ValueError("_validate_config: 'seed_list' must be a list") - + for i, seed_item in enumerate(seed_list): if not isinstance(seed_item, (list, tuple)) or len(seed_item) != 2: raise ValueError(f"_validate_config: Seed item {i} must be [seed, weight] pair") - + seed_value, weight = seed_item if not isinstance(seed_value, int) or seed_value < 0: raise ValueError(f"_validate_config: Seed value {i} must be a non-negative integer") - + if not isinstance(weight, (int, float)) or weight < 0: raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number") - - interpolation_method = seed_blend_config.get('interpolation_method', 'linear') - if interpolation_method not in ['linear', 'slerp']: + + interpolation_method = seed_blend_config.get("interpolation_method", "linear") + if interpolation_method not in ["linear", "slerp"]: raise ValueError("_validate_config: seed blending interpolation_method must be 'linear' or 'slerp'") # Validate pipeline hook configurations if present (Phase 4: Configuration Integration) _validate_pipeline_hook_configs(config) # Validate separate normalize settings if present - if 'normalize_prompt_weights' in config: - normalize_prompt_weights = config['normalize_prompt_weights'] + if "normalize_prompt_weights" in config: + normalize_prompt_weights = config["normalize_prompt_weights"] if not isinstance(normalize_prompt_weights, bool): raise ValueError("_validate_config: 'normalize_prompt_weights' must be a boolean value") - - if 'normalize_seed_weights' in config: - normalize_seed_weights = config['normalize_seed_weights'] + + if "normalize_seed_weights" in config: + normalize_seed_weights = config["normalize_seed_weights"] if not isinstance(normalize_seed_weights, bool): raise ValueError("_validate_config: 'normalize_seed_weights' must be a boolean value") - diff --git a/src/streamdiffusion/hooks.py b/src/streamdiffusion/hooks.py index ec5db4155..02f10270d 100644 --- a/src/streamdiffusion/hooks.py +++ b/src/streamdiffusion/hooks.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional + import torch @@ -13,6 +14,7 @@ class EmbedsCtx: - prompt_embeds: [batch, seq_len, dim] - negative_prompt_embeds: optional [batch, seq_len, dim] """ + prompt_embeds: torch.Tensor negative_prompt_embeds: Optional[torch.Tensor] = None @@ -28,6 +30,7 @@ class StepCtx: - guidance_mode: one of {"none","full","self","initialize"} - sdxl_cond: optional dict with SDXL micro-cond tensors """ + x_t_latent: torch.Tensor t_list: torch.Tensor step_index: Optional[int] @@ -38,6 +41,7 @@ class StepCtx: @dataclass class UnetKwargsDelta: """Delta produced by UNet hooks to augment UNet call kwargs.""" + down_block_additional_residuals: Optional[List[torch.Tensor]] = None mid_block_additional_residual: Optional[torch.Tensor] = None added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None @@ -48,37 +52,37 @@ class UnetKwargsDelta: @dataclass class ImageCtx: """Context passed to image processing hooks. - + Fields: - image: [B, C, H, W] tensor in image space - width: image width - - height: image height + - height: image height - step_index: optional step index for multi-step processing """ + image: torch.Tensor width: int height: int step_index: Optional[int] = None -@dataclass +@dataclass class LatentCtx: """Context passed to latent processing hooks. - + Fields: - latent: [B, C, H/8, W/8] tensor in latent space - timestep: optional timestep tensor for diffusion context - step_index: optional step index for multi-step processing """ + latent: torch.Tensor timestep: Optional[torch.Tensor] = None step_index: Optional[int] = None - # Type aliases for clarity EmbeddingHook = Callable[[EmbedsCtx], EmbedsCtx] UnetHook = Callable[[StepCtx], UnetKwargsDelta] ImageHook = Callable[[ImageCtx], ImageCtx] LatentHook = Callable[[LatentCtx], LatentCtx] - diff --git a/src/streamdiffusion/image_filter.py b/src/streamdiffusion/image_filter.py index 5523c8869..e975567a0 100644 --- a/src/streamdiffusion/image_filter.py +++ b/src/streamdiffusion/image_filter.py @@ -1,5 +1,5 @@ -from typing import Optional import random +from typing import Optional import torch import torch.nn.functional as F diff --git a/src/streamdiffusion/image_utils.py b/src/streamdiffusion/image_utils.py index 200295b37..77d7275c2 100644 --- a/src/streamdiffusion/image_utils.py +++ b/src/streamdiffusion/image_utils.py @@ -30,9 +30,7 @@ def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: images = (images * 255).round().astype("uint8") if images.shape[-1] == 1: # special case for grayscale (single channel) images - pil_images = [ - PIL.Image.fromarray(image.squeeze(), mode="L") for image in images - ] + pil_images = [PIL.Image.fromarray(image.squeeze(), mode="L") for image in images] else: pil_images = [PIL.Image.fromarray(image) for image in images] @@ -56,12 +54,7 @@ def postprocess_image( if do_denormalize is None: do_denormalize = [do_normalize_flg] * image.shape[0] - image = torch.stack( - [ - denormalize(image[i]) if do_denormalize[i] else image[i] - for i in range(image.shape[0]) - ] - ) + image = torch.stack([denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]) if output_type == "pt": return image @@ -91,8 +84,6 @@ def pil2tensor(image_pil: PIL.Image.Image) -> torch.Tensor: img, _ = process_image(image_pil) imgs.append(img) imgs = torch.vstack(imgs) - images = torch.nn.functional.interpolate( - imgs, size=(height, width), mode="bilinear" - ) + images = torch.nn.functional.interpolate(imgs, size=(height, width), mode="bilinear") image_tensors = images.to(torch.float16) return image_tensors diff --git a/src/streamdiffusion/model_detection.py b/src/streamdiffusion/model_detection.py index e9eef252c..fbd28933e 100644 --- a/src/streamdiffusion/model_detection.py +++ b/src/streamdiffusion/model_detection.py @@ -1,13 +1,15 @@ """Comprehensive model detection for TensorRT and pipeline support""" -from typing import Dict, Tuple, Optional, Any, List +from typing import Any, Dict, Optional import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + # Gracefully import the SD3 model class; it might not exist in older diffusers versions. try: from diffusers.models.transformers.mm_dit import MMDiTTransformer2DModel + HAS_MMDIT = True except ImportError: # Create a dummy class if the import fails to prevent runtime errors. @@ -15,6 +17,8 @@ HAS_MMDIT = False import logging + + logger = logging.getLogger(__name__) @@ -23,7 +27,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str Comprehensive and robust model detection using definitive architectural features. This function replaces heuristic-based analysis with a deterministic, - rule-based approach by first inspecting the model's class and then its key + rule-based approach by first inspecting the model's class and then its key configuration parameters that define the architecture. Args: @@ -50,9 +54,9 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str confidence = 1.0 # Differentiating SD3 vs. SD3-Turbo from the MMDiT config alone is currently # speculative. A check on the pipeline's scheduler is a reasonable proxy. - if pipe and hasattr(pipe, 'scheduler'): - scheduler_name = getattr(pipe.scheduler.config, '_class_name', '').lower() - if 'lcm' in scheduler_name or 'turbo' in scheduler_name: + if pipe and hasattr(pipe, "scheduler"): + scheduler_name = getattr(pipe.scheduler.config, "_class_name", "").lower() + if "lcm" in scheduler_name or "turbo" in scheduler_name: is_turbo = True model_type = "SD3-Turbo" else: @@ -62,7 +66,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str # 2. UNet-based Model Detection (SDXL, SD2.1, SD1.5) elif isinstance(model, UNet2DConditionModel): config = model.config - + # 2a. SDXL vs. non-SDXL # The `addition_embed_type` is the clearest indicator for the SDXL architecture. if config.get("addition_embed_type") is not None: @@ -73,7 +77,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str # Base SDXL has `time_cond_proj_dim` (e.g., 256), while Turbo has it set to `None`. if config.get("time_cond_proj_dim") is None: is_turbo = True - + # 2b. SD2.1 vs. SD1.5 (if not SDXL) # Differentiate based on the text encoder's projection dimension. else: @@ -90,10 +94,10 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str confidence = 0.7 # 3. ControlNet Model Detection (detect underlying architecture) - elif hasattr(model, 'config') and hasattr(model.config, 'cross_attention_dim'): + elif hasattr(model, "config") and hasattr(model.config, "cross_attention_dim"): # ControlNet models have UNet-like configs, detect their base architecture config = model.config - + # Apply same detection logic as UNet models if config.get("addition_embed_type") is not None: model_type = "SDXL" @@ -107,12 +111,12 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str model_type = "SD2.1" confidence = 0.95 elif cross_attention_dim == 768: - model_type = "SD1.5" + model_type = "SD1.5" confidence = 0.95 else: model_type = "SD-finetune" confidence = 0.7 - + else: # The model is not a known UNet or MMDiT class. confidence = 0.0 @@ -120,44 +124,46 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str # Populate architecture and compatibility details (can be expanded as needed) architecture_details = { - 'model_class': model.__class__.__name__, - 'in_channels': getattr(model.config, 'in_channels', 'N/A'), - 'cross_attention_dim': getattr(model.config, 'cross_attention_dim', 'N/A'), - 'block_out_channels': getattr(model.config, 'block_out_channels', 'N/A'), + "model_class": model.__class__.__name__, + "in_channels": getattr(model.config, "in_channels", "N/A"), + "cross_attention_dim": getattr(model.config, "cross_attention_dim", "N/A"), + "block_out_channels": getattr(model.config, "block_out_channels", "N/A"), } - + # For UNet models, add detailed characteristics that SDXL code expects if isinstance(model, UNet2DConditionModel): unet_chars = detect_unet_characteristics(model) - architecture_details.update({ - 'has_time_conditioning': unet_chars['has_time_cond'], - 'has_addition_embeds': unet_chars['has_addition_embed'], - }) - + architecture_details.update( + { + "has_time_conditioning": unet_chars["has_time_cond"], + "has_addition_embeds": unet_chars["has_addition_embed"], + } + ) + # For ControlNet models, add similar characteristics - elif hasattr(model, 'config') and hasattr(model.config, 'cross_attention_dim'): + elif hasattr(model, "config") and hasattr(model.config, "cross_attention_dim"): # ControlNet models have similar config structure to UNet config = model.config has_addition_embed = config.get("addition_embed_type") is not None - has_time_cond = hasattr(config, 'time_cond_proj_dim') and config.time_cond_proj_dim is not None - - architecture_details.update({ - 'has_time_conditioning': has_time_cond, - 'has_addition_embeds': has_addition_embed, - }) - - compatibility_info = { - 'notes': f"Detected as {model_type} with {confidence:.2f} confidence based on architecture." - } + has_time_cond = hasattr(config, "time_cond_proj_dim") and config.time_cond_proj_dim is not None + + architecture_details.update( + { + "has_time_conditioning": has_time_cond, + "has_addition_embeds": has_addition_embed, + } + ) + + compatibility_info = {"notes": f"Detected as {model_type} with {confidence:.2f} confidence based on architecture."} result = { - 'model_type': model_type, - 'is_turbo': is_turbo, - 'is_sdxl': is_sdxl, - 'is_sd3': is_sd3, - 'confidence': confidence, - 'architecture_details': architecture_details, - 'compatibility_info': compatibility_info, + "model_type": model_type, + "is_turbo": is_turbo, + "is_sdxl": is_sdxl, + "is_sd3": is_sd3, + "confidence": confidence, + "architecture_details": architecture_details, + "compatibility_info": compatibility_info, } return result @@ -166,13 +172,13 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]: """Detect detailed UNet characteristics including SDXL-specific features""" config = unet.config - + # Get cross attention dimensions to detect model type - cross_attention_dim = getattr(config, 'cross_attention_dim', None) - + cross_attention_dim = getattr(config, "cross_attention_dim", None) + # Detect SDXL by multiple indicators is_sdxl = False - + # Check cross attention dimension if isinstance(cross_attention_dim, (list, tuple)): # SDXL typically has [1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280, 1280] @@ -180,73 +186,74 @@ def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]: elif isinstance(cross_attention_dim, int): # Single value - SDXL uses 2048 for concatenated embeddings, or 1280+ for individual encoders is_sdxl = cross_attention_dim >= 1280 - + # Check addition_embed_type for SDXL detection (strong indicator) - addition_embed_type = getattr(config, 'addition_embed_type', None) + addition_embed_type = getattr(config, "addition_embed_type", None) has_addition_embed = addition_embed_type is not None - - if addition_embed_type in ['text_time', 'text_time_guidance']: + + if addition_embed_type in ["text_time", "text_time_guidance"]: is_sdxl = True # This is a definitive SDXL indicator - + # Check if model has time conditioning projection (SDXL feature) - has_time_cond = hasattr(config, 'time_cond_proj_dim') and config.time_cond_proj_dim is not None - + has_time_cond = hasattr(config, "time_cond_proj_dim") and config.time_cond_proj_dim is not None + # Additional SDXL detection checks - if hasattr(config, 'num_class_embeds') and config.num_class_embeds is not None: + if hasattr(config, "num_class_embeds") and config.num_class_embeds is not None: is_sdxl = True # SDXL often has class embeddings - + # Check sample size (SDXL typically uses 128 vs 64 for SD1.5) - sample_size = getattr(config, 'sample_size', 64) + sample_size = getattr(config, "sample_size", 64) if sample_size >= 128: is_sdxl = True - + return { - 'is_sdxl': is_sdxl, - 'has_time_cond': has_time_cond, - 'has_addition_embed': has_addition_embed, - 'cross_attention_dim': cross_attention_dim, - 'addition_embed_type': addition_embed_type, - 'in_channels': getattr(config, 'in_channels', 4), - 'sample_size': getattr(config, 'sample_size', 64 if not is_sdxl else 128), - 'block_out_channels': tuple(getattr(config, 'block_out_channels', [])), - 'attention_head_dim': getattr(config, 'attention_head_dim', None) + "is_sdxl": is_sdxl, + "has_time_cond": has_time_cond, + "has_addition_embed": has_addition_embed, + "cross_attention_dim": cross_attention_dim, + "addition_embed_type": addition_embed_type, + "in_channels": getattr(config, "in_channels", 4), + "sample_size": getattr(config, "sample_size", 64 if not is_sdxl else 128), + "block_out_channels": tuple(getattr(config, "block_out_channels", [])), + "attention_head_dim": getattr(config, "attention_head_dim", None), } + # This is used for controlnet/ipadapter model detection - can be deprecated (along with detect_unet_characteristics) def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str: """Detect model type from diffusers UNet configuration""" characteristics = detect_unet_characteristics(unet) - - in_channels = characteristics['in_channels'] - block_out_channels = characteristics['block_out_channels'] - cross_attention_dim = characteristics['cross_attention_dim'] - is_sdxl = characteristics['is_sdxl'] - + + in_channels = characteristics["in_channels"] + block_out_channels = characteristics["block_out_channels"] + cross_attention_dim = characteristics["cross_attention_dim"] + is_sdxl = characteristics["is_sdxl"] + # Use enhanced SDXL detection if is_sdxl: return "SDXL" - + # Original detection logic for other models - if (cross_attention_dim == 768 and - block_out_channels == (320, 640, 1280, 1280) and - in_channels == 4): + if cross_attention_dim == 768 and block_out_channels == (320, 640, 1280, 1280) and in_channels == 4: return "SD15" - - elif (cross_attention_dim == 1024 and - block_out_channels == (320, 640, 1280, 1280) and - in_channels == 4): + + elif cross_attention_dim == 1024 and block_out_channels == (320, 640, 1280, 1280) and in_channels == 4: return "SD21" - + elif cross_attention_dim == 768 and in_channels == 4: return "SD15" elif cross_attention_dim == 1024 and in_channels == 4: return "SD21" - + if cross_attention_dim == 768: - print(f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15") + print( + f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15" + ) return "SD15" elif cross_attention_dim == 1024: - print(f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21") + print( + f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21" + ) return "SD21" else: raise ValueError( @@ -260,58 +267,58 @@ def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str: def extract_unet_architecture(unet: UNet2DConditionModel) -> Dict[str, Any]: """ Extract UNet architecture details needed for TensorRT engine building. - + This function provides the essential architecture information needed for TensorRT engine compilation in a clean, structured format. - + Args: unet: The UNet model to analyze - + Returns: Dict with architecture parameters for TensorRT engine building """ config = unet.config - + # Basic model parameters model_channels = config.block_out_channels[0] if config.block_out_channels else 320 block_out_channels = tuple(config.block_out_channels) channel_mult = tuple(ch // model_channels for ch in block_out_channels) - + # Resolution blocks - if hasattr(config, 'layers_per_block'): + if hasattr(config, "layers_per_block"): if isinstance(config.layers_per_block, (list, tuple)): num_res_blocks = tuple(config.layers_per_block) else: num_res_blocks = tuple([config.layers_per_block] * len(block_out_channels)) else: num_res_blocks = tuple([2] * len(block_out_channels)) - + # Attention and context dimensions context_dim = config.cross_attention_dim in_channels = config.in_channels - + # Attention head configuration - attention_head_dim = getattr(config, 'attention_head_dim', 8) + attention_head_dim = getattr(config, "attention_head_dim", 8) if isinstance(attention_head_dim, (list, tuple)): attention_head_dim = attention_head_dim[0] - + # Transformer depth - transformer_depth = getattr(config, 'transformer_layers_per_block', 1) + transformer_depth = getattr(config, "transformer_layers_per_block", 1) if isinstance(transformer_depth, (list, tuple)): transformer_depth = tuple(transformer_depth) else: transformer_depth = tuple([transformer_depth] * len(block_out_channels)) - + # Time embedding - time_embed_dim = getattr(config, 'time_embedding_dim', None) + time_embed_dim = getattr(config, "time_embedding_dim", None) if time_embed_dim is None: time_embed_dim = model_channels * 4 - + # Build architecture dictionary architecture_dict = { "model_channels": model_channels, "in_channels": in_channels, - "out_channels": getattr(config, 'out_channels', in_channels), + "out_channels": getattr(config, "out_channels", in_channels), "num_res_blocks": num_res_blocks, "channel_mult": channel_mult, "context_dim": context_dim, @@ -319,48 +326,50 @@ def extract_unet_architecture(unet: UNet2DConditionModel) -> Dict[str, Any]: "transformer_depth": transformer_depth, "time_embed_dim": time_embed_dim, "block_out_channels": block_out_channels, - # Additional configuration - "use_linear_in_transformer": getattr(config, 'use_linear_in_transformer', False), - "conv_in_kernel": getattr(config, 'conv_in_kernel', 3), - "conv_out_kernel": getattr(config, 'conv_out_kernel', 3), - "resnet_time_scale_shift": getattr(config, 'resnet_time_scale_shift', 'default'), - "class_embed_type": getattr(config, 'class_embed_type', None), - "num_class_embeds": getattr(config, 'num_class_embeds', None), - + "use_linear_in_transformer": getattr(config, "use_linear_in_transformer", False), + "conv_in_kernel": getattr(config, "conv_in_kernel", 3), + "conv_out_kernel": getattr(config, "conv_out_kernel", 3), + "resnet_time_scale_shift": getattr(config, "resnet_time_scale_shift", "default"), + "class_embed_type": getattr(config, "class_embed_type", None), + "num_class_embeds": getattr(config, "num_class_embeds", None), # Block types - "down_block_types": getattr(config, 'down_block_types', []), - "up_block_types": getattr(config, 'up_block_types', []), + "down_block_types": getattr(config, "down_block_types", []), + "up_block_types": getattr(config, "up_block_types", []), } - + return architecture_dict def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[str, Any]: """ Validate and fix architecture dictionary using model type presets. - + Ensures that all required architecture parameters are present and have reasonable values for the specified model type. - + Args: arch_dict: Architecture dictionary to validate model_type: Expected model type for validation - + Returns: Validated and corrected architecture dictionary """ - + # Check for required keys required_keys = [ - "model_channels", "channel_mult", "num_res_blocks", - "context_dim", "in_channels", "block_out_channels" + "model_channels", + "channel_mult", + "num_res_blocks", + "context_dim", + "in_channels", + "block_out_channels", ] - + for key in required_keys: if key not in arch_dict: raise ValueError(f"Missing required architecture parameter: {key}") - + # Ensure tuple format for sequence parameters for key in ["channel_mult", "num_res_blocks", "transformer_depth", "block_out_channels"]: if key in arch_dict and not isinstance(arch_dict[key], tuple): @@ -371,12 +380,11 @@ def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[st arch_dict[key] = tuple(arch_dict[key]) else: arch_dict[key] = preset[key] - + # Validate sequence lengths match expected_levels = len(arch_dict["channel_mult"]) for key in ["num_res_blocks", "transformer_depth"]: if key in arch_dict and len(arch_dict[key]) != expected_levels: arch_dict[key] = preset[key] - - return arch_dict + return arch_dict diff --git a/src/streamdiffusion/modules/__init__.py b/src/streamdiffusion/modules/__init__.py index 54954961a..f3242ca59 100644 --- a/src/streamdiffusion/modules/__init__.py +++ b/src/streamdiffusion/modules/__init__.py @@ -1,22 +1,21 @@ # StreamDiffusion Modules Package from .controlnet_module import ControlNetModule +from .image_processing_module import ImagePostprocessingModule, ImagePreprocessingModule, ImageProcessingModule from .ipadapter_module import IPAdapterModule -from .image_processing_module import ImageProcessingModule, ImagePreprocessingModule, ImagePostprocessingModule -from .latent_processing_module import LatentProcessingModule, LatentPreprocessingModule, LatentPostprocessingModule +from .latent_processing_module import LatentPostprocessingModule, LatentPreprocessingModule, LatentProcessingModule + __all__ = [ # Existing modules - 'ControlNetModule', - 'IPAdapterModule', - + "ControlNetModule", + "IPAdapterModule", # Pipeline processing base classes - 'ImageProcessingModule', - 'LatentProcessingModule', - + "ImageProcessingModule", + "LatentProcessingModule", # Pipeline processing timing-specific modules - 'ImagePreprocessingModule', - 'ImagePostprocessingModule', - 'LatentPreprocessingModule', - 'LatentPostprocessingModule', + "ImagePreprocessingModule", + "ImagePostprocessingModule", + "LatentPreprocessingModule", + "LatentPostprocessingModule", ] diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py index 6c8655b13..e0a57f3b6 100644 --- a/src/streamdiffusion/modules/controlnet_module.py +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -1,18 +1,18 @@ from __future__ import annotations +import logging import threading from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import torch from diffusers.models import ControlNetModel -import logging -from streamdiffusion.hooks import StepCtx, UnetKwargsDelta, UnetHook +from streamdiffusion.hooks import StepCtx, UnetHook, UnetKwargsDelta +from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser from streamdiffusion.preprocessing.preprocessing_orchestrator import ( PreprocessingOrchestrator, ) -from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser @dataclass @@ -55,17 +55,17 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) -> self._prepared_dtype: Optional[torch.dtype] = None self._prepared_batch: Optional[int] = None self._images_version: int = 0 - + # Cache expensive lookups to avoid repeated hasattr/getattr calls self._engines_by_id: Dict[str, Any] = {} self._engines_cache_valid: bool = False self._is_sdxl: Optional[bool] = None self._expected_text_len: int = 77 - + # SDXL-specific caching for performance optimization self._sdxl_conditioning_cache: Optional[Dict[str, torch.Tensor]] = None self._sdxl_conditioning_valid: bool = False - + # Cache engine type detection to avoid repeated hasattr calls self._engine_type_cache: Dict[str, bool] = {} @@ -78,9 +78,9 @@ def install(self, stream) -> None: # Register UNet hook stream.unet_hooks.append(self.build_unet_hook()) # Expose controlnet collections so existing updater can find them - setattr(stream, 'controlnets', self.controlnets) - setattr(stream, 'controlnet_scales', self.controlnet_scales) - setattr(stream, 'preprocessors', self.preprocessors) + setattr(stream, "controlnets", self.controlnets) + setattr(stream, "controlnet_scales", self.controlnet_scales) + setattr(stream, "preprocessors", self.preprocessors) # Reset prepared tensors on install self._prepared_tensors = [] self._prepared_device = None @@ -92,18 +92,26 @@ def install(self, stream) -> None: self._sdxl_conditioning_valid = False self._engine_type_cache.clear() - def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None: + def add_controlnet( + self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None + ) -> None: model = self._load_pytorch_controlnet_model(cfg.model_id, cfg.conditioning_channels) preproc = None if cfg.preprocessor: from streamdiffusion.preprocessing.processors import get_preprocessor - preproc = get_preprocessor(cfg.preprocessor, pipeline_ref=self._stream, normalization_context='controlnet', params=cfg.preprocessor_params) + + preproc = get_preprocessor( + cfg.preprocessor, + pipeline_ref=self._stream, + normalization_context="controlnet", + params=cfg.preprocessor_params, + ) # Apply provided parameters to the preprocessor instance if cfg.preprocessor_params: params = cfg.preprocessor_params or {} # If the preprocessor exposes a 'params' dict, update it - if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): + if hasattr(preproc, "params") and isinstance(getattr(preproc, "params"), dict): preproc.params.update(params) # Also set attributes directly when they exist for name, value in params.items(): @@ -113,16 +121,15 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st except Exception: pass - # Align preprocessor target size with stream resolution once (avoid double-resize later) try: - if hasattr(preproc, 'params') and isinstance(getattr(preproc, 'params'), dict): - preproc.params['image_width'] = int(self._stream.width) - preproc.params['image_height'] = int(self._stream.height) - if hasattr(preproc, 'image_width'): - setattr(preproc, 'image_width', int(self._stream.width)) - if hasattr(preproc, 'image_height'): - setattr(preproc, 'image_height', int(self._stream.height)) + if hasattr(preproc, "params") and isinstance(getattr(preproc, "params"), dict): + preproc.params["image_width"] = int(self._stream.width) + preproc.params["image_height"] = int(self._stream.height) + if hasattr(preproc, "image_width"): + setattr(preproc, "image_width", int(self._stream.width)) + if hasattr(preproc, "image_height"): + setattr(preproc, "image_height", int(self._stream.height)) except Exception: pass @@ -142,7 +149,9 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st # Invalidate SDXL conditioning cache when ControlNet configuration changes self._sdxl_conditioning_valid = False - def update_control_image_efficient(self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None) -> None: + def update_control_image_efficient( + self, control_image: Union[str, Any, torch.Tensor], index: Optional[int] = None + ) -> None: if self._preprocessing_orchestrator is None: return with self._collections_lock: @@ -150,23 +159,15 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te return total = len(self.controlnets) # Build active scales, respecting enabled_list if present - scales = [ - (self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0) - for i in range(total) - ] - if hasattr(self, 'enabled_list') and self.enabled_list and len(self.enabled_list) == total: + scales = [(self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0) for i in range(total)] + if hasattr(self, "enabled_list") and self.enabled_list and len(self.enabled_list) == total: scales = [sc if bool(self.enabled_list[i]) else 0.0 for i, sc in enumerate(scales)] preprocessors = [self.preprocessors[i] if i < len(self.preprocessors) else None for i in range(total)] # Single-index fast path if index is not None: results = self._preprocessing_orchestrator.process_sync( - control_image, - preprocessors, - scales, - self._stream.width, - self._stream.height, - index + control_image, preprocessors, scales, self._stream.width, self._stream.height, index ) processed = results[index] if results and len(results) > index else None with self._collections_lock: @@ -182,11 +183,7 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te # Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync) processed_images = self._preprocessing_orchestrator.process_pipelined( - control_image, - preprocessors, - scales, - self._stream.width, - self._stream.height + control_image, preprocessors, scales, self._stream.width, self._stream.height ) # If orchestrator returns empty list, it indicates no update needed for this frame @@ -243,7 +240,7 @@ def reorder_controlnets_by_model_ids(self, desired_model_ids: List[str]) -> None # Build current mapping from model_id to index current_ids: List[str] = [] for i, cn in enumerate(self.controlnets): - model_id = getattr(cn, 'model_id', f'controlnet_{i}') + model_id = getattr(cn, "model_id", f"controlnet_{i}") current_ids.append(model_id) # Compute new index order @@ -275,23 +272,29 @@ def get_current_config(self) -> List[Dict[str, Any]]: cfg: List[Dict[str, Any]] = [] with self._collections_lock: for i, cn in enumerate(self.controlnets): - model_id = getattr(cn, 'model_id', f'controlnet_{i}') + model_id = getattr(cn, "model_id", f"controlnet_{i}") scale = self.controlnet_scales[i] if i < len(self.controlnet_scales) else 1.0 - preproc_params = getattr(self.preprocessors[i], 'params', {}) if i < len(self.preprocessors) and self.preprocessors[i] else {} - cfg.append({ - 'model_id': model_id, - 'conditioning_scale': scale, - 'preprocessor_params': preproc_params, - 'enabled': (self.enabled_list[i] if i < len(self.enabled_list) else True), - }) + preproc_params = ( + getattr(self.preprocessors[i], "params", {}) + if i < len(self.preprocessors) and self.preprocessors[i] + else {} + ) + cfg.append( + { + "model_id": model_id, + "conditioning_scale": scale, + "preprocessor_params": preproc_params, + "enabled": (self.enabled_list[i] if i < len(self.enabled_list) else True), + } + ) return cfg def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_size: int) -> None: """Prepare control image tensors for the current frame. - + This method is called once per frame to prepare all control images with the correct device, dtype, and batch size. This avoids redundant operations during each denoising step. - + Args: device: Target device for tensors dtype: Target dtype for tensors @@ -300,22 +303,22 @@ def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_ with self._collections_lock: # Check if we need to re-prepare tensors cache_valid = ( - self._prepared_device == device and - self._prepared_dtype == dtype and - self._prepared_batch == batch_size and - len(self._prepared_tensors) == len(self.controlnet_images) + self._prepared_device == device + and self._prepared_dtype == dtype + and self._prepared_batch == batch_size + and len(self._prepared_tensors) == len(self.controlnet_images) ) - + if cache_valid: return - + # Prepare tensors for current frame self._prepared_tensors = [] for img in self.controlnet_images: if img is None: self._prepared_tensors.append(None) continue - + # Prepare tensor with correct batch size prepared = img if prepared.dim() == 4 and prepared.shape[0] != batch_size: @@ -324,63 +327,62 @@ def prepare_frame_tensors(self, device: torch.device, dtype: torch.dtype, batch_ else: repeat_factor = max(1, batch_size // prepared.shape[0]) prepared = prepared.repeat(repeat_factor, 1, 1, 1)[:batch_size] - + # Move to correct device and dtype prepared = prepared.to(device=device, dtype=dtype) self._prepared_tensors.append(prepared) - + # Update cache state self._prepared_device = device self._prepared_dtype = dtype self._prepared_batch = batch_size - def _get_cached_sdxl_conditioning(self, ctx: 'StepCtx') -> Optional[Dict[str, torch.Tensor]]: + def _get_cached_sdxl_conditioning(self, ctx: "StepCtx") -> Optional[Dict[str, torch.Tensor]]: """Get cached SDXL conditioning to avoid repeated preparation""" if not self._is_sdxl or ctx.sdxl_cond is None: return None - + # Check if cache is valid if self._sdxl_conditioning_valid and self._sdxl_conditioning_cache is not None: cached = self._sdxl_conditioning_cache # Verify batch size matches current context - if ('text_embeds' in cached and - cached['text_embeds'].shape[0] == ctx.x_t_latent.shape[0]): + if "text_embeds" in cached and cached["text_embeds"].shape[0] == ctx.x_t_latent.shape[0]: return cached - + # Cache miss or invalid - prepare new conditioning try: conditioning = {} - if 'text_embeds' in ctx.sdxl_cond: - text_embeds = ctx.sdxl_cond['text_embeds'] + if "text_embeds" in ctx.sdxl_cond: + text_embeds = ctx.sdxl_cond["text_embeds"] batch_size = ctx.x_t_latent.shape[0] - + # Optimize batch expansion for SDXL text embeddings if text_embeds.shape[0] != batch_size: if text_embeds.shape[0] == 1: - conditioning['text_embeds'] = text_embeds.repeat(batch_size, 1) + conditioning["text_embeds"] = text_embeds.repeat(batch_size, 1) else: - conditioning['text_embeds'] = text_embeds[:batch_size] + conditioning["text_embeds"] = text_embeds[:batch_size] else: - conditioning['text_embeds'] = text_embeds - - if 'time_ids' in ctx.sdxl_cond: - time_ids = ctx.sdxl_cond['time_ids'] + conditioning["text_embeds"] = text_embeds + + if "time_ids" in ctx.sdxl_cond: + time_ids = ctx.sdxl_cond["time_ids"] batch_size = ctx.x_t_latent.shape[0] - + # Optimize batch expansion for SDXL time IDs if time_ids.shape[0] != batch_size: if time_ids.shape[0] == 1: - conditioning['time_ids'] = time_ids.repeat(batch_size, 1) + conditioning["time_ids"] = time_ids.repeat(batch_size, 1) else: - conditioning['time_ids'] = time_ids[:batch_size] + conditioning["time_ids"] = time_ids[:batch_size] else: - conditioning['time_ids'] = time_ids - + conditioning["time_ids"] = time_ids + # Cache the prepared conditioning self._sdxl_conditioning_cache = conditioning self._sdxl_conditioning_valid = True return conditioning - + except Exception: # Fallback to original conditioning on any error return ctx.sdxl_cond @@ -399,8 +401,10 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # Single pass to collect active ControlNet data active_data = [] enabled_flags = self.enabled_list if len(self.enabled_list) == len(self.controlnets) else None - - for i, (cn, img, scale) in enumerate(zip(self.controlnets, self.controlnet_images, self.controlnet_scales)): + + for i, (cn, img, scale) in enumerate( + zip(self.controlnets, self.controlnet_images, self.controlnet_scales) + ): if cn is not None and img is not None and scale > 0: enabled = enabled_flags[i] if enabled_flags else True if enabled: @@ -413,9 +417,11 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if not self._engines_cache_valid: self._engines_by_id.clear() try: - if hasattr(self._stream, 'controlnet_engines') and isinstance(self._stream.controlnet_engines, list): + if hasattr(self._stream, "controlnet_engines") and isinstance( + self._stream.controlnet_engines, list + ): for eng in self._stream.controlnet_engines: - mid = getattr(eng, 'model_id', None) + mid = getattr(eng, "model_id", None) if mid: self._engines_by_id[mid] = eng self._engines_cache_valid = True @@ -425,17 +431,17 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # Cache SDXL detection to avoid repeated hasattr calls if self._is_sdxl is None: try: - self._is_sdxl = getattr(self._stream, 'is_sdxl', False) + self._is_sdxl = getattr(self._stream, "is_sdxl", False) except Exception: self._is_sdxl = False - encoder_hidden_states = self._stream.prompt_embeds[:, :self._expected_text_len, :] + encoder_hidden_states = self._stream.prompt_embeds[:, : self._expected_text_len, :] base_kwargs: Dict[str, Any] = { - 'sample': x_t, - 'timestep': t_list, - 'encoder_hidden_states': encoder_hidden_states, - 'return_dict': False, + "sample": x_t, + "timestep": t_list, + "encoder_hidden_states": encoder_hidden_states, + "return_dict": False, } down_samples_list: List[List[torch.Tensor]] = [] @@ -443,20 +449,22 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # Ensure tensors are prepared for this frame # This should have been called earlier, but we call it here as a safety net - if (self._prepared_device != x_t.device or - self._prepared_dtype != x_t.dtype or - self._prepared_batch != x_t.shape[0]): + if ( + self._prepared_device != x_t.device + or self._prepared_dtype != x_t.dtype + or self._prepared_batch != x_t.shape[0] + ): self.prepare_frame_tensors(x_t.device, x_t.dtype, x_t.shape[0]) - + # Use pre-prepared tensors prepared_images = self._prepared_tensors for cn, img, scale, idx_i in active_data: # Swap to TRT engine if available for this model_id (use cached lookup) - model_id = getattr(cn, 'model_id', None) + model_id = getattr(cn, "model_id", None) if model_id and model_id in self._engines_by_id: cn = self._engines_by_id[model_id] - + # Use pre-prepared tensor current_img = prepared_images[idx_i] if idx_i < len(prepared_images) else img if current_img is None: @@ -467,12 +475,12 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if cache_key in self._engine_type_cache: is_trt_engine = self._engine_type_cache[cache_key] else: - is_trt_engine = hasattr(cn, 'engine') and hasattr(cn, 'stream') + is_trt_engine = hasattr(cn, "engine") and hasattr(cn, "stream") self._engine_type_cache[cache_key] = is_trt_engine - + # Get optimized SDXL conditioning (uses caching to avoid repeated tensor operations) added_cond_kwargs = self._get_cached_sdxl_conditioning(ctx) - + try: if is_trt_engine: # TensorRT engine path @@ -483,7 +491,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: encoder_hidden_states=encoder_hidden_states, controlnet_cond=current_img, conditioning_scale=float(scale), - **added_cond_kwargs + **added_cond_kwargs, ) else: down_samples, mid_sample = cn( @@ -491,7 +499,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: timestep=t_list, encoder_hidden_states=encoder_hidden_states, controlnet_cond=current_img, - conditioning_scale=float(scale) + conditioning_scale=float(scale), ) else: # PyTorch ControlNet path @@ -503,7 +511,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: controlnet_cond=current_img, conditioning_scale=float(scale), return_dict=False, - added_cond_kwargs=added_cond_kwargs + added_cond_kwargs=added_cond_kwargs, ) else: down_samples, mid_sample = cn( @@ -512,21 +520,30 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: encoder_hidden_states=encoder_hidden_states, controlnet_cond=current_img, conditioning_scale=float(scale), - return_dict=False + return_dict=False, ) except Exception as e: import traceback - __import__('logging').getLogger(__name__).error("ControlNetModule: controlnet forward failed: %s", e) + + __import__("logging").getLogger(__name__).error( + "ControlNetModule: controlnet forward failed: %s", e + ) try: - __import__('logging').getLogger(__name__).error("ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s", - (tuple(encoder_hidden_states.shape) if isinstance(encoder_hidden_states, torch.Tensor) else None), - (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None), - scale, - self._is_sdxl, - is_trt_engine) + __import__("logging").getLogger(__name__).error( + "ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s", + ( + tuple(encoder_hidden_states.shape) + if isinstance(encoder_hidden_states, torch.Tensor) + else None + ), + (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None), + scale, + self._is_sdxl, + is_trt_engine, + ) except Exception: pass - __import__('logging').getLogger(__name__).error(traceback.format_exc()) + __import__("logging").getLogger(__name__).error(traceback.format_exc()) continue down_samples_list.append(down_samples) mid_samples_list.append(mid_sample) @@ -555,48 +572,53 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: return _unet_hook - def _prepare_control_image(self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any]) -> torch.Tensor: + def _prepare_control_image( + self, control_image: Union[str, Any, torch.Tensor], preprocessor: Optional[Any] + ) -> torch.Tensor: if self._preprocessing_orchestrator is None: raise RuntimeError("ControlNetModule: preprocessing orchestrator is not initialized") # Reuse orchestrator API used by BaseControlNetPipeline images = self._preprocessing_orchestrator.process_sync( - control_image, - [preprocessor], - [1.0], - self._stream.width, - self._stream.height, - 0 + control_image, [preprocessor], [1.0], self._stream.width, self._stream.height, 0 ) # API returns a list; pick first if present return images[0] if images else None - #FIXME: more robust model management is needed in general. - def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: Optional[int] = None) -> ControlNetModel: - from pathlib import Path - import logging + # FIXME: more robust model management is needed in general. + def _load_pytorch_controlnet_model( + self, model_id: str, conditioning_channels: Optional[int] = None + ) -> ControlNetModel: import os + from pathlib import Path + logger = logging.getLogger(__name__) - + try: # Prepare loading kwargs load_kwargs = {"torch_dtype": self.dtype} if conditioning_channels is not None: load_kwargs["conditioning_channels"] = conditioning_channels - + # Check if offline mode is enabled via environment variables - is_offline = os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1" - + is_offline = ( + os.environ.get("HF_HUB_OFFLINE", "0") == "1" or os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1" + ) + if Path(model_id).exists(): model_path = Path(model_id) - + # Check if it's a direct file path to a safetensors/ckpt file - if model_path.is_file() and model_path.suffix in ['.safetensors', '.ckpt', '.bin']: - logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Loading ControlNet from single file: {model_path} (channels={conditioning_channels})") + if model_path.is_file() and model_path.suffix in [".safetensors", ".ckpt", ".bin"]: + logger.info( + f"ControlNetModule._load_pytorch_controlnet_model: Loading ControlNet from single file: {model_path} (channels={conditioning_channels})" + ) # Try loading from single file (works for most ControlNet models) try: controlnet = ControlNetModel.from_single_file(str(model_path), **load_kwargs) except Exception as e: - logger.warning(f"ControlNetModule._load_pytorch_controlnet_model: Single file loading failed: {e}") + logger.warning( + f"ControlNetModule._load_pytorch_controlnet_model: Single file loading failed: {e}" + ) # Fallback: try pretrained loading in case it's in a proper directory structure load_kwargs["local_files_only"] = True controlnet = ControlNetModel.from_pretrained(str(model_path.parent), **load_kwargs) @@ -608,29 +630,27 @@ def _load_pytorch_controlnet_model(self, model_id: str, conditioning_channels: O # Loading from HuggingFace Hub - respect offline mode if is_offline: load_kwargs["local_files_only"] = True - logger.info(f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only") - + logger.info( + f"ControlNetModule._load_pytorch_controlnet_model: Offline mode enabled, loading '{model_id}' from cache only" + ) + if "/" in model_id and model_id.count("/") > 1: parts = model_id.split("/") repo_id = "/".join(parts[:2]) subfolder = "/".join(parts[2:]) - controlnet = ControlNetModel.from_pretrained( - repo_id, subfolder=subfolder, **load_kwargs - ) + controlnet = ControlNetModel.from_pretrained(repo_id, subfolder=subfolder, **load_kwargs) else: controlnet = ControlNetModel.from_pretrained(model_id, **load_kwargs) controlnet = controlnet.to(device=self.device, dtype=self.dtype) # Track model_id for updater diffing try: - setattr(controlnet, 'model_id', model_id) + setattr(controlnet, "model_id", model_id) except Exception: pass return controlnet except Exception as e: import traceback + logger.error(f"ControlNetModule: failed to load model '{model_id}': {e}") logger.error(traceback.format_exc()) raise - - - diff --git a/src/streamdiffusion/modules/image_processing_module.py b/src/streamdiffusion/modules/image_processing_module.py index ffea6e5fe..b96f0c0b6 100644 --- a/src/streamdiffusion/modules/image_processing_module.py +++ b/src/streamdiffusion/modules/image_processing_module.py @@ -1,55 +1,57 @@ -from typing import List, Optional, Any, Dict +from typing import Any, Dict, List + import torch -from ..preprocessing.orchestrator_user import OrchestratorUser -from ..preprocessing.pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator from ..hooks import ImageCtx, ImageHook +from ..preprocessing.orchestrator_user import OrchestratorUser class ImageProcessingModule(OrchestratorUser): """ Shared base class for image domain processing modules. - + Handles sequential chain execution for both preprocessing and postprocessing timing variants. Processing domain is always image tensors. """ - + def __init__(self): """Initialize image processing module.""" self.processors = [] - + def _process_image_chain(self, input_image: torch.Tensor) -> torch.Tensor: """Execute sequential chain of processors in image domain. - + Uses the shared orchestrator's sequential chain processing. """ if not self.processors: return input_image - + ordered_processors = self._get_ordered_processors() return self._preprocessing_orchestrator.execute_pipeline_chain( input_image, ordered_processors, processing_domain="image" ) - + def add_processor(self, proc_config: Dict[str, Any]) -> None: """Add a processor using the existing registry, following ControlNet pattern.""" from streamdiffusion.preprocessing.processors import get_preprocessor - - processor_type = proc_config.get('type') + + processor_type = proc_config.get("type") if not processor_type: raise ValueError("Processor config missing 'type' field") - + # Check if processor is enabled (default to True, same as ControlNet) - enabled = proc_config.get('enabled', True) - + enabled = proc_config.get("enabled", True) + # Create processor using existing registry (same as ControlNet) # ImageProcessingModule uses 'pipeline' normalization context - processor = get_preprocessor(processor_type, pipeline_ref=getattr(self, '_stream', None), normalization_context='pipeline') - + processor = get_preprocessor( + processor_type, pipeline_ref=getattr(self, "_stream", None), normalization_context="pipeline" + ) + # Apply parameters (same pattern as ControlNet) - processor_params = proc_config.get('params', {}) + processor_params = proc_config.get("params", {}) if processor_params: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): + if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict): processor.params.update(processor_params) for name, value in processor_params.items(): try: @@ -57,109 +59,109 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None: setattr(processor, name, value) except Exception: pass - + # Set order for sequential execution - order = proc_config.get('order', len(self.processors)) - setattr(processor, 'order', order) - + order = proc_config.get("order", len(self.processors)) + setattr(processor, "order", order) + # Set enabled state - setattr(processor, 'enabled', enabled) - + setattr(processor, "enabled", enabled) + # Align preprocessor target size with stream resolution (same as ControlNet) - if hasattr(self, '_stream'): + if hasattr(self, "_stream"): try: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): - processor.params['image_width'] = int(self._stream.width) - processor.params['image_height'] = int(self._stream.height) - if hasattr(processor, 'image_width'): - setattr(processor, 'image_width', int(self._stream.width)) - if hasattr(processor, 'image_height'): - setattr(processor, 'image_height', int(self._stream.height)) + if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict): + processor.params["image_width"] = int(self._stream.width) + processor.params["image_height"] = int(self._stream.height) + if hasattr(processor, "image_width"): + setattr(processor, "image_width", int(self._stream.width)) + if hasattr(processor, "image_height"): + setattr(processor, "image_height", int(self._stream.height)) except Exception: pass - + self.processors.append(processor) - + def _get_ordered_processors(self) -> List[Any]: """Return enabled processors in execution order based on their order attribute.""" # Filter for enabled processors first, then sort by order - enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)] - return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0)) + enabled_processors = [p for p in self.processors if getattr(p, "enabled", True)] + return sorted(enabled_processors, key=lambda p: getattr(p, "order", 0)) class ImagePreprocessingModule(ImageProcessingModule): """ Image domain preprocessing module - executes before VAE encoding. - + Timing: After image_processor.preprocess(), before similar_image_filter Uses pipelined processing for performance optimization. """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrators.""" self._stream = stream # Store stream reference for dimension access self.attach_orchestrator(stream) # For sequential chain processing (fallback) self.attach_pipeline_preprocessing_orchestrator(stream) # For pipelined processing stream.image_preprocessing_hooks.append(self.build_image_hook()) - + def build_image_hook(self) -> ImageHook: """Build hook function that processes image context with pipelined processing.""" + def hook(ctx: ImageCtx) -> ImageCtx: ctx.image = self._process_image_pipelined(ctx.image) return ctx + return hook - + def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor: """Execute pipelined processing of preprocessors for performance. - + Uses PipelinePreprocessingOrchestrator for Frame N-1 results while starting Frame N processing. Falls back to synchronous processing when needed. """ if not self.processors: return input_image - + ordered_processors = self._get_ordered_processors() - + # Use pipelined pipeline preprocessing orchestrator for performance - return self._pipeline_preprocessing_orchestrator.process_pipelined( - input_image, ordered_processors - ) + return self._pipeline_preprocessing_orchestrator.process_pipelined(input_image, ordered_processors) class ImagePostprocessingModule(ImageProcessingModule): """ Image domain postprocessing module - executes after VAE decoding. - + Timing: After decode_image(), before returning final output Uses pipelined processing for performance optimization. """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrators.""" self._stream = stream # Store stream reference for dimension access self.attach_preprocessing_orchestrator(stream) # For sequential chain processing (fallback) self.attach_postprocessing_orchestrator(stream) # For pipelined processing stream.image_postprocessing_hooks.append(self.build_image_hook()) - + def build_image_hook(self) -> ImageHook: """Build hook function that processes image context with pipelined processing.""" + def hook(ctx: ImageCtx) -> ImageCtx: ctx.image = self._process_image_pipelined(ctx.image) return ctx + return hook - + def _process_image_pipelined(self, input_image: torch.Tensor) -> torch.Tensor: """Execute pipelined processing of postprocessors for performance. - + Uses PostprocessingOrchestrator for Frame N-1 results while starting Frame N processing. Falls back to synchronous processing when needed. """ if not self.processors: return input_image - + ordered_processors = self._get_ordered_processors() - + # Use pipelined postprocessing orchestrator for performance - return self._postprocessing_orchestrator.process_pipelined( - input_image, ordered_processors - ) + return self._postprocessing_orchestrator.process_pipelined(input_image, ordered_processors) diff --git a/src/streamdiffusion/modules/ipadapter_module.py b/src/streamdiffusion/modules/ipadapter_module.py index b283799f3..4e3b95ead 100644 --- a/src/streamdiffusion/modules/ipadapter_module.py +++ b/src/streamdiffusion/modules/ipadapter_module.py @@ -1,16 +1,18 @@ from __future__ import annotations +import logging +import os from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Any from enum import Enum +from typing import Any, Dict, Optional, Tuple + import torch -from streamdiffusion.hooks import EmbedsCtx, EmbeddingHook, StepCtx, UnetKwargsDelta, UnetHook -import os +from streamdiffusion.hooks import EmbeddingHook, EmbedsCtx, StepCtx, UnetHook, UnetKwargsDelta from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser -import logging from streamdiffusion.utils.reporting import report_error + logger = logging.getLogger(__name__) @@ -27,6 +29,7 @@ class IPAdapterConfig: This module focuses only on embedding composition (step 2 of migration). Runtime installation and wrapper wiring will come in later steps. """ + style_image_key: Optional[str] = None num_image_tokens: int = 4 # e.g., 4 for standard, 16 for plus ipadapter_model_path: Optional[str] = None @@ -59,7 +62,7 @@ class IPAdapterConfig: "image_encoder_path": "h94/IP-Adapter/models/image_encoder", }, ("SD2.1", IPAdapterType.REGULAR): None, # not available from h94 (ip-adapter_sd21.bin was never released) - ("SD2.1", IPAdapterType.PLUS): None, # not available from h94 + ("SD2.1", IPAdapterType.PLUS): None, # not available from h94 ("SD2.1", IPAdapterType.FACEID): None, # not available from h94 ("SDXL", IPAdapterType.REGULAR): { "model_path": "h94/IP-Adapter/sdxl_models/ip-adapter_sdxl.bin", @@ -78,15 +81,15 @@ class IPAdapterConfig: # Set of all known HF model paths — used to distinguish known vs custom paths. # Custom/local paths are never overridden. _KNOWN_IPADAPTER_PATHS: frozenset = frozenset( - entry["model_path"] - for entry in IPADAPTER_MODEL_MAP.values() - if entry is not None + entry["model_path"] for entry in IPADAPTER_MODEL_MAP.values() if entry is not None ) -_KNOWN_ENCODER_PATHS: frozenset = frozenset({ - "h94/IP-Adapter/models/image_encoder", - "h94/IP-Adapter/sdxl_models/image_encoder", -}) +_KNOWN_ENCODER_PATHS: frozenset = frozenset( + { + "h94/IP-Adapter/models/image_encoder", + "h94/IP-Adapter/sdxl_models/image_encoder", + } +) def _normalize_model_type(detected_model_type: str, is_sdxl: bool) -> Optional[str]: @@ -183,10 +186,7 @@ def resolve_ipadapter_paths( # Resolve encoder path (only if it's a known HF encoder — custom encoders untouched) if current_encoder_path in _KNOWN_ENCODER_PATHS and current_encoder_path != correct_encoder_path: - logger.info( - f"IP-Adapter: resolving image encoder " - f"'{current_encoder_path}' → '{correct_encoder_path}'." - ) + logger.info(f"IP-Adapter: resolving image encoder '{current_encoder_path}' → '{correct_encoder_path}'.") cfg["image_encoder_path"] = correct_encoder_path return cfg @@ -209,7 +209,9 @@ def build_embedding_hook(self, stream) -> EmbeddingHook: def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: # Fetch cached image token embeddings (prompt, negative) - cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings(style_key) + cached: Optional[Tuple[torch.Tensor, torch.Tensor]] = stream._param_updater.get_cached_embeddings( + style_key + ) image_prompt_tokens: Optional[torch.Tensor] = None image_negative_tokens: Optional[torch.Tensor] = None if cached is not None: @@ -220,7 +222,9 @@ def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: batch_size = ctx.prompt_embeds.shape[0] if image_prompt_tokens is None: image_prompt_tokens = torch.zeros( - (batch_size, num_tokens, hidden_dim), dtype=ctx.prompt_embeds.dtype, device=ctx.prompt_embeds.device + (batch_size, num_tokens, hidden_dim), + dtype=ctx.prompt_embeds.dtype, + device=ctx.prompt_embeds.device, ) else: if image_prompt_tokens.shape[1] != num_tokens: @@ -242,7 +246,9 @@ def _embedding_hook(ctx: EmbedsCtx) -> EmbedsCtx: if neg_with_image is not None: if image_negative_tokens is None: image_negative_tokens = torch.zeros( - (neg_with_image.shape[0], num_tokens, hidden_dim), dtype=neg_with_image.dtype, device=neg_with_image.device + (neg_with_image.shape[0], num_tokens, hidden_dim), + dtype=neg_with_image.dtype, + device=neg_with_image.device, ) else: if image_negative_tokens.shape[0] != neg_with_image.shape[0]: @@ -291,14 +297,14 @@ def install(self, stream) -> None: # Create IP-Adapter and install processors into UNet (FaceID-aware) ip_kwargs = { - 'pipe': stream.pipe, - 'ipadapter_ckpt_path': resolved_ip_path, - 'image_encoder_path': resolved_encoder_path, - 'device': stream.device, - 'dtype': stream.dtype, + "pipe": stream.pipe, + "ipadapter_ckpt_path": resolved_ip_path, + "image_encoder_path": resolved_encoder_path, + "device": stream.device, + "dtype": stream.dtype, } if self.config.type == IPAdapterType.FACEID and self.config.insightface_model_name: - ip_kwargs['insightface_model_name'] = self.config.insightface_model_name + ip_kwargs["insightface_model_name"] = self.config.insightface_model_name print( f"IPAdapterModule.install: Initializing FaceID IP-Adapter with InsightFace model: {self.config.insightface_model_name}" ) @@ -311,6 +317,7 @@ def install(self, stream) -> None: # AttnProcessor2_0 which accepts kvo_cache and returns (hidden_states, kvo_cache). try: from diffusers.models.attention_processor import AttnProcessor2_0 as NativeAttnProcessor2_0 + attn_procs = stream.pipe.unet.attn_processors for name in attn_procs: if name.endswith("attn1.processor"): @@ -324,6 +331,7 @@ def install(self, stream) -> None: if self.config.type == IPAdapterType.FACEID: try: from streamdiffusion.preprocessing.processors.faceid_embedding import FaceIDEmbeddingPreprocessor + embedding_preprocessor = FaceIDEmbeddingPreprocessor( ipadapter=ipadapter, device=stream.device, @@ -357,11 +365,11 @@ def install(self, stream) -> None: # Expose IPAdapter instance as single source of truth try: - setattr(stream, 'ipadapter', ipadapter) + setattr(stream, "ipadapter", ipadapter) # Extend IPAdapter with our custom attributes since diffusers IPAdapter doesn't expose current state - setattr(ipadapter, 'weight_type', self.config.weight_type) # For build_layer_weights - setattr(ipadapter, 'scale', float(self.config.scale)) # Track current scale - setattr(ipadapter, 'enabled', bool(self.config.enabled)) # Track enabled state + setattr(ipadapter, "weight_type", self.config.weight_type) # For build_layer_weights + setattr(ipadapter, "scale", float(self.config.scale)) # Track current scale + setattr(ipadapter, "enabled", bool(self.config.enabled)) # Track enabled state except Exception: pass @@ -389,7 +397,10 @@ def _resolve_model_path(self, model_path: Optional[str]) -> str: from huggingface_hub import hf_hub_download, snapshot_download except Exception as e: import logging - logging.getLogger(__name__).error(f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}") + + logging.getLogger(__name__).error( + f"IPAdapterModule: huggingface_hub required to resolve '{model_path}': {e}" + ) raise parts = model_path.split("/") @@ -419,28 +430,28 @@ def build_unet_hook(self, stream) -> UnetHook: - For PyTorch UNet with installed IP processors, modulate per-layer processor scale by time factor """ _last_enabled_state = None # Track previous enabled state to avoid redundant updates - + def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # If no IP-Adapter installed, do nothing - if not hasattr(stream, 'ipadapter') or stream.ipadapter is None: + if not hasattr(stream, "ipadapter") or stream.ipadapter is None: return UnetKwargsDelta() # Check if IPAdapter is enabled - enabled = getattr(stream.ipadapter, 'enabled', True) + enabled = getattr(stream.ipadapter, "enabled", True) # Read base weight and weight type from IPAdapter instance try: - base_weight = float(getattr(stream.ipadapter, 'scale', 1.0)) if enabled else 0.0 + base_weight = float(getattr(stream.ipadapter, "scale", 1.0)) if enabled else 0.0 except Exception: base_weight = 0.0 if not enabled else 1.0 - weight_type = getattr(stream.ipadapter, 'weight_type', None) + weight_type = getattr(stream.ipadapter, "weight_type", None) # Determine total steps and current step index for time scheduling total_steps = None try: - if hasattr(stream, 'denoising_steps_num') and isinstance(stream.denoising_steps_num, int): + if hasattr(stream, "denoising_steps_num") and isinstance(stream.denoising_steps_num, int): total_steps = int(stream.denoising_steps_num) - elif hasattr(stream, 't_list') and stream.t_list is not None: + elif hasattr(stream, "t_list") and stream.t_list is not None: total_steps = len(stream.t_list) except Exception: total_steps = None @@ -449,6 +460,7 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if total_steps is not None and ctx.step_index is not None: try: from diffusers_ipadapter.ip_adapter.attention_processor import build_time_weight_factor + time_factor = float(build_time_weight_factor(weight_type, int(ctx.step_index), int(total_steps))) except Exception: # Do not add fallback mechanisms @@ -456,18 +468,20 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: # TensorRT engine path: supply ipadapter_scale vector via extra kwargs try: - is_trt_unet = hasattr(stream, 'unet') and hasattr(stream.unet, 'engine') and hasattr(stream.unet, 'stream') + is_trt_unet = ( + hasattr(stream, "unet") and hasattr(stream.unet, "engine") and hasattr(stream.unet, "stream") + ) except Exception: is_trt_unet = False - if is_trt_unet and getattr(stream.unet, 'use_ipadapter', False): + if is_trt_unet and getattr(stream.unet, "use_ipadapter", False): try: from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights except Exception: # If helper unavailable, do not construct weights here build_layer_weights = None # type: ignore - num_ip_layers = getattr(stream.unet, 'num_ip_layers', None) + num_ip_layers = getattr(stream.unet, "num_ip_layers", None) if isinstance(num_ip_layers, int) and num_ip_layers > 0: weights_tensor = None try: @@ -476,24 +490,26 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: except Exception: weights_tensor = None if weights_tensor is None: - weights_tensor = torch.full((num_ip_layers,), float(base_weight), dtype=torch.float32, device=stream.device) + weights_tensor = torch.full( + (num_ip_layers,), float(base_weight), dtype=torch.float32, device=stream.device + ) # Apply per-step time factor try: weights_tensor = weights_tensor * float(time_factor) except Exception: pass - return UnetKwargsDelta(extra_unet_kwargs={'ipadapter_scale': weights_tensor}) + return UnetKwargsDelta(extra_unet_kwargs={"ipadapter_scale": weights_tensor}) # PyTorch UNet path: modulate installed processor scales by time factor and enabled state try: nonlocal _last_enabled_state # Only process if we need to make changes (time scaling or state transition) - needs_update = (time_factor != 1.0 or enabled != _last_enabled_state) - if needs_update and hasattr(stream.pipe, 'unet') and hasattr(stream.pipe.unet, 'attn_processors'): + needs_update = time_factor != 1.0 or enabled != _last_enabled_state + if needs_update and hasattr(stream.pipe, "unet") and hasattr(stream.pipe.unet, "attn_processors"): _last_enabled_state = enabled for proc in stream.pipe.unet.attn_processors.values(): - if hasattr(proc, 'scale') and hasattr(proc, '_ip_layer_index'): - base_val = getattr(proc, '_base_scale', proc.scale) + if hasattr(proc, "scale") and hasattr(proc, "_ip_layer_index"): + base_val = getattr(proc, "_base_scale", proc.scale) # Apply both enabled state and time factor final_scale = float(base_val) * float(time_factor) if enabled else 0.0 proc.scale = final_scale @@ -503,4 +519,3 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: return UnetKwargsDelta() return _unet_hook - diff --git a/src/streamdiffusion/modules/latent_processing_module.py b/src/streamdiffusion/modules/latent_processing_module.py index 256c66f0e..78edf1b37 100644 --- a/src/streamdiffusion/modules/latent_processing_module.py +++ b/src/streamdiffusion/modules/latent_processing_module.py @@ -1,54 +1,55 @@ -from typing import List, Optional, Any, Dict +from typing import Any, Dict, List + import torch -from ..preprocessing.orchestrator_user import OrchestratorUser from ..hooks import LatentCtx, LatentHook +from ..preprocessing.orchestrator_user import OrchestratorUser class LatentProcessingModule(OrchestratorUser): """ Shared base class for latent domain processing modules. - + Handles sequential chain execution for both preprocessing and postprocessing timing variants. Processing domain is always latent tensors. """ - + def __init__(self): """Initialize latent processing module.""" self.processors = [] - + def _process_latent_chain(self, input_latent: torch.Tensor) -> torch.Tensor: """Execute sequential chain of processors in latent domain. - + Uses the shared orchestrator's sequential chain processing. """ if not self.processors: return input_latent - + ordered_processors = self._get_ordered_processors() return self._preprocessing_orchestrator.execute_pipeline_chain( input_latent, ordered_processors, processing_domain="latent" ) - + def add_processor(self, proc_config: Dict[str, Any]) -> None: """Add a processor using the existing registry, following ControlNet pattern.""" from streamdiffusion.preprocessing.processors import get_preprocessor - - processor_type = proc_config.get('type') + + processor_type = proc_config.get("type") if not processor_type: raise ValueError("Processor config missing 'type' field") - + # Check if processor is enabled (default to True, same as ControlNet) - enabled = proc_config.get('enabled', True) - + enabled = proc_config.get("enabled", True) + # Create processor using existing registry (same as ControlNet) # LatentProcessingModule uses 'latent' normalization context (works in latent space) - processor = get_preprocessor(processor_type, pipeline_ref=self._stream, normalization_context='latent') - + processor = get_preprocessor(processor_type, pipeline_ref=self._stream, normalization_context="latent") + # Apply parameters (same pattern as ControlNet) - processor_params = proc_config.get('params', {}) + processor_params = proc_config.get("params", {}) if processor_params: - if hasattr(processor, 'params') and isinstance(getattr(processor, 'params'), dict): + if hasattr(processor, "params") and isinstance(getattr(processor, "params"), dict): processor.params.update(processor_params) for name, value in processor_params.items(): try: @@ -56,62 +57,66 @@ def add_processor(self, proc_config: Dict[str, Any]) -> None: setattr(processor, name, value) except Exception: pass - + # Set order for sequential execution - order = proc_config.get('order', len(self.processors)) - setattr(processor, 'order', order) - + order = proc_config.get("order", len(self.processors)) + setattr(processor, "order", order) + # Set enabled state - setattr(processor, 'enabled', enabled) - + setattr(processor, "enabled", enabled) + # Pipeline reference is now automatically handled by the factory function - + self.processors.append(processor) - + def _get_ordered_processors(self) -> List[Any]: """Return enabled processors in execution order based on their order attribute.""" # Filter for enabled processors first, then sort by order - enabled_processors = [p for p in self.processors if getattr(p, 'enabled', True)] - return sorted(enabled_processors, key=lambda p: getattr(p, 'order', 0)) + enabled_processors = [p for p in self.processors if getattr(p, "enabled", True)] + return sorted(enabled_processors, key=lambda p: getattr(p, "order", 0)) class LatentPreprocessingModule(LatentProcessingModule): """ Latent domain preprocessing module - executes after VAE encoding, before diffusion. - + Timing: After encode_image(), before predict_x0_batch() """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrator.""" self.attach_orchestrator(stream) self._stream = stream # Store stream reference like ControlNet module does stream.latent_preprocessing_hooks.append(self.build_latent_hook()) - + def build_latent_hook(self) -> LatentHook: """Build hook function that processes latent context.""" + def hook(ctx: LatentCtx) -> LatentCtx: ctx.latent = self._process_latent_chain(ctx.latent) return ctx + return hook class LatentPostprocessingModule(LatentProcessingModule): """ Latent domain postprocessing module - executes after diffusion, before VAE decoding. - + Timing: After predict_x0_batch(), before decode_image() """ - + def install(self, stream) -> None: """Install module by registering hook with stream and attaching orchestrator.""" self.attach_orchestrator(stream) self._stream = stream # Store stream reference like ControlNet module does stream.latent_postprocessing_hooks.append(self.build_latent_hook()) - + def build_latent_hook(self) -> LatentHook: """Build hook function that processes latent context.""" + def hook(ctx: LatentCtx) -> LatentCtx: ctx.latent = self._process_latent_chain(ctx.latent) return ctx + return hook diff --git a/src/streamdiffusion/pip_utils.py b/src/streamdiffusion/pip_utils.py index 6ae3f11cd..4a28c0a0e 100644 --- a/src/streamdiffusion/pip_utils.py +++ b/src/streamdiffusion/pip_utils.py @@ -27,13 +27,16 @@ def _check_torch_installed(): raise RuntimeError(msg) if not torch.version.cuda: - raise RuntimeError("Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package.") + raise RuntimeError( + "Detected CPU-only PyTorch. Install CUDA-enabled torch/vision/audio before installing this package." + ) def get_cuda_version() -> str | None: _check_torch_installed() import torch + return torch.version.cuda @@ -66,7 +69,7 @@ def is_installed(package: str) -> bool: def run_python(command: str, env: Dict[str, str] | None = None) -> str: run_kwargs = { - "args": f"\"{python}\" {command}", + "args": f'"{python}" {command}', "shell": True, "env": os.environ if env is None else env, "encoding": "utf8", diff --git a/src/streamdiffusion/preprocessing/__init__.py b/src/streamdiffusion/preprocessing/__init__.py index 4228ee69b..c52a8a2ed 100644 --- a/src/streamdiffusion/preprocessing/__init__.py +++ b/src/streamdiffusion/preprocessing/__init__.py @@ -1,13 +1,14 @@ -from .preprocessing_orchestrator import PreprocessingOrchestrator -from .postprocessing_orchestrator import PostprocessingOrchestrator -from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator from .base_orchestrator import BaseOrchestrator from .orchestrator_user import OrchestratorUser +from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator +from .postprocessing_orchestrator import PostprocessingOrchestrator +from .preprocessing_orchestrator import PreprocessingOrchestrator + __all__ = [ "PreprocessingOrchestrator", "PostprocessingOrchestrator", "PipelinePreprocessingOrchestrator", "BaseOrchestrator", - "OrchestratorUser" + "OrchestratorUser", ] diff --git a/src/streamdiffusion/preprocessing/base_orchestrator.py b/src/streamdiffusion/preprocessing/base_orchestrator.py index d6d86bf2b..e5f6c1b22 100644 --- a/src/streamdiffusion/preprocessing/base_orchestrator.py +++ b/src/streamdiffusion/preprocessing/base_orchestrator.py @@ -1,144 +1,148 @@ -import torch -from typing import List, Optional, Union, Dict, Any, Tuple, Callable, TypeVar, Generic -from abc import ABC, abstractmethod -import numpy as np import concurrent.futures import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Optional, TypeVar + +import torch + logger = logging.getLogger(__name__) # Type variables for generic orchestrator -T = TypeVar('T') # Input type (e.g., ControlImage for preprocessing) -R = TypeVar('R') # Result type (e.g., List[torch.Tensor] for preprocessing) +T = TypeVar("T") # Input type (e.g., ControlImage for preprocessing) +R = TypeVar("R") # Result type (e.g., List[torch.Tensor] for preprocessing) class BaseOrchestrator(Generic[T, R], ABC): """ Generic base orchestrator for parallelized and pipelined processing. - + Handles thread pool management, pipeline state, and inter-frame pipelining while leaving domain-specific processing logic to subclasses. - + Type Parameters: T: Input type for processing operations R: Result type returned from processing operations """ - - def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, timeout_ms: float = 10.0, pipeline_ref: Optional[Any] = None): + + def __init__( + self, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_workers: int = 4, + timeout_ms: float = 10.0, + pipeline_ref: Optional[Any] = None, + ): self.device = device self.dtype = dtype self.timeout_ms = timeout_ms self.pipeline_ref = pipeline_ref self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) - + # Pipeline state for pipelined processing self._next_frame_future = None self._next_frame_result = None - + # CUDA stream for background processing to avoid GPU contention self._background_stream = None device_str = str(device) if device_str.startswith("cuda") and torch.cuda.is_available(): self._background_stream = torch.cuda.Stream() - - def cleanup(self) -> None: """Cleanup thread pool and CUDA stream resources""" - if hasattr(self, '_executor'): + if hasattr(self, "_executor"): self._executor.shutdown(wait=True) - + # Cleanup CUDA stream if it exists - if hasattr(self, '_background_stream') and self._background_stream is not None: + if hasattr(self, "_background_stream") and self._background_stream is not None: # Synchronize the stream before cleanup torch.cuda.synchronize() self._background_stream = None - + def __del__(self): """Cleanup on destruction""" try: self.cleanup() except: pass - + @abstractmethod def _should_use_sync_processing(self, *args, **kwargs) -> bool: """ Determine if synchronous processing should be used instead of pipelined. - + Subclasses implement domain-specific logic (e.g., feedback preprocessor detection). - + Returns: True if sync processing should be used, False for pipelined processing """ pass - + @abstractmethod def _process_frame_background(self, *args, **kwargs) -> Dict[str, Any]: """ Process a frame in the background thread. - + Subclasses implement their specific processing logic here. - + Returns: Dictionary containing processing results and status """ pass - + def process_pipelined(self, input_data: T, *args, **kwargs) -> R: """ Process input with intelligent pipelining. - + Automatically falls back to sync processing when required by domain logic, otherwise uses pipelined processing for performance. - + Args: input_data: Input data to process *args, **kwargs: Additional arguments passed to processing methods - + Returns: Processing results """ # Check if sync processing is required (domain-specific logic) if self._should_use_sync_processing(*args, **kwargs): return self.process_sync(input_data, *args, **kwargs) - + # Use pipelined processing # Wait for previous frame processing; non-blocking with short timeout self._wait_for_previous_processing() - + # Start next frame processing in background self._start_next_frame_processing(input_data, *args, **kwargs) - + # Apply current frame processing results if available; otherwise signal no update return self._apply_current_frame_processing(*args, **kwargs) - + @abstractmethod def process_sync(self, input_data: T, *args, **kwargs) -> R: """ Process input synchronously. - + Subclasses implement their specific synchronous processing logic. - + Args: input_data: Input data to process *args, **kwargs: Additional arguments passed to processing methods - + Returns: Processing results """ pass - + def _start_next_frame_processing(self, input_data: T, *args, **kwargs) -> None: """Start processing for next frame in background thread""" # Submit background processing - self._next_frame_future = self._executor.submit( - self._process_frame_background, input_data, *args, **kwargs - ) - + self._next_frame_future = self._executor.submit(self._process_frame_background, input_data, *args, **kwargs) + def _wait_for_previous_processing(self) -> None: """Wait for previous frame processing with configurable timeout""" - if hasattr(self, '_next_frame_future') and self._next_frame_future is not None: + if hasattr(self, "_next_frame_future") and self._next_frame_future is not None: try: # Use configurable timeout based on orchestrator type self._next_frame_result = self._next_frame_future.result(timeout=self.timeout_ms / 1000.0) @@ -150,52 +154,52 @@ def _wait_for_previous_processing(self) -> None: self._next_frame_result = None else: self._next_frame_result = None - + def _apply_current_frame_processing(self, processors=None, *args, **kwargs) -> R: """ Apply processing results from previous iteration. - + Default implementation provides common fallback logic for tensor-to-tensor orchestrators. Subclasses can override this method for specialized behavior. - + Args: processors: List of processors/postprocessors to apply (parameter name varies by subclass) *args, **kwargs: Additional arguments - + Returns: Processing results, or processed current input if no results available """ - if not hasattr(self, '_next_frame_result') or self._next_frame_result is None: + if not hasattr(self, "_next_frame_result") or self._next_frame_result is None: # First frame or no background results - process current input synchronously - if hasattr(self, '_current_input_tensor') and self._current_input_tensor is not None: + if hasattr(self, "_current_input_tensor") and self._current_input_tensor is not None: if processors: return self.process_sync(self._current_input_tensor, processors) else: return self._current_input_tensor - + # If we don't have current input stored, we have an issue class_name = self.__class__.__name__ logger.error(f"{class_name}: No background results and no current input tensor available") raise RuntimeError(f"{class_name}: No processing results available") - + result = self._next_frame_result - if result['status'] != 'success': + if result["status"] != "success": class_name = self.__class__.__name__ logger.warning(f"{class_name}: Background processing failed: {result.get('error', 'Unknown error')}") # Process current input synchronously on error - if hasattr(self, '_current_input_tensor') and self._current_input_tensor is not None: + if hasattr(self, "_current_input_tensor") and self._current_input_tensor is not None: if processors: return self.process_sync(self._current_input_tensor, processors) else: return self._current_input_tensor raise RuntimeError(f"{class_name}: Background processing failed and no fallback available") - - return result['result'] - + + return result["result"] + def _set_background_stream_context(self): """ Set CUDA stream context for background processing. - + Returns: The original stream to restore later, or None if no background stream """ @@ -204,11 +208,11 @@ def _set_background_stream_context(self): torch.cuda.set_stream(self._background_stream) return original_stream return None - + def _restore_stream_context(self, original_stream): """ Restore the original CUDA stream context. - + Args: original_stream: The stream to restore, or None to do nothing """ diff --git a/src/streamdiffusion/preprocessing/orchestrator_user.py b/src/streamdiffusion/preprocessing/orchestrator_user.py index 2503c14e3..e731540b9 100644 --- a/src/streamdiffusion/preprocessing/orchestrator_user.py +++ b/src/streamdiffusion/preprocessing/orchestrator_user.py @@ -2,9 +2,9 @@ from typing import Optional -from .preprocessing_orchestrator import PreprocessingOrchestrator -from .postprocessing_orchestrator import PostprocessingOrchestrator from .pipeline_preprocessing_orchestrator import PipelinePreprocessingOrchestrator +from .postprocessing_orchestrator import PostprocessingOrchestrator +from .preprocessing_orchestrator import PreprocessingOrchestrator class OrchestratorUser: @@ -20,32 +20,36 @@ class OrchestratorUser: def attach_orchestrator(self, stream) -> None: """Attach preprocessing orchestrator (backward compatibility).""" self.attach_preprocessing_orchestrator(stream) - + def attach_preprocessing_orchestrator(self, stream) -> None: """Attach shared preprocessing orchestrator from stream.""" - orchestrator = getattr(stream, 'preprocessing_orchestrator', None) + orchestrator = getattr(stream, "preprocessing_orchestrator", None) if orchestrator is None: # Lazy-create on stream once, on first user that needs it - orchestrator = PreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream) - setattr(stream, 'preprocessing_orchestrator', orchestrator) + orchestrator = PreprocessingOrchestrator( + device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream + ) + setattr(stream, "preprocessing_orchestrator", orchestrator) self._preprocessing_orchestrator = orchestrator - + def attach_postprocessing_orchestrator(self, stream) -> None: """Attach shared postprocessing orchestrator from stream.""" - orchestrator = getattr(stream, 'postprocessing_orchestrator', None) + orchestrator = getattr(stream, "postprocessing_orchestrator", None) if orchestrator is None: # Lazy-create on stream once, on first user that needs it - orchestrator = PostprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream) - setattr(stream, 'postprocessing_orchestrator', orchestrator) + orchestrator = PostprocessingOrchestrator( + device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream + ) + setattr(stream, "postprocessing_orchestrator", orchestrator) self._postprocessing_orchestrator = orchestrator - + def attach_pipeline_preprocessing_orchestrator(self, stream) -> None: """Attach shared pipeline preprocessing orchestrator from stream.""" - orchestrator = getattr(stream, 'pipeline_preprocessing_orchestrator', None) + orchestrator = getattr(stream, "pipeline_preprocessing_orchestrator", None) if orchestrator is None: # Lazy-create on stream once, on first user that needs it - orchestrator = PipelinePreprocessingOrchestrator(device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream) - setattr(stream, 'pipeline_preprocessing_orchestrator', orchestrator) + orchestrator = PipelinePreprocessingOrchestrator( + device=stream.device, dtype=stream.dtype, max_workers=4, pipeline_ref=stream + ) + setattr(stream, "pipeline_preprocessing_orchestrator", orchestrator) self._pipeline_preprocessing_orchestrator = orchestrator - - diff --git a/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py index 8cf4e7171..382e874fb 100644 --- a/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/pipeline_preprocessing_orchestrator.py @@ -1,32 +1,42 @@ -import torch -from typing import List, Dict, Any, Optional import logging +from typing import Any, Dict, List, Optional + +import torch + from .base_orchestrator import BaseOrchestrator + logger = logging.getLogger(__name__) + class PipelinePreprocessingOrchestrator(BaseOrchestrator[torch.Tensor, torch.Tensor]): """ Orchestrates pipeline input preprocessing with parallelization and pipelining. - + Handles preprocessing of input tensors before they enter the diffusion pipeline. - + Tensor ranges: - Input: Receives [-1, 1] tensors from image_processor.preprocess() - Processors: Work in [-1, 1] space when normalization_context='pipeline' - Output: Returns [-1, 1] tensors for pipeline processing - + Note: Processors created with normalization_context='pipeline' expect and preserve [-1, 1] range. No automatic conversion happens in this orchestrator. """ - - def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, pipeline_ref: Optional[Any] = None): + + def __init__( + self, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_workers: int = 4, + pipeline_ref: Optional[Any] = None, + ): # Pipeline preprocessing: 10ms timeout for responsive processing super().__init__(device, dtype, max_workers, timeout_ms=10.0, pipeline_ref=pipeline_ref) - + # Pipeline preprocessing specific state self._current_input_tensor = None # For BaseOrchestrator fallback logic - + def _should_use_sync_processing(self, *args, **kwargs) -> bool: """ Determine if synchronous processing should be used instead of pipelined. @@ -44,123 +54,102 @@ def _should_use_sync_processing(self, *args, **kwargs) -> bool: if not processors: return False for proc in processors: - if proc is not None and getattr(proc, 'requires_sync_processing', False): + if proc is not None and getattr(proc, "requires_sync_processing", False): return True return False - - def process_pipelined(self, - input_tensor: torch.Tensor, - processors: List[Any], - *args, **kwargs) -> torch.Tensor: + + def process_pipelined(self, input_tensor: torch.Tensor, processors: List[Any], *args, **kwargs) -> torch.Tensor: """ Process input with intelligent pipelining. - + Overrides base method to store current input tensor for fallback logic. """ # Store current input for fallback logic self._current_input_tensor = input_tensor - + # RACE CONDITION FIX: Check if there are actually enabled processors # Filter to only enabled processors (same logic as _get_ordered_processors) - enabled_processors = [p for p in processors if getattr(p, 'enabled', True)] if processors else [] - + enabled_processors = [p for p in processors if getattr(p, "enabled", True)] if processors else [] + if not enabled_processors: return input_tensor - + # Call parent implementation return super().process_pipelined(input_tensor, processors, *args, **kwargs) - - def process_sync(self, - input_tensor: torch.Tensor, - processors: List[Any]) -> torch.Tensor: + + def process_sync(self, input_tensor: torch.Tensor, processors: List[Any]) -> torch.Tensor: """ Process pipeline input tensor synchronously through preprocessors. - + Implementation of BaseOrchestrator.process_sync for pipeline preprocessing. - + Args: input_tensor: Input tensor to preprocess (already normalized) processors: List of preprocessor instances - + Returns: Preprocessed tensor ready for pipeline processing """ if not processors: return input_tensor - + # Sequential application of processors current_tensor = input_tensor for processor in processors: if processor is not None: current_tensor = self._apply_single_processor(current_tensor, processor) - + return current_tensor - - def _process_frame_background(self, - input_tensor: torch.Tensor, - processors: List[Any]) -> Dict[str, Any]: + + def _process_frame_background(self, input_tensor: torch.Tensor, processors: List[Any]) -> Dict[str, Any]: """ Process a frame in the background thread. - + Implementation of BaseOrchestrator._process_frame_background for pipeline preprocessing. - + Returns: Dictionary containing processing results and status """ try: # Set CUDA stream for background processing original_stream = self._set_background_stream_context() - + if not processors: - return { - 'result': input_tensor, - 'status': 'success' - } - + return {"result": input_tensor, "status": "success"} + # Process processors sequentially (most pipeline preprocessing is dependent) current_tensor = input_tensor for processor in processors: if processor is not None: current_tensor = self._apply_single_processor(current_tensor, processor) - - return { - 'result': current_tensor, - 'status': 'success' - } - + + return {"result": current_tensor, "status": "success"} + except Exception as e: logger.error(f"PipelinePreprocessingOrchestrator: Background processing failed: {e}") # Return original input tensor on error - return { - 'result': input_tensor, - 'error': str(e), - 'status': 'error' - } + return {"result": input_tensor, "error": str(e), "status": "error"} finally: # Restore original CUDA stream self._restore_stream_context(original_stream) - - - - def _apply_single_processor(self, - input_tensor: torch.Tensor, - processor: Any) -> torch.Tensor: + + def _apply_single_processor(self, input_tensor: torch.Tensor, processor: Any) -> torch.Tensor: """ Apply a single processor to the input tensor. - + Args: input_tensor: Input tensor to process processor: Processor instance - + Returns: Processed tensor """ try: # Apply processor - if hasattr(processor, 'process_tensor'): + if hasattr(processor, "process_tensor"): # Prefer tensor processing method result = processor.process_tensor(input_tensor) - elif hasattr(processor, 'process'): + elif hasattr(processor, "process"): # Use general process method result = processor.process(input_tensor) elif callable(processor): @@ -169,18 +158,18 @@ def _apply_single_processor(self, else: logger.warning(f"PipelinePreprocessingOrchestrator: Unknown processor type: {type(processor)}") return input_tensor - + # Ensure result is a tensor if isinstance(result, torch.Tensor): return result else: logger.warning(f"PipelinePreprocessingOrchestrator: Processor returned non-tensor: {type(result)}") return input_tensor - + except Exception as e: logger.error(f"PipelinePreprocessingOrchestrator: Processor failed: {e}") return input_tensor # Return original on error - + def clear_cache(self) -> None: """Clear preprocessing cache""" pass diff --git a/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py index 742a80e60..ef5dceb32 100644 --- a/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/postprocessing_orchestrator.py @@ -73,7 +73,7 @@ def _should_use_sync_processing(self, *args, **kwargs) -> bool: if not processors: return False for proc in processors: - if proc is not None and getattr(proc, 'requires_sync_processing', False): + if proc is not None and getattr(proc, "requires_sync_processing", False): return True return False diff --git a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py index cf05fbfaa..fd554ac6f 100644 --- a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py @@ -1,12 +1,15 @@ -import torch -from typing import List, Optional, Union, Dict, Any, Tuple -from PIL import Image -import numpy as np import logging -from diffusers.utils import load_image +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch import torchvision.transforms as transforms +from diffusers.utils import load_image +from PIL import Image + from .base_orchestrator import BaseOrchestrator + logger = logging.getLogger(__name__) # Type alias for control image input @@ -16,40 +19,45 @@ class PreprocessingOrchestrator(BaseOrchestrator[ControlImage, List[Optional[torch.Tensor]]]): """ Orchestrates module preprocessing with typical orchestrator pipelining, but with additional intraframe parallelization, caching, and optimization. - Modules (IPAdapter, Controlnet) share intraframe parallelism. + Modules (IPAdapter, Controlnet) share intraframe parallelism. Handles image format conversion (while most are GPU native,some preprocessors are CPU only), preprocessor execution, and result caching. """ - - def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16, max_workers: int = 4, pipeline_ref: Optional[Any] = None): + + def __init__( + self, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + max_workers: int = 4, + pipeline_ref: Optional[Any] = None, + ): # Preprocessing: 10ms timeout for fast frame-skipping behavior super().__init__(device, dtype, max_workers, timeout_ms=10.0, pipeline_ref=pipeline_ref) - + # Caching self._preprocessed_cache: Dict[str, torch.Tensor] = {} self._last_input_frame = None - + # Optimized transforms self._cached_transform = transforms.ToTensor() - + # Cache pipelining decision to avoid hot path checks self._preprocessors_cache_key = None self._has_feedback_cache = False - - - - - #Abstract method implementations - def process_sync(self, - control_image: ControlImage, - preprocessors: List[Optional[Any]], - scales: List[float] = None, - stream_width: int = None, - stream_height: int = None, - index: Optional[int] = None, - processing_type: str = "controlnet") -> Union[List[Optional[torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor]]]: + + # Abstract method implementations + def process_sync( + self, + control_image: ControlImage, + preprocessors: List[Optional[Any]], + scales: List[float] = None, + stream_width: int = None, + stream_height: int = None, + index: Optional[int] = None, + processing_type: str = "controlnet", + ) -> Union[List[Optional[torch.Tensor]], List[Tuple[torch.Tensor, torch.Tensor]]]: """ Process images synchronously for ControlNet or IPAdapter preprocessing. - + Args: control_image: Input image to process preprocessors: List of preprocessor instances @@ -58,7 +66,7 @@ def process_sync(self, stream_height: Target height for processing index: If specified, only process this single ControlNet index (ControlNet only) processing_type: "controlnet" or "ipadapter" to specify processing mode - + Returns: ControlNet: List of processed tensors for each ControlNet IPAdapter: List of (positive_embeds, negative_embeds) tuples @@ -79,410 +87,390 @@ def process_sync(self, ) else: raise ValueError(f"Invalid processing_type: {processing_type}. Must be 'controlnet' or 'ipadapter'") - + def _should_use_sync_processing(self, *args, **kwargs) -> bool: """ Check for pipeline-aware preprocessors that require sync processing. - - Pipeline-aware preprocessors (feedback, temporal, etc.) need synchronous processing + + Pipeline-aware preprocessors (feedback, temporal, etc.) need synchronous processing to avoid temporal artifacts and ensure access to previous pipeline outputs. - + Args: *args: Arguments from process_pipelined call (preprocessors, scales, stream_width, stream_height) **kwargs: Keyword arguments - + Returns: True if pipeline-aware preprocessors detected, False otherwise """ # Extract preprocessors from args - they're the first argument after control_image if len(args) < 1: return False - + preprocessors = args[0] # preprocessors is first arg after control_image return self._check_pipeline_aware_cached(preprocessors) - def _process_frame_background(self, - control_image: ControlImage, - *args, **kwargs) -> Dict[str, Any]: + def _process_frame_background(self, control_image: ControlImage, *args, **kwargs) -> Dict[str, Any]: """ Process a frame in the background thread. - + Implementation of BaseOrchestrator._process_frame_background for ControlNet preprocessing. Automatically detects processing mode based on current state. - + Returns: Dictionary containing processing results and status """ try: # Set CUDA stream for background processing original_stream = self._set_background_stream_context() - + # Check if last argument is "ipadapter" processing type if args and len(args) >= 5 and args[4] == "ipadapter": # Handle embedding preprocessing embedding_preprocessors = args[0] - stream_width = args[2] + stream_width = args[2] stream_height = args[3] - + # Prepare processing data control_variants = self._prepare_input_variants(control_image, thread_safe=True) - + # Process using existing IPAdapter logic try: results = self._process_ipadapter_preprocessors_parallel( embedding_preprocessors, control_variants, stream_width, stream_height ) - return { - 'results': results, - 'status': 'success' - } + return {"results": results, "status": "success"} except Exception as e: import traceback + traceback.print_exc() - return { - 'error': str(e), - 'status': 'error' - } - elif hasattr(self, '_current_processing_mode') and self._current_processing_mode == "embedding": + return {"error": str(e), "status": "error"} + elif hasattr(self, "_current_processing_mode") and self._current_processing_mode == "embedding": # Handle embedding preprocessing (legacy path) embedding_preprocessors = args[0] - stream_width = args[2] + stream_width = args[2] stream_height = args[3] - + # Prepare processing data control_variants = self._prepare_input_variants(control_image, thread_safe=True) - + # Process using existing IPAdapter logic try: results = self._process_ipadapter_preprocessors_parallel( embedding_preprocessors, control_variants, stream_width, stream_height ) - return { - 'results': results, - 'status': 'success' - } + return {"results": results, "status": "success"} except Exception as e: import traceback + traceback.print_exc() - return { - 'error': str(e), - 'status': 'error' - } + return {"error": str(e), "status": "error"} else: # Handle ControlNet preprocessing (default mode) preprocessors = args[0] scales = args[1] stream_width = args[2] stream_height = args[3] - + # Check if any processing is needed if not any(scale > 0 for scale in scales): - return {'status': 'success', 'results': [None] * len(preprocessors)} - #TODO: can we reuse similarity filter here? - if (self._last_input_frame is not None and - isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) and - control_image is self._last_input_frame): - return {'status': 'success', 'results': []} # Signal no update needed - + return {"status": "success", "results": [None] * len(preprocessors)} + # TODO: can we reuse similarity filter here? + if ( + self._last_input_frame is not None + and isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) + and control_image is self._last_input_frame + ): + return {"status": "success", "results": []} # Signal no update needed + self._last_input_frame = control_image - + # Prepare processing data preprocessor_groups = self._group_preprocessors(preprocessors, scales) active_indices = [i for i, scale in enumerate(scales) if scale > 0] - + if not active_indices: - return {'status': 'success', 'results': [None] * len(preprocessors)} - + return {"status": "success", "results": [None] * len(preprocessors)} + # Optimize input preparation control_variants = self._prepare_input_variants(control_image, thread_safe=True) - + # Process using unified parallel logic processed_images = self._process_controlnet_preprocessors_parallel( preprocessor_groups, control_variants, stream_width, stream_height, preprocessors ) - - return { - 'results': processed_images, - 'status': 'success' - } - + + return {"results": processed_images, "status": "success"} + except Exception as e: logger.error(f"PreprocessingOrchestrator: Background processing failed: {e}") - return { - 'error': str(e), - 'status': 'error' - } + return {"error": str(e), "status": "error"} finally: # Restore original CUDA stream self._restore_stream_context(original_stream) - - def _apply_current_frame_processing(self, - preprocessors: List[Optional[Any]] = None, - scales: List[float] = None, - *args, **kwargs) -> List[Optional[torch.Tensor]]: + + def _apply_current_frame_processing( + self, preprocessors: List[Optional[Any]] = None, scales: List[float] = None, *args, **kwargs + ) -> List[Optional[torch.Tensor]]: """ Apply processing results from previous iteration. - + Overrides BaseOrchestrator._apply_current_frame_processing for module preprocessing. - + Returns: List of processed tensors, or empty list to signal no update needed """ - if not hasattr(self, '_next_frame_result') or self._next_frame_result is None: + if not hasattr(self, "_next_frame_result") or self._next_frame_result is None: # Return empty list to signal no update needed return [] - + # Handle case where preprocessors is None if preprocessors is None: return [] - + processed_images = [None] * len(preprocessors) - + result = self._next_frame_result - if result['status'] != 'success': + if result["status"] != "success": # Return empty list to signal no update needed on error return [] - + # Handle case where no update is needed (cached input) - if 'results' in result and len(result['results']) == 0: + if "results" in result and len(result["results"]) == 0: return [] - + # Get the processed results directly - processed_images = result.get('results', []) + processed_images = result.get("results", []) if not processed_images: return [] - + return processed_images - - #Controlnet methods - def prepare_control_image(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessor: Optional[Any], - target_width: int, - target_height: int) -> torch.Tensor: + + # Controlnet methods + def prepare_control_image( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + preprocessor: Optional[Any], + target_width: int, + target_height: int, + ) -> torch.Tensor: """ Prepare a single control image for ControlNet input with format conversion and preprocessing. - + Args: control_image: Input image in various formats preprocessor: Optional preprocessor to apply target_width: Target width for the output tensor target_height: Target height for the output tensor - + Returns: Processed tensor ready for ControlNet """ # Load image if path if isinstance(control_image, str): control_image = load_image(control_image) - + # Fast tensor processing path if isinstance(control_image, torch.Tensor): return self._process_tensor_input(control_image, preprocessor, target_width, target_height) - + # Apply preprocessor to non-tensor inputs if preprocessor is not None: control_image = preprocessor.process(control_image) - + # Convert to tensor return self._convert_to_tensor(control_image, target_width, target_height) - - def _process_multiple_controlnets_sync(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessors: List[Optional[Any]], - scales: List[float], - stream_width: int, - stream_height: int) -> List[Optional[torch.Tensor]]: + + def _process_multiple_controlnets_sync( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + preprocessors: List[Optional[Any]], + scales: List[float], + stream_width: int, + stream_height: int, + ) -> List[Optional[torch.Tensor]]: """Process multiple ControlNets synchronously with parallel execution""" # Check if any processing is needed if not any(scale > 0 for scale in scales): return [None] * len(preprocessors) - - #TODO: can we reuse similarity filter here? + + # TODO: can we reuse similarity filter here? # Check cache for same input - return early without changing anything - if (self._last_input_frame is not None and - isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) and - control_image is self._last_input_frame): + if ( + self._last_input_frame is not None + and isinstance(control_image, (torch.Tensor, np.ndarray, Image.Image)) + and control_image is self._last_input_frame + ): # Return empty list to signal no update needed return [] - + self._last_input_frame = control_image self.clear_cache() - + # Prepare input variants for optimal processing control_variants = self._prepare_input_variants(control_image, stream_width, stream_height) - + # Group preprocessors to avoid duplicate work preprocessor_groups = self._group_preprocessors(preprocessors, scales) - + if not preprocessor_groups: return [None] * len(preprocessors) - + # Process groups using parallel logic (efficient for 1 or many items) return self._process_controlnet_preprocessors_parallel( preprocessor_groups, control_variants, stream_width, stream_height, preprocessors ) - - def _process_single_controlnet(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - preprocessors: List[Optional[Any]], - scales: List[float], - stream_width: int, - stream_height: int, - index: int) -> List[Optional[torch.Tensor]]: + + def _process_single_controlnet( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + preprocessors: List[Optional[Any]], + scales: List[float], + stream_width: int, + stream_height: int, + index: int, + ) -> List[Optional[torch.Tensor]]: """Process a single ControlNet by index""" if not (0 <= index < len(preprocessors)): raise IndexError(f"ControlNet index {index} out of range") - + if scales[index] == 0: return [None] * len(preprocessors) - + processed_images = [None] * len(preprocessors) - processed_image = self.prepare_control_image( - control_image, preprocessors[index], stream_width, stream_height - ) + processed_image = self.prepare_control_image(control_image, preprocessors[index], stream_width, stream_height) processed_images[index] = processed_image - + return processed_images - - def _process_controlnet_preprocessors_parallel(self, - preprocessor_groups: Dict[str, Dict[str, Any]], - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int, - preprocessors: List[Optional[Any]]) -> List[Optional[torch.Tensor]]: + + def _process_controlnet_preprocessors_parallel( + self, + preprocessor_groups: Dict[str, Dict[str, Any]], + control_variants: Dict[str, Any], + stream_width: int, + stream_height: int, + preprocessors: List[Optional[Any]], + ) -> List[Optional[torch.Tensor]]: """Process ControlNet preprocessor groups in parallel""" futures = [ self._executor.submit( - self._process_single_preprocessor_group, - prep_key, group, control_variants, stream_width, stream_height + self._process_single_preprocessor_group, prep_key, group, control_variants, stream_width, stream_height ) for prep_key, group in preprocessor_groups.items() ] - + processed_images = [None] * len(preprocessors) - + for future in futures: result = future.result() - if result and result['processed_image'] is not None: - prep_key = result['prep_key'] - processed_image = result['processed_image'] - indices = result['indices'] - + if result and result["processed_image"] is not None: + prep_key = result["prep_key"] + processed_image = result["processed_image"] + indices = result["indices"] + # Cache and assign cache_key = f"prep_{prep_key}" self._preprocessed_cache[cache_key] = processed_image for index in indices: processed_images[index] = processed_image - + return processed_images - - #IPAdapter methods - def _process_multiple_ipadapters_sync(self, - control_image: ControlImage, - preprocessors: List[Optional[Any]], - stream_width: int, - stream_height: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: + + # IPAdapter methods + def _process_multiple_ipadapters_sync( + self, control_image: ControlImage, preprocessors: List[Optional[Any]], stream_width: int, stream_height: int + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """ Process IPAdapter preprocessors synchronously. - + This is the implementation that was previously in process_ipadapter_preprocessors(). """ if not preprocessors: return [] - + # For IPAdapter preprocessing, we don't skip on cache hits - we need the actual embeddings # (Unlike spatial preprocessing where empty list means "no update needed") - + # Prepare input variants for processing control_variants = self._prepare_input_variants(control_image, stream_width, stream_height) - + # Process using parallel logic (efficient for 1 or many items) results = self._process_ipadapter_preprocessors_parallel( preprocessors, control_variants, stream_width, stream_height ) - + return results - - def _process_ipadapter_preprocessors_parallel(self, - ipadapter_preprocessors: List[Any], - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: + + def _process_ipadapter_preprocessors_parallel( + self, + ipadapter_preprocessors: List[Any], + control_variants: Dict[str, Any], + stream_width: int, + stream_height: int, + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Process multiple IPAdapter preprocessors in parallel""" futures = [ self._executor.submit( - self._process_single_ipadapter, - i, preprocessor, control_variants, stream_width, stream_height + self._process_single_ipadapter, i, preprocessor, control_variants, stream_width, stream_height ) for i, preprocessor in enumerate(ipadapter_preprocessors) ] - + results = [None] * len(ipadapter_preprocessors) - + for future in futures: result = future.result() - if result and result['embeddings'] is not None: - index = result['index'] - embeddings = result['embeddings'] + if result and result["embeddings"] is not None: + index = result["index"] + embeddings = result["embeddings"] results[index] = embeddings - + return results - - def _process_single_ipadapter(self, - index: int, - preprocessor: Any, - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int) -> Optional[Dict[str, Any]]: + + def _process_single_ipadapter( + self, index: int, preprocessor: Any, control_variants: Dict[str, Any], stream_width: int, stream_height: int + ) -> Optional[Dict[str, Any]]: """Process a single IPAdapter preprocessor""" try: # Use tensor processing if available and input is tensor - if (hasattr(preprocessor, 'process_tensor') and - control_variants['tensor'] is not None): - embeddings = preprocessor.process_tensor(control_variants['tensor']) - return { - 'index': index, - 'embeddings': embeddings - } - + if hasattr(preprocessor, "process_tensor") and control_variants["tensor"] is not None: + embeddings = preprocessor.process_tensor(control_variants["tensor"]) + return {"index": index, "embeddings": embeddings} + # Use PIL processing for non-tensor inputs - if control_variants['image'] is not None: - embeddings = preprocessor.process(control_variants['image']) - return { - 'index': index, - 'embeddings': embeddings - } - + if control_variants["image"] is not None: + embeddings = preprocessor.process(control_variants["image"]) + return {"index": index, "embeddings": embeddings} + return None - - except Exception as e: + + except Exception: import traceback + traceback.print_exc() return None - #Helper methods + # Helper methods def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bool: """ Efficiently check for pipeline-aware preprocessors using caching - + Only performs expensive isinstance checks when preprocessor list actually changes. """ # Create cache key from preprocessor identities cache_key = tuple(id(p) for p in preprocessors) - + # Return cached result if preprocessors haven't changed if cache_key == self._preprocessors_cache_key: return self._has_feedback_cache # Reuse cache variable for backward compatibility - + # Preprocessors changed - recompute and cache self._preprocessors_cache_key = cache_key self._has_feedback_cache = False - + try: # Check for the mixin or class attribute first for prep in preprocessors: - if prep is not None and getattr(prep, 'requires_sync_processing', False): + if prep is not None and getattr(prep, "requires_sync_processing", False): self._has_feedback_cache = True break except Exception: @@ -490,6 +478,7 @@ def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bo try: from .processors.feedback import FeedbackPreprocessor from .processors.temporal_net import TemporalNetPreprocessor + for prep in preprocessors: if isinstance(prep, (FeedbackPreprocessor, TemporalNetPreprocessor)): self._has_feedback_cache = True @@ -499,44 +488,43 @@ def _check_pipeline_aware_cached(self, preprocessors: List[Optional[Any]]) -> bo for prep in preprocessors: if prep is not None: class_name = prep.__class__.__name__.lower() - if any(name in class_name for name in ['feedback', 'temporal']): + if any(name in class_name for name in ["feedback", "temporal"]): self._has_feedback_cache = True break - + return self._has_feedback_cache def clear_cache(self) -> None: """Clear preprocessing cache""" self._preprocessed_cache.clear() self._last_input_frame = None - + # ========================================================================= # Pipeline Chain Processing Methods (For Hook System Compatibility) # ========================================================================= - - def execute_pipeline_chain(self, - input_data: torch.Tensor, - processors: List[Any], - processing_domain: str = "image") -> torch.Tensor: + + def execute_pipeline_chain( + self, input_data: torch.Tensor, processors: List[Any], processing_domain: str = "image" + ) -> torch.Tensor: """Execute ordered sequential chain of processors for pipeline hooks. - + This method provides compatibility with the hook system modules that expect sequential processor execution rather than pipelined processing. - + Args: input_data: Input tensor (image or latent domain) processors: List of processor instances to execute in sequence processing_domain: "image" or "latent" to determine processing path - + Returns: Processed tensor in same domain as input """ if not processors: return input_data - + result = input_data ordered_processors = self._order_processors(processors) - + for processor in ordered_processors: try: if processing_domain == "image": @@ -549,58 +537,58 @@ def execute_pipeline_chain(self, logger.error(f"execute_pipeline_chain: Processor {type(processor).__name__} failed: {e}") # Continue with next processor rather than failing entire chain continue - + return result - + def _order_processors(self, processors: List[Any]) -> List[Any]: """Order processors based on their configuration. - + Processors can define an 'order' attribute to control execution sequence. """ - return sorted(processors, key=lambda p: getattr(p, 'order', 0)) - + return sorted(processors, key=lambda p: getattr(p, "order", 0)) + def _process_image_processor_chain(self, image_tensor: torch.Tensor, processor: Any) -> torch.Tensor: """Process single image processor in chain, handling tensor<->PIL conversion. - + Leverages existing format conversion and processing logic. """ # Convert tensor to PIL for processor (reuse existing conversion logic) try: # Use existing tensor to PIL conversion from prepare_control_image logic pil_image = self._tensor_to_pil_safe(image_tensor) - + # Process using existing processor execution pattern - if hasattr(processor, 'process'): + if hasattr(processor, "process"): processed_pil = processor.process(pil_image) else: processed_pil = processor(pil_image) - + # Convert back to tensor (reuse existing PIL to tensor logic) result_tensor = self._pil_to_tensor_safe(processed_pil, image_tensor.device, image_tensor.dtype) return result_tensor - + except Exception as e: logger.error(f"_process_image_processor_chain: Failed processing {type(processor).__name__}: {e}") return image_tensor # Return input unchanged on failure - + def _process_latent_processor_chain(self, latent_tensor: torch.Tensor, processor: Any) -> torch.Tensor: """Process single latent processor in chain. - + Direct tensor processing - no format conversion needed for latent domain. """ try: # Latent processors work directly on tensors - if hasattr(processor, 'process_tensor'): + if hasattr(processor, "process_tensor"): return processor.process_tensor(latent_tensor) - elif hasattr(processor, 'process'): + elif hasattr(processor, "process"): return processor.process(latent_tensor) else: return processor(latent_tensor) - + except Exception as e: logger.error(f"_process_latent_processor_chain: Failed processing {type(processor).__name__}: {e}") return latent_tensor # Return input unchanged on failure - + def _tensor_to_pil_safe(self, tensor: torch.Tensor) -> Image.Image: """Convert tensor to PIL Image safely (reuse existing conversion logic).""" # Leverage existing tensor conversion from prepare_control_image @@ -609,50 +597,48 @@ def _tensor_to_pil_safe(self, tensor: torch.Tensor) -> Image.Image: if tensor.dim() == 3 and tensor.shape[0] == 3: # Convert from CHW to HWC tensor = tensor.permute(1, 2, 0) - + # CRITICAL FIX: Handle VAE output range [-1, 1] -> [0, 1] -> [0, 255] # VAE decode_image() outputs in [-1, 1] range, need to convert to [0, 1] first if tensor.min() < 0: - logger.debug(f"_tensor_to_pil_safe: Converting from VAE range [-1, 1] to [0, 1]") + logger.debug("_tensor_to_pil_safe: Converting from VAE range [-1, 1] to [0, 1]") tensor = (tensor / 2.0 + 0.5).clamp(0, 1) # Convert [-1, 1] -> [0, 1] - + # Ensure proper range [0, 1] -> [0, 255] if tensor.max() <= 1.0: tensor = tensor * 255.0 - + # Convert to numpy and then PIL numpy_image = tensor.detach().cpu().numpy().astype(np.uint8) return Image.fromarray(numpy_image) - + def _pil_to_tensor_safe(self, pil_image: Image.Image, device: str, dtype: torch.dtype) -> torch.Tensor: """Convert PIL Image to tensor safely (reuse existing conversion logic).""" # Convert PIL to numpy numpy_image = np.array(pil_image) - + # Convert to tensor and normalize to [0, 1] tensor = torch.from_numpy(numpy_image).float() / 255.0 - + # Convert HWC to CHW if tensor.dim() == 3: tensor = tensor.permute(2, 0, 1) - + # Add batch dimension and move to device tensor = tensor.unsqueeze(0).to(device=device, dtype=dtype) - + # CRITICAL: Convert back to VAE input range [-1, 1] for postprocessing # VAE expects inputs in [-1, 1] range, so convert [0, 1] -> [-1, 1] tensor = (tensor - 0.5) * 2.0 # Convert [0, 1] -> [-1, 1] - + return tensor - - def _process_tensor_input(self, - control_tensor: torch.Tensor, - preprocessor: Optional[Any], - target_width: int, - target_height: int) -> torch.Tensor: + + def _process_tensor_input( + self, control_tensor: torch.Tensor, preprocessor: Optional[Any], target_width: int, target_height: int + ) -> torch.Tensor: """Process tensor input with GPU acceleration when possible""" # Fast path for tensor input with GPU preprocessor - if preprocessor is not None and hasattr(preprocessor, 'process_tensor'): + if preprocessor is not None and hasattr(preprocessor, "process_tensor"): try: processed_tensor = preprocessor.process_tensor(control_tensor) # Ensure NCHW shape @@ -661,155 +647,139 @@ def _process_tensor_input(self, return processed_tensor.to(device=self.device, dtype=self.dtype) except Exception: pass # Fall through to standard processing - + # Direct tensor passthrough (no preprocessor) - preprocessors handle their own sizing if preprocessor is None: # For passthrough, we still need basic format handling if control_tensor.dim() == 3: control_tensor = control_tensor.unsqueeze(0) return control_tensor.to(device=self.device, dtype=self.dtype) - + # Convert to PIL for preprocessor, then back to tensor if control_tensor.dim() == 4: control_tensor = control_tensor[0] if control_tensor.dim() == 3 and control_tensor.shape[0] in [1, 3]: control_tensor = control_tensor.permute(1, 2, 0) - + if control_tensor.is_cuda: control_tensor = control_tensor.cpu() - + control_array = control_tensor.numpy() if control_array.max() <= 1.0: control_array = (control_array * 255).astype(np.uint8) - + control_image = Image.fromarray(control_array.astype(np.uint8)) return self.prepare_control_image(control_image, preprocessor, target_width, target_height) - - def _convert_to_tensor(self, - control_image: Union[Image.Image, np.ndarray], - target_width: int, - target_height: int) -> torch.Tensor: + + def _convert_to_tensor( + self, control_image: Union[Image.Image, np.ndarray], target_width: int, target_height: int + ) -> torch.Tensor: """Convert PIL Image or numpy array to tensor - preprocessors handle their own sizing""" # Handle PIL Images - no resizing here, preprocessors handle their target size if isinstance(control_image, Image.Image): control_tensor = self._cached_transform(control_image).unsqueeze(0) return control_tensor.to(device=self.device, dtype=self.dtype) - + # Handle numpy arrays if isinstance(control_image, np.ndarray): if control_image.max() <= 1.0: control_image = (control_image * 255).astype(np.uint8) control_image = Image.fromarray(control_image) return self._convert_to_tensor(control_image, target_width, target_height) - + raise ValueError(f"Unsupported control image type: {type(control_image)}") - + def _to_tensor_safe(self, image: Image.Image) -> torch.Tensor: """Thread-safe tensor conversion from PIL Image""" return self._cached_transform(image).unsqueeze(0).to(device=self.device, dtype=self.dtype) - - def _prepare_input_variants(self, - control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], - stream_width: int = None, - stream_height: int = None, - thread_safe: bool = False) -> Dict[str, Any]: + + def _prepare_input_variants( + self, + control_image: Union[str, Image.Image, np.ndarray, torch.Tensor], + stream_width: int = None, + stream_height: int = None, + thread_safe: bool = False, + ) -> Dict[str, Any]: """Prepare optimized input variants for different processing paths - + Args: control_image: Input image in various formats stream_width: Target width (unused, kept for backward compatibility) stream_height: Target height (unused, kept for backward compatibility) thread_safe: If True, use thread-safe key naming for background processing - + Returns: Dictionary with 'tensor' and 'image'/'image_safe' keys """ - image_key = 'image_safe' if thread_safe else 'image' - + image_key = "image_safe" if thread_safe else "image" + if isinstance(control_image, torch.Tensor): return { - 'tensor': control_image, - image_key: None # Will create if needed + "tensor": control_image, + image_key: None, # Will create if needed } elif isinstance(control_image, Image.Image): image_copy = control_image.copy() - return { - image_key: image_copy, - 'tensor': self._to_tensor_safe(image_copy) - } + return {image_key: image_copy, "tensor": self._to_tensor_safe(image_copy)} elif isinstance(control_image, str): image_loaded = load_image(control_image) - return { - image_key: image_loaded, - 'tensor': self._to_tensor_safe(image_loaded) - } + return {image_key: image_loaded, "tensor": self._to_tensor_safe(image_loaded)} else: - return { - image_key: control_image, - 'tensor': None - } - - def _group_preprocessors(self, - preprocessors: List[Optional[Any]], - scales: List[float]) -> Dict[str, Dict[str, Any]]: + return {image_key: control_image, "tensor": None} + + def _group_preprocessors( + self, preprocessors: List[Optional[Any]], scales: List[float] + ) -> Dict[str, Dict[str, Any]]: """Group preprocessors by type to avoid duplicate processing""" preprocessor_groups = {} - + for i, scale in enumerate(scales): if scale > 0: preprocessor = preprocessors[i] - preprocessor_key = id(preprocessor) if preprocessor is not None else 'passthrough' - + preprocessor_key = id(preprocessor) if preprocessor is not None else "passthrough" + if preprocessor_key not in preprocessor_groups: - preprocessor_groups[preprocessor_key] = { - 'preprocessor': preprocessor, - 'indices': [] - } - preprocessor_groups[preprocessor_key]['indices'].append(i) - + preprocessor_groups[preprocessor_key] = {"preprocessor": preprocessor, "indices": []} + preprocessor_groups[preprocessor_key]["indices"].append(i) + return preprocessor_groups - def _process_single_preprocessor_group(self, - prep_key: str, - group: Dict[str, Any], - control_variants: Dict[str, Any], - stream_width: int, - stream_height: int) -> Optional[Dict[str, Any]]: + def _process_single_preprocessor_group( + self, + prep_key: str, + group: Dict[str, Any], + control_variants: Dict[str, Any], + stream_width: int, + stream_height: int, + ) -> Optional[Dict[str, Any]]: """Process a single preprocessor group with optimal input selection""" try: - preprocessor = group['preprocessor'] - indices = group['indices'] - + preprocessor = group["preprocessor"] + indices = group["indices"] + # Try tensor processing first (fastest path) - if (preprocessor is not None and - hasattr(preprocessor, 'process_tensor') and - control_variants['tensor'] is not None): + if ( + preprocessor is not None + and hasattr(preprocessor, "process_tensor") + and control_variants["tensor"] is not None + ): try: processed_image = self.prepare_control_image( - control_variants['tensor'], preprocessor, stream_width, stream_height + control_variants["tensor"], preprocessor, stream_width, stream_height ) - return { - 'prep_key': prep_key, - 'indices': indices, - 'processed_image': processed_image - } + return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image} except Exception: pass # Fall through to PIL processing - + # PIL processing fallback - if control_variants['image'] is not None: + if control_variants["image"] is not None: processed_image = self.prepare_control_image( - control_variants['image'], preprocessor, stream_width, stream_height + control_variants["image"], preprocessor, stream_width, stream_height ) - return { - 'prep_key': prep_key, - 'indices': indices, - 'processed_image': processed_image - } - + return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image} + return None - + except Exception as e: logger.error(f"PreprocessingOrchestrator: Preprocessor {prep_key} failed: {e}") return None - diff --git a/src/streamdiffusion/preprocessing/processors/__init__.py b/src/streamdiffusion/preprocessing/processors/__init__.py index 5674ab4af..26f00e732 100644 --- a/src/streamdiffusion/preprocessing/processors/__init__.py +++ b/src/streamdiffusion/preprocessing/processors/__init__.py @@ -1,26 +1,29 @@ -from .base import BasePreprocessor, PipelineAwareProcessor from typing import Any + +from .base import BasePreprocessor, PipelineAwareProcessor +from .blur import BlurPreprocessor from .canny import CannyPreprocessor from .depth import DepthPreprocessor -from .openpose import OpenPosePreprocessor -from .lineart import LineartPreprocessor -from .standard_lineart import StandardLineartPreprocessor -from .passthrough import PassthroughPreprocessor from .external import ExternalPreprocessor -from .soft_edge import SoftEdgePreprocessor -from .hed import HEDPreprocessor -from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor from .faceid_embedding import FaceIDEmbeddingPreprocessor from .feedback import FeedbackPreprocessor +from .hed import HEDPreprocessor +from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor from .latent_feedback import LatentFeedbackPreprocessor +from .lineart import LineartPreprocessor +from .openpose import OpenPosePreprocessor +from .passthrough import PassthroughPreprocessor +from .realesrgan_trt import RealESRGANProcessor from .sharpen import SharpenPreprocessor +from .soft_edge import SoftEdgePreprocessor +from .standard_lineart import StandardLineartPreprocessor from .upscale import UpscalePreprocessor -from .blur import BlurPreprocessor -from .realesrgan_trt import RealESRGANProcessor + # Try to import TensorRT preprocessors - might not be available on all systems try: from .depth_tensorrt import DepthAnythingTensorrtPreprocessor + DEPTH_TENSORRT_AVAILABLE = True except ImportError: DepthAnythingTensorrtPreprocessor = None @@ -28,6 +31,7 @@ try: from .pose_tensorrt import YoloNasPoseTensorrtPreprocessor + POSE_TENSORRT_AVAILABLE = True except ImportError: YoloNasPoseTensorrtPreprocessor = None @@ -35,6 +39,7 @@ try: from .temporal_net_tensorrt import TemporalNetTensorRTPreprocessor + TEMPORAL_NET_TENSORRT_AVAILABLE = True except ImportError: TemporalNetTensorRTPreprocessor = None @@ -42,6 +47,7 @@ try: from .mediapipe_pose import MediaPipePosePreprocessor + MEDIAPIPE_POSE_AVAILABLE = True except ImportError: MediaPipePosePreprocessor = None @@ -49,6 +55,7 @@ try: from .mediapipe_segmentation import MediaPipeSegmentationPreprocessor + MEDIAPIPE_SEGMENTATION_AVAILABLE = True except ImportError: MediaPipeSegmentationPreprocessor = None @@ -71,7 +78,7 @@ "upscale": UpscalePreprocessor, "blur": BlurPreprocessor, "realesrgan_trt": RealESRGANProcessor, -} +} # Add TensorRT preprocessors if available if DEPTH_TENSORRT_AVAILABLE: @@ -94,27 +101,29 @@ def get_preprocessor_class(name: str) -> type: """ Get a preprocessor class by name - + Args: name: Name of the preprocessor - + Returns: Preprocessor class - + Raises: ValueError: If preprocessor name is not found """ if name not in _preprocessor_registry: available = ", ".join(_preprocessor_registry.keys()) raise ValueError(f"Unknown preprocessor '{name}'. Available: {available}") - + return _preprocessor_registry[name] -def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: str = 'controlnet', params: Any = None) -> BasePreprocessor: +def get_preprocessor( + name: str, pipeline_ref: Any = None, normalization_context: str = "controlnet", params: Any = None +) -> BasePreprocessor: """ Get a preprocessor by name - + Args: name: Name of the preprocessor pipeline_ref: Pipeline reference for pipeline-aware processors (required for some processors) @@ -122,20 +131,25 @@ def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: - 'controlnet': Expects/produces [0,1] range for ControlNet conditioning - 'pipeline': Expects/produces [-1,1] range for pipeline image processing - 'latent': Works in latent space (no normalization needed) - + Returns: Preprocessor instance - + Raises: ValueError: If preprocessor name is not found or pipeline_ref missing for pipeline-aware processor """ processor_class = get_preprocessor_class(name) - + # Check if this is a pipeline-aware processor - if hasattr(processor_class, 'requires_sync_processing') and processor_class.requires_sync_processing: + if hasattr(processor_class, "requires_sync_processing") and processor_class.requires_sync_processing: if pipeline_ref is None: raise ValueError(f"Processor '{name}' requires a pipeline_ref") - return processor_class(pipeline_ref=pipeline_ref, normalization_context=normalization_context, _registry_name=name, **(params or {})) + return processor_class( + pipeline_ref=pipeline_ref, + normalization_context=normalization_context, + _registry_name=name, + **(params or {}), + ) else: return processor_class(normalization_context=normalization_context, _registry_name=name, **(params or {})) @@ -143,7 +157,7 @@ def get_preprocessor(name: str, pipeline_ref: Any = None, normalization_context: def register_preprocessor(name: str, preprocessor_class): """ Register a new preprocessor - + Args: name: Name to register under preprocessor_class: Preprocessor class @@ -160,7 +174,7 @@ def list_preprocessors(): "BasePreprocessor", "PipelineAwareProcessor", "CannyPreprocessor", - "DepthPreprocessor", + "DepthPreprocessor", "OpenPosePreprocessor", "LineartPreprocessor", "StandardLineartPreprocessor", @@ -195,14 +209,16 @@ def list_preprocessors(): # region Custom Processor Discovery -import logging -import os import importlib.util import inspect +import logging +import os from pathlib import Path + _logger = logging.getLogger(__name__) + def _discover_custom_processors(): """Auto-discover custom processors from repo_root/custom_processors/ folder.""" if os.getenv("STREAMDIFFUSION_DISABLE_CUSTOM_PROCESSORS") == "1": @@ -216,7 +232,7 @@ def _discover_custom_processors(): return _logger.info("Scanning custom_processors/ for custom processors...") for item in custom_dir.iterdir(): - if not item.is_dir() or item.name.startswith(('.', '_')): + if not item.is_dir() or item.name.startswith((".", "_")): continue manifest_file = item / "processors.yaml" if manifest_file.exists(): @@ -226,20 +242,22 @@ def _discover_custom_processors(): except Exception as e: _logger.error(f"Custom processor discovery failed: {e}") + def _load_processor_collection(collection_dir, manifest_file): """Load processors from a collection with processors.yaml manifest.""" import yaml + try: - with open(manifest_file, 'r') as f: + with open(manifest_file, "r") as f: manifest = yaml.safe_load(f) - processor_files = manifest.get('processors', []) + processor_files = manifest.get("processors", []) if not processor_files: _logger.warning(f"Collection '{collection_dir.name}' has empty processors list") return _logger.info(f"Loading collection '{collection_dir.name}' ({len(processor_files)} processors)") for proc_file in processor_files: if isinstance(proc_file, dict): - filename, enabled = proc_file.get('file'), proc_file.get('enabled', True) + filename, enabled = proc_file.get("file"), proc_file.get("enabled", True) if not enabled: continue else: @@ -252,24 +270,28 @@ def _load_processor_collection(collection_dir, manifest_file): except Exception as e: _logger.error(f"Failed to load collection {collection_dir.name}: {e}") + def _load_processor_folder_auto(folder): """Auto-discover processors by scanning for .py files (no manifest).""" _logger.info(f"Auto-scanning folder: {folder.name}") for py_file in folder.glob("*.py"): - if py_file.name.startswith('_') or py_file.name in ['base.py', 'setup.py']: + if py_file.name.startswith("_") or py_file.name in ["base.py", "setup.py"]: continue _load_processor_from_file(py_file, py_file.stem) + def _load_processor_from_file(file_path, proc_name): """Load and register a processor class from a Python file.""" try: spec = importlib.util.spec_from_file_location( - f"custom_processors.{file_path.parent.name}.{file_path.stem}", file_path) + f"custom_processors.{file_path.parent.name}.{file_path.stem}", file_path + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) found_classes = [ - (name, obj) for name, obj in inspect.getmembers(module, inspect.isclass) + (name, obj) + for name, obj in inspect.getmembers(module, inspect.isclass) if issubclass(obj, (BasePreprocessor, PipelineAwareProcessor)) and obj not in [BasePreprocessor, PipelineAwareProcessor] ] @@ -284,5 +306,6 @@ def _load_processor_from_file(file_path, proc_name): except Exception as e: _logger.error(f" Failed to load {file_path.name}: {e}") + _discover_custom_processors() -# endregion \ No newline at end of file +# endregion diff --git a/src/streamdiffusion/preprocessing/processors/base.py b/src/streamdiffusion/preprocessing/processors/base.py index 218a459f5..155dfddde 100644 --- a/src/streamdiffusion/preprocessing/processors/base.py +++ b/src/streamdiffusion/preprocessing/processors/base.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Union, Dict, Any, Tuple, Optional +from typing import Any, Dict, Tuple, Union + +import numpy as np import torch import torch.nn.functional as F -import numpy as np from PIL import Image @@ -10,12 +11,11 @@ class BasePreprocessor(ABC): """ Base class for ControlNet preprocessors with template method pattern """ - - - def __init__(self, normalization_context: str = 'controlnet', **kwargs): + + def __init__(self, normalization_context: str = "controlnet", **kwargs): """ Initialize the preprocessor - + Args: normalization_context: Context for normalization handling. - 'controlnet': Expects/produces [0,1] range for ControlNet conditioning @@ -25,15 +25,15 @@ def __init__(self, normalization_context: str = 'controlnet', **kwargs): """ self.params = kwargs self.normalization_context = normalization_context - self.device = kwargs.get('device', 'cuda' if torch.cuda.is_available() else 'cpu') - self.dtype = kwargs.get('dtype', torch.float16) - + self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") + self.dtype = kwargs.get("dtype", torch.float16) + @classmethod def get_preprocessor_metadata(cls) -> Dict[str, Any]: """ Get comprehensive metadata for this preprocessor. Subclasses should override this to define their specific metadata. - + Returns: Dictionary containing: - display_name: Human-readable name @@ -45,9 +45,9 @@ def get_preprocessor_metadata(cls) -> Dict[str, Any]: "display_name": cls.__name__.replace("Preprocessor", ""), "description": f"Preprocessor for {cls.__name__.replace('Preprocessor', '').lower()}", "parameters": {}, - "use_cases": [] + "use_cases": [], } - + def process(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image.Image: """ Template method - handles all common operations @@ -55,7 +55,7 @@ def process(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image. image = self.validate_input(image) processed = self._process_core(image) return self._ensure_target_size(processed) - + def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Template method for GPU tensor processing @@ -63,14 +63,14 @@ def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor: tensor = self.validate_tensor_input(image_tensor) processed = self._process_tensor_core(tensor) return self._ensure_target_size_tensor(processed) - + @abstractmethod def _process_core(self, image: Image.Image) -> Image.Image: """ Subclasses implement ONLY their specific algorithm """ pass - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ Optional GPU processing (fallback to PIL if not overridden) @@ -78,7 +78,7 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: pil_image = self.tensor_to_pil(tensor) processed_pil = self._process_core(pil_image) return self.pil_to_tensor(processed_pil) - + def _ensure_target_size(self, image: Image.Image) -> Image.Image: """ Centralized PIL resize logic @@ -87,7 +87,7 @@ def _ensure_target_size(self, image: Image.Image) -> Image.Image: if image.size != (target_width, target_height): return image.resize((target_width, target_height), Image.LANCZOS) return image - + def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor: """ Centralized tensor resize logic @@ -95,54 +95,54 @@ def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor: target_width, target_height = self.get_target_dimensions() current_size = tensor.shape[-2:] target_size = (target_height, target_width) - + if current_size != target_size: if tensor.dim() == 3: tensor = tensor.unsqueeze(0) - tensor = F.interpolate(tensor, size=target_size, mode='bilinear', align_corners=False) + tensor = F.interpolate(tensor, size=target_size, mode="bilinear", align_corners=False) if tensor.shape[0] == 1: tensor = tensor.squeeze(0) return tensor - + def validate_tensor_input(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Validate and normalize tensor input for processing - + Args: image_tensor: Input tensor - + Returns: Tensor in CHW format, on correct device Range: [0,1] if input was [0,255], otherwise preserves input range - + Note: This preserves [-1,1] tensors (from pipeline) since max() <= 1.0 """ # Handle batch dimension if image_tensor.dim() == 4: image_tensor = image_tensor[0] # Take first image from batch - + # Convert to CHW format if needed if image_tensor.dim() == 3 and image_tensor.shape[0] not in [1, 3]: # Likely HWC format, convert to CHW image_tensor = image_tensor.permute(2, 0, 1) - + # Ensure correct device and dtype image_tensor = image_tensor.to(device=self.device, dtype=self.dtype) - + # Normalize to [0,1] range only if tensor is in [0,255] uint8 range # Preserves [-1,1] and [0,1] ranges (max <= 1.0) if image_tensor.max() > 1.0: image_tensor = image_tensor / 255.0 - + return image_tensor - + def tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: """ Convert tensor to PIL Image (minimize CPU transfers) - + Args: tensor: Input tensor - + Returns: PIL Image """ @@ -151,39 +151,39 @@ def tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: tensor = tensor[0] if tensor.dim() == 3 and tensor.shape[0] in [1, 3]: tensor = tensor.permute(1, 2, 0) - + # Convert to numpy (unavoidable for PIL) if tensor.is_cuda: tensor = tensor.cpu() - + # Convert to uint8 if tensor.max() <= 1.0: tensor = (tensor * 255).clamp(0, 255).to(torch.uint8) else: tensor = tensor.clamp(0, 255).to(torch.uint8) - + array = tensor.numpy() - + if array.shape[-1] == 3: - return Image.fromarray(array, 'RGB') + return Image.fromarray(array, "RGB") elif array.shape[-1] == 1: - return Image.fromarray(array.squeeze(-1), 'L') + return Image.fromarray(array.squeeze(-1), "L") else: return Image.fromarray(array) - + def pil_to_tensor(self, image: Image.Image) -> torch.Tensor: """ Convert PIL Image to tensor on GPU - + Args: image: PIL Image - + Returns: Tensor on correct device """ # Convert to numpy first array = np.array(image) - + # Convert to tensor if len(array.shape) == 2: # Grayscale tensor = torch.from_numpy(array).float() / 255.0 @@ -191,25 +191,25 @@ def pil_to_tensor(self, image: Image.Image) -> torch.Tensor: else: # RGB tensor = torch.from_numpy(array).float() / 255.0 tensor = tensor.permute(2, 0, 1) # HWC to CHW - + # Move to device tensor = tensor.to(device=self.device, dtype=self.dtype) return tensor.unsqueeze(0) # Add batch dimension - + def validate_input(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image.Image: """ Convert input to PIL Image for processing - + Args: image: Input image in various formats - + Returns: PIL Image """ if isinstance(image, torch.Tensor): # Use tensor_to_pil method for better handling return self.tensor_to_pil(image) - + if isinstance(image, np.ndarray): # Ensure uint8 format if image.dtype != np.uint8: @@ -217,83 +217,83 @@ def validate_input(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) - + # Convert to PIL Image if len(image.shape) == 3: - image = Image.fromarray(image, 'RGB') + image = Image.fromarray(image, "RGB") else: - image = Image.fromarray(image, 'L') - + image = Image.fromarray(image, "L") + if not isinstance(image, Image.Image): raise ValueError(f"Unsupported image type: {type(image)}") - + return image - + def get_target_dimensions(self) -> Tuple[int, int]: """ Get target output dimensions (width, height) """ # Check for explicit width/height parameters first - width = self.params.get('image_width') - height = self.params.get('image_height') - + width = self.params.get("image_width") + height = self.params.get("image_height") + if width is not None and height is not None: return (width, height) - + # Fallback to square resolution for backwards compatibility - resolution = self.params.get('image_resolution', 512) + resolution = self.params.get("image_resolution", 512) return (resolution, resolution) - + def __call__(self, image: Union[Image.Image, np.ndarray, torch.Tensor], **kwargs) -> Image.Image: """ Process an image (convenience method) - + Args: image: Input image **kwargs: Additional parameters to override defaults - + Returns: Processed PIL Image """ # Update parameters for this call params = {**self.params, **kwargs} - + # Store original params and update original_params = self.params self.params = params - + try: result = self.process(image) finally: # Restore original params self.params = original_params - + return result class PipelineAwareProcessor(BasePreprocessor): """ Abstract base class for processors that need access to pipeline state (previous outputs). - - This base class marks processors as requiring synchronous processing to avoid + + This base class marks processors as requiring synchronous processing to avoid temporal artifacts and ensures they have access to pipeline references. - + Usage: class MyProcessor(PipelineAwareProcessor): pass - + Examples: - FeedbackPreprocessor: Needs previous diffusion output - TemporalNetPreprocessor: Needs previous frame for optical flow """ - + # Class attribute to mark processors as requiring sync processing requires_sync_processing = True - - def __init__(self, pipeline_ref: Any, normalization_context: str = 'controlnet', **kwargs): + + def __init__(self, pipeline_ref: Any, normalization_context: str = "controlnet", **kwargs): """ Initialize pipeline-aware functionality - + Args: pipeline_ref: Reference to the StreamDiffusion pipeline instance (required) normalization_context: Context for normalization handling @@ -302,4 +302,4 @@ def __init__(self, pipeline_ref: Any, normalization_context: str = 'controlnet', if pipeline_ref is None: raise ValueError(f"{self.__class__.__name__} requires a pipeline_ref") super().__init__(normalization_context=normalization_context, **kwargs) - self.pipeline_ref = pipeline_ref \ No newline at end of file + self.pipeline_ref = pipeline_ref diff --git a/src/streamdiffusion/preprocessing/processors/blur.py b/src/streamdiffusion/preprocessing/processors/blur.py index 2e12c8e6a..694d9eb06 100644 --- a/src/streamdiffusion/preprocessing/processors/blur.py +++ b/src/streamdiffusion/preprocessing/processors/blur.py @@ -1,19 +1,18 @@ import torch import torch.nn.functional as F -import numpy as np from PIL import Image -from typing import Union + from .base import BasePreprocessor class BlurPreprocessor(BasePreprocessor): """ Gaussian blur preprocessor for ControlNet - + Applies Gaussian blur to the input image using GPU-accelerated operations. Useful for creating soft, dreamy effects or reducing image detail. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -24,22 +23,22 @@ def get_preprocessor_metadata(cls): "type": "float", "default": 2.0, "range": [0.1, 10.0], - "description": "Intensity of the blur effect. Higher values create stronger blur." + "description": "Intensity of the blur effect. Higher values create stronger blur.", }, "kernel_size": { "type": "int", "default": 15, "range": [3, 51], - "description": "Size of the blur kernel. Must be odd. Larger values create smoother blur." - } + "description": "Size of the blur kernel. Must be odd. Larger values create smoother blur.", + }, }, - "use_cases": ["Soft focus effects", "Background blur", "Artistic rendering", "Detail reduction"] + "use_cases": ["Soft focus effects", "Background blur", "Artistic rendering", "Detail reduction"], } - + def __init__(self, blur_intensity: float = 2.0, kernel_size: int = 15, **kwargs): """ Initialize Blur preprocessor - + Args: blur_intensity: Standard deviation for Gaussian kernel (higher = more blur) kernel_size: Size of the blur kernel (must be odd) @@ -48,58 +47,55 @@ def __init__(self, blur_intensity: float = 2.0, kernel_size: int = 15, **kwargs) # Ensure kernel_size is odd if kernel_size % 2 == 0: kernel_size += 1 - - super().__init__( - blur_intensity=blur_intensity, - kernel_size=kernel_size, - **kwargs - ) - + + super().__init__(blur_intensity=blur_intensity, kernel_size=kernel_size, **kwargs) + # Cache the Gaussian kernel for efficiency self._cached_kernel = None self._cached_kernel_size = None self._cached_intensity = None - + def _create_gaussian_kernel(self, kernel_size: int, intensity: float) -> torch.Tensor: """ Create a 2D Gaussian kernel for blurring - + Args: kernel_size: Size of the kernel (must be odd) intensity: Standard deviation of the Gaussian - + Returns: 2D Gaussian kernel tensor """ # Create coordinate grids coords = torch.arange(kernel_size, dtype=self.dtype, device=self.device) coords = coords - (kernel_size - 1) / 2 - + # Create 2D coordinate grids - y_grid, x_grid = torch.meshgrid(coords, coords, indexing='ij') - + y_grid, x_grid = torch.meshgrid(coords, coords, indexing="ij") + # Calculate Gaussian values gaussian = torch.exp(-(x_grid**2 + y_grid**2) / (2 * intensity**2)) - + # Normalize to sum to 1 gaussian = gaussian / gaussian.sum() - + return gaussian - + def _get_gaussian_kernel(self, kernel_size: int, intensity: float) -> torch.Tensor: """ Get cached Gaussian kernel or create new one """ - if (self._cached_kernel is None or - self._cached_kernel_size != kernel_size or - self._cached_intensity != intensity): - + if ( + self._cached_kernel is None + or self._cached_kernel_size != kernel_size + or self._cached_intensity != intensity + ): self._cached_kernel = self._create_gaussian_kernel(kernel_size, intensity) self._cached_kernel_size = kernel_size self._cached_intensity = intensity - + return self._cached_kernel - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply Gaussian blur to the input image using PIL/numpy fallback @@ -107,46 +103,46 @@ def _process_core(self, image: Image.Image) -> Image.Image: # Convert to tensor for processing tensor = self.pil_to_tensor(image) tensor = tensor.squeeze(0) # Remove batch dimension - + # Process on GPU blurred = self._process_tensor_core(tensor) - + # Convert back to PIL return self.tensor_to_pil(blurred) - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU for Gaussian blur """ - blur_intensity = self.params.get('blur_intensity', 2.0) - kernel_size = self.params.get('kernel_size', 15) - + blur_intensity = self.params.get("blur_intensity", 2.0) + kernel_size = self.params.get("kernel_size", 15) + # Ensure kernel_size is odd if kernel_size % 2 == 0: kernel_size += 1 - + # Get the Gaussian kernel kernel = self._get_gaussian_kernel(kernel_size, blur_intensity) - + # Ensure tensor has batch dimension if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) - + # Ensure tensor is on the correct device and dtype image_tensor = image_tensor.to(device=self.device, dtype=self.dtype) - + # Reshape kernel for conv2d: (out_channels, in_channels/groups, H, W) # We'll apply the same kernel to each channel separately num_channels = image_tensor.shape[1] kernel_conv = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1) - + # Apply Gaussian blur using conv2d with groups=num_channels for per-channel convolution padding = kernel_size // 2 blurred = F.conv2d( image_tensor, kernel_conv, padding=padding, - groups=num_channels # Apply kernel separately to each channel + groups=num_channels, # Apply kernel separately to each channel ) - + return blurred diff --git a/src/streamdiffusion/preprocessing/processors/canny.py b/src/streamdiffusion/preprocessing/processors/canny.py index 7c25e9ab4..90e47241e 100644 --- a/src/streamdiffusion/preprocessing/processors/canny.py +++ b/src/streamdiffusion/preprocessing/processors/canny.py @@ -1,18 +1,19 @@ import cv2 import numpy as np -from PIL import Image import torch -from typing import Union +from PIL import Image + from .base import BasePreprocessor -#TODO provide gpu native edge detection + +# TODO provide gpu native edge detection class CannyPreprocessor(BasePreprocessor): """ Canny edge detection preprocessor for ControlNet - + Detects edges in the input image using the Canny edge detection algorithm. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -23,52 +24,48 @@ def get_preprocessor_metadata(cls): "type": "int", "default": 100, "range": [1, 255], - "description": "Lower threshold for edge detection. Lower values detect more edges." + "description": "Lower threshold for edge detection. Lower values detect more edges.", }, "high_threshold": { - "type": "int", + "type": "int", "default": 200, "range": [1, 255], - "description": "Upper threshold for edge detection. Higher values are more selective." - } + "description": "Upper threshold for edge detection. Higher values are more selective.", + }, }, - "use_cases": ["Line art", "Architecture", "Technical drawings", "Clean edge detection"] + "use_cases": ["Line art", "Architecture", "Technical drawings", "Clean edge detection"], } - + def __init__(self, low_threshold: int = 100, high_threshold: int = 200, **kwargs): """ Initialize Canny preprocessor - + Args: low_threshold: Lower threshold for edge detection high_threshold: Upper threshold for edge detection **kwargs: Additional parameters """ - super().__init__( - low_threshold=low_threshold, - high_threshold=high_threshold, - **kwargs - ) - + super().__init__(low_threshold=low_threshold, high_threshold=high_threshold, **kwargs) + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply Canny edge detection to the input image """ image_np = np.array(image) - + if len(image_np.shape) == 3: gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) else: gray = image_np - - low_threshold = self.params.get('low_threshold', 100) - high_threshold = self.params.get('high_threshold', 200) - + + low_threshold = self.params.get("low_threshold", 100) + high_threshold = self.params.get("high_threshold", 200) + edges = cv2.Canny(gray, low_threshold, high_threshold) edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) - + return Image.fromarray(edges_rgb) - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU for Canny edge detection @@ -77,18 +74,18 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: gray_tensor = 0.299 * image_tensor[0] + 0.587 * image_tensor[1] + 0.114 * image_tensor[2] else: gray_tensor = image_tensor[0] if image_tensor.shape[0] == 1 else image_tensor - + gray_cpu = gray_tensor.cpu() gray_np = (gray_cpu * 255).clamp(0, 255).to(torch.uint8).numpy() - - low_threshold = self.params.get('low_threshold', 100) - high_threshold = self.params.get('high_threshold', 200) - + + low_threshold = self.params.get("low_threshold", 100) + high_threshold = self.params.get("high_threshold", 200) + edges = cv2.Canny(gray_np, low_threshold, high_threshold) - + edges_tensor = torch.from_numpy(edges).float() / 255.0 edges_tensor = edges_tensor.to(device=self.device, dtype=self.dtype) - + edges_rgb = edges_tensor.unsqueeze(0).repeat(3, 1, 1) - - return edges_rgb \ No newline at end of file + + return edges_rgb diff --git a/src/streamdiffusion/preprocessing/processors/depth.py b/src/streamdiffusion/preprocessing/processors/depth.py index fbf57dc83..b7287ffa0 100644 --- a/src/streamdiffusion/preprocessing/processors/depth.py +++ b/src/streamdiffusion/preprocessing/processors/depth.py @@ -1,12 +1,14 @@ import numpy as np -from PIL import Image import torch -from typing import Union, Optional +from PIL import Image + from .base import BasePreprocessor + try: import torch from transformers import pipeline + TRANSFORMERS_AVAILABLE = True except ImportError: TRANSFORMERS_AVAILABLE = False @@ -15,29 +17,25 @@ class DepthPreprocessor(BasePreprocessor): """ Depth estimation preprocessor for ControlNet using MiDaS - + Estimates depth maps from input images using the MiDaS depth estimation model. """ - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Depth Estimation", "description": "Estimates depth from the input image using MiDaS. Good for adding depth-based control to generation.", - "parameters": { - - }, - "use_cases": ["3D-aware generation", "Depth preservation", "Scene understanding"] + "parameters": {}, + "use_cases": ["3D-aware generation", "Depth preservation", "Scene understanding"], } - - def __init__(self, - model_name: str = "Intel/dpt-large", - detect_resolution: int = 512, - image_resolution: int = 512, - **kwargs): + + def __init__( + self, model_name: str = "Intel/dpt-large", detect_resolution: int = 512, image_resolution: int = 512, **kwargs + ): """ Initialize depth preprocessor - + Args: model_name: Name of the depth estimation model to use detect_resolution: Resolution for depth detection @@ -46,102 +44,94 @@ def __init__(self, """ if not TRANSFORMERS_AVAILABLE: raise ImportError( - "transformers library is required for depth preprocessing. " - "Install it with: pip install transformers" + "transformers library is required for depth preprocessing. Install it with: pip install transformers" ) - + super().__init__( - model_name=model_name, - detect_resolution=detect_resolution, - image_resolution=image_resolution, - **kwargs + model_name=model_name, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs ) - + self._depth_estimator = None - + @property def depth_estimator(self): """Lazy loading of the depth estimation model""" if self._depth_estimator is None: - model_name = self.params.get('model_name', 'Intel/dpt-large') + model_name = self.params.get("model_name", "Intel/dpt-large") print(f"Loading depth estimation model: {model_name}") self._depth_estimator = pipeline( - 'depth-estimation', - model=model_name, - device=0 if torch.cuda.is_available() else -1 + "depth-estimation", model=model_name, device=0 if torch.cuda.is_available() else -1 ) return self._depth_estimator - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply depth estimation to the input image """ - detect_resolution = self.params.get('detect_resolution', 512) + detect_resolution = self.params.get("detect_resolution", 512) image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS) - + depth_result = self.depth_estimator(image_resized) - depth_map = depth_result['depth'] - - if hasattr(depth_map, 'cpu'): + depth_map = depth_result["depth"] + + if hasattr(depth_map, "cpu"): depth_np = depth_map.cpu().numpy() else: depth_np = np.array(depth_map) - + depth_min = depth_np.min() depth_max = depth_np.max() if depth_max > depth_min: depth_normalized = ((depth_np - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8) else: depth_normalized = np.zeros_like(depth_np, dtype=np.uint8) - + depth_rgb = np.stack([depth_normalized] * 3, axis=-1) return Image.fromarray(depth_rgb) - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU for depth estimation """ - detect_resolution = self.params.get('detect_resolution', 512) + detect_resolution = self.params.get("detect_resolution", 512) current_size = image_tensor.shape[-2:] - + if current_size != (detect_resolution, detect_resolution): import torch.nn.functional as F + if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) - + resized_tensor = F.interpolate( - image_tensor, - size=(detect_resolution, detect_resolution), - mode='bilinear', - align_corners=False + image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False ) - + if image_tensor.shape[0] == 1: resized_tensor = resized_tensor.squeeze(0) else: resized_tensor = image_tensor - + pil_image = self.tensor_to_pil(resized_tensor) - + depth_result = self.depth_estimator(pil_image) - depth_map = depth_result['depth'] - - if hasattr(depth_map, 'to'): + depth_map = depth_result["depth"] + + if hasattr(depth_map, "to"): depth_tensor = depth_map.to(device=self.device, dtype=self.dtype) else: depth_np = np.array(depth_map) depth_tensor = torch.from_numpy(depth_np).to(device=self.device, dtype=self.dtype) - + depth_min = depth_tensor.min() depth_max = depth_tensor.max() if depth_max > depth_min: depth_normalized = (depth_tensor - depth_min) / (depth_max - depth_min) else: depth_normalized = torch.zeros_like(depth_tensor) - + if depth_normalized.dim() == 2: depth_rgb = depth_normalized.unsqueeze(0).repeat(3, 1, 1) else: depth_rgb = depth_normalized - - return depth_rgb \ No newline at end of file + + return depth_rgb diff --git a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py index 993ee242d..70ce9dbce 100644 --- a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py @@ -1,19 +1,23 @@ -#NOTE: ported from https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt +# NOTE: ported from https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt import os + +import cv2 import numpy as np import torch import torch.nn.functional as F -import cv2 from PIL import Image -from typing import Union, Optional + from .base import BasePreprocessor + try: + from collections import OrderedDict + import tensorrt as trt from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False @@ -40,7 +44,7 @@ class TensorRTEngine: """Simplified TensorRT engine wrapper for depth estimation inference (optimized)""" - + def __init__(self, engine_path): self.engine_path = engine_path self.engine = None @@ -65,13 +69,11 @@ def allocate_buffers(self, device="cuda"): name = self.engine.get_tensor_name(idx) shape = self.context.get_tensor_shape(name) dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: self.context.set_input_shape(name, shape) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=device) + + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) self.tensors[name] = tensor def infer(self, feed_dict, stream=None): @@ -79,7 +81,7 @@ def infer(self, feed_dict, stream=None): # Use cached stream if none provided if stream is None: stream = self._cuda_stream - + # Copy input data to tensors for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) @@ -87,39 +89,35 @@ def infer(self, feed_dict, stream=None): # Set tensor addresses for name, tensor in self.tensors.items(): self.context.set_tensor_address(name, tensor.data_ptr()) - + # Execute inference success = self.context.execute_async_v3(stream) if not success: raise ValueError("ERROR: TensorRT inference failed.") - + return self.tensors class DepthAnythingTensorrtPreprocessor(BasePreprocessor): """ Depth Anything TensorRT preprocessor for ControlNet - + Uses TensorRT-optimized Depth Anything model for fast depth estimation. """ + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Depth Estimation (TensorRT)", "description": "Fast TensorRT-optimized depth estimation using Depth Anything model. Significantly faster than standard depth estimation.", - "parameters": { - - }, - "use_cases": ["High-performance depth estimation", "Real-time applications", "3D-aware generation"] + "parameters": {}, + "use_cases": ["High-performance depth estimation", "Real-time applications", "3D-aware generation"], } - def __init__(self, - engine_path: str = None, - detect_resolution: int = 518, - image_resolution: int = 512, - **kwargs): + + def __init__(self, engine_path: str = None, detect_resolution: int = 518, image_resolution: int = 512, **kwargs): """ Initialize TensorRT depth preprocessor - + Args: engine_path: Path to TensorRT engine file detect_resolution: Resolution for depth detection (should match engine input) @@ -131,74 +129,68 @@ def __init__(self, "TensorRT and polygraphy libraries are required for TensorRT depth preprocessing. " "Install them with: pip install tensorrt polygraphy" ) - + super().__init__( - engine_path=engine_path, - detect_resolution=detect_resolution, - image_resolution=image_resolution, - **kwargs + engine_path=engine_path, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs ) - + self._engine = None - + @property def engine(self): """Lazy loading of the TensorRT engine""" if self._engine is None: - engine_path = self.params.get('engine_path') + engine_path = self.params.get("engine_path") if engine_path is None: raise ValueError( "engine_path is required for TensorRT depth preprocessing. " "Please provide it in the preprocessor_params config." ) - + if not os.path.exists(engine_path): raise FileNotFoundError(f"TensorRT engine not found: {engine_path}") - + print(f"Loading TensorRT depth estimation engine: {engine_path}") - + self._engine = TensorRTEngine(engine_path) self._engine.load() self._engine.activate() self._engine.allocate_buffers() - + return self._engine - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply TensorRT depth estimation to the input image """ - detect_resolution = self.params.get('detect_resolution', 518) - + detect_resolution = self.params.get("detect_resolution", 518) + image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) - + image_resized = F.interpolate( - image_tensor, - size=(detect_resolution, detect_resolution), - mode='bilinear', - align_corners=False + image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False ) - + if torch.cuda.is_available(): image_resized = image_resized.cuda() - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized}, cuda_stream) - depth = result['output'] - + depth = result["output"] + depth = np.reshape(depth.cpu().numpy(), (detect_resolution, detect_resolution)) depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.astype(np.uint8) - + original_size = image.size depth = cv2.resize(depth, original_size) - + depth_rgb = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB) result = Image.fromarray(depth_rgb) - + return result - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU to avoid CPU transfers @@ -207,20 +199,19 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: image_tensor = image_tensor.unsqueeze(0) if not image_tensor.is_cuda: image_tensor = image_tensor.cuda() - - detect_resolution = self.params.get('detect_resolution', 518) - + + detect_resolution = self.params.get("detect_resolution", 518) + image_resized = torch.nn.functional.interpolate( - image_tensor, size=(detect_resolution, detect_resolution), - mode='bilinear', align_corners=False + image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False ) - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized}, cuda_stream) - depth_tensor = result['output'] - + depth_tensor = result["output"] + depth_tensor = depth_tensor.squeeze() if depth_tensor.dim() > 2 else depth_tensor depth_min, depth_max = depth_tensor.min(), depth_tensor.max() depth_normalized = (depth_tensor - depth_min) / (depth_max - depth_min) - - return depth_normalized.repeat(3, 1, 1).unsqueeze(0) \ No newline at end of file + + return depth_normalized.repeat(3, 1, 1).unsqueeze(0) diff --git a/src/streamdiffusion/preprocessing/processors/external.py b/src/streamdiffusion/preprocessing/processors/external.py index 80bd7fe8d..3a205b132 100644 --- a/src/streamdiffusion/preprocessing/processors/external.py +++ b/src/streamdiffusion/preprocessing/processors/external.py @@ -1,95 +1,87 @@ +from typing import Union + import numpy as np import torch from PIL import Image -from typing import Union, Optional, Dict, Any + from .base import BasePreprocessor class ExternalPreprocessor(BasePreprocessor): """ External source preprocessor for client-processed control data - + """ - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "External", "description": "Allows using external preprocessing tools or custom processing pipelines.", - "parameters": { - - }, - "use_cases": ["Custom processing", "Third-party tools integration", "Pre-processed control images"] + "parameters": {}, + "use_cases": ["Custom processing", "Third-party tools integration", "Pre-processed control images"], } - - def __init__(self, - image_resolution: int = 512, - validate_input: bool = True, - **kwargs): + + def __init__(self, image_resolution: int = 512, validate_input: bool = True, **kwargs): """ Initialize external source preprocessor - + Args: image_resolution: Target output resolution validate_input: Whether to validate the control image format **kwargs: Additional parameters """ - super().__init__( - image_resolution=image_resolution, - validate_input=validate_input, - **kwargs - ) - + super().__init__(image_resolution=image_resolution, validate_input=validate_input, **kwargs) + def _process_core(self, image: Image.Image) -> Image.Image: """ Process client-preprocessed control image - + Applies minimal server-side validation to control images that have already been processed by external sources. """ # Optional validation of control image format - if self.params.get('validate_input', True): + if self.params.get("validate_input", True): image = self._validate_control_image(image) - + return image - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly (optimized path for external sources) - + For external sources, tensor input likely comes from client WebGL/Canvas processing, so minimal processing needed. """ return tensor - + def _validate_control_image(self, image: Image.Image) -> Image.Image: """ Validate that the control image is in proper format """ # Convert to RGB if needed - if image.mode != 'RGB': - image = image.convert('RGB') - + if image.mode != "RGB": + image = image.convert("RGB") + # Basic validation - check if image has content # (not completely black, which might indicate processing failure) img_array = np.array(image) brightness = np.mean(img_array) - + if brightness < 1.0: # Very dark image, might be processing error print("ExternalPreprocessor._validate_control_image: Warning - control image appears very dark") - + return image - - + def __call__(self, image: Union[Image.Image, np.ndarray, torch.Tensor], **kwargs) -> Image.Image: """ Process control image (convenience method) """ # Store any client metadata if provided - client_metadata = kwargs.get('client_metadata', {}) + client_metadata = kwargs.get("client_metadata", {}) if client_metadata: - source = client_metadata.get('source', 'unknown') - control_type = client_metadata.get('type', 'unknown') + source = client_metadata.get("source", "unknown") + control_type = client_metadata.get("type", "unknown") print(f"ExternalPreprocessor: Received {control_type} control from {source}") - - return super().__call__(image, **kwargs) \ No newline at end of file + + return super().__call__(image, **kwargs) diff --git a/src/streamdiffusion/preprocessing/processors/faceid_embedding.py b/src/streamdiffusion/preprocessing/processors/faceid_embedding.py index c6421897a..31bb042ae 100644 --- a/src/streamdiffusion/preprocessing/processors/faceid_embedding.py +++ b/src/streamdiffusion/preprocessing/processors/faceid_embedding.py @@ -1,9 +1,12 @@ -from typing import Tuple, Any +from typing import Any, Tuple + import torch from PIL import Image -from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor + from streamdiffusion.utils.reporting import report_error +from .ipadapter_embedding import IPAdapterEmbeddingPreprocessor + class FaceIDEmbeddingPreprocessor(IPAdapterEmbeddingPreprocessor): """ @@ -45,9 +48,7 @@ def __init__(self, ipadapter: Any, faceid_v2_weight: float = 1.0, **kwargs): self.faceid_v2_weight = float(faceid_v2_weight) if not hasattr(ipadapter, "insightface_model") or ipadapter.insightface_model is None: - raise ValueError( - "FaceIDEmbeddingPreprocessor: ipadapter must have an initialized InsightFace model" - ) + raise ValueError("FaceIDEmbeddingPreprocessor: ipadapter must have an initialized InsightFace model") def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -78,8 +79,4 @@ def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor] def update_faceid_v2_weight(self, weight: float) -> None: self.faceid_v2_weight = float(weight) - print( - f"FaceIDEmbeddingPreprocessor.update_faceid_v2_weight: Updated weight to {self.faceid_v2_weight}" - ) - - + print(f"FaceIDEmbeddingPreprocessor.update_faceid_v2_weight: Updated weight to {self.faceid_v2_weight}") diff --git a/src/streamdiffusion/preprocessing/processors/feedback.py b/src/streamdiffusion/preprocessing/processors/feedback.py index 72a37a7f0..05cce45fa 100644 --- a/src/streamdiffusion/preprocessing/processors/feedback.py +++ b/src/streamdiffusion/preprocessing/processors/feedback.py @@ -1,28 +1,30 @@ +from typing import Any + import torch from PIL import Image -from typing import Union, Optional, Any + from .base import PipelineAwareProcessor class FeedbackPreprocessor(PipelineAwareProcessor): """ Feedback preprocessor for ControlNet - + Creates a configurable blend between the current input image and the previous frame's diffusion output. This creates a feedback loop where each generated frame influences the next generation, while allowing control over the blend strength for stability and creative effects. - + Formula: output = (1 - feedback_strength) * input_image + feedback_strength * previous_output - + Examples: - feedback_strength = 0.0: Pure passthrough (input only) - feedback_strength = 0.5: 50/50 blend (default) - feedback_strength = 1.0: Pure feedback (previous output only) - + The preprocessor accesses the pipeline's prev_image_result to get the previous output. For the first frame (when no previous output exists), it falls back to the input image. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -34,21 +36,29 @@ def get_preprocessor_metadata(cls): "default": 0.5, "range": [0.0, 1.0], "step": 0.01, - "description": "Strength of feedback blend (0.0 = pure input, 1.0 = pure feedback)" + "description": "Strength of feedback blend (0.0 = pure input, 1.0 = pure feedback)", } }, - "use_cases": ["Temporal consistency", "Video-like generation", "Smooth transitions", "Deforum", "Blast off"] + "use_cases": [ + "Temporal consistency", + "Video-like generation", + "Smooth transitions", + "Deforum", + "Blast off", + ], } - - def __init__(self, - pipeline_ref: Any, - normalization_context: str = 'controlnet', - image_resolution: int = 512, - feedback_strength: float = 0.5, - **kwargs): + + def __init__( + self, + pipeline_ref: Any, + normalization_context: str = "controlnet", + image_resolution: int = 512, + feedback_strength: float = 0.5, + **kwargs, + ): """ Initialize feedback preprocessor - + Args: pipeline_ref: Reference to the StreamDiffusion pipeline instance (required) normalization_context: Context for normalization handling @@ -61,41 +71,42 @@ def __init__(self, normalization_context=normalization_context, image_resolution=image_resolution, feedback_strength=feedback_strength, - **kwargs + **kwargs, ) self.feedback_strength = max(0.0, min(1.0, feedback_strength)) # Clamp to [0, 1] self._first_frame = True - + def reset(self): """Reset the processor state (useful for new sequences)""" self._first_frame = True - + def _process_core(self, image: Image.Image) -> Image.Image: """ Process using configurable blend of input image + previous frame output - + Args: image: Current input image - + Returns: Blended PIL Image (blend strength controlled by feedback_strength), or input image for first frame """ # Check if we have a pipeline reference and previous output - if (self.pipeline_ref is not None and - hasattr(self.pipeline_ref, 'prev_image_result') and - self.pipeline_ref.prev_image_result is not None and - not self._first_frame): - + if ( + self.pipeline_ref is not None + and hasattr(self.pipeline_ref, "prev_image_result") + and self.pipeline_ref.prev_image_result is not None + and not self._first_frame + ): prev_output_tensor = self.pipeline_ref.prev_image_result # Convert previous output tensor to PIL Image if prev_output_tensor.dim() == 4: prev_output_tensor = prev_output_tensor[0] # Remove batch dimension - + # Context-aware normalization handling - if self.normalization_context == 'controlnet': + if self.normalization_context == "controlnet": # ControlNet context: Convert from [-1, 1] (VAE output) to [0, 1] (ControlNet input) prev_output_tensor = (prev_output_tensor / 2.0 + 0.5).clamp(0, 1) - elif self.normalization_context == 'pipeline': + elif self.normalization_context == "pipeline": # Pipeline context: prev_output is already [-1, 1], but pil_to_tensor produces [0, 1] # So we need to convert input to [-1, 1] to match prev_output # Convert prev_output to [0, 1] for blending in standard image space @@ -103,15 +114,15 @@ def _process_core(self, image: Image.Image) -> Image.Image: else: # Unknown context - assume controlnet for backward compatibility prev_output_tensor = (prev_output_tensor / 2.0 + 0.5).clamp(0, 1) - + # Convert both to tensors for blending prev_output_pil = self.tensor_to_pil(prev_output_tensor) input_tensor = self.pil_to_tensor(image).squeeze(0) # Remove batch dim for blending prev_tensor = self.pil_to_tensor(prev_output_pil).squeeze(0) - + # Blend with configurable strength (both tensors now in [0, 1] range) blended_tensor = (1 - self.feedback_strength) * input_tensor + self.feedback_strength * prev_tensor - + # Convert back to PIL blended_pil = self.tensor_to_pil(blended_tensor) return blended_pil @@ -119,35 +130,36 @@ def _process_core(self, image: Image.Image) -> Image.Image: # First frame, no pipeline ref, or no previous output available - use input image self._first_frame = False return image - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ Process using configurable blend of input tensor + previous frame output (GPU-optimized path) - + Args: tensor: Current input tensor - + Returns: Blended tensor (blend strength controlled by feedback_strength), or input tensor for first frame """ # Check if we have a pipeline reference and previous output - if (self.pipeline_ref is not None and - hasattr(self.pipeline_ref, 'prev_image_result') and - self.pipeline_ref.prev_image_result is not None and - not self._first_frame): - + if ( + self.pipeline_ref is not None + and hasattr(self.pipeline_ref, "prev_image_result") + and self.pipeline_ref.prev_image_result is not None + and not self._first_frame + ): prev_output = self.pipeline_ref.prev_image_result input_tensor = tensor - + # Context-aware normalization handling - if self.normalization_context == 'controlnet': + if self.normalization_context == "controlnet": # ControlNet context: prev_output is [-1, 1] from VAE, input is [0, 1] # Convert prev_output from [-1, 1] to [0, 1] to match input prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1) # Normalize input tensor to [0, 1] if needed if input_tensor.max() > 1.0: input_tensor = input_tensor / 255.0 - elif self.normalization_context == 'pipeline': + elif self.normalization_context == "pipeline": # Pipeline context: both prev_output and input_tensor are in [-1, 1] range # - prev_output comes from VAE decode (always [-1, 1]) # - input_tensor arrives as [-1, 1] from image_processor.preprocess() @@ -157,17 +169,20 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: else: # Unknown context - log warning and assume controlnet behavior for backward compatibility import logging - logging.warning(f"FeedbackPreprocessor: Unknown normalization_context '{self.normalization_context}', using controlnet behavior") + + logging.warning( + f"FeedbackPreprocessor: Unknown normalization_context '{self.normalization_context}', using controlnet behavior" + ) prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1) if input_tensor.max() > 1.0: input_tensor = input_tensor / 255.0 - + # Ensure both tensors have same format for blending if prev_output.dim() == 4 and prev_output.shape[0] == 1: prev_output = prev_output[0] # Remove batch dimension if input_tensor.dim() == 4 and input_tensor.shape[0] == 1: input_tensor = input_tensor[0] # Remove batch dimension - + # Resize if dimensions don't match if prev_output.shape != input_tensor.shape: # Use the input tensor's shape as target @@ -176,18 +191,18 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: if prev_output.dim() == 3: prev_output = prev_output.unsqueeze(0) prev_output = torch.nn.functional.interpolate( - prev_output, size=target_size, mode='bilinear', align_corners=False + prev_output, size=target_size, mode="bilinear", align_corners=False ) if prev_output.shape[0] == 1: prev_output = prev_output.squeeze(0) - + # Blend with configurable strength blended_tensor = (1 - self.feedback_strength) * input_tensor + self.feedback_strength * prev_output - + # Ensure correct output format if blended_tensor.dim() == 3: blended_tensor = blended_tensor.unsqueeze(0) # Add batch dimension back - + # Ensure correct device and dtype blended_tensor = blended_tensor.to(device=self.device, dtype=self.dtype) return blended_tensor diff --git a/src/streamdiffusion/preprocessing/processors/hed.py b/src/streamdiffusion/preprocessing/processors/hed.py index 78c770878..0c5ef9a6b 100644 --- a/src/streamdiffusion/preprocessing/processors/hed.py +++ b/src/streamdiffusion/preprocessing/processors/hed.py @@ -1,11 +1,13 @@ -import torch import numpy as np +import torch from PIL import Image -from typing import Union, Optional + from .base import BasePreprocessor + try: from controlnet_aux import HEDdetector + CONTROLNET_AUX_AVAILABLE = True except ImportError: CONTROLNET_AUX_AVAILABLE = False @@ -15,83 +17,81 @@ class HEDPreprocessor(BasePreprocessor): """ HED (Holistically-Nested Edge Detection) preprocessor - + Uses controlnet_aux HEDdetector for high-quality edge detection. """ - + _model_cache = {} - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "HED Edge Detection", "description": "Holistically-Nested Edge Detection for clean, structured edge maps.", "parameters": { - "safe": { - "type": "bool", - "default": True, - "description": "Whether to use safe mode for edge detection" - } + "safe": {"type": "bool", "default": True, "description": "Whether to use safe mode for edge detection"} }, - "use_cases": ["Structured edge detection", "Clean architectural edges", "Line art generation"] + "use_cases": ["Structured edge detection", "Clean architectural edges", "Line art generation"], } - + def __init__(self, safe: bool = True, **kwargs): if not CONTROLNET_AUX_AVAILABLE: - raise ImportError("controlnet_aux is required for HED preprocessor. Install with: pip install controlnet_aux") - + raise ImportError( + "controlnet_aux is required for HED preprocessor. Install with: pip install controlnet_aux" + ) + super().__init__(**kwargs) self.safe = safe self.model = None self._load_model() - + def _load_model(self): """Load controlnet_aux HED model with caching""" cache_key = f"hed_{self.device}" - + if cache_key in self._model_cache: self.model = self._model_cache[cache_key] return - + print("HEDPreprocessor: Loading controlnet_aux HED model") try: # Initialize HED detector self.model = HEDdetector.from_pretrained("lllyasviel/Annotators") - if hasattr(self.model, 'to'): + if hasattr(self.model, "to"): self.model = self.model.to(self.device) - + # Cache the model self._model_cache[cache_key] = self.model print(f"HEDPreprocessor: Successfully loaded model on {self.device}") - + except Exception as e: raise RuntimeError(f"Failed to load HED model: {e}") - + def _process_core(self, image: Image.Image) -> Image.Image: """Apply HED edge detection to the input image""" # Get target dimensions target_width, target_height = self.get_target_dimensions() - + # Process with controlnet_aux result = self.model(image, output_type="pil") - + # Ensure result is PIL Image if not isinstance(result, Image.Image): if isinstance(result, np.ndarray): result = Image.fromarray(result) else: raise ValueError(f"Unexpected result type: {type(result)}") - + # Resize to target size if needed if result.size != (target_width, target_height): result = result.resize((target_width, target_height), Image.LANCZOS) - + return result - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ GPU-optimized HED processing using tensors - + Note: controlnet_aux doesn't support direct tensor input, so we convert to PIL and back. This is still reasonably fast due to optimized conversions in the base class. """ @@ -99,24 +99,18 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: pil_image = self.tensor_to_pil(image_tensor) processed_pil = self._process_core(pil_image) return self.pil_to_tensor(processed_pil) - - + @classmethod - def create_optimized(cls, device: str = 'cuda', dtype: torch.dtype = torch.float16, **kwargs): + def create_optimized(cls, device: str = "cuda", dtype: torch.dtype = torch.float16, **kwargs): """ Create an optimized HED preprocessor - + Args: device: Target device ('cuda' or 'cpu') dtype: Data type for inference **kwargs: Additional parameters - + Returns: Optimized HEDPreprocessor instance """ - return cls( - device=device, - dtype=dtype, - safe=True, - **kwargs - ) \ No newline at end of file + return cls(device=device, dtype=dtype, safe=True, **kwargs) diff --git a/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py b/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py index 8b7d28a08..398119bfb 100644 --- a/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py +++ b/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py @@ -1,6 +1,8 @@ -from typing import Union, Tuple, Optional, Any +from typing import Any, Tuple, Union + import torch from PIL import Image + from .base import BasePreprocessor @@ -9,53 +11,53 @@ class IPAdapterEmbeddingPreprocessor(BasePreprocessor): Preprocessor that generates IPAdapter embeddings instead of spatial conditioning. Leverages existing preprocessing infrastructure for parallel IPAdapter embedding generation. """ - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "IPAdapter Embedding", "description": "Generates IPAdapter embeddings for style transfer and image conditioning instead of spatial control maps.", "parameters": {}, - "use_cases": ["Style transfer", "Image conditioning", "Semantic control", "Content-aware generation"] + "use_cases": ["Style transfer", "Image conditioning", "Semantic control", "Content-aware generation"], } - + def __init__(self, ipadapter: Any, **kwargs): super().__init__(**kwargs) self.ipadapter = ipadapter # Verify the ipadapter has the required method - if not hasattr(ipadapter, 'get_image_embeds'): + if not hasattr(ipadapter, "get_image_embeds"): raise ValueError("IPAdapterEmbeddingPreprocessor: ipadapter must have 'get_image_embeds' method") - + # Create dedicated CUDA stream for IPAdapter processing to avoid TensorRT conflicts self._ipadapter_stream = torch.cuda.Stream() if torch.cuda.is_available() else None - + def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]: """Returns (positive_embeds, negative_embeds) instead of processed image""" if self._ipadapter_stream is not None: # Use dedicated stream to avoid TensorRT stream capture conflicts with torch.cuda.stream(self._ipadapter_stream): image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image]) - + # Wait for stream completion and move tensors to default stream self._ipadapter_stream.synchronize() - + # Ensure tensors are accessible from default stream - if hasattr(image_embeds, 'record_stream'): + if hasattr(image_embeds, "record_stream"): image_embeds.record_stream(torch.cuda.current_stream()) - if hasattr(negative_embeds, 'record_stream'): + if hasattr(negative_embeds, "record_stream"): negative_embeds.record_stream(torch.cuda.current_stream()) else: # Fallback for non-CUDA environments image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image]) - + return image_embeds, negative_embeds - + def _process_tensor_core(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """GPU-optimized path for tensor inputs""" # Convert tensor to PIL for IPAdapter processing pil_image = self.tensor_to_pil(tensor) return self._process_core(pil_image) - + def process(self, image: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """Override base process to return embeddings tuple instead of PIL Image""" if isinstance(image, torch.Tensor): @@ -63,9 +65,9 @@ def process(self, image: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor else: image = self.validate_input(image) result = self._process_core(image) - + return result - + def process_tensor(self, image_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Override base process_tensor to return embeddings tuple""" tensor = self.validate_tensor_input(image_tensor) diff --git a/src/streamdiffusion/preprocessing/processors/latent_feedback.py b/src/streamdiffusion/preprocessing/processors/latent_feedback.py index e5b8f8b2c..38d70961d 100644 --- a/src/streamdiffusion/preprocessing/processors/latent_feedback.py +++ b/src/streamdiffusion/preprocessing/processors/latent_feedback.py @@ -1,27 +1,29 @@ +from typing import Any + import torch -from typing import Optional, Any + from .base import PipelineAwareProcessor class LatentFeedbackPreprocessor(PipelineAwareProcessor): """ Latent domain feedback preprocessor - + Creates a configurable blend between the current input latent and the previous frame's latent output. This creates a feedback loop in latent space where each generated latent influences the next generation, providing temporal consistency without the overhead of VAE encoding/decoding. - + Formula: output = (1 - feedback_strength) * input_latent + feedback_strength * previous_latent - + Examples: - feedback_strength = 0.0: Pure passthrough (input only) - feedback_strength = 0.15: Default safe blend - feedback_strength = 0.40: Maximum safe feedback (values > 0.4 produce garbage) - + The preprocessor accesses the pipeline's prev_latent_result to get the previous latent output. For the first frame (when no previous output exists), it falls back to the input latent. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -33,20 +35,24 @@ def get_preprocessor_metadata(cls): "default": 0.15, "range": [0.0, 0.40], "step": 0.01, - "description": "Strength of latent feedback blend (0.0 = pure input, .40 = more feedback)" + "description": "Strength of latent feedback blend (0.0 = pure input, .40 = more feedback)", } }, - "use_cases": ["Latent temporal consistency", "Latent space transitions", "Efficient feedback", "Latent preprocessing", "Temporal stability"] + "use_cases": [ + "Latent temporal consistency", + "Latent space transitions", + "Efficient feedback", + "Latent preprocessing", + "Temporal stability", + ], } - - def __init__(self, - pipeline_ref: Any, - normalization_context: str = 'latent', - feedback_strength: float = 0.15, - **kwargs): + + def __init__( + self, pipeline_ref: Any, normalization_context: str = "latent", feedback_strength: float = 0.15, **kwargs + ): """ Initialize latent feedback preprocessor - + Args: pipeline_ref: Reference to the StreamDiffusion pipeline instance (required) normalization_context: Context for normalization handling (latent space doesn't need normalization) @@ -57,29 +63,31 @@ def __init__(self, pipeline_ref=pipeline_ref, normalization_context=normalization_context, feedback_strength=feedback_strength, - **kwargs + **kwargs, ) - self.feedback_strength = max(0.0, min(0.40, feedback_strength)) # Clamp to [0, 0.40] - values > 0.4 produce garbage + self.feedback_strength = max( + 0.0, min(0.40, feedback_strength) + ) # Clamp to [0, 0.40] - values > 0.4 produce garbage self._first_frame = True - + def _get_previous_data(self): """Get previous frame latent data from pipeline""" if self.pipeline_ref is not None: # Get previous OUTPUT latent (after diffusion), not input latent # Check for prev_latent_result (the actual attribute name used by the pipeline) - if hasattr(self.pipeline_ref, 'prev_latent_result'): + if hasattr(self.pipeline_ref, "prev_latent_result"): if self.pipeline_ref.prev_latent_result is not None and not self._first_frame: return self.pipeline_ref.prev_latent_result return None - - #TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class + + # TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class def validate_tensor_input(self, latent_tensor: torch.Tensor) -> torch.Tensor: """ Validate latent tensor input - preserve batch dimensions for latent processing - + Args: latent_tensor: Input latent tensor in format [B, C, H/8, W/8] - + Returns: Validated latent tensor with preserved batch dimension """ @@ -87,18 +95,18 @@ def validate_tensor_input(self, latent_tensor: torch.Tensor) -> torch.Tensor: # Only ensure correct device and dtype latent_tensor = latent_tensor.to(device=self.device, dtype=self.dtype) return latent_tensor - - #TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class + + # TODO: eventually, these processors should be divided by input and output domain rather than overriding image-first basec class def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor: """ Override base class resize logic - latent tensors should NOT be resized to image dimensions - + For latent domain processing, we want to preserve the latent space dimensions, not resize to image target dimensions like image-domain processors. """ # For latent feedback, just return the tensor as-is without any resizing return tensor - + def _process_core(self, image): """ For latent feedback, we don't process PIL images directly. @@ -108,23 +116,23 @@ def _process_core(self, image): "LatentFeedbackPreprocessor is designed for latent domain processing. " "Use _process_tensor_core or process_tensor for latent tensors." ) - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ Process latent tensor with feedback blending - + Args: tensor: Current input latent tensor in format [B, C, H/8, W/8] - + Returns: Blended latent tensor (blend strength controlled by feedback_strength), or input tensor for first frame """ # Get previous frame latent data using mixin method prev_latent = self._get_previous_data() - + if prev_latent is not None: input_latent = tensor - + # Ensure both tensors have the same batch size for element-wise blending # If batch sizes differ, expand the smaller one to match if prev_latent.shape[0] != input_latent.shape[0]: @@ -139,22 +147,22 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: min_batch = min(prev_latent.shape[0], input_latent.shape[0]) prev_latent = prev_latent[:min_batch] input_latent = input_latent[:min_batch] - + # Resize spatial dimensions if they don't match (though this should be rare in latent space) if prev_latent.shape[2:] != input_latent.shape[2:]: target_size = input_latent.shape[2:] # Get H, W from input prev_latent = torch.nn.functional.interpolate( - prev_latent, size=target_size, mode='bilinear', align_corners=False + prev_latent, size=target_size, mode="bilinear", align_corners=False ) - + # Blend current latent with previous latent for temporal consistency # Higher feedback_strength = more influence from previous frame blended_latent = (1 - self.feedback_strength) * input_latent + self.feedback_strength * prev_latent - + # Add safety measures for latent values to prevent extreme outputs # Clamp to reasonable range based on typical latent distributions blended_latent = torch.clamp(blended_latent, min=-10.0, max=10.0) - + # Ensure correct device and dtype blended_latent = blended_latent.to(device=self.device, dtype=self.dtype) return blended_latent diff --git a/src/streamdiffusion/preprocessing/processors/lineart.py b/src/streamdiffusion/preprocessing/processors/lineart.py index 4f0bafa81..030043c4d 100644 --- a/src/streamdiffusion/preprocessing/processors/lineart.py +++ b/src/streamdiffusion/preprocessing/processors/lineart.py @@ -1,29 +1,34 @@ import logging -import numpy as np -from PIL import Image -from typing import Union, Optional import time + +from PIL import Image + from .base import BasePreprocessor + logger = logging.getLogger(__name__) try: - from controlnet_aux import LineartDetector, LineartAnimeDetector + from controlnet_aux import LineartAnimeDetector, LineartDetector + CONTROLNET_AUX_AVAILABLE = True except ImportError: CONTROLNET_AUX_AVAILABLE = False - raise ImportError("LineartPreprocessor: controlnet_aux is required for real-time optimization. Install with: pip install controlnet_aux") + raise ImportError( + "LineartPreprocessor: controlnet_aux is required for real-time optimization. Install with: pip install controlnet_aux" + ) + -#TODO provide gpu native lineart detection +# TODO provide gpu native lineart detection class LineartPreprocessor(BasePreprocessor): """ Real-time optimized Lineart detection preprocessor for ControlNet - + Extracts line art from input images using controlnet_aux line art detection models. Supports both realistic and anime-style line art extraction. Optimized for real-time performance - no fallbacks. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -33,26 +38,28 @@ def get_preprocessor_metadata(cls): "coarse": { "type": "bool", "default": True, - "description": "Whether to use coarse line art detection (faster but less detailed)" + "description": "Whether to use coarse line art detection (faster but less detailed)", }, "anime_style": { "type": "bool", "default": False, - "description": "Whether to use anime-style line art detection" - } + "description": "Whether to use anime-style line art detection", + }, }, - "use_cases": ["Sketch to image", "Line art generation", "Clean line extraction"] + "use_cases": ["Sketch to image", "Line art generation", "Clean line extraction"], } - - def __init__(self, - detect_resolution: int = 512, - image_resolution: int = 512, - coarse: bool = True, - anime_style: bool = False, - **kwargs): + + def __init__( + self, + detect_resolution: int = 512, + image_resolution: int = 512, + coarse: bool = True, + anime_style: bool = False, + **kwargs, + ): """ Initialize Lineart preprocessor - + Args: detect_resolution: Resolution for line art detection image_resolution: Output image resolution @@ -65,34 +72,34 @@ def __init__(self, image_resolution=image_resolution, coarse=coarse, anime_style=anime_style, - **kwargs + **kwargs, ) self._detector = None - + @property def detector(self): """Lazy loading of the line art detector - controlnet_aux only""" if self._detector is None: start_time = time.time() - anime_style = self.params.get('anime_style', False) - + anime_style = self.params.get("anime_style", False) + if anime_style: - self._detector = LineartAnimeDetector.from_pretrained('lllyasviel/Annotators') + self._detector = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") else: - self._detector = LineartDetector.from_pretrained('lllyasviel/Annotators') + self._detector = LineartDetector.from_pretrained("lllyasviel/Annotators") load_time = time.time() - start_time logger.info(f"Lineart detector loaded in {load_time:.3f}s") - + return self._detector - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply line art detection to the input image """ - detect_resolution = self.params.get('detect_resolution', 512) - coarse = self.params.get('coarse', False) + detect_resolution = self.params.get("detect_resolution", 512) + coarse = self.params.get("coarse", False) if image.size != (detect_resolution, detect_resolution): image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS) @@ -100,10 +107,7 @@ def _process_core(self, image: Image.Image) -> Image.Image: image_resized = image lineart_image = self.detector( - image_resized, - detect_resolution=detect_resolution, - image_resolution=detect_resolution, - coarse=coarse + image_resized, detect_resolution=detect_resolution, image_resolution=detect_resolution, coarse=coarse ) - return lineart_image \ No newline at end of file + return lineart_image diff --git a/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py b/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py index 7a09a8d31..3d7c72093 100644 --- a/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py +++ b/src/streamdiffusion/preprocessing/processors/mediapipe_pose.py @@ -1,12 +1,16 @@ +from typing import List + +import cv2 import numpy as np import torch -import cv2 -from PIL import Image, ImageDraw -from typing import Union, Optional, List, Tuple, Dict +from PIL import Image + from .base import BasePreprocessor + try: import mediapipe as mp + MEDIAPIPE_AVAILABLE = True except ImportError: MEDIAPIPE_AVAILABLE = False @@ -21,49 +25,87 @@ # 10: RKnee, 11: RAnkle, 12: LHip, 13: LKnee, 14: LAnkle, # 15: REye, 16: LEye, 17: REar, 18: LEar, 19: LBigToe, # 20: LSmallToe, 21: LHeel, 22: RBigToe, 23: RSmallToe, 24: RHeel - - 0: 0, # Nose -> Nose - 1: None, # Neck (calculated from shoulders) + 0: 0, # Nose -> Nose + 1: None, # Neck (calculated from shoulders) 2: 12, # RShoulder -> RightShoulder - 3: 14, # RElbow -> RightElbow + 3: 14, # RElbow -> RightElbow 4: 16, # RWrist -> RightWrist 5: 11, # LShoulder -> LeftShoulder 6: 13, # LElbow -> LeftElbow 7: 15, # LWrist -> LeftWrist - 8: None, # MidHip (calculated from hips) + 8: None, # MidHip (calculated from hips) 9: 24, # RHip -> RightHip - 10: 26, # RKnee -> RightKnee - 11: 28, # RAnkle -> RightAnkle - 12: 23, # LHip -> LeftHip - 13: 25, # LKnee -> LeftKnee - 14: 27, # LAnkle -> LeftAnkle + 10: 26, # RKnee -> RightKnee + 11: 28, # RAnkle -> RightAnkle + 12: 23, # LHip -> LeftHip + 13: 25, # LKnee -> LeftKnee + 14: 27, # LAnkle -> LeftAnkle 15: 5, # REye -> RightEye 16: 2, # LEye -> LeftEye 17: 8, # REar -> RightEar 18: 7, # LEar -> LeftEar - 19: 31, # LBigToe -> LeftFootIndex - 20: 31, # LSmallToe -> LeftFootIndex (approximation) - 21: 29, # LHeel -> LeftHeel - 22: 32, # RBigToe -> RightFootIndex - 23: 32, # RSmallToe -> RightFootIndex (approximation) - 24: 30 # RHeel -> RightHeel + 19: 31, # LBigToe -> LeftFootIndex + 20: 31, # LSmallToe -> LeftFootIndex (approximation) + 21: 29, # LHeel -> LeftHeel + 22: 32, # RBigToe -> RightFootIndex + 23: 32, # RSmallToe -> RightFootIndex (approximation) + 24: 30, # RHeel -> RightHeel } # OpenPose connections for proper skeleton rendering OPENPOSE_LIMB_SEQUENCE = [ - [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], - [1, 8], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13], - [13, 14], [1, 0], [0, 15], [15, 17], [0, 16], [16, 18], - [14, 19], [19, 20], [14, 21], [11, 22], [22, 23], [11, 24] + [1, 2], + [1, 5], + [2, 3], + [3, 4], + [5, 6], + [6, 7], + [1, 8], + [8, 9], + [9, 10], + [10, 11], + [8, 12], + [12, 13], + [13, 14], + [1, 0], + [0, 15], + [15, 17], + [0, 16], + [16, 18], + [14, 19], + [19, 20], + [14, 21], + [11, 22], + [22, 23], + [11, 24], ] # Standard OpenPose colors (BGR format) - matching actual OpenPose output OPENPOSE_COLORS = [ - [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], - [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], - [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], - [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0], [255, 85, 0], - [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0] + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], + [0, 255, 0], + [0, 255, 85], + [0, 255, 170], + [0, 255, 255], + [0, 170, 255], + [0, 85, 255], + [0, 0, 255], + [85, 0, 255], + [170, 0, 255], + [255, 0, 255], + [255, 0, 170], + [255, 0, 85], + [255, 0, 0], + [255, 85, 0], + [255, 170, 0], + [255, 255, 0], + [170, 255, 0], + [85, 255, 0], ] # OPTIMIZATION: Vectorized mapping for MediaPipe to OpenPose conversion @@ -76,19 +118,20 @@ OPENPOSE_COLORS_ARRAY = np.array(OPENPOSE_COLORS, dtype=np.uint8) LIMB_SEQUENCE_ARRAY = np.array(OPENPOSE_LIMB_SEQUENCE, dtype=np.int32) + class MediaPipePosePreprocessor(BasePreprocessor): """ MediaPipe-based pose preprocessor for ControlNet that outputs OpenPose-style annotations - + Converts MediaPipe's 33 keypoints to OpenPose's 25 keypoints format and renders them in the standard OpenPose style for ControlNet compatibility. - + Improvements inspired by TouchDesigner MediaPipe plugin: - Better confidence filtering - Temporal smoothing for jitter reduction - Improved multi-pose support preparation """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -100,89 +143,88 @@ def get_preprocessor_metadata(cls): "default": 0.5, "range": [0.0, 1.0], "step": 0.01, - "description": "Minimum confidence for pose detection" + "description": "Minimum confidence for pose detection", }, "min_tracking_confidence": { "type": "float", "default": 0.5, "range": [0.0, 1.0], "step": 0.01, - "description": "Minimum confidence for pose tracking" + "description": "Minimum confidence for pose tracking", }, "model_complexity": { "type": "int", "default": 1, "range": [0, 2], - "description": "MediaPipe model complexity (0=fastest, 2=most accurate)" + "description": "MediaPipe model complexity (0=fastest, 2=most accurate)", }, "static_image_mode": { "type": "bool", "default": False, - "description": "Use static image mode (slower but more accurate per frame)" - }, - "draw_hands": { - "type": "bool", - "default": True, - "description": "Whether to draw hand poses" - }, - "draw_face": { - "type": "bool", - "default": False, - "description": "Whether to draw face landmarks" + "description": "Use static image mode (slower but more accurate per frame)", }, + "draw_hands": {"type": "bool", "default": True, "description": "Whether to draw hand poses"}, + "draw_face": {"type": "bool", "default": False, "description": "Whether to draw face landmarks"}, "line_thickness": { "type": "int", "default": 2, "range": [1, 10], - "description": "Thickness of skeleton lines" + "description": "Thickness of skeleton lines", }, "circle_radius": { "type": "int", "default": 4, "range": [1, 10], - "description": "Radius of joint circles" + "description": "Radius of joint circles", }, "confidence_threshold": { "type": "float", "default": 0.3, "range": [0.0, 1.0], "step": 0.01, - "description": "Minimum confidence for rendering keypoints" + "description": "Minimum confidence for rendering keypoints", }, "enable_smoothing": { "type": "bool", "default": True, - "description": "Enable temporal smoothing to reduce jitter" + "description": "Enable temporal smoothing to reduce jitter", }, "smoothing_factor": { "type": "float", "default": 0.7, "range": [0.0, 1.0], "step": 0.01, - "description": "Smoothing strength (higher = more smoothing)" - } + "description": "Smoothing strength (higher = more smoothing)", + }, }, - "use_cases": ["Detailed pose control", "Hand and face detection", "Real-time pose tracking", "Custom confidence tuning"] + "use_cases": [ + "Detailed pose control", + "Hand and face detection", + "Real-time pose tracking", + "Custom confidence tuning", + ], } - - def __init__(self, - detect_resolution: int = 256, # OPTIMIZATION: Reduced from 512 for 4x speedup - image_resolution: int = 512, - min_detection_confidence: float = 0.5, - min_tracking_confidence: float = 0.5, - model_complexity: int = 1, - static_image_mode: bool = False, # OPTIMIZATION: Video mode for tracking (3-5x faster) - draw_hands: bool = True, - draw_face: bool = False, # Simplified - disable face by default - line_thickness: int = 2, - circle_radius: int = 4, - confidence_threshold: float = 0.3, # TouchDesigner-style confidence filtering - enable_smoothing: bool = True, # TouchDesigner-inspired smoothing - smoothing_factor: float = 0.7, # Smoothing strength - **kwargs): + + def __init__( + self, + detect_resolution: int = 256, # OPTIMIZATION: Reduced from 512 for 4x speedup + image_resolution: int = 512, + min_detection_confidence: float = 0.5, + min_tracking_confidence: float = 0.5, + model_complexity: int = 1, + static_image_mode: bool = False, # OPTIMIZATION: Video mode for tracking (3-5x faster) + draw_hands: bool = True, + draw_face: bool = False, # Simplified - disable face by default + line_thickness: int = 2, + circle_radius: int = 4, + confidence_threshold: float = 0.3, # TouchDesigner-style confidence filtering + enable_smoothing: bool = True, # TouchDesigner-inspired smoothing + smoothing_factor: float = 0.7, # Smoothing strength + **kwargs, + ): """ Initialize MediaPipe pose preprocessor with TouchDesigner-inspired improvements - + Args: detect_resolution: Resolution for pose detection image_resolution: Output image resolution @@ -201,10 +243,9 @@ def __init__(self, """ if not MEDIAPIPE_AVAILABLE: raise ImportError( - "MediaPipe is required for MediaPipe pose preprocessing. " - "Install it with: pip install mediapipe" + "MediaPipe is required for MediaPipe pose preprocessing. Install it with: pip install mediapipe" ) - + super().__init__( detect_resolution=detect_resolution, image_resolution=image_resolution, @@ -219,312 +260,335 @@ def __init__(self, confidence_threshold=confidence_threshold, enable_smoothing=enable_smoothing, smoothing_factor=smoothing_factor, - **kwargs + **kwargs, ) - + self._detector = None self._current_options = None # TouchDesigner-style smoothing buffers self._smoothing_buffers = {} - + @property def detector(self): """Lazy loading of the MediaPipe Holistic detector with GPU optimization""" new_options = { - 'min_detection_confidence': self.params.get('min_detection_confidence', 0.5), - 'min_tracking_confidence': self.params.get('min_tracking_confidence', 0.5), - 'model_complexity': self.params.get('model_complexity', 1), - 'static_image_mode': self.params.get('static_image_mode', False), # Video mode default + "min_detection_confidence": self.params.get("min_detection_confidence", 0.5), + "min_tracking_confidence": self.params.get("min_tracking_confidence", 0.5), + "model_complexity": self.params.get("model_complexity", 1), + "static_image_mode": self.params.get("static_image_mode", False), # Video mode default } - + # Initialize or update detector if needed if self._detector is None or self._current_options != new_options: if self._detector is not None: self._detector.close() - + # OPTIMIZATION: Try GPU delegate first, fallback to CPU try: print("MediaPipePosePreprocessor.detector: Attempting GPU delegate initialization") - + # Try to create base options with GPU delegate try: - base_options = mp.tasks.BaseOptions( - delegate=mp.tasks.BaseOptions.Delegate.GPU - ) + base_options = mp.tasks.BaseOptions(delegate=mp.tasks.BaseOptions.Delegate.GPU) print("MediaPipePosePreprocessor.detector: GPU delegate available") except Exception as gpu_error: print(f"MediaPipePosePreprocessor.detector: GPU delegate failed ({gpu_error}), using CPU") - base_options = mp.tasks.BaseOptions( - delegate=mp.tasks.BaseOptions.Delegate.CPU - ) - + base_options = mp.tasks.BaseOptions(delegate=mp.tasks.BaseOptions.Delegate.CPU) + # Create detector with optimized settings - print(f"MediaPipePosePreprocessor.detector: Initializing MediaPipe Holistic (video_mode={not new_options['static_image_mode']})") + print( + f"MediaPipePosePreprocessor.detector: Initializing MediaPipe Holistic (video_mode={not new_options['static_image_mode']})" + ) self._detector = mp.solutions.holistic.Holistic( - static_image_mode=new_options['static_image_mode'], - model_complexity=new_options['model_complexity'], + static_image_mode=new_options["static_image_mode"], + model_complexity=new_options["model_complexity"], enable_segmentation=False, refine_face_landmarks=False, # Keep simple for speed - min_detection_confidence=new_options['min_detection_confidence'], - min_tracking_confidence=new_options['min_tracking_confidence'], + min_detection_confidence=new_options["min_detection_confidence"], + min_tracking_confidence=new_options["min_tracking_confidence"], ) - + except Exception as e: print(f"MediaPipePosePreprocessor.detector: Advanced options failed ({e}), using basic setup") # Fallback to basic setup self._detector = mp.solutions.holistic.Holistic( - static_image_mode=new_options['static_image_mode'], - model_complexity=new_options['model_complexity'], + static_image_mode=new_options["static_image_mode"], + model_complexity=new_options["model_complexity"], enable_segmentation=False, refine_face_landmarks=False, - min_detection_confidence=new_options['min_detection_confidence'], - min_tracking_confidence=new_options['min_tracking_confidence'], + min_detection_confidence=new_options["min_detection_confidence"], + min_tracking_confidence=new_options["min_tracking_confidence"], ) - + self._current_options = new_options - + return self._detector - + def _apply_smoothing(self, keypoints: List[List[float]], pose_id: str = "default") -> List[List[float]]: """ Apply TouchDesigner-inspired temporal smoothing - VECTORIZED - + Args: keypoints: Current frame keypoints pose_id: Unique identifier for this pose - + Returns: Smoothed keypoints """ - if not self.params.get('enable_smoothing', True) or not keypoints: + if not self.params.get("enable_smoothing", True) or not keypoints: return keypoints - - smoothing_factor = self.params.get('smoothing_factor', 0.7) - + + smoothing_factor = self.params.get("smoothing_factor", 0.7) + # Initialize buffer for this pose if needed if pose_id not in self._smoothing_buffers: self._smoothing_buffers[pose_id] = keypoints.copy() return keypoints - + # OPTIMIZATION: Vectorized exponential smoothing current_array = np.array(keypoints, dtype=np.float32) previous_array = np.array(self._smoothing_buffers[pose_id], dtype=np.float32) - + # Create confidence mask for selective smoothing confidence_mask = current_array[:, 2] > 0.1 - + # Vectorized smoothing calculation smoothed_array = previous_array.copy() # Apply smoothing only where confidence is good - smoothed_array[confidence_mask, :2] = ( - previous_array[confidence_mask, :2] * smoothing_factor + - current_array[confidence_mask, :2] * (1 - smoothing_factor) - ) + smoothed_array[confidence_mask, :2] = previous_array[confidence_mask, :2] * smoothing_factor + current_array[ + confidence_mask, :2 + ] * (1 - smoothing_factor) # Always use current confidence values smoothed_array[:, 2] = current_array[:, 2] - + # Update buffer and return smoothed_list = smoothed_array.tolist() self._smoothing_buffers[pose_id] = smoothed_list return smoothed_list - - def _mediapipe_to_openpose(self, mediapipe_landmarks: List, image_width: int, image_height: int) -> List[List[float]]: + + def _mediapipe_to_openpose( + self, mediapipe_landmarks: List, image_width: int, image_height: int + ) -> List[List[float]]: """ Convert MediaPipe landmarks to OpenPose format - VECTORIZED - + Args: mediapipe_landmarks: MediaPipe pose landmarks image_width: Image width image_height: Image height - + Returns: OpenPose keypoints in [x, y, confidence] format """ if not mediapipe_landmarks: return [] - + # OPTIMIZATION: Vectorized landmark conversion # Extract all coordinates and confidences in one go - landmarks_data = np.array([ - [lm.x * image_width, lm.y * image_height, - lm.visibility if hasattr(lm, 'visibility') else 1.0] - for lm in mediapipe_landmarks - ], dtype=np.float32) - + landmarks_data = np.array( + [ + [lm.x * image_width, lm.y * image_height, lm.visibility if hasattr(lm, "visibility") else 1.0] + for lm in mediapipe_landmarks + ], + dtype=np.float32, + ) + # Initialize OpenPose keypoints array (25 points x 3 values) openpose_keypoints = np.zeros((25, 3), dtype=np.float32) - + # OPTIMIZATION: Vectorized mapping using advanced indexing # Only map valid indices that exist in landmarks_data valid_mask = MEDIAPIPE_INDICES < len(landmarks_data) valid_mp_indices = MEDIAPIPE_INDICES[valid_mask] valid_op_indices = OPENPOSE_INDICES[valid_mask] - + # Vectorized assignment openpose_keypoints[valid_op_indices] = landmarks_data[valid_mp_indices] - + # OPTIMIZATION: Vectorized derived point calculations - confidence_threshold = self.params.get('confidence_threshold', 0.3) - + confidence_threshold = self.params.get("confidence_threshold", 0.3) + # Neck (1): midpoint between shoulders (indices 11, 12) - if (len(landmarks_data) > 12 and - landmarks_data[11, 2] > confidence_threshold and - landmarks_data[12, 2] > confidence_threshold): + if ( + len(landmarks_data) > 12 + and landmarks_data[11, 2] > confidence_threshold + and landmarks_data[12, 2] > confidence_threshold + ): # Vectorized midpoint calculation neck_point = np.mean(landmarks_data[[11, 12]], axis=0) neck_point[2] = np.min(landmarks_data[[11, 12], 2]) # Min confidence openpose_keypoints[1] = neck_point - + # MidHip (8): midpoint between hips (indices 23, 24) - if (len(landmarks_data) > 24 and - landmarks_data[23, 2] > confidence_threshold and - landmarks_data[24, 2] > confidence_threshold): + if ( + len(landmarks_data) > 24 + and landmarks_data[23, 2] > confidence_threshold + and landmarks_data[24, 2] > confidence_threshold + ): # Vectorized midpoint calculation midhip_point = np.mean(landmarks_data[[23, 24]], axis=0) midhip_point[2] = np.min(landmarks_data[[23, 24], 2]) # Min confidence openpose_keypoints[8] = midhip_point - + # Convert back to list format for compatibility return openpose_keypoints.tolist() - + def _draw_openpose_skeleton(self, image: np.ndarray, keypoints: List[List[float]]) -> np.ndarray: """ Draw OpenPose-style skeleton on image - + Args: image: Input image keypoints: OpenPose keypoints - + Returns: Image with skeleton drawn """ if not keypoints or len(keypoints) != 25: return image - + h, w = image.shape[:2] - line_thickness = self.params.get('line_thickness', 2) - circle_radius = self.params.get('circle_radius', 4) - confidence_threshold = self.params.get('confidence_threshold', 0.3) - + line_thickness = self.params.get("line_thickness", 2) + circle_radius = self.params.get("circle_radius", 4) + confidence_threshold = self.params.get("confidence_threshold", 0.3) + # OPTIMIZATION: Vectorized limb drawing with confidence filtering keypoints_array = np.array(keypoints, dtype=np.float32) - + # Draw limbs for i, (start_idx, end_idx) in enumerate(LIMB_SEQUENCE_ARRAY): - if (start_idx < len(keypoints_array) and end_idx < len(keypoints_array) and - keypoints_array[start_idx, 2] > confidence_threshold and keypoints_array[end_idx, 2] > confidence_threshold): - + if ( + start_idx < len(keypoints_array) + and end_idx < len(keypoints_array) + and keypoints_array[start_idx, 2] > confidence_threshold + and keypoints_array[end_idx, 2] > confidence_threshold + ): start_point = (int(keypoints_array[start_idx, 0]), int(keypoints_array[start_idx, 1])) end_point = (int(keypoints_array[end_idx, 0]), int(keypoints_array[end_idx, 1])) - + # Use vectorized color array color = OPENPOSE_COLORS_ARRAY[i % len(OPENPOSE_COLORS_ARRAY)].tolist() - + cv2.line(image, start_point, end_point, color, line_thickness) - + # OPTIMIZATION: Vectorized keypoint drawing with confidence filtering confidence_mask = keypoints_array[:, 2] > confidence_threshold valid_indices = np.where(confidence_mask)[0] - + for i in valid_indices: center = (int(keypoints_array[i, 0]), int(keypoints_array[i, 1])) color = OPENPOSE_COLORS_ARRAY[i % len(OPENPOSE_COLORS_ARRAY)].tolist() cv2.circle(image, center, circle_radius, color, -1) - + return image - + def _draw_hand_keypoints(self, image: np.ndarray, hand_landmarks: List, is_left_hand: bool = True) -> np.ndarray: """ Draw hand keypoints in OpenPose style - FIXED coordinate mapping - + Args: image: Input image hand_landmarks: MediaPipe hand landmarks is_left_hand: Whether this is the left hand - + Returns: Image with hand keypoints drawn """ if not hand_landmarks: return image - + h, w = image.shape[:2] - confidence_threshold = self.params.get('confidence_threshold', 0.3) - + confidence_threshold = self.params.get("confidence_threshold", 0.3) + # Standard hand connections (21 landmarks per hand) hand_connections = [ # Thumb - (0, 1), (1, 2), (2, 3), (3, 4), - # Index finger - (0, 5), (5, 6), (6, 7), (7, 8), + (0, 1), + (1, 2), + (2, 3), + (3, 4), + # Index finger + (0, 5), + (5, 6), + (6, 7), + (7, 8), # Middle finger - (0, 9), (9, 10), (10, 11), (11, 12), + (0, 9), + (9, 10), + (10, 11), + (11, 12), # Ring finger - (0, 13), (13, 14), (14, 15), (15, 16), + (0, 13), + (13, 14), + (14, 15), + (15, 16), # Pinky - (0, 17), (17, 18), (18, 19), (19, 20), + (0, 17), + (17, 18), + (18, 19), + (19, 20), # Palm connections - (5, 9), (9, 13), (13, 17), + (5, 9), + (9, 13), + (13, 17), ] - + # OPTIMIZATION: Vectorized hand coordinate conversion landmarks_array = np.array([[lm.x * w, lm.y * h] for lm in hand_landmarks], dtype=np.int32) hand_points = [(int(pt[0]), int(pt[1])) for pt in landmarks_array] - + # Standard hand colors hand_color = [255, 128, 0] if is_left_hand else [0, 255, 255] # Orange for left, cyan for right - + # Draw connections for start_idx, end_idx in hand_connections: if start_idx < len(hand_points) and end_idx < len(hand_points): cv2.line(image, hand_points[start_idx], hand_points[end_idx], hand_color, 2) - + # Draw keypoints for point in hand_points: cv2.circle(image, point, 3, hand_color, -1) - + return image - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply MediaPipe pose detection and create OpenPose-style annotation """ - detect_resolution = self.params.get('detect_resolution', 512) + detect_resolution = self.params.get("detect_resolution", 512) image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS) - + rgb_image = cv2.cvtColor(np.array(image_resized), cv2.COLOR_BGR2RGB) - + results = self.detector.process(rgb_image) - + pose_image = np.zeros((detect_resolution, detect_resolution, 3), dtype=np.uint8) - + if results.pose_landmarks: openpose_keypoints = self._mediapipe_to_openpose( - results.pose_landmarks.landmark, - detect_resolution, - detect_resolution + results.pose_landmarks.landmark, detect_resolution, detect_resolution ) - + openpose_keypoints = self._apply_smoothing(openpose_keypoints, "main_pose") - + pose_image = self._draw_openpose_skeleton(pose_image, openpose_keypoints) - - draw_hands = self.params.get('draw_hands', True) + + draw_hands = self.params.get("draw_hands", True) if draw_hands: if results.left_hand_landmarks: pose_image = self._draw_hand_keypoints( pose_image, results.left_hand_landmarks.landmark, is_left_hand=True ) - + if results.right_hand_landmarks: pose_image = self._draw_hand_keypoints( pose_image, results.right_hand_landmarks.landmark, is_left_hand=False ) - + pose_pil = Image.fromarray(cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB)) - + return pose_pil - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU to avoid unnecessary CPU transfers @@ -532,23 +596,23 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: pil_image = self.tensor_to_pil(image_tensor) processed_pil = self._process_core(pil_image) return self.pil_to_tensor(processed_pil) - + def reset_smoothing_buffers(self): """Reset smoothing buffers (useful for new sequences)""" print("MediaPipePosePreprocessor.reset_smoothing_buffers: Clearing smoothing buffers") self._smoothing_buffers.clear() - + def reset_tracking(self): """Reset MediaPipe tracking for new video sequences (when using video mode)""" print("MediaPipePosePreprocessor.reset_tracking: Resetting MediaPipe tracking state") - if hasattr(self, '_detector') and self._detector is not None: + if hasattr(self, "_detector") and self._detector is not None: # Force detector recreation to reset tracking state self._detector.close() self._detector = None self._current_options = None self.reset_smoothing_buffers() - + def __del__(self): """Cleanup MediaPipe detector""" - if hasattr(self, '_detector') and self._detector is not None: - self._detector.close() \ No newline at end of file + if hasattr(self, "_detector") and self._detector is not None: + self._detector.close() diff --git a/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py b/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py index 004b250c5..0f4893f23 100644 --- a/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py +++ b/src/streamdiffusion/preprocessing/processors/mediapipe_segmentation.py @@ -1,12 +1,16 @@ +from typing import Tuple + +import cv2 import numpy as np import torch -import cv2 from PIL import Image -from typing import Union, Optional, List, Tuple + from .base import BasePreprocessor + try: import mediapipe as mp + MEDIAPIPE_AVAILABLE = True except ImportError: MEDIAPIPE_AVAILABLE = False @@ -15,11 +19,11 @@ class MediaPipeSegmentationPreprocessor(BasePreprocessor): """ MediaPipe-based segmentation preprocessor for ControlNet - + Uses MediaPipe's Selfie Segmentation model to create accurate person segmentation masks. Outputs binary masks suitable for ControlNet conditioning. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -30,43 +34,50 @@ def get_preprocessor_metadata(cls): "type": "int", "default": 1, "range": [0, 1], - "description": "Model type (0=general/faster, 1=landscape/better quality)" + "description": "Model type (0=general/faster, 1=landscape/better quality)", }, "threshold": { "type": "float", "default": 0.5, "range": [0.0, 1.0], "step": 0.01, - "description": "Confidence threshold for segmentation" + "description": "Confidence threshold for segmentation", }, "blur_radius": { "type": "int", "default": 0, "range": [0, 20], - "description": "Blur radius for mask smoothing (0=no blur)" + "description": "Blur radius for mask smoothing (0=no blur)", }, "invert_mask": { "type": "bool", "default": False, - "description": "Whether to invert the segmentation mask" - } + "description": "Whether to invert the segmentation mask", + }, }, - "use_cases": ["Precise object control", "Background replacement", "Person segmentation", "Mask generation"] + "use_cases": [ + "Precise object control", + "Background replacement", + "Person segmentation", + "Mask generation", + ], } - - def __init__(self, - detect_resolution: int = 512, - image_resolution: int = 512, - model_selection: int = 1, # 0 for general model, 1 for landscape model - threshold: float = 0.5, - blur_radius: int = 0, - invert_mask: bool = False, - output_mode: str = "binary", # "binary", "alpha", "background" - background_color: Tuple[int, int, int] = (0, 0, 0), - **kwargs): + + def __init__( + self, + detect_resolution: int = 512, + image_resolution: int = 512, + model_selection: int = 1, # 0 for general model, 1 for landscape model + threshold: float = 0.5, + blur_radius: int = 0, + invert_mask: bool = False, + output_mode: str = "binary", # "binary", "alpha", "background" + background_color: Tuple[int, int, int] = (0, 0, 0), + **kwargs, + ): """ Initialize MediaPipe segmentation preprocessor - + Args: detect_resolution: Resolution for segmentation processing image_resolution: Output image resolution @@ -83,7 +94,7 @@ def __init__(self, "MediaPipe is required for MediaPipe segmentation preprocessing. " "Install it with: pip install mediapipe" ) - + super().__init__( detect_resolution=detect_resolution, image_resolution=image_resolution, @@ -93,145 +104,145 @@ def __init__(self, invert_mask=invert_mask, output_mode=output_mode, background_color=background_color, - **kwargs + **kwargs, ) - + self._segmentor = None self._current_options = None - + @property def segmentor(self): """Lazy loading of the MediaPipe Selfie Segmentation model""" new_options = { - 'model_selection': self.params.get('model_selection', 1), + "model_selection": self.params.get("model_selection", 1), } - + # Initialize or update segmentor if needed if self._segmentor is None or self._current_options != new_options: if self._segmentor is not None: self._segmentor.close() - - print(f"MediaPipeSegmentationPreprocessor.segmentor: Initializing MediaPipe Selfie Segmentation model") + + print("MediaPipeSegmentationPreprocessor.segmentor: Initializing MediaPipe Selfie Segmentation model") self._segmentor = mp.solutions.selfie_segmentation.SelfieSegmentation( - model_selection=new_options['model_selection'] + model_selection=new_options["model_selection"] ) self._current_options = new_options - + return self._segmentor - + def _apply_mask_smoothing(self, mask: np.ndarray) -> np.ndarray: """ Apply smoothing to the segmentation mask - + Args: mask: Input segmentation mask - + Returns: Smoothed mask """ - blur_radius = self.params.get('blur_radius', 0) - + blur_radius = self.params.get("blur_radius", 0) + if blur_radius > 0: # Apply Gaussian blur for smoother edges kernel_size = blur_radius * 2 + 1 mask = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0) - + return mask - + def _threshold_mask(self, mask: np.ndarray) -> np.ndarray: """ Apply threshold to segmentation mask - + Args: mask: Input segmentation mask (0.0-1.0) - + Returns: Binary mask """ - threshold = self.params.get('threshold', 0.5) - invert_mask = self.params.get('invert_mask', False) - + threshold = self.params.get("threshold", 0.5) + invert_mask = self.params.get("invert_mask", False) + # Apply threshold binary_mask = (mask > threshold).astype(np.uint8) - + # Invert if requested if invert_mask: binary_mask = 1 - binary_mask - + return binary_mask - + def _create_output_image(self, original_image: np.ndarray, mask: np.ndarray) -> np.ndarray: """ Create final output image based on output mode - + Args: original_image: Original input image mask: Segmentation mask - + Returns: Output image """ - output_mode = self.params.get('output_mode', 'binary') - - if output_mode == 'binary': + output_mode = self.params.get("output_mode", "binary") + + if output_mode == "binary": # Create binary black/white mask binary_mask = self._threshold_mask(mask) output = np.stack([binary_mask * 255] * 3, axis=-1) - - elif output_mode == 'alpha': + + elif output_mode == "alpha": # Create RGBA output with alpha channel if len(original_image.shape) == 3: alpha = (mask * 255).astype(np.uint8) output = np.concatenate([original_image, alpha[..., np.newaxis]], axis=-1) else: output = original_image - - elif output_mode == 'background': + + elif output_mode == "background": # Replace background with solid color - background_color = self.params.get('background_color', (0, 0, 0)) + background_color = self.params.get("background_color", (0, 0, 0)) binary_mask = self._threshold_mask(mask) - + output = original_image.copy() # Apply background where mask is 0 for i in range(3): output[..., i] = np.where(binary_mask, output[..., i], background_color[i]) - + else: raise ValueError(f"Unknown output_mode: {output_mode}") - + return output - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply MediaPipe segmentation to the input image """ - detect_resolution = self.params.get('detect_resolution', 512) + detect_resolution = self.params.get("detect_resolution", 512) image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS) - + rgb_image = cv2.cvtColor(np.array(image_resized), cv2.COLOR_BGR2RGB) - + results = self.segmentor.process(rgb_image) - + if results.segmentation_mask is not None: mask = results.segmentation_mask - + mask = self._apply_mask_smoothing(mask) - + output_image = self._create_output_image(rgb_image, mask) else: - output_mode = self.params.get('output_mode', 'binary') - if output_mode == 'binary': + output_mode = self.params.get("output_mode", "binary") + if output_mode == "binary": output_image = np.zeros((detect_resolution, detect_resolution, 3), dtype=np.uint8) else: output_image = rgb_image - + if output_image.shape[-1] == 4: - result_pil = Image.fromarray(output_image, 'RGBA') + result_pil = Image.fromarray(output_image, "RGBA") else: - result_pil = Image.fromarray(output_image, 'RGB') - + result_pil = Image.fromarray(output_image, "RGB") + return result_pil - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU to avoid unnecessary CPU transfers @@ -239,8 +250,8 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: pil_image = self.tensor_to_pil(image_tensor) processed_pil = self._process_core(pil_image) return self.pil_to_tensor(processed_pil) - + def __del__(self): """Cleanup MediaPipe segmentor""" - if hasattr(self, '_segmentor') and self._segmentor is not None: - self._segmentor.close() \ No newline at end of file + if hasattr(self, "_segmentor") and self._segmentor is not None: + self._segmentor.close() diff --git a/src/streamdiffusion/preprocessing/processors/openpose.py b/src/streamdiffusion/preprocessing/processors/openpose.py index a7ba15498..53ac8afe2 100644 --- a/src/streamdiffusion/preprocessing/processors/openpose.py +++ b/src/streamdiffusion/preprocessing/processors/openpose.py @@ -1,16 +1,18 @@ -import numpy as np from PIL import Image, ImageDraw -from typing import Union, Optional, List, Tuple + from .base import BasePreprocessor + try: import cv2 + OPENCV_AVAILABLE = True except ImportError: OPENCV_AVAILABLE = False try: from controlnet_aux import OpenposeDetector + CONTROLNET_AUX_AVAILABLE = True except ImportError: CONTROLNET_AUX_AVAILABLE = False @@ -19,10 +21,10 @@ class OpenPosePreprocessor(BasePreprocessor): """ OpenPose human pose detection preprocessor for ControlNet - + Detects human poses and creates stick figure representations. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -32,26 +34,28 @@ def get_preprocessor_metadata(cls): "include_hands": { "type": "bool", "default": False, - "description": "Whether to include hand keypoints in detection" + "description": "Whether to include hand keypoints in detection", }, "include_face": { "type": "bool", "default": False, - "description": "Whether to include face keypoints in detection" - } + "description": "Whether to include face keypoints in detection", + }, }, - "use_cases": ["Human pose control", "Dance movements", "Character poses"] + "use_cases": ["Human pose control", "Dance movements", "Character poses"], } - - def __init__(self, - detect_resolution: int = 512, - image_resolution: int = 512, - include_hands: bool = False, - include_face: bool = False, - **kwargs): + + def __init__( + self, + detect_resolution: int = 512, + image_resolution: int = 512, + include_hands: bool = False, + include_face: bool = False, + **kwargs, + ): """ Initialize OpenPose preprocessor - + Args: detect_resolution: Resolution for pose detection image_resolution: Output image resolution @@ -64,81 +68,93 @@ def __init__(self, image_resolution=image_resolution, include_hands=include_hands, include_face=include_face, - **kwargs + **kwargs, ) - + self._detector = None - + @property def detector(self): """Lazy loading of the OpenPose detector""" if self._detector is None: if CONTROLNET_AUX_AVAILABLE: print("Loading OpenPose detector from controlnet_aux") - self._detector = OpenposeDetector.from_pretrained('lllyasviel/Annotators') + self._detector = OpenposeDetector.from_pretrained("lllyasviel/Annotators") else: print("Warning: controlnet_aux not available, using fallback OpenPose implementation") self._detector = self._create_fallback_detector() return self._detector - + def _create_fallback_detector(self): """Create a simple fallback detector if controlnet_aux is not available""" + class FallbackDetector: def __call__(self, image, include_hands=False, include_face=False): # Simple fallback: return a blank image with some basic pose lines width, height = image.size - pose_image = Image.new('RGB', (width, height), (0, 0, 0)) + pose_image = Image.new("RGB", (width, height), (0, 0, 0)) draw = ImageDraw.Draw(pose_image) - + # Draw a basic stick figure in the center center_x, center_y = width // 2, height // 2 - + # Head head_radius = min(width, height) // 20 - draw.ellipse([ - center_x - head_radius, center_y - height // 4 - head_radius, - center_x + head_radius, center_y - height // 4 + head_radius - ], outline=(255, 255, 255), width=2) - + draw.ellipse( + [ + center_x - head_radius, + center_y - height // 4 - head_radius, + center_x + head_radius, + center_y - height // 4 + head_radius, + ], + outline=(255, 255, 255), + width=2, + ) + # Body body_top = center_y - height // 4 + head_radius body_bottom = center_y + height // 6 draw.line([center_x, body_top, center_x, body_bottom], fill=(255, 255, 255), width=2) - + # Arms arm_length = width // 6 arm_y = body_top + (body_bottom - body_top) // 3 draw.line([center_x - arm_length, arm_y, center_x + arm_length, arm_y], fill=(255, 255, 255), width=2) - + # Legs leg_length = height // 8 - draw.line([center_x, body_bottom, center_x - leg_length//2, body_bottom + leg_length], fill=(255, 255, 255), width=2) - draw.line([center_x, body_bottom, center_x + leg_length//2, body_bottom + leg_length], fill=(255, 255, 255), width=2) - + draw.line( + [center_x, body_bottom, center_x - leg_length // 2, body_bottom + leg_length], + fill=(255, 255, 255), + width=2, + ) + draw.line( + [center_x, body_bottom, center_x + leg_length // 2, body_bottom + leg_length], + fill=(255, 255, 255), + width=2, + ) + return pose_image - + return FallbackDetector() - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply OpenPose detection to the input image """ - detect_resolution = self.params.get('detect_resolution', 512) + detect_resolution = self.params.get("detect_resolution", 512) image_resized = image.resize((detect_resolution, detect_resolution), Image.LANCZOS) - - include_hands = self.params.get('include_hands', False) - include_face = self.params.get('include_face', False) - - if CONTROLNET_AUX_AVAILABLE and hasattr(self.detector, '__call__'): + + include_hands = self.params.get("include_hands", False) + include_face = self.params.get("include_face", False) + + if CONTROLNET_AUX_AVAILABLE and hasattr(self.detector, "__call__"): try: - pose_image = self.detector( - image_resized, - hand_and_face=include_hands or include_face - ) + pose_image = self.detector(image_resized, hand_and_face=include_hands or include_face) except Exception as e: print(f"Warning: OpenPose detection failed, using fallback: {e}") pose_image = self._create_fallback_detector()(image_resized, include_hands, include_face) else: pose_image = self.detector(image_resized, include_hands, include_face) - - return pose_image \ No newline at end of file + + return pose_image diff --git a/src/streamdiffusion/preprocessing/processors/passthrough.py b/src/streamdiffusion/preprocessing/processors/passthrough.py index e4d1125fe..ea81de6e4 100644 --- a/src/streamdiffusion/preprocessing/processors/passthrough.py +++ b/src/streamdiffusion/preprocessing/processors/passthrough.py @@ -1,55 +1,47 @@ -import numpy as np -from PIL import Image import torch -from typing import Union, Optional +from PIL import Image + from .base import BasePreprocessor class PassthroughPreprocessor(BasePreprocessor): """ Passthrough preprocessor for ControlNet - + Simply passes the input image through without any processing. Useful for ControlNets that expect the raw input image, such as: - Tile ControlNet - Reference ControlNet - Custom ControlNets that don't need preprocessing """ - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Passthrough", "description": "Passes the input image through with minimal processing. Used for tile ControlNet or when you want to use the input image directly.", - "parameters": { - - }, - "use_cases": ["Tile ControlNet", "Image-to-image with structure preservation", "Upscaling with control"] + "parameters": {}, + "use_cases": ["Tile ControlNet", "Image-to-image with structure preservation", "Upscaling with control"], } - - def __init__(self, - image_resolution: int = 512, - **kwargs): + + def __init__(self, image_resolution: int = 512, **kwargs): """ Initialize passthrough preprocessor - + Args: image_resolution: Output image resolution **kwargs: Additional parameters (ignored for passthrough) """ - super().__init__( - image_resolution=image_resolution, - **kwargs - ) - + super().__init__(image_resolution=image_resolution, **kwargs) + def _process_core(self, image: Image.Image) -> Image.Image: """ Pass through the input image with no processing """ return image - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ Pass through tensor with no processing """ - return tensor \ No newline at end of file + return tensor diff --git a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py index 7662c37cc..58bea9479 100644 --- a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py @@ -1,19 +1,23 @@ -#NOTE: ported from https://github.com/yuvraj108c/ComfyUI-YoloNasPose-Tensorrt +# NOTE: ported from https://github.com/yuvraj108c/ComfyUI-YoloNasPose-Tensorrt import os + +import cv2 import numpy as np import torch import torch.nn.functional as F -import cv2 from PIL import Image -from typing import Union, Optional, List, Tuple + from .base import BasePreprocessor + try: + from collections import OrderedDict + import tensorrt as trt from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False @@ -40,7 +44,7 @@ class TensorRTEngine: """Simplified TensorRT engine wrapper for pose estimation inference (optimized)""" - + def __init__(self, engine_path): self.engine_path = engine_path self.engine = None @@ -64,13 +68,11 @@ def allocate_buffers(self, device="cuda"): name = self.engine.get_tensor_name(idx) shape = self.context.get_tensor_shape(name) dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: self.context.set_input_shape(name, shape) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=device) + + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) self.tensors[name] = tensor def infer(self, feed_dict, stream=None): @@ -78,7 +80,7 @@ def infer(self, feed_dict, stream=None): # Use cached stream if none provided if stream is None: stream = self._cuda_stream - + # Copy input data to tensors for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) @@ -86,23 +88,23 @@ def infer(self, feed_dict, stream=None): # Set tensor addresses for name, tensor in self.tensors.items(): self.context.set_tensor_address(name, tensor.data_ptr()) - + # Execute inference success = self.context.execute_async_v3(stream) if not success: raise ValueError("TensorRT inference failed.") - + return self.tensors class PoseVisualization: """Pose drawing utilities ported from ComfyUI YoloNasPose node""" - + @staticmethod def draw_skeleton(image, keypoints, edge_links, edge_colors, joint_thickness=10, keypoint_radius=10): """Draw pose skeleton on image""" overlay = image.copy() - + # Draw edges/links between keypoints for (kp1, kp2), color in zip(edge_links, edge_colors): if kp1 < len(keypoints) and kp2 < len(keypoints): @@ -113,32 +115,32 @@ def draw_skeleton(image, keypoints, edge_links, edge_colors, joint_thickness=10, p1 = (int(keypoints[kp1][0]), int(keypoints[kp1][1])) p2 = (int(keypoints[kp2][0]), int(keypoints[kp2][1])) cv2.line(overlay, p1, p2, color=color, thickness=joint_thickness, lineType=cv2.LINE_AA) - + # Draw keypoints for keypoint in keypoints: if len(keypoint) >= 3 and keypoint[2] > 0.5: # confidence threshold x, y = int(keypoint[0]), int(keypoint[1]) cv2.circle(overlay, (x, y), keypoint_radius, (0, 255, 0), -1, cv2.LINE_AA) - + return cv2.addWeighted(overlay, 0.75, image, 0.25, 0) @staticmethod def draw_poses(image, poses, edge_links, edge_colors, joint_thickness=10, keypoint_radius=10): """Draw multiple poses on image""" result = image.copy() - + for pose in poses: result = PoseVisualization.draw_skeleton( result, pose, edge_links, edge_colors, joint_thickness, keypoint_radius ) - + return result def iterate_over_batch_predictions(predictions, batch_size): """Process batch predictions from TensorRT output""" num_detections, batch_boxes, batch_scores, batch_joints = predictions - + for image_index in range(batch_size): num_detection_in_image = int(num_detections[image_index, 0]) @@ -150,35 +152,62 @@ def iterate_over_batch_predictions(predictions, batch_size): else: pred_scores = batch_scores[image_index, :num_detection_in_image] pred_boxes = batch_boxes[image_index, :num_detection_in_image] - pred_joints = batch_joints[image_index, :num_detection_in_image].reshape( - (num_detection_in_image, -1, 3)) + pred_joints = batch_joints[image_index, :num_detection_in_image].reshape((num_detection_in_image, -1, 3)) yield image_index, pred_boxes, pred_scores, pred_joints + # precompute edge links define skeleton connections (COCO format) -edge_links = [[0, 17], [13, 15], [14, 16], [12, 14], [12, 17], [5, 6], - [11, 13], [7, 9], [5, 7], [17, 11], [6, 8], [8, 10], - [1, 3], [0, 1], [0, 2], [2, 4]] +edge_links = [ + [0, 17], + [13, 15], + [14, 16], + [12, 14], + [12, 17], + [5, 6], + [11, 13], + [7, 9], + [5, 7], + [17, 11], + [6, 8], + [8, 10], + [1, 3], + [0, 1], + [0, 2], + [2, 4], +] edge_colors = [ - [255, 0, 0], [255, 85, 0], [170, 255, 0], [85, 255, 0], [85, 255, 0], - [85, 0, 255], [255, 170, 0], [0, 177, 58], [0, 179, 119], [179, 179, 0], - [0, 119, 179], [0, 179, 179], [119, 0, 179], [179, 0, 179], [178, 0, 118], [178, 0, 118] + [255, 0, 0], + [255, 85, 0], + [170, 255, 0], + [85, 255, 0], + [85, 255, 0], + [85, 0, 255], + [255, 170, 0], + [0, 177, 58], + [0, 179, 119], + [179, 179, 0], + [0, 119, 179], + [0, 179, 179], + [119, 0, 179], + [179, 0, 179], + [178, 0, 118], + [178, 0, 118], ] + + def show_predictions_from_batch_format(predictions): """Convert predictions to pose visualization format""" try: - image_index, pred_boxes, pred_scores, pred_joints = next( - iter(iterate_over_batch_predictions(predictions, 1))) + image_index, pred_boxes, pred_scores, pred_joints = next(iter(iterate_over_batch_predictions(predictions, 1))) except Exception as e: raise RuntimeError(f"show_predictions_from_batch_format: Error in iterate_over_batch_predictions: {e}") - - # Handle case where no poses are detected if pred_joints.shape[0] == 0: return np.zeros((640, 640, 3)) - + # Add middle joint between shoulders (keypoints 5 and 6) try: # Calculate middle joints for all poses at once @@ -187,49 +216,50 @@ def show_predictions_from_batch_format(predictions): new_pred_joints = np.concatenate([pred_joints, middle_joints[:, np.newaxis]], axis=1) except Exception as e: raise RuntimeError(f"show_predictions_from_batch_format: Error processing poses: {e}") - + # Create black background for pose visualization black_image = np.zeros((640, 640, 3)) - + try: image = PoseVisualization.draw_poses( - image=black_image, - poses=new_pred_joints, - edge_links=edge_links, - edge_colors=edge_colors, + image=black_image, + poses=new_pred_joints, + edge_links=edge_links, + edge_colors=edge_colors, joint_thickness=10, - keypoint_radius=10 + keypoint_radius=10, ) except Exception as e: raise RuntimeError(f"show_predictions_from_batch_format: Error in pose drawing: {e}") - + return image class YoloNasPoseTensorrtPreprocessor(BasePreprocessor): """ YoloNas Pose TensorRT preprocessor for ControlNet - + Uses TensorRT-optimized YoloNas Pose model for fast pose estimation. """ - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Pose Detection (TensorRT)", "description": "Fast TensorRT-optimized pose detection using YOLO-NAS Pose model. Detects human pose keypoints with high performance.", "parameters": {}, - "use_cases": ["Human pose control", "Character animation", "Pose-guided generation", "Real-time pose detection"] + "use_cases": [ + "Human pose control", + "Character animation", + "Pose-guided generation", + "Real-time pose detection", + ], } - - def __init__(self, - engine_path: str = None, - detect_resolution: int = 640, - image_resolution: int = 512, - **kwargs): + + def __init__(self, engine_path: str = None, detect_resolution: int = 640, image_resolution: int = 512, **kwargs): """ Initialize TensorRT pose preprocessor - + Args: engine_path: Path to TensorRT engine file detect_resolution: Resolution for pose detection (should match engine input) @@ -241,78 +271,72 @@ def __init__(self, "TensorRT and polygraphy libraries are required for TensorRT pose preprocessing. " "Install them with: pip install tensorrt polygraphy" ) - + super().__init__( - engine_path=engine_path, - detect_resolution=detect_resolution, - image_resolution=image_resolution, - **kwargs + engine_path=engine_path, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs ) - + self._engine = None self._device = "cuda" if torch.cuda.is_available() else "cpu" self._is_cuda_available = torch.cuda.is_available() - + @property def engine(self): """Lazy loading of the TensorRT engine""" if self._engine is None: - engine_path = self.params.get('engine_path') + engine_path = self.params.get("engine_path") if engine_path is None: raise ValueError( "engine_path is required for TensorRT pose preprocessing. " "Please provide it in the preprocessor_params config." ) - + if not os.path.exists(engine_path): raise FileNotFoundError(f"TensorRT engine not found: {engine_path}") - + self._engine = TensorRTEngine(engine_path) self._engine.load() self._engine.activate() self._engine.allocate_buffers() - + return self._engine - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply TensorRT pose estimation to the input image """ - detect_resolution = self.params.get('detect_resolution', 640) - + detect_resolution = self.params.get("detect_resolution", 640) + image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) - + image_resized = F.interpolate( - image_tensor, - size=(detect_resolution, detect_resolution), - mode='bilinear', - align_corners=False + image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False ) - + image_resized_uint8 = (image_resized * 255.0).type(torch.uint8) - + if self._is_cuda_available: image_resized_uint8 = image_resized_uint8.cuda() - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized_uint8}, cuda_stream) - - predictions = [result[key].cpu().numpy() for key in result.keys() if key != 'input'] - + + predictions = [result[key].cpu().numpy() for key in result.keys() if key != "input"] + try: pose_image = show_predictions_from_batch_format(predictions) except Exception: # Fallback to black image on error pose_image = np.zeros((detect_resolution, detect_resolution, 3)) - + pose_image = pose_image.clip(0, 255).astype(np.uint8) pose_image = cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB) - + result = Image.fromarray(pose_image) - + return result - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU to avoid CPU transfers @@ -321,31 +345,30 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: image_tensor = image_tensor.unsqueeze(0) if not image_tensor.is_cuda: image_tensor = image_tensor.cuda() - - detect_resolution = self.params.get('detect_resolution', 640) - + + detect_resolution = self.params.get("detect_resolution", 640) + image_resized = torch.nn.functional.interpolate( - image_tensor, size=(detect_resolution, detect_resolution), - mode='bilinear', align_corners=False + image_tensor, size=(detect_resolution, detect_resolution), mode="bilinear", align_corners=False ) - + image_resized_uint8 = (image_resized * 255.0).type(torch.uint8) - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized_uint8}, cuda_stream) - - predictions = [result[key].cpu().numpy() for key in result.keys() if key != 'input'] - + + predictions = [result[key].cpu().numpy() for key in result.keys() if key != "input"] + try: pose_image = show_predictions_from_batch_format(predictions) pose_image = pose_image.clip(0, 255).astype(np.uint8) pose_image = cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB) - + pose_tensor = torch.from_numpy(pose_image).float() / 255.0 pose_tensor = pose_tensor.permute(2, 0, 1).unsqueeze(0).cuda() - + except Exception: # Fallback to black tensor on error pose_tensor = torch.zeros(1, 3, detect_resolution, detect_resolution).cuda() - - return pose_tensor \ No newline at end of file + + return pose_tensor diff --git a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py index 18adfbb26..cdd548765 100644 --- a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py +++ b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py @@ -1,22 +1,23 @@ # NOTE: ported from https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt -import os -import torch +import logging +from collections import OrderedDict +from pathlib import Path +from typing import Tuple + import numpy as np -from PIL import Image -from typing import Optional, Tuple import requests +import torch +from PIL import Image from tqdm import tqdm -import hashlib -import logging -from pathlib import Path -from collections import OrderedDict from .base import BasePreprocessor + # Try to import spandrel for model loading try: from spandrel import ModelLoader + SPANDREL_AVAILABLE = True except ImportError: SPANDREL_AVAILABLE = False @@ -24,9 +25,11 @@ # Try to import TensorRT dependencies try: import tensorrt as trt - from streamdiffusion.acceleration.tensorrt.utilities import engine_from_bytes, bytes_from_path + + from streamdiffusion.acceleration.tensorrt.utilities import bytes_from_path, engine_from_bytes + TRT_AVAILABLE = True - + # Numpy to PyTorch dtype mapping (same as depth_tensorrt.py) numpy_to_torch_dtype_dict = { np.uint8: torch.uint8, @@ -40,27 +43,28 @@ np.complex64: torch.complex64, np.complex128: torch.complex128, } - + # Handle bool type for numpy compatibility (same as depth_tensorrt.py) if np.version.full_version >= "1.24.0": numpy_to_torch_dtype_dict[np.bool_] = torch.bool else: numpy_to_torch_dtype_dict[np.bool] = torch.bool - + except ImportError: TRT_AVAILABLE = False class RealESRGANEngine: """TensorRT engine wrapper for RealESRGAN inference (following depth_tensorrt pattern)""" - + def __init__(self, engine_path): self.engine_path = engine_path self.engine = None self.context = None self.tensors = OrderedDict() - + import threading + self._inference_lock = threading.Lock() def load(self): @@ -79,13 +83,13 @@ def allocate_buffers(self, input_shape, device="cuda"): # Set input shape for dynamic sizing input_name = "input" self.context.set_input_shape(input_name, input_shape) - + # Allocate tensors for all bindings for idx in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(idx) shape = self.context.get_tensor_shape(name) dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - + # Convert numpy dtype to torch dtype if dtype == np.float32: torch_dtype = torch.float32 @@ -93,7 +97,7 @@ def allocate_buffers(self, input_shape, device="cuda"): torch_dtype = torch.float16 else: torch_dtype = torch.float32 - + tensor = torch.empty(tuple(shape), dtype=torch_dtype, device=device) self.tensors[name] = tensor @@ -102,7 +106,7 @@ def infer(self, feed_dict, stream=None): # Use provided stream or current stream context if stream is None: stream = torch.cuda.current_stream().cuda_stream - + # Copy input data to tensors for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) @@ -111,27 +115,29 @@ def infer(self, feed_dict, stream=None): for name, tensor in self.tensors.items(): addr = tensor.data_ptr() self.context.set_tensor_address(name, addr) - + with self._inference_lock: success = self.context.execute_async_v3(stream) - + if not success: raise RuntimeError("RealESRGANEngine: TensorRT inference failed") - + torch.cuda.synchronize() - + return self.tensors + logger = logging.getLogger(__name__) + class RealESRGANProcessor(BasePreprocessor): """ RealESRGAN 2x upscaling processor with automatic model download, ONNX export, and TensorRT acceleration. """ - + MODEL_URL = "https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true" - - @classmethod + + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "RealESRGAN 2x", @@ -140,94 +146,98 @@ def get_preprocessor_metadata(cls): "enable_tensorrt": { "type": "bool", "default": True, - "description": "Use TensorRT acceleration for faster inference" + "description": "Use TensorRT acceleration for faster inference", }, "force_rebuild": { - "type": "bool", + "type": "bool", "default": False, - "description": "Force rebuild TensorRT engine even if it exists" - } + "description": "Force rebuild TensorRT engine even if it exists", + }, }, - "use_cases": ["High-quality upscaling", "Real-time 2x enlargement", "Image enhancement"] + "use_cases": ["High-quality upscaling", "Real-time 2x enlargement", "Image enhancement"], } - + def __init__(self, enable_tensorrt: bool = True, force_rebuild: bool = False, **kwargs): super().__init__(enable_tensorrt=enable_tensorrt, force_rebuild=force_rebuild, **kwargs) self.enable_tensorrt = enable_tensorrt and TRT_AVAILABLE self.force_rebuild = force_rebuild self.scale_factor = 2 # RealESRGAN 2x model - + # Model paths self.models_dir = Path("models") / "realesrgan" self.models_dir.mkdir(parents=True, exist_ok=True) self.model_path = self.models_dir / "RealESRGAN_x2.pth" self.onnx_path = self.models_dir / "RealESRGAN_x2.onnx" self.engine_path = self.models_dir / f"RealESRGAN_x2_{trt.__version__ if TRT_AVAILABLE else 'notrt'}.trt" - + # Model state self.pytorch_model = None self._engine = None # Lazy loading like depth processor - + # Thread safety for engine initialization import threading + self._engine_lock = threading.Lock() - + # Initialize self._ensure_model_ready() - + @property def engine(self): """Lazy loading of the TensorRT engine""" if self._engine is None: if not self.engine_path.exists(): raise FileNotFoundError(f"TensorRT engine not found: {self.engine_path}") - + self._engine = RealESRGANEngine(str(self.engine_path)) self._engine.load() self._engine.activate() - + # Allocate buffers for standard input size (will be reallocated as needed) standard_shape = (1, 3, 512, 512) self._engine.allocate_buffers(standard_shape, device=self.device) - + return self._engine - + def _download_file(self, url: str, save_path: Path): """Download file with progress bar""" if save_path.exists(): return - + response = requests.get(url, stream=True) response.raise_for_status() - - total_size = int(response.headers.get('content-length', 0)) - - with open(save_path, 'wb') as file, tqdm( - desc=f"Downloading {save_path.name}", - total=total_size, - unit='iB', - unit_scale=True, - unit_divisor=1024, - colour='green' - ) as progress_bar: + + total_size = int(response.headers.get("content-length", 0)) + + with ( + open(save_path, "wb") as file, + tqdm( + desc=f"Downloading {save_path.name}", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + colour="green", + ) as progress_bar, + ): for data in response.iter_content(chunk_size=1024): size = file.write(data) progress_bar.update(size) - + def _ensure_model_ready(self): """Ensure PyTorch model is downloaded and loaded""" # Download model if needed if not self.model_path.exists(): self._download_file(self.MODEL_URL, self.model_path) - + # Load PyTorch model if self.pytorch_model is None: self._load_pytorch_model() - + # Setup TensorRT if enabled if self.enable_tensorrt: self._setup_tensorrt() - + def _load_pytorch_model(self): """Load PyTorch model from file""" if not SPANDREL_AVAILABLE: @@ -235,92 +245,92 @@ def _load_pytorch_model(self): state_dict = torch.load(self.model_path, map_location=self.device) # This is a simplified approach - real implementation would need model architecture return - + model_descriptor = ModelLoader().load_from_file(str(self.model_path)) # Don't force dtype conversion as it can cause type mismatches # Let the model keep its native dtype and convert inputs as needed self.pytorch_model = model_descriptor.model.eval().to(device=self.device) model_dtype = next(self.pytorch_model.parameters()).dtype - + def _export_to_onnx(self): """Export PyTorch model to ONNX format""" if self.onnx_path.exists() and not self.force_rebuild: return - + if self.pytorch_model is None: self._load_pytorch_model() - + if self.pytorch_model is None: return - + # Test with small input for export test_input = torch.randn(1, 3, 256, 256).to(self.device) - + dynamic_axes = { "input": {0: "batch_size", 2: "height", 3: "width"}, "output": {0: "batch_size", 2: "height", 3: "width"}, } - + with torch.no_grad(): torch.onnx.export( self.pytorch_model, test_input, str(self.onnx_path), verbose=False, - input_names=['input'], - output_names=['output'], + input_names=["input"], + output_names=["output"], opset_version=17, export_params=True, dynamic_axes=dynamic_axes, ) - + def _setup_tensorrt(self): """Setup TensorRT engine""" if not TRT_AVAILABLE: return - + # Export to ONNX first if needed if not self.onnx_path.exists(): self._export_to_onnx() - + # Build/load TensorRT engine self._load_tensorrt_engine() - + def _load_tensorrt_engine(self): """Load or build TensorRT engine""" if self.engine_path.exists() and not self.force_rebuild: self._load_existing_engine() else: self._build_tensorrt_engine() - + def _load_existing_engine(self): """Load existing TensorRT engine (now handled by lazy loading property)""" # Engine loading is now handled by the lazy loading 'engine' property # This method is kept for compatibility but does nothing pass - + def _build_tensorrt_engine(self): """Build TensorRT engine from ONNX model""" if not self.onnx_path.exists(): return - + try: # Create builder and network builder = trt.Builder(trt.Logger(trt.Logger.WARNING)) network = builder.create_network() # EXPLICIT_BATCH deprecated/ignored in TRT 10.x parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) - + # Parse ONNX model - with open(self.onnx_path, 'rb') as model: + with open(self.onnx_path, "rb") as model: if not parser.parse(model.read()): for error in range(parser.num_errors): pass return - + # Configure builder config = builder.create_builder_config() config.set_flag(trt.BuilderFlag.FP16) # Enable FP16 for better performance - + # Set optimization profile for dynamic shapes profile = builder.create_optimization_profile() min_shape = (1, 3, 256, 256) @@ -328,87 +338,87 @@ def _build_tensorrt_engine(self): max_shape = (1, 3, 1024, 1024) profile.set_shape("input", min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) - + # Build engine engine = builder.build_serialized_network(network, config) - + if engine is None: return - + # Save engine - with open(self.engine_path, 'wb') as f: + with open(self.engine_path, "wb") as f: f.write(engine) - + # Load the built engine self._load_existing_engine() - - except Exception as e: + + except Exception: pass - + def _process_with_tensorrt(self, tensor: torch.Tensor) -> torch.Tensor: """Process tensor using TensorRT engine (following depth_tensorrt pattern)""" batch_size, channels, height, width = tensor.shape input_shape = (batch_size, channels, height, width) - + # Ensure buffers are allocated for this input shape - if not hasattr(self.engine, 'tensors') or len(self.engine.tensors) == 0: + if not hasattr(self.engine, "tensors") or len(self.engine.tensors) == 0: self.engine.allocate_buffers(input_shape, device=self.device) else: # Check if we need to reallocate for different input shape input_tensor_shape = self.engine.tensors.get("input", torch.empty(0)).shape if input_tensor_shape != input_shape: self.engine.allocate_buffers(input_shape, device=self.device) - + # Prepare input tensor input_tensor = tensor.contiguous() if input_tensor.dtype != self.engine.tensors["input"].dtype: input_tensor = input_tensor.to(dtype=self.engine.tensors["input"].dtype) - + # Use engine inference with current stream context for proper synchronization cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": input_tensor}, cuda_stream) - output_tensor = result['output'] - + output_tensor = result["output"] + # Ensure output is properly clamped to [0, 1] range for RealESRGAN output_tensor = torch.clamp(output_tensor, 0.0, 1.0) - + return output_tensor.clone() - + def _process_with_pytorch(self, tensor: torch.Tensor) -> torch.Tensor: """Process tensor using PyTorch model""" if self.pytorch_model is None: raise RuntimeError("_process_with_pytorch: PyTorch model not loaded") - + # Ensure model and input tensor have compatible dtypes model_dtype = next(self.pytorch_model.parameters()).dtype original_dtype = tensor.dtype if tensor.dtype != model_dtype: tensor = tensor.to(dtype=model_dtype) - + with torch.no_grad(): result = self.pytorch_model(tensor) - + # Ensure output is properly clamped to [0, 1] range for RealESRGAN result = torch.clamp(result, 0.0, 1.0) - + # Convert result to the desired output dtype (self.dtype) if result.dtype != self.dtype: result = result.to(dtype=self.dtype) - + return result - + def _process_core(self, image: Image.Image) -> Image.Image: """Core processing using PIL Image""" # Convert to tensor for processing tensor = self.pil_to_tensor(image) if tensor.dim() == 3: tensor = tensor.unsqueeze(0) - + # Process with available backend if self.enable_tensorrt and TRT_AVAILABLE and self.engine_path.exists(): try: output_tensor = self._process_with_tensorrt(tensor) - except Exception as e: + except Exception: output_tensor = self._process_with_pytorch(tensor) elif self.pytorch_model is not None: output_tensor = self._process_with_pytorch(tensor) @@ -416,29 +426,29 @@ def _process_core(self, image: Image.Image) -> Image.Image: # Fallback to simple upscaling if no model is available target_width, target_height = self.get_target_dimensions() return image.resize((target_width, target_height), Image.LANCZOS) - + # Convert back to PIL if output_tensor.dim() == 4: output_tensor = output_tensor.squeeze(0) - + result_image = self.tensor_to_pil(output_tensor) - + return result_image - + def _ensure_target_size(self, image: Image.Image) -> Image.Image: """ Override base class method - for upscaling, we want to keep the upscaled size Don't resize back to original dimensions """ return image - + def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor: """ Override base class method - for upscaling, we want to keep the upscaled size Don't resize back to original dimensions """ return tensor - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """Core tensor processing""" if tensor.dim() == 3: @@ -446,49 +456,46 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: squeeze_output = True else: squeeze_output = False - + # Process with available backend if self.enable_tensorrt and TRT_AVAILABLE and self.engine_path.exists(): try: output_tensor = self._process_with_tensorrt(tensor) - except Exception as e: + except Exception: output_tensor = self._process_with_pytorch(tensor) elif self.pytorch_model is not None: output_tensor = self._process_with_pytorch(tensor) else: # Fallback using interpolation output_tensor = torch.nn.functional.interpolate( - tensor, - scale_factor=self.scale_factor, - mode='bicubic', - align_corners=False + tensor, scale_factor=self.scale_factor, mode="bicubic", align_corners=False ) - + if squeeze_output: output_tensor = output_tensor.squeeze(0) - + return output_tensor - + def get_target_dimensions(self) -> Tuple[int, int]: """Get target output dimensions (width, height) - 2x upscaled""" - width = self.params.get('image_width') - height = self.params.get('image_height') - + width = self.params.get("image_width") + height = self.params.get("image_height") + if width is not None and height is not None: target_dims = (width * self.scale_factor, height * self.scale_factor) return target_dims - + # Fallback to square resolution - resolution = self.params.get('image_resolution', 512) + resolution = self.params.get("image_resolution", 512) target_resolution = resolution * self.scale_factor target_dims = (target_resolution, target_resolution) return target_dims - + def __del__(self): """Cleanup resources""" - if hasattr(self, '_engine') and self._engine is not None: + if hasattr(self, "_engine") and self._engine is not None: # Cleanup dedicated stream if it exists - if hasattr(self._engine, '_dedicated_stream'): + if hasattr(self._engine, "_dedicated_stream"): torch.cuda.synchronize() del self._engine._dedicated_stream del self._engine diff --git a/src/streamdiffusion/preprocessing/processors/sharpen.py b/src/streamdiffusion/preprocessing/processors/sharpen.py index 9660e1cee..05d36fc2d 100644 --- a/src/streamdiffusion/preprocessing/processors/sharpen.py +++ b/src/streamdiffusion/preprocessing/processors/sharpen.py @@ -1,22 +1,21 @@ import torch import torch.nn.functional as F -import numpy as np from PIL import Image -from typing import Union + from .base import BasePreprocessor class SharpenPreprocessor(BasePreprocessor): """ GPU-heavy image sharpening preprocessor using unsharp masking and edge enhancement - + Applies sophisticated sharpening using multiple Gaussian operations: - Multi-scale unsharp masking - Edge-preserving enhancement - Laplacian-based detail enhancement - All operations performed on GPU for maximum performance """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -27,52 +26,54 @@ def get_preprocessor_metadata(cls): "type": "float", "default": 1.5, "range": [0.1, 5.0], - "description": "Overall sharpening intensity. Higher values create stronger effects." + "description": "Overall sharpening intensity. Higher values create stronger effects.", }, "unsharp_radius": { "type": "float", "default": 1.0, "range": [0.1, 5.0], - "description": "Radius for unsharp masking blur. Affects detail scale." + "description": "Radius for unsharp masking blur. Affects detail scale.", }, "edge_enhancement": { "type": "float", "default": 0.5, "range": [0.0, 2.0], - "description": "Edge enhancement factor. Emphasizes image boundaries." + "description": "Edge enhancement factor. Emphasizes image boundaries.", }, "detail_boost": { "type": "float", "default": 0.3, "range": [0.0, 1.0], - "description": "Fine detail enhancement using Laplacian filtering." + "description": "Fine detail enhancement using Laplacian filtering.", }, "noise_reduction": { "type": "float", "default": 0.1, "range": [0.0, 0.5], - "description": "Mild noise reduction to prevent amplification." + "description": "Mild noise reduction to prevent amplification.", }, "multi_scale": { "type": "bool", "default": True, - "description": "Use multi-scale processing for better quality (more GPU intensive)." - } + "description": "Use multi-scale processing for better quality (more GPU intensive).", + }, }, - "use_cases": ["Detail enhancement", "Photo sharpening", "Edge definition", "Clarity improvement"] + "use_cases": ["Detail enhancement", "Photo sharpening", "Edge definition", "Clarity improvement"], } - - def __init__(self, - sharpen_intensity: float = 1.5, - unsharp_radius: float = 1.0, - edge_enhancement: float = 0.5, - detail_boost: float = 0.3, - noise_reduction: float = 0.1, - multi_scale: bool = True, - **kwargs): + + def __init__( + self, + sharpen_intensity: float = 1.5, + unsharp_radius: float = 1.0, + edge_enhancement: float = 0.5, + detail_boost: float = 0.3, + noise_reduction: float = 0.1, + multi_scale: bool = True, + **kwargs, + ): """ Initialize Sharpen preprocessor - + Args: sharpen_intensity: Overall sharpening strength unsharp_radius: Blur radius for unsharp masking @@ -89,194 +90,182 @@ def __init__(self, detail_boost=detail_boost, noise_reduction=noise_reduction, multi_scale=multi_scale, - **kwargs + **kwargs, ) - + # Cache kernels for efficiency self._cached_gaussian_kernels = {} self._cached_laplacian_kernel = None self._cached_edge_kernels = None - + def _create_gaussian_kernel(self, size: int, sigma: float) -> torch.Tensor: """Create 2D Gaussian kernel""" coords = torch.arange(size, dtype=self.dtype, device=self.device) coords = coords - (size - 1) / 2 - y_grid, x_grid = torch.meshgrid(coords, coords, indexing='ij') + y_grid, x_grid = torch.meshgrid(coords, coords, indexing="ij") gaussian = torch.exp(-(x_grid**2 + y_grid**2) / (2 * sigma**2)) return gaussian / gaussian.sum() - + def _get_gaussian_kernel(self, sigma: float) -> torch.Tensor: """Get cached Gaussian kernel""" # Calculate appropriate kernel size (6 sigma rule) size = max(3, int(6 * sigma + 1)) if size % 2 == 0: size += 1 - + key = (size, sigma) if key not in self._cached_gaussian_kernels: self._cached_gaussian_kernels[key] = self._create_gaussian_kernel(size, sigma) - + return self._cached_gaussian_kernels[key] - + def _create_laplacian_kernel(self) -> torch.Tensor: """Create Laplacian kernel for edge detection""" - kernel = torch.tensor([ - [0, -1, 0], - [-1, 4, -1], - [0, -1, 0] - ], dtype=self.dtype, device=self.device) + kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=self.dtype, device=self.device) return kernel - + def _get_laplacian_kernel(self) -> torch.Tensor: """Get cached Laplacian kernel""" if self._cached_laplacian_kernel is None: self._cached_laplacian_kernel = self._create_laplacian_kernel() return self._cached_laplacian_kernel - + def _create_edge_kernels(self) -> tuple: """Create Sobel edge detection kernels""" - sobel_x = torch.tensor([ - [-1, 0, 1], - [-2, 0, 2], - [-1, 0, 1] - ], dtype=self.dtype, device=self.device) - - sobel_y = torch.tensor([ - [-1, -2, -1], - [0, 0, 0], - [1, 2, 1] - ], dtype=self.dtype, device=self.device) - + sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=self.dtype, device=self.device) + + sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=self.dtype, device=self.device) + return sobel_x, sobel_y - + def _get_edge_kernels(self) -> tuple: """Get cached edge kernels""" if self._cached_edge_kernels is None: self._cached_edge_kernels = self._create_edge_kernels() return self._cached_edge_kernels - + def _apply_kernel(self, image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: """Apply convolution kernel to image""" num_channels = image.shape[1] padding = kernel.shape[-1] // 2 - + # Expand kernel for all channels kernel_conv = kernel.unsqueeze(0).unsqueeze(0).repeat(num_channels, 1, 1, 1) - + return F.conv2d(image, kernel_conv, padding=padding, groups=num_channels) - + def _gaussian_blur(self, image: torch.Tensor, sigma: float) -> torch.Tensor: """Apply Gaussian blur""" kernel = self._get_gaussian_kernel(sigma) return self._apply_kernel(image, kernel) - + def _unsharp_mask(self, image: torch.Tensor, radius: float, intensity: float) -> torch.Tensor: """Apply unsharp masking""" # Create blurred version blurred = self._gaussian_blur(image, radius) - + # Create mask (original - blurred) mask = image - blurred - + # Apply sharpening sharpened = image + intensity * mask - + return torch.clamp(sharpened, 0, 1) - + def _edge_enhancement(self, image: torch.Tensor, strength: float) -> torch.Tensor: """Enhance edges using Sobel operators""" sobel_x, sobel_y = self._get_edge_kernels() - + # Calculate gradients grad_x = self._apply_kernel(image, sobel_x) grad_y = self._apply_kernel(image, sobel_y) - + # Calculate edge magnitude edge_magnitude = torch.sqrt(grad_x**2 + grad_y**2) - + # Enhance edges enhanced = image + strength * edge_magnitude - + return torch.clamp(enhanced, 0, 1) - + def _detail_enhancement(self, image: torch.Tensor, strength: float) -> torch.Tensor: """Enhance fine details using Laplacian""" laplacian = self._get_laplacian_kernel() - + # Apply Laplacian filter details = self._apply_kernel(image, laplacian) - + # Add details back to image enhanced = image + strength * details - + return torch.clamp(enhanced, 0, 1) - + def _noise_reduction_light(self, image: torch.Tensor, strength: float) -> torch.Tensor: """Light noise reduction using small Gaussian blur""" if strength <= 0: return image - + # Very light blur to reduce noise noise_reduced = self._gaussian_blur(image, 0.3) - + # Blend with original return (1 - strength) * image + strength * noise_reduced - + def _multi_scale_sharpen(self, image: torch.Tensor) -> torch.Tensor: """Apply multi-scale sharpening for better quality""" - sharpen_intensity = self.params.get('sharpen_intensity', 1.5) - unsharp_radius = self.params.get('unsharp_radius', 1.0) - + sharpen_intensity = self.params.get("sharpen_intensity", 1.5) + unsharp_radius = self.params.get("unsharp_radius", 1.0) + # Multiple scales for better quality scales = [unsharp_radius * 0.5, unsharp_radius, unsharp_radius * 2.0] weights = [0.3, 0.5, 0.2] - + result = image.clone() - + for scale, weight in zip(scales, weights): # Apply unsharp mask at this scale sharpened_scale = self._unsharp_mask(image, scale, sharpen_intensity * weight) - + # Blend with result result = result + weight * (sharpened_scale - image) - + return torch.clamp(result, 0, 1) - + def _process_core(self, image: Image.Image) -> Image.Image: """Apply sharpening using PIL/numpy fallback""" # Convert to tensor for GPU processing tensor = self.pil_to_tensor(image) tensor = tensor.squeeze(0) # Remove batch dimension - + # Process on GPU sharpened = self._process_tensor_core(tensor) - + # Convert back to PIL return self.tensor_to_pil(sharpened) - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """GPU-intensive sharpening processing""" # Ensure batch dimension if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) - + # Ensure correct device and dtype image_tensor = image_tensor.to(device=self.device, dtype=self.dtype) - + # Get parameters - sharpen_intensity = self.params.get('sharpen_intensity', 1.5) - unsharp_radius = self.params.get('unsharp_radius', 1.0) - edge_enhancement = self.params.get('edge_enhancement', 0.5) - detail_boost = self.params.get('detail_boost', 0.3) - noise_reduction = self.params.get('noise_reduction', 0.1) - multi_scale = self.params.get('multi_scale', True) - + sharpen_intensity = self.params.get("sharpen_intensity", 1.5) + unsharp_radius = self.params.get("unsharp_radius", 1.0) + edge_enhancement = self.params.get("edge_enhancement", 0.5) + detail_boost = self.params.get("detail_boost", 0.3) + noise_reduction = self.params.get("noise_reduction", 0.1) + multi_scale = self.params.get("multi_scale", True) + result = image_tensor.clone() - + # Step 1: Light noise reduction (prevent amplification) if noise_reduction > 0: result = self._noise_reduction_light(result, noise_reduction) - + # Step 2: Main sharpening if multi_scale: # Multi-scale processing (more GPU intensive) @@ -284,16 +273,16 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: else: # Single-scale unsharp masking result = self._unsharp_mask(result, unsharp_radius, sharpen_intensity) - + # Step 3: Edge enhancement if edge_enhancement > 0: result = self._edge_enhancement(result, edge_enhancement) - + # Step 4: Fine detail enhancement if detail_boost > 0: result = self._detail_enhancement(result, detail_boost) - + # Final clamp to ensure valid range result = torch.clamp(result, 0, 1) - + return result diff --git a/src/streamdiffusion/preprocessing/processors/soft_edge.py b/src/streamdiffusion/preprocessing/processors/soft_edge.py index 67537982b..abbf37b88 100644 --- a/src/streamdiffusion/preprocessing/processors/soft_edge.py +++ b/src/streamdiffusion/preprocessing/processors/soft_edge.py @@ -1,9 +1,7 @@ import torch import torch.nn as nn -import torch.nn.functional as F -import numpy as np from PIL import Image -from typing import Union, Optional + from .base import BasePreprocessor @@ -12,95 +10,100 @@ class MultiScaleSobelOperator(nn.Module): Real-time multi-scale Sobel edge detector optimized for soft HED-like edges Based on the existing SobelOperator but enhanced for soft edge detection """ - + def __init__(self, device="cuda", dtype=torch.float16): super(MultiScaleSobelOperator, self).__init__() self.device = device self.dtype = dtype - + # Multi-scale edge detection (3 scales) self.edge_conv_x_1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(device) self.edge_conv_y_1 = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False).to(device) - + self.edge_conv_x_2 = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False).to(device) self.edge_conv_y_2 = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False).to(device) - + self.edge_conv_x_3 = nn.Conv2d(1, 1, kernel_size=7, padding=3, bias=False).to(device) self.edge_conv_y_3 = nn.Conv2d(1, 1, kernel_size=7, padding=3, bias=False).to(device) - + # Gaussian blur for soft edges self.blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, bias=False).to(device) - + self._setup_kernels() - + def _setup_kernels(self): """Setup Sobel kernels for different scales""" # Scale 1: Standard 3x3 Sobel - sobel_x_3 = torch.tensor([ - [-1.0, 0.0, 1.0], - [-2.0, 0.0, 2.0], - [-1.0, 0.0, 1.0] - ], device=self.device, dtype=self.dtype) - - sobel_y_3 = torch.tensor([ - [-1.0, -2.0, -1.0], - [0.0, 0.0, 0.0], - [1.0, 2.0, 1.0] - ], device=self.device, dtype=self.dtype) - + sobel_x_3 = torch.tensor( + [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=self.device, dtype=self.dtype + ) + + sobel_y_3 = torch.tensor( + [[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]], device=self.device, dtype=self.dtype + ) + # Scale 2: 5x5 Sobel - sobel_x_5 = torch.tensor([ - [-1, -2, 0, 2, 1], - [-2, -3, 0, 3, 2], - [-3, -5, 0, 5, 3], - [-2, -3, 0, 3, 2], - [-1, -2, 0, 2, 1] - ], device=self.device, dtype=self.dtype) / 16.0 - + sobel_x_5 = ( + torch.tensor( + [[-1, -2, 0, 2, 1], [-2, -3, 0, 3, 2], [-3, -5, 0, 5, 3], [-2, -3, 0, 3, 2], [-1, -2, 0, 2, 1]], + device=self.device, + dtype=self.dtype, + ) + / 16.0 + ) + sobel_y_5 = sobel_x_5.T - + # Scale 3: 7x7 Sobel (smoothed) - sobel_x_7 = torch.tensor([ - [-1, -2, -3, 0, 3, 2, 1], - [-2, -3, -4, 0, 4, 3, 2], - [-3, -4, -5, 0, 5, 4, 3], - [-4, -5, -6, 0, 6, 5, 4], - [-3, -4, -5, 0, 5, 4, 3], - [-2, -3, -4, 0, 4, 3, 2], - [-1, -2, -3, 0, 3, 2, 1] - ], device=self.device, dtype=self.dtype) / 32.0 - + sobel_x_7 = ( + torch.tensor( + [ + [-1, -2, -3, 0, 3, 2, 1], + [-2, -3, -4, 0, 4, 3, 2], + [-3, -4, -5, 0, 5, 4, 3], + [-4, -5, -6, 0, 6, 5, 4], + [-3, -4, -5, 0, 5, 4, 3], + [-2, -3, -4, 0, 4, 3, 2], + [-1, -2, -3, 0, 3, 2, 1], + ], + device=self.device, + dtype=self.dtype, + ) + / 32.0 + ) + sobel_y_7 = sobel_x_7.T - + # Gaussian kernel for smoothing - gaussian_5 = torch.tensor([ - [1, 4, 6, 4, 1], - [4, 16, 24, 16, 4], - [6, 24, 36, 24, 6], - [4, 16, 24, 16, 4], - [1, 4, 6, 4, 1] - ], device=self.device, dtype=self.dtype) / 256.0 - + gaussian_5 = ( + torch.tensor( + [[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [6, 24, 36, 24, 6], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]], + device=self.device, + dtype=self.dtype, + ) + / 256.0 + ) + # Set kernel weights self.edge_conv_x_1.weight = nn.Parameter(sobel_x_3.view(1, 1, 3, 3)) self.edge_conv_y_1.weight = nn.Parameter(sobel_y_3.view(1, 1, 3, 3)) - + self.edge_conv_x_2.weight = nn.Parameter(sobel_x_5.view(1, 1, 5, 5)) self.edge_conv_y_2.weight = nn.Parameter(sobel_y_5.view(1, 1, 5, 5)) - + self.edge_conv_x_3.weight = nn.Parameter(sobel_x_7.view(1, 1, 7, 7)) self.edge_conv_y_3.weight = nn.Parameter(sobel_y_7.view(1, 1, 7, 7)) - + self.blur.weight = nn.Parameter(gaussian_5.view(1, 1, 5, 5)) @torch.no_grad() def forward(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Fast multi-scale soft edge detection - + Args: image_tensor: Input tensor [B, C, H, W] or [C, H, W] - + Returns: Soft edge map tensor [B, 1, H, W] or [1, H, W] """ @@ -109,108 +112,108 @@ def forward(self, image_tensor: torch.Tensor) -> torch.Tensor: if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) squeeze_output = True - + # Convert to grayscale if needed if image_tensor.shape[1] == 3: # RGB to grayscale gray = 0.299 * image_tensor[:, 0:1] + 0.587 * image_tensor[:, 1:2] + 0.114 * image_tensor[:, 2:3] else: gray = image_tensor[:, 0:1] - + # Multi-scale edge detection # Scale 1 (fine details) edge_x1 = self.edge_conv_x_1(gray) edge_y1 = self.edge_conv_y_1(gray) edge1 = torch.sqrt(edge_x1**2 + edge_y1**2) - + # Scale 2 (medium details) edge_x2 = self.edge_conv_x_2(gray) edge_y2 = self.edge_conv_y_2(gray) edge2 = torch.sqrt(edge_x2**2 + edge_y2**2) - + # Scale 3 (coarse details) edge_x3 = self.edge_conv_x_3(gray) edge_y3 = self.edge_conv_y_3(gray) edge3 = torch.sqrt(edge_x3**2 + edge_y3**2) - + # Combine scales with weights (like HED side outputs) combined_edge = 0.5 * edge1 + 0.3 * edge2 + 0.2 * edge3 - + # Apply Gaussian smoothing for soft edges soft_edge = self.blur(combined_edge) - + # Normalize to [0, 1] range soft_edge = soft_edge / (soft_edge.max() + 1e-8) - + # Apply soft sigmoid activation for smooth transitions soft_edge = torch.sigmoid(soft_edge * 6.0 - 3.0) # Soft S-curve - + if squeeze_output: soft_edge = soft_edge.squeeze(0) - + return soft_edge class SoftEdgePreprocessor(BasePreprocessor): """ Real-time soft edge detection preprocessor - HED alternative - + Uses multi-scale Sobel operations for extremely fast soft edge detection that mimics HED output quality at 50x+ the speed. """ - + _model_cache = {} - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Soft Edge Detection", "description": "Real-time soft edge detection optimized for smooth, artistic edge maps using multi-scale Sobel operations.", "parameters": {}, - "use_cases": ["Artistic edge maps", "Soft stylistic control", "Real-time edge detection"] + "use_cases": ["Artistic edge maps", "Soft stylistic control", "Real-time edge detection"], } - + def __init__(self, **kwargs): """ Initialize soft edge preprocessor - + Args: **kwargs: Additional parameters """ super().__init__(**kwargs) self.model = None self._load_model() - + def _load_model(self): """ Load multi-scale Sobel operator with caching """ cache_key = f"soft_edge_{self.device}_{self.dtype}" - + if cache_key in self._model_cache: self.model = self._model_cache[cache_key] return - + print("SoftEdgePreprocessor: Loading real-time multi-scale edge detector") self.model = MultiScaleSobelOperator(device=self.device, dtype=self.dtype) self.model.eval() - + # Cache the model self._model_cache[cache_key] = self.model - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply soft edge detection to the input image """ # Convert PIL to tensor for GPU processing image_tensor = self.pil_to_tensor(image).squeeze(0) # Remove batch dim - + # Process with GPU-accelerated tensor method processed_tensor = self._process_tensor_core(image_tensor) - + # Convert back to PIL return self.tensor_to_pil(processed_tensor) - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ GPU-optimized soft edge processing using tensors @@ -218,25 +221,25 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: with torch.no_grad(): # Ensure correct input format and device image_tensor = image_tensor.to(device=self.device, dtype=self.dtype) - + # Normalize to [0, 1] if needed if image_tensor.max() > 1.0: image_tensor = image_tensor / 255.0 - + # Multi-scale edge detection edge_map = self.model(image_tensor) - + # Convert to 3-channel RGB format if edge_map.dim() == 3: edge_map = edge_map.repeat(3, 1, 1) else: edge_map = edge_map.repeat(1, 3, 1, 1).squeeze(0) - + # Ensure output is in [0, 1] range edge_map = torch.clamp(edge_map, 0.0, 1.0) - + return edge_map - + def get_model_info(self) -> dict: """ Get information about the loaded model @@ -248,24 +251,20 @@ def get_model_info(self) -> dict: "device": str(self.device), "dtype": str(self.dtype), "description": "Real-time multi-scale soft edge detection, HED quality at 50x+ speed", - "expected_fps": "100+ FPS at 512x512" + "expected_fps": "100+ FPS at 512x512", } - + @classmethod - def create_optimized(cls, device: str = 'cuda', dtype: torch.dtype = torch.float16, **kwargs): + def create_optimized(cls, device: str = "cuda", dtype: torch.dtype = torch.float16, **kwargs): """ Create an optimized soft edge preprocessor for real-time use - + Args: device: Target device ('cuda' or 'cpu') dtype: Data type for inference **kwargs: Additional parameters - + Returns: Optimized SoftEdgePreprocessor instance """ - return cls( - device=device, - dtype=dtype, - **kwargs - ) \ No newline at end of file + return cls(device=device, dtype=dtype, **kwargs) diff --git a/src/streamdiffusion/preprocessing/processors/standard_lineart.py b/src/streamdiffusion/preprocessing/processors/standard_lineart.py index bc732ea04..81a8ade1a 100644 --- a/src/streamdiffusion/preprocessing/processors/standard_lineart.py +++ b/src/streamdiffusion/preprocessing/processors/standard_lineart.py @@ -1,22 +1,22 @@ -import numpy as np -import cv2 -from PIL import Image -from typing import Union, Optional import time -from .base import BasePreprocessor + +import numpy as np import torch import torch.nn.functional as F +from PIL import Image + +from .base import BasePreprocessor class StandardLineartPreprocessor(BasePreprocessor): """ Real-time optimized Standard Lineart detection preprocessor for ControlNet - + Extracts line art from input images using traditional computer vision techniques. Uses Gaussian blur and intensity calculations to detect lines without requiring pre-trained models. GPU-accelerated with PyTorch for optimal real-time performance. """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -28,27 +28,29 @@ def get_preprocessor_metadata(cls): "default": 6.0, "range": [1.0, 20.0], "step": 0.1, - "description": "Standard deviation for Gaussian blur (higher = smoother lines)" + "description": "Standard deviation for Gaussian blur (higher = smoother lines)", }, "intensity_threshold": { "type": "int", "default": 8, "range": [1, 50], - "description": "Threshold for intensity calculation (lower = more sensitive)" - } + "description": "Threshold for intensity calculation (lower = more sensitive)", + }, }, - "use_cases": ["Traditional line art", "Simple edge detection", "No AI model required"] + "use_cases": ["Traditional line art", "Simple edge detection", "No AI model required"], } - - def __init__(self, - detect_resolution: int = 512, - image_resolution: int = 512, - gaussian_sigma: float = 6.0, - intensity_threshold: int = 8, - **kwargs): + + def __init__( + self, + detect_resolution: int = 512, + image_resolution: int = 512, + gaussian_sigma: float = 6.0, + intensity_threshold: int = 8, + **kwargs, + ): """ Initialize Standard Lineart preprocessor - + Args: detect_resolution: Resolution for line art detection image_resolution: Output image resolution @@ -56,39 +58,39 @@ def __init__(self, intensity_threshold: Threshold for intensity calculation **kwargs: Additional parameters """ - + super().__init__( detect_resolution=detect_resolution, image_resolution=image_resolution, gaussian_sigma=gaussian_sigma, intensity_threshold=intensity_threshold, - **kwargs + **kwargs, ) - + # Initialize GPU device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + def _gaussian_kernel(self, kernel_size: int, sigma: float, device=None) -> torch.Tensor: """Create 2D Gaussian kernel - based on existing codebase pattern""" x, y = torch.meshgrid( - torch.linspace(-1, 1, kernel_size, device=device), - torch.linspace(-1, 1, kernel_size, device=device), - indexing="ij" + torch.linspace(-1, 1, kernel_size, device=device), + torch.linspace(-1, 1, kernel_size, device=device), + indexing="ij", ) d = torch.sqrt(x * x + y * y) g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) return g / g.sum() - + def _gaussian_blur_torch(self, image: torch.Tensor, sigma: float) -> torch.Tensor: """Apply Gaussian blur using PyTorch - GPU accelerated""" # Calculate kernel size from sigma (odd number) kernel_size = int(2 * torch.ceil(torch.tensor(3 * sigma)) + 1) if kernel_size % 2 == 0: kernel_size += 1 - + # Create Gaussian kernel kernel = self._gaussian_kernel(kernel_size, sigma, device=image.device) - + # Handle different input shapes if image.dim() == 3: # HWC format H, W, C = image.shape @@ -100,31 +102,31 @@ def _gaussian_blur_torch(self, image: torch.Tensor, sigma: float) -> torch.Tenso needs_reshape = False else: raise ValueError(f"standardlineart_gaussian_blur_torch: Unsupported image shape: {image.shape}") - + # Expand kernel for all channels kernel = kernel.repeat(image.shape[1], 1, 1).unsqueeze(1) - + # Apply blur with reflection padding padding = kernel_size // 2 - padded_image = F.pad(image, (padding, padding, padding, padding), 'reflect') + padded_image = F.pad(image, (padding, padding, padding, padding), "reflect") blurred = F.conv2d(padded_image, kernel, padding=0, groups=image.shape[1]) - + # Convert back to original format if needed if needs_reshape: blurred = blurred.squeeze(0).permute(1, 2, 0) # BCHW -> HWC - + return blurred - + def _ensure_hwc3_torch(self, x: torch.Tensor) -> torch.Tensor: """Ensure image has 3 channels (HWC3 format) - PyTorch version""" if x.dim() == 2: x = x.unsqueeze(-1) # Add channel dimension - + if x.dim() != 3: raise ValueError(f"standardlineart_ensure_hwc3_torch: Expected 2D or 3D tensor, got {x.dim()}D") - + H, W, C = x.shape - + if C == 3: return x elif C == 1: @@ -136,90 +138,87 @@ def _ensure_hwc3_torch(self, x: torch.Tensor) -> torch.Tensor: return torch.clamp(y, 0, 255) else: raise ValueError(f"standardlineart_ensure_hwc3_torch: Unsupported channel count: {C}") - + def _pad64(self, x: int) -> int: """Pad to nearest multiple of 64""" return int(torch.ceil(torch.tensor(float(x) / 64.0)) * 64 - x) - + def _resize_image_with_pad_torch(self, input_image: torch.Tensor, resolution: int) -> tuple: """Resize image with padding to target resolution - PyTorch GPU accelerated""" img = self._ensure_hwc3_torch(input_image) H_raw, W_raw, _ = img.shape - + if resolution == 0: return img, lambda x: x - + k = float(resolution) / float(min(H_raw, W_raw)) H_target = int(torch.round(torch.tensor(float(H_raw) * k))) W_target = int(torch.round(torch.tensor(float(W_raw) * k))) - + # Convert to BCHW for interpolation img_bchw = img.permute(2, 0, 1).unsqueeze(0) # HWC -> BCHW - + # Use PyTorch's interpolate for GPU-accelerated resize - mode = 'bicubic' if k > 1 else 'area' + mode = "bicubic" if k > 1 else "area" img_resized_bchw = F.interpolate( - img_bchw, - size=(H_target, W_target), - mode=mode, - align_corners=False if mode == 'bicubic' else None + img_bchw, size=(H_target, W_target), mode=mode, align_corners=False if mode == "bicubic" else None ) - + # Convert back to HWC img_resized = img_resized_bchw.squeeze(0).permute(1, 2, 0) - + # Apply padding H_pad, W_pad = self._pad64(H_target), self._pad64(W_target) - img_padded = F.pad(img_resized.permute(2, 0, 1), (0, W_pad, 0, H_pad), mode='replicate').permute(1, 2, 0) + img_padded = F.pad(img_resized.permute(2, 0, 1), (0, W_pad, 0, H_pad), mode="replicate").permute(1, 2, 0) def remove_pad(x): return x[:H_target, :W_target, ...] return img_padded, remove_pad - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply standard line art detection to the input image """ start_time = time.time() - + if isinstance(image, Image.Image): input_image_cpu = np.array(image, dtype=np.uint8) else: input_image_cpu = image.astype(np.uint8) - + input_image = torch.from_numpy(input_image_cpu).float().to(self.device) - - detect_resolution = self.params.get('detect_resolution', 512) - gaussian_sigma = self.params.get('gaussian_sigma', 6.0) - intensity_threshold = self.params.get('intensity_threshold', 8) - + + detect_resolution = self.params.get("detect_resolution", 512) + gaussian_sigma = self.params.get("gaussian_sigma", 6.0) + intensity_threshold = self.params.get("intensity_threshold", 8) + input_image, remove_pad = self._resize_image_with_pad_torch(input_image, detect_resolution) - + x = input_image - + g = self._gaussian_blur_torch(x, gaussian_sigma) - + intensity = torch.min(g - x, dim=2)[0] intensity = torch.clamp(intensity, 0, 255) - + threshold_mask = intensity > intensity_threshold if torch.any(threshold_mask): median_val = torch.median(intensity[threshold_mask]) normalization_factor = max(16, float(median_val)) else: normalization_factor = 16 - + intensity = intensity / normalization_factor intensity = intensity * 127 - + detected_map = torch.clamp(intensity, 0, 255).byte() detected_map = detected_map.unsqueeze(-1) detected_map = self._ensure_hwc3_torch(detected_map.float()) - + detected_map = remove_pad(detected_map) - + detected_map_cpu = detected_map.byte().cpu().numpy() lineart_image = Image.fromarray(detected_map_cpu) - - return lineart_image \ No newline at end of file + + return lineart_image diff --git a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py index 8151722ac..94930e1ad 100644 --- a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py @@ -1,26 +1,32 @@ -import torch -import torch.nn.functional as F -import numpy as np -from PIL import Image import logging from pathlib import Path from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + from .base import PipelineAwareProcessor + # Try to import TensorRT dependencies try: + from collections import OrderedDict + import tensorrt as trt from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False # Try to import torchvision for RAFT model try: - from torchvision.models.optical_flow import raft_small, Raft_Small_Weights + from torchvision.models.optical_flow import Raft_Small_Weights, raft_small from torchvision.utils import flow_to_image + TORCHVISION_AVAILABLE = True except ImportError: TORCHVISION_AVAILABLE = False @@ -48,7 +54,7 @@ class TensorRTEngine: """TensorRT engine wrapper for RAFT optical flow inference""" - + def __init__(self, engine_path): self.engine_path = engine_path self.engine = None @@ -69,7 +75,7 @@ def activate(self): def allocate_buffers(self, device="cuda", input_shape=None): """ Allocate input/output buffers - + Args: device: Device to allocate tensors on input_shape: Shape for input tensors (B, C, H, W). Required for engines with dynamic shapes. @@ -88,7 +94,6 @@ def allocate_buffers(self, device="cuda", input_shape=None): else: raise - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: # For dynamic shapes, use provided input_shape if input_shape is not None and any(dim == -1 for dim in shape): @@ -99,7 +104,7 @@ def allocate_buffers(self, device="cuda", input_shape=None): else: # For output tensors, get shape after input shapes are set shape = self.context.get_tensor_shape(name) - + # Verify shape has no dynamic dimensions if any(dim == -1 for dim in shape): raise RuntimeError( @@ -114,7 +119,7 @@ def infer(self, feed_dict, stream=None): """Run inference with optional stream parameter""" if stream is None: stream = self._cuda_stream - + # Check if we need to update tensor shapes for dynamic dimensions need_realloc = False for name, buf in feed_dict.items(): @@ -122,7 +127,7 @@ def infer(self, feed_dict, stream=None): if self.tensors[name].shape != buf.shape: need_realloc = True break - + # Reallocate buffers if input shape changed if need_realloc: # Update input shapes @@ -134,18 +139,18 @@ def infer(self, feed_dict, stream=None): except: # Tensor name might not be in engine, skip pass - + # Reallocate all tensors with new shapes for idx in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(idx) shape = self.context.get_tensor_shape(name) dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=self.tensors[name].device) + + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to( + device=self.tensors[name].device + ) self.tensors[name] = tensor - + # Copy input data to tensors for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) @@ -153,26 +158,26 @@ def infer(self, feed_dict, stream=None): # Set tensor addresses for name, tensor in self.tensors.items(): self.context.set_tensor_address(name, tensor.data_ptr()) - + # Execute inference success = self.context.execute_async_v3(stream) if not success: raise ValueError("TensorRT inference failed.") - + return self.tensors class TemporalNetTensorRTPreprocessor(PipelineAwareProcessor): """ TensorRT-accelerated TemporalNet preprocessor for temporal consistency using optical flow visualization. - + This preprocessor uses TensorRT to accelerate RAFT optical flow computation and creates a 6-channel control tensor by concatenating the previous input frame (RGB) with a colorized optical flow visualization (RGB) computed between the previous and current input frames. - + Output: [prev_input_RGB, flow_RGB(prev_input → current_input)] """ - + @classmethod def get_preprocessor_metadata(cls): return { @@ -182,53 +187,59 @@ def get_preprocessor_metadata(cls): "engine_path": { "type": "str", "default": None, - "description": "Path to pre-built TensorRT engine file. Use compile_raft_tensorrt.py to build one." + "description": "Path to pre-built TensorRT engine file. Use compile_raft_tensorrt.py to build one.", }, "flow_strength": { "type": "float", "default": 1.0, "range": [0.0, 2.0], "step": 0.1, - "description": "Strength multiplier for optical flow visualization (1.0 = normal, higher = more pronounced flow)" + "description": "Strength multiplier for optical flow visualization (1.0 = normal, higher = more pronounced flow)", }, "height": { "type": "int", "default": 512, "range": [256, 1024], "step": 64, - "description": "Height for optical flow computation (must be within engine's height range)" + "description": "Height for optical flow computation (must be within engine's height range)", }, "width": { "type": "int", "default": 512, "range": [256, 1024], "step": 64, - "description": "Width for optical flow computation (must be within engine's width range)" + "description": "Width for optical flow computation (must be within engine's width range)", }, "output_format": { - "type": "str", + "type": "str", "default": "concat", "options": ["concat", "warped_only"], - "description": "Output format: 'concat' for 6-channel (prev_input+flow_RGB), 'warped_only' for 3-channel flow RGB only" - } + "description": "Output format: 'concat' for 6-channel (prev_input+flow_RGB), 'warped_only' for 3-channel flow RGB only", + }, }, - "use_cases": ["High-performance video generation", "Real-time temporal consistency", "GPU-optimized motion control"] + "use_cases": [ + "High-performance video generation", + "Real-time temporal consistency", + "GPU-optimized motion control", + ], } - - def __init__(self, - pipeline_ref: Any, - engine_path: str = None, - height: int = 512, - width: int = 512, - flow_strength: float = 1.0, - output_format: str = "concat", - **kwargs): + + def __init__( + self, + pipeline_ref: Any, + engine_path: str = None, + height: int = 512, + width: int = 512, + flow_strength: float = 1.0, + output_format: str = "concat", + **kwargs, + ): """ Initialize TensorRT TemporalNet preprocessor - + Args: pipeline_ref: Reference to the StreamDiffusion pipeline instance (required) - engine_path: Path to pre-built TensorRT engine file (required). + engine_path: Path to pre-built TensorRT engine file (required). Build one using: python -m streamdiffusion.tools.compile_raft_tensorrt height: Height for optical flow computation (must be within engine's height range) width: Width for optical flow computation (must be within engine's width range) @@ -236,13 +247,12 @@ def __init__(self, output_format: "concat" for 6-channel [prev_input+flow_RGB], "warped_only" for 3-channel flow RGB only **kwargs: Additional parameters passed to BasePreprocessor """ - + if not TORCHVISION_AVAILABLE: raise ImportError( - "torchvision is required for TemporalNet preprocessing. " - "Install it with: pip install torchvision" + "torchvision is required for TemporalNet preprocessing. Install it with: pip install torchvision" ) - + if not TENSORRT_AVAILABLE: raise ImportError( "TensorRT and polygraphy are required for TensorRT acceleration. " @@ -255,7 +265,7 @@ def __init__(self, " python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution 512x512 --max_resolution 1024x1024 --output_dir ./models/temporal_net\n" "Then pass the engine path to this preprocessor." ) - + super().__init__( pipeline_ref=pipeline_ref, height=height, @@ -263,17 +273,17 @@ def __init__(self, engine_path=engine_path, flow_strength=flow_strength, output_format=output_format, - **kwargs + **kwargs, ) - + self.flow_strength = max(0.0, min(2.0, flow_strength)) self.height = height self.width = width self._first_frame = True - + # Store previous input frame for flow computation self.prev_input = None - + # Engine path self.engine_path = Path(engine_path) if not self.engine_path.exists(): @@ -282,17 +292,17 @@ def __init__(self, f"Build one using:\n" f" python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution {height}x{width} --max_resolution {height}x{width} --output_dir {self.engine_path.parent}" ) - + # Model state self.trt_engine = None - + # Cached tensors for performance self._grid_cache = {} self._tensor_cache = {} - + # Load TensorRT engine self._load_tensorrt_engine() - + def _load_tensorrt_engine(self): """Load pre-built TensorRT engine""" logger.info(f"_load_tensorrt_engine: Loading TensorRT engine: {self.engine_path}") @@ -300,11 +310,11 @@ def _load_tensorrt_engine(self): self.trt_engine = TensorRTEngine(str(self.engine_path)) self.trt_engine.load() self.trt_engine.activate() - + # For dynamic shapes, provide the input shape based on image dimensions input_shape = (1, 3, self.height, self.width) self.trt_engine.allocate_buffers(device=self.device, input_shape=input_shape) - + logger.info(f"_load_tensorrt_engine: TensorRT engine loaded successfully from {self.engine_path}") logger.info(f"_load_tensorrt_engine: Using resolution: {self.height}x{self.width}") except Exception as e: @@ -315,16 +325,14 @@ def _load_tensorrt_engine(self): f"Make sure the engine was built with a resolution range that includes {self.height}x{self.width}.\n" f"For example: python -m streamdiffusion.tools.compile_raft_tensorrt --min_resolution 512x512 --max_resolution 1024x1024" ) - - def _process_core(self, image: Image.Image) -> Image.Image: """ Process using TensorRT-accelerated optical flow warping - + Args: image: Current input image - + Returns: Warped previous frame for temporal guidance, or fallback for first frame """ @@ -332,50 +340,50 @@ def _process_core(self, image: Image.Image) -> Image.Image: tensor = self.pil_to_tensor(image) result_tensor = self._process_tensor_core(tensor) return self.tensor_to_pil(result_tensor) - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ Process using TensorRT-accelerated optical flow computation (GPU-optimized path) - + Args: tensor: Current input tensor - + Returns: Concatenated tensor: [prev_input_RGB, flow_RGB] for temporal guidance """ - + # Normalize input tensor input_tensor = tensor if input_tensor.max() > 1.0: input_tensor = input_tensor / 255.0 - + # Ensure consistent format if input_tensor.dim() == 4 and input_tensor.shape[0] == 1: input_tensor = input_tensor[0] - + # Check if we have a previous input frame if self.prev_input is not None and not self._first_frame: try: # Compute optical flow between prev_input -> current_input flow_rgb_tensor = self._compute_flow_to_rgb_tensor(self.prev_input, input_tensor) - + # Check output format - output_format = self.params.get('output_format', 'concat') + output_format = self.params.get("output_format", "concat") if output_format == "concat": # Concatenate prev_input + flow_RGB for TemporalNet2 (6 channels) result_tensor = self._concatenate_frames_tensor(self.prev_input, flow_rgb_tensor) else: # Return only flow RGB (3 channels) result_tensor = flow_rgb_tensor - + # Ensure correct output format if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) - + result = result_tensor.to(device=self.device, dtype=self.dtype) except Exception as e: logger.error(f"_process_tensor_core: TensorRT optical flow failed: {e}") - output_format = self.params.get('output_format', 'concat') + output_format = self.params.get("output_format", "concat") if output_format == "concat": # Create 6-channel fallback by concatenating prev_input with itself result_tensor = self._concatenate_frames_tensor(self.prev_input, self.prev_input) @@ -393,19 +401,19 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: self._first_frame = False if tensor.dim() == 3: tensor = tensor.unsqueeze(0) - + # Handle 6-channel output for first frame - output_format = self.params.get('output_format', 'concat') + output_format = self.params.get("output_format", "concat") if output_format == "concat": # For first frame, concatenate current frame with zeros (no flow) if tensor.dim() == 4 and tensor.shape[0] == 1: current_tensor = tensor[0] else: current_tensor = tensor - + # Create zero tensor for flow (same shape as current_tensor) zero_flow = torch.zeros_like(current_tensor, device=self.device, dtype=current_tensor.dtype) - + result_tensor = self._concatenate_frames_tensor(current_tensor, zero_flow) if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) @@ -420,164 +428,148 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) - + # Store current input as previous for next frame self.prev_input = input_tensor.clone() - + return result - - def _compute_flow_to_rgb_tensor(self, prev_input_tensor: torch.Tensor, current_input_tensor: torch.Tensor) -> torch.Tensor: + + def _compute_flow_to_rgb_tensor( + self, prev_input_tensor: torch.Tensor, current_input_tensor: torch.Tensor + ) -> torch.Tensor: """ Compute optical flow between prev_input -> current_input and convert to RGB visualization - + Args: prev_input_tensor: Previous input frame tensor (CHW format, [0,1]) on GPU current_input_tensor: Current input frame tensor (CHW format, [0,1]) on GPU - + Returns: Flow visualization as RGB tensor (CHW format, [0,1]) on GPU """ target_width, target_height = self.get_target_dimensions() - + # Convert to float32 for TensorRT processing prev_tensor = prev_input_tensor.to(device=self.device, dtype=torch.float32) current_tensor = current_input_tensor.to(device=self.device, dtype=torch.float32) - + # Resize for flow computation if needed (keep on GPU) if current_tensor.shape[-1] != self.width or current_tensor.shape[-2] != self.height: prev_resized = F.interpolate( - prev_tensor.unsqueeze(0), - size=(self.height, self.width), - mode='bilinear', - align_corners=False + prev_tensor.unsqueeze(0), size=(self.height, self.width), mode="bilinear", align_corners=False ).squeeze(0) current_resized = F.interpolate( - current_tensor.unsqueeze(0), - size=(self.height, self.width), - mode='bilinear', - align_corners=False + current_tensor.unsqueeze(0), size=(self.height, self.width), mode="bilinear", align_corners=False ).squeeze(0) else: prev_resized = prev_tensor current_resized = current_tensor - + # Compute optical flow using TensorRT: prev_input -> current_input flow = self._compute_optical_flow_tensorrt(prev_resized, current_resized) - + # Apply flow strength scaling (GPU operation) - flow_strength = self.params.get('flow_strength', 1.0) + flow_strength = self.params.get("flow_strength", 1.0) if flow_strength != 1.0: flow = flow * flow_strength - + # Convert flow to RGB visualization using torchvision's flow_to_image # flow_to_image expects (2, H, W) and returns (3, H, W) in range [0, 255] flow_rgb = flow_to_image(flow) # Returns uint8 tensor [0, 255] - + # Convert to float [0, 1] range flow_rgb = flow_rgb.float() / 255.0 - + # Resize back to target resolution if needed (keep on GPU) if flow_rgb.shape[-1] != target_width or flow_rgb.shape[-2] != target_height: flow_rgb = F.interpolate( - flow_rgb.unsqueeze(0), - size=(target_height, target_width), - mode='bilinear', - align_corners=False + flow_rgb.unsqueeze(0), size=(target_height, target_width), mode="bilinear", align_corners=False ).squeeze(0) - + # Convert to processor's dtype only at the very end result = flow_rgb.to(dtype=self.dtype) - + return result - + def _compute_optical_flow_tensorrt(self, frame1: torch.Tensor, frame2: torch.Tensor) -> torch.Tensor: """ Compute optical flow between two frames using TensorRT-accelerated RAFT - + Args: frame1: First frame tensor (CHW format, [0,1]) frame2: Second frame tensor (CHW format, [0,1]) - + Returns: Optical flow tensor (2HW format) """ - + if self.trt_engine is None: raise RuntimeError("_compute_optical_flow_tensorrt: TensorRT engine not loaded") - + # Prepare inputs for TensorRT frame1_batch = frame1.unsqueeze(0) frame2_batch = frame2.unsqueeze(0) - + # Apply RAFT preprocessing if available weights = Raft_Small_Weights.DEFAULT - if hasattr(weights, 'transforms') and weights.transforms is not None: + if hasattr(weights, "transforms") and weights.transforms is not None: transforms = weights.transforms() frame1_batch, frame2_batch = transforms(frame1_batch, frame2_batch) - + # Run TensorRT inference - feed_dict = { - 'frame1': frame1_batch, - 'frame2': frame2_batch - } - + feed_dict = {"frame1": frame1_batch, "frame2": frame2_batch} + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.trt_engine.infer(feed_dict, cuda_stream) - flow = result['flow'][0] # Remove batch dimension - + flow = result["flow"][0] # Remove batch dimension + return flow - - def _warp_frame_tensor(self, frame: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: """ Warp frame using optical flow with cached coordinate grids - + Args: frame: Frame to warp (CHW format) flow: Optical flow (2HW format) - + Returns: Warped frame tensor """ H, W = frame.shape[-2:] - + # Use cached grid if available grid_key = (H, W) if grid_key not in self._grid_cache: grid_y, grid_x = torch.meshgrid( torch.arange(H, device=self.device, dtype=torch.float32), torch.arange(W, device=self.device, dtype=torch.float32), - indexing='ij' + indexing="ij", ) self._grid_cache[grid_key] = (grid_x, grid_y) else: grid_x, grid_y = self._grid_cache[grid_key] - + # Apply flow to coordinates new_x = grid_x + flow[0] new_y = grid_y + flow[1] - + # Normalize coordinates to [-1, 1] for grid_sample new_x = 2.0 * new_x / (W - 1) - 1.0 new_y = 2.0 * new_y / (H - 1) - 1.0 - + # Create sampling grid (HW2 format for grid_sample) grid = torch.stack([new_x, new_y], dim=-1).unsqueeze(0) - + # Warp frame warped_batch = F.grid_sample( - frame.unsqueeze(0), - grid, - mode='bilinear', - padding_mode='border', - align_corners=True + frame.unsqueeze(0), grid, mode="bilinear", padding_mode="border", align_corners=True ) - + result = warped_batch.squeeze(0) - + return result - + def _concatenate_frames(self, current_image: Image.Image, warped_image: Image.Image) -> Image.Image: """Concatenate current frame and warped previous frame for TemporalNet2 (6-channel input)""" # Convert to tensors and use tensor concatenation for consistency @@ -585,43 +577,43 @@ def _concatenate_frames(self, current_image: Image.Image, warped_image: Image.Im warped_tensor = self.pil_to_tensor(warped_image).squeeze(0) result_tensor = self._concatenate_frames_tensor(current_tensor, warped_tensor) return self.tensor_to_pil(result_tensor) - + def _concatenate_frames_tensor(self, current_tensor: torch.Tensor, warped_tensor: torch.Tensor) -> torch.Tensor: """ Concatenate current frame and warped previous frame tensors for TemporalNet2 (6-channel input) - + Args: current_tensor: Current input frame tensor (CHW format) warped_tensor: Warped previous frame tensor (CHW format) - + Returns: Concatenated tensor (6CHW format) """ # Ensure same size if current_tensor.shape != warped_tensor.shape: target_width, target_height = self.get_target_dimensions() - + if current_tensor.shape[-2:] != (target_height, target_width): current_tensor = F.interpolate( current_tensor.unsqueeze(0), size=(target_height, target_width), - mode='bilinear', - align_corners=False + mode="bilinear", + align_corners=False, ).squeeze(0) - + if warped_tensor.shape[-2:] != (target_height, target_width): warped_tensor = F.interpolate( warped_tensor.unsqueeze(0), size=(target_height, target_width), - mode='bilinear', - align_corners=False + mode="bilinear", + align_corners=False, ).squeeze(0) - + # Concatenate along channel dimension: [current_R, current_G, current_B, warped_R, warped_G, warped_B] concatenated = torch.cat([current_tensor, warped_tensor], dim=0) - + return concatenated - + def reset(self): """ Reset the preprocessor state (useful for new sequences) @@ -631,4 +623,4 @@ def reset(self): # Clear caches to free memory self._grid_cache.clear() self._tensor_cache.clear() - torch.cuda.empty_cache() \ No newline at end of file + torch.cuda.empty_cache() diff --git a/src/streamdiffusion/preprocessing/processors/upscale.py b/src/streamdiffusion/preprocessing/processors/upscale.py index 82659b7b8..38a69d499 100644 --- a/src/streamdiffusion/preprocessing/processors/upscale.py +++ b/src/streamdiffusion/preprocessing/processors/upscale.py @@ -1,7 +1,9 @@ +from typing import Literal + import torch import torch.nn.functional as F from PIL import Image -from typing import Literal + from .base import BasePreprocessor @@ -10,8 +12,8 @@ class UpscalePreprocessor(BasePreprocessor): Image upscaling preprocessor with multiple interpolation algorithms. Supports bilinear, lanczos, bicubic, and nearest neighbor upscaling. """ - - @classmethod + + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Upscale", @@ -21,71 +23,71 @@ def get_preprocessor_metadata(cls): "type": "float", "default": 2.0, "range": [1.0, 4.0], - "description": "Upscaling factor" + "description": "Upscaling factor", }, "algorithm": { "type": "str", "default": "bilinear", "options": ["bilinear", "lanczos", "bicubic", "nearest"], - "description": "Interpolation algorithm: bilinear (fast), lanczos (high quality), bicubic (balanced), nearest (pixel art)" - } + "description": "Interpolation algorithm: bilinear (fast), lanczos (high quality), bicubic (balanced), nearest (pixel art)", + }, }, - "use_cases": ["Real-time upscaling", "Image enhancement", "Resolution conversion"] + "use_cases": ["Real-time upscaling", "Image enhancement", "Resolution conversion"], } - - def __init__(self, scale_factor: float = 2.0, algorithm: Literal["bilinear", "lanczos", "bicubic", "nearest"] = "bilinear", **kwargs): + + def __init__( + self, + scale_factor: float = 2.0, + algorithm: Literal["bilinear", "lanczos", "bicubic", "nearest"] = "bilinear", + **kwargs, + ): super().__init__(scale_factor=scale_factor, algorithm=algorithm, **kwargs) self.scale_factor = scale_factor self.algorithm = algorithm - + # Map algorithm names to PIL and PyTorch modes self.pil_resample_map = { "bilinear": Image.BILINEAR, "lanczos": Image.LANCZOS, "bicubic": Image.BICUBIC, - "nearest": Image.NEAREST + "nearest": Image.NEAREST, } - + self.torch_mode_map = { "bilinear": "bilinear", "lanczos": "bicubic", # PyTorch doesn't have lanczos, use bicubic as closest "bicubic": "bicubic", - "nearest": "nearest" + "nearest": "nearest", } - + def _process_core(self, image: Image.Image) -> Image.Image: """PIL-based upscaling""" target_width, target_height = self.get_target_dimensions() resample_method = self.pil_resample_map.get(self.algorithm, Image.BILINEAR) return image.resize((target_width, target_height), resample_method) - + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """Tensor-based upscaling""" target_width, target_height = self.get_target_dimensions() - + if tensor.dim() == 3: tensor = tensor.unsqueeze(0) - + mode = self.torch_mode_map.get(self.algorithm, "bilinear") - + if mode in ["bilinear", "bicubic"]: - return F.interpolate(tensor, size=(target_height, target_width), - mode=mode, align_corners=False) + return F.interpolate(tensor, size=(target_height, target_width), mode=mode, align_corners=False) else: # nearest - return F.interpolate(tensor, size=(target_height, target_width), - mode=mode) - + return F.interpolate(tensor, size=(target_height, target_width), mode=mode) + def get_target_dimensions(self): """Handle scale factor for dimensions""" - width = self.params.get('image_width') - height = self.params.get('image_height') - + width = self.params.get("image_width") + height = self.params.get("image_height") + if width is not None and height is not None: return (int(width * self.scale_factor), int(height * self.scale_factor)) - - base_resolution = self.params.get('image_resolution', 512) + + base_resolution = self.params.get("image_resolution", 512) target_resolution = int(base_resolution * self.scale_factor) return (target_resolution, target_resolution) - - - diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 8f81b6f89..a344d3351 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -1,15 +1,18 @@ -from typing import List, Optional, Dict, Tuple, Literal, Any, Callable +import logging import threading +from typing import Any, Dict, List, Literal, Optional, Tuple + import torch import torch.nn.functional as F -import gc -import logging + logger = logging.getLogger(__name__) from .preprocessing.orchestrator_user import OrchestratorUser + class CacheStats: """Helper class to track cache statistics""" + def __init__(self): self.hits = 0 self.misses = 0 @@ -22,7 +25,13 @@ def record_miss(self): class StreamParameterUpdater(OrchestratorUser): - def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: bool = True, normalize_seed_weights: bool = True): + def __init__( + self, + stream_diffusion, + wrapper=None, + normalize_prompt_weights: bool = True, + normalize_seed_weights: bool = True, + ): self.stream = stream_diffusion self.wrapper = wrapper # Reference to wrapper for accessing pipeline structure self.normalize_prompt_weights = normalize_prompt_weights @@ -39,8 +48,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._seed_cache: Dict[int, Dict] = {} self._current_seed_list: List[Tuple[int, float]] = [] self._seed_cache_stats = CacheStats() - - + # Attach shared orchestrator once (lazy-creates on stream if absent) self.attach_orchestrator(self.stream) @@ -50,6 +58,7 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._current_style_images: Dict[str, Any] = {} # Use the shared orchestrator attached via OrchestratorUser self._embedding_orchestrator = self._preprocessing_orchestrator + def get_cache_info(self) -> Dict: """Get cache statistics for monitoring performance.""" total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses @@ -68,7 +77,7 @@ def get_cache_info(self) -> Dict: "seed_cache_hits": self._seed_cache_stats.hits, "seed_cache_misses": self._seed_cache_stats.misses, "seed_hit_rate": f"{seed_hit_rate:.2%}", - "current_seeds": len(self._current_seed_list) + "current_seeds": len(self._current_seed_list), } def clear_caches(self) -> None: @@ -81,7 +90,7 @@ def clear_caches(self) -> None: self._seed_cache.clear() self._current_seed_list.clear() self._seed_cache_stats = CacheStats() - + # Clear embedding caches self._embedding_cache.clear() self._current_style_images.clear() @@ -93,13 +102,13 @@ def get_normalize_prompt_weights(self) -> bool: def get_normalize_seed_weights(self) -> bool: """Get the current seed weight normalization setting.""" return self.normalize_seed_weights - + # Deprecated enhancer registration removed; embedding composition is handled via stream.embedding_hooks def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: str) -> None: """ Register an embedding preprocessor for parallel processing. - + Args: preprocessor: IPAdapterEmbeddingPreprocessor instance style_image_key: Unique key for the style image this preprocessor handles @@ -108,28 +117,27 @@ def register_embedding_preprocessor(self, preprocessor: Any, style_image_key: st # Ensure orchestrator is present self.attach_orchestrator(self.stream) self._embedding_orchestrator = self._preprocessing_orchestrator - + self._embedding_preprocessors.append((preprocessor, style_image_key)) - + def unregister_embedding_preprocessor(self, style_image_key: str) -> None: """Unregister an embedding preprocessor by style image key.""" original_count = len(self._embedding_preprocessors) self._embedding_preprocessors = [ - (preprocessor, key) for preprocessor, key in self._embedding_preprocessors - if key != style_image_key + (preprocessor, key) for preprocessor, key in self._embedding_preprocessors if key != style_image_key ] removed_count = original_count - len(self._embedding_preprocessors) - + # Clear cached embeddings for this key if style_image_key in self._embedding_cache: del self._embedding_cache[style_image_key] if style_image_key in self._current_style_images: del self._current_style_images[style_image_key] - + def update_style_image(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None: """ Update a style image and trigger embedding preprocessing. - + Args: style_image_key: Unique key for the style image style_image: The style image (PIL Image, path, etc.) @@ -138,14 +146,16 @@ def update_style_image(self, style_image_key: str, style_image: Any, is_stream: """ # Store the style image self._current_style_images[style_image_key] = style_image - + # Trigger preprocessing for this style image self._preprocess_style_image_parallel(style_image_key, style_image, is_stream) - - def _preprocess_style_image_parallel(self, style_image_key: str, style_image: Any, is_stream: bool = False) -> None: + + def _preprocess_style_image_parallel( + self, style_image_key: str, style_image: Any, is_stream: bool = False + ) -> None: """ Preprocessing for a specific style image with mode selection - + Args: style_image_key: Unique key for the style image style_image: The style image to process @@ -153,57 +163,47 @@ def _preprocess_style_image_parallel(self, style_image_key: str, style_image: An """ if not self._embedding_preprocessors or self._embedding_orchestrator is None: return - + # Find preprocessors for this key relevant_preprocessors = [ - preprocessor for preprocessor, key in self._embedding_preprocessors - if key == style_image_key + preprocessor for preprocessor, key in self._embedding_preprocessors if key == style_image_key ] - + if not relevant_preprocessors: return - + # Choose processing mode based on is_stream parameter try: if is_stream: # Pipelined processing - optimized for throughput with 1-frame lag embedding_results = self._embedding_orchestrator.process_pipelined( - style_image, - relevant_preprocessors, - None, - self.stream.width, - self.stream.height, - "ipadapter" + style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, "ipadapter" ) else: # Synchronous processing - immediate results for discrete updates embedding_results = self._embedding_orchestrator.process_sync( - style_image, - relevant_preprocessors, - None, - self.stream.width, - self.stream.height, - None, - "ipadapter" + style_image, relevant_preprocessors, None, self.stream.width, self.stream.height, None, "ipadapter" ) - + # Cache results for this style image key if embedding_results and embedding_results[0] is not None: self._embedding_cache[style_image_key] = embedding_results[0] else: # This is an error condition - we should always have results - raise RuntimeError(f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'") - - except Exception as e: + raise RuntimeError( + f"_preprocess_style_image_parallel: Failed to generate embeddings for style image '{style_image_key}'" + ) + + except Exception: import traceback + traceback.print_exc() - + def get_cached_embeddings(self, style_image_key: str) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Get cached embeddings for a style image key""" cached_result = self._embedding_cache.get(style_image_key, None) return cached_result - def _normalize_weights(self, weights: List[float], normalize: bool) -> torch.Tensor: """Generic weight normalization helper""" weights_tensor = torch.tensor(weights, device=self.stream.device, dtype=self.stream.dtype) @@ -218,7 +218,7 @@ def _validate_index(self, index: int, item_list: List, operation_name: str) -> b return False if index < 0 or index >= len(item_list): - logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list)-1})") + logger.warning(f"{operation_name}: Warning: Index {index} out of range (0-{len(item_list) - 1})") return False return True @@ -281,28 +281,27 @@ def update_stream_params( f"provided t_index_list (max index: {max_t_index}). Adjusting to {max_t_index + 1}." ) num_inference_steps = max_t_index + 1 - + old_num_steps = len(self.stream.timesteps) self.stream.scheduler.set_timesteps(num_inference_steps, self.stream.device) self.stream.timesteps = self.stream.scheduler.timesteps.to(self.stream.device) - + # If t_index_list wasn't explicitly provided, rescale existing t_list proportionally if t_index_list is None and old_num_steps > 0: # Rescale each index proportionally to the new number of steps # e.g., if t_list = [0, 16, 32, 45] with 50 steps -> [0, 3, 6, 8] with 9 steps scale_factor = (num_inference_steps - 1) / (old_num_steps - 1) if old_num_steps > 1 else 1.0 - t_index_list = [ - min(round(t * scale_factor), num_inference_steps - 1) - for t in self.stream.t_list - ] - + t_index_list = [min(round(t * scale_factor), num_inference_steps - 1) for t in self.stream.t_list] + # Now update timestep-dependent parameters with the correct t_index_list if t_index_list is not None: self._recalculate_timestep_dependent_params(t_index_list) if guidance_scale is not None: if self.stream.cfg_type == "none" and guidance_scale > 1.0: - logger.warning("update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect") + logger.warning( + "update_stream_params: Warning: guidance_scale > 1.0 with cfg_type='none' will have no effect" + ) self.stream.guidance_scale = guidance_scale if delta is not None: @@ -310,7 +309,7 @@ def update_stream_params( if seed is not None: self._update_seed(seed) - + if normalize_prompt_weights is not None: self.normalize_prompt_weights = normalize_prompt_weights logger.info(f"update_stream_params: Prompt weight normalization set to {normalize_prompt_weights}") @@ -324,44 +323,42 @@ def update_stream_params( self._update_blended_prompts( prompt_list=prompt_list, negative_prompt=negative_prompt or self._current_negative_prompt, - prompt_interpolation_method=prompt_interpolation_method + prompt_interpolation_method=prompt_interpolation_method, ) # Handle seed blending if seed_list is provided if seed_list is not None: - self._update_blended_seeds( - seed_list=seed_list, - interpolation_method=seed_interpolation_method - ) - + self._update_blended_seeds(seed_list=seed_list, interpolation_method=seed_interpolation_method) # Handle ControlNet configuration updates if controlnet_config is not None: - #TODO: happy path for control images + # TODO: happy path for control images self._update_controlnet_config(controlnet_config) - + # Handle IPAdapter configuration updates if ipadapter_config is not None: - logger.info(f"update_stream_params: Updating IPAdapter configuration") + logger.info("update_stream_params: Updating IPAdapter configuration") self._update_ipadapter_config(ipadapter_config) - + # Handle Hook configuration updates if image_preprocessing_config is not None: - logger.info(f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors") + logger.info( + f"update_stream_params: Updating image preprocessing configuration with {len(image_preprocessing_config)} processors" + ) logger.info(f"update_stream_params: image_preprocessing_config = {image_preprocessing_config}") - self._update_hook_config('image_preprocessing', image_preprocessing_config) - + self._update_hook_config("image_preprocessing", image_preprocessing_config) + if image_postprocessing_config is not None: - logger.info(f"update_stream_params: Updating image postprocessing configuration") - self._update_hook_config('image_postprocessing', image_postprocessing_config) - + logger.info("update_stream_params: Updating image postprocessing configuration") + self._update_hook_config("image_postprocessing", image_postprocessing_config) + if latent_preprocessing_config is not None: - logger.info(f"update_stream_params: Updating latent preprocessing configuration") - self._update_hook_config('latent_preprocessing', latent_preprocessing_config) - + logger.info("update_stream_params: Updating latent preprocessing configuration") + self._update_hook_config("latent_preprocessing", latent_preprocessing_config) + if latent_postprocessing_config is not None: - logger.info(f"update_stream_params: Updating latent postprocessing configuration") - self._update_hook_config('latent_postprocessing', latent_postprocessing_config) + logger.info("update_stream_params: Updating latent postprocessing configuration") + self._update_hook_config("latent_postprocessing", latent_postprocessing_config) if self.stream.kvo_cache: if cache_interval is not None: @@ -375,9 +372,7 @@ def update_stream_params( # runtime — resizing one-at-a-time races with TRT inference (causes "Dimensions # with name C must be equal" errors). cache_maxframes is a logical write window. actual_cache_size = ( - self.stream.kvo_cache[0].shape[1] - if self.stream.kvo_cache - else cache_maxframes + self.stream.kvo_cache[0].shape[1] if self.stream.kvo_cache else cache_maxframes ) if cache_maxframes > actual_cache_size: logger.warning( @@ -395,9 +390,7 @@ def update_stream_params( @torch.inference_mode() def update_prompt_weights( - self, - prompt_weights: List[float], - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, prompt_weights: List[float], prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Update weights for current prompt list without re-encoding prompts.""" if not self._current_prompt_list: @@ -405,7 +398,9 @@ def update_prompt_weights( return if len(prompt_weights) != len(self._current_prompt_list): - logger.warning(f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}") + logger.warning( + f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}" + ) return # Update the current prompt list with new weights @@ -420,9 +415,7 @@ def update_prompt_weights( @torch.inference_mode() def update_seed_weights( - self, - seed_weights: List[float], - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed_weights: List[float], interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update weights for current seed list without regenerating noise.""" if not self._current_seed_list: @@ -430,7 +423,9 @@ def update_seed_weights( return if len(seed_weights) != len(self._current_seed_list): - logger.warning(f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}") + logger.warning( + f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}" + ) return # Update the current seed list with new weights @@ -448,7 +443,7 @@ def _update_blended_prompts( self, prompt_list: List[Tuple[str, float]], negative_prompt: str = "", - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", ) -> None: """Update prompt embeddings using multiple weighted prompts.""" # Store current state @@ -461,14 +456,10 @@ def _update_blended_prompts( # Apply blending self._apply_prompt_blending(prompt_interpolation_method) - def _cache_prompt_embeddings( - self, - prompt_list: List[Tuple[str, float]], - negative_prompt: str - ) -> None: + def _cache_prompt_embeddings(self, prompt_list: List[Tuple[str, float]], negative_prompt: str) -> None: """Cache prompt embeddings for efficient reuse.""" for idx, (prompt_text, weight) in enumerate(prompt_list): - if idx not in self._prompt_cache or self._prompt_cache[idx]['text'] != prompt_text: + if idx not in self._prompt_cache or self._prompt_cache[idx]["text"] != prompt_text: # Cache miss - encode the prompt self._prompt_cache_stats.record_miss() encoder_output = self.stream.pipe.encode_prompt( @@ -482,10 +473,7 @@ def _cache_prompt_embeddings( if len(self._prompt_cache) >= 32: oldest_key = next(iter(self._prompt_cache)) del self._prompt_cache[oldest_key] - self._prompt_cache[idx] = { - 'embed': encoder_output[0], - 'text': prompt_text - } + self._prompt_cache[idx] = {"embed": encoder_output[0], "text": prompt_text} else: # Cache hit self._prompt_cache_stats.record_hit() @@ -500,7 +488,7 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", for idx, (prompt_text, weight) in enumerate(self._current_prompt_list): if idx in self._prompt_cache: - embeddings.append(self._prompt_cache[idx]['embed']) + embeddings.append(self._prompt_cache[idx]["embed"]) weights.append(weight) if not embeddings: @@ -545,13 +533,14 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", # No CFG, just use the blended embeddings final_prompt_embeds = combined_embeds.repeat(self.stream.batch_size, 1, 1) final_negative_embeds = None # Will be set by enhancers if needed - + # Enhancer mechanism removed in favor of embedding_hooks # Run embedding hooks to compose final embeddings (e.g., append IP-Adapter tokens) try: - if hasattr(self.stream, 'embedding_hooks') and self.stream.embedding_hooks: + if hasattr(self.stream, "embedding_hooks") and self.stream.embedding_hooks: from .hooks import EmbedsCtx # local import to avoid cycles + embeds_ctx = EmbedsCtx( prompt_embeds=final_prompt_embeds, negative_prompt_embeds=final_negative_embeds, @@ -562,8 +551,9 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", final_negative_embeds = embeds_ctx.negative_prompt_embeds except Exception as e: import logging + logging.getLogger(__name__).error(f"_apply_prompt_blending: embedding hook failed: {e}") - + # Set final embeddings on stream self.stream.prompt_embeds = final_prompt_embeds if final_negative_embeds is not None: @@ -604,9 +594,7 @@ def _slerp(self, embed1: torch.Tensor, embed2: torch.Tensor, t: float) -> torch. @torch.inference_mode() def _update_blended_seeds( - self, - seed_list: List[Tuple[int, float]], - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed_list: List[Tuple[int, float]], interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update seed tensors using multiple weighted seeds.""" # Store current state @@ -621,7 +609,7 @@ def _update_blended_seeds( def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None: """Cache seed noise tensors for efficient reuse.""" for idx, (seed_value, weight) in enumerate(seed_list): - if idx not in self._seed_cache or self._seed_cache[idx]['seed'] != seed_value: + if idx not in self._seed_cache or self._seed_cache[idx]["seed"] != seed_value: # Cache miss - generate noise for the seed self._seed_cache_stats.record_miss() generator = torch.Generator(device=self.stream.device) @@ -631,13 +619,10 @@ def _cache_seed_noise(self, seed_list: List[Tuple[int, float]]) -> None: (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[idx] = { - 'noise': noise, - 'seed': seed_value - } + self._seed_cache[idx] = {"noise": noise, "seed": seed_value} else: # Cache hit self._seed_cache_stats.record_hit() @@ -652,7 +637,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) for idx, (seed_value, weight) in enumerate(self._current_seed_list): if idx in self._seed_cache: - noise_tensors.append(self._seed_cache[idx]['noise']) + noise_tensors.append(self._seed_cache[idx]["noise"]) weights.append(weight) if not noise_tensors: @@ -673,7 +658,7 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) combined_noise = torch.zeros_like(noise_tensors[0]) for noise, weight in zip(noise_tensors, weights): combined_noise += weight * noise - + # Preserve noise magnitude when weights are normalized if self.normalize_seed_weights and len(noise_tensors) > 1: original_magnitude = torch.mean(torch.stack([torch.norm(noise) for noise in noise_tensors])) @@ -748,6 +733,7 @@ def _update_seed(self, seed: int) -> None: def _get_scheduler_scalings(self, timestep): """Get LCM/TCD-specific scaling factors for boundary conditions.""" from diffusers import LCMScheduler + if isinstance(self.stream.scheduler, LCMScheduler): c_skip, c_out = self.stream.scheduler.get_scalings_for_boundary_condition_discrete(timestep) return c_skip, c_out @@ -765,9 +751,7 @@ def _update_timestep_calculations(self) -> None: for t in self.stream.t_list: self.stream.sub_timesteps.append(self.stream.timesteps[t]) - sub_timesteps_tensor = torch.tensor( - self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device - ) + sub_timesteps_tensor = torch.tensor(self.stream.sub_timesteps, dtype=torch.long, device=self.stream.device) self.stream.sub_timesteps_tensor = torch.repeat_interleave( sub_timesteps_tensor, repeats=self.stream.frame_bff_size if self.stream.use_denoising_batch else 1, @@ -793,12 +777,8 @@ def _update_timestep_calculations(self) -> None: ) if self.stream.use_denoising_batch: - self.stream.c_skip = torch.repeat_interleave( - self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0 - ) - self.stream.c_out = torch.repeat_interleave( - self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0 - ) + self.stream.c_skip = torch.repeat_interleave(self.stream.c_skip, repeats=self.stream.frame_bff_size, dim=0) + self.stream.c_out = torch.repeat_interleave(self.stream.c_out, repeats=self.stream.frame_bff_size, dim=0) # Update alpha_prod_t_sqrt and beta_prod_t_sqrt alpha_prod_t_sqrt_list = [] @@ -838,29 +818,25 @@ def _update_timestep_values_only(self, t_index_list: List[int]) -> None: def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> None: """Recalculate all parameters that depend on t_index_list.""" - + # Check if this is a structural change (length) or just value change if len(t_index_list) == len(self.stream.t_list): # Same length - only values changed, use lightweight update (working branch behavior) self._update_timestep_values_only(t_index_list) return - + # Length changed - do full recalculation including batch-dependent parameters (broken branch logic - but it works for this case!) self.stream.t_list = t_index_list self.stream.denoising_steps_num = len(self.stream.t_list) old_batch_size = self.stream.batch_size - + if self.stream.use_denoising_batch: self.stream.batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size if self.stream.cfg_type == "initialize": - self.stream.trt_unet_batch_size = ( - self.stream.denoising_steps_num + 1 - ) * self.stream.frame_bff_size + self.stream.trt_unet_batch_size = (self.stream.denoising_steps_num + 1) * self.stream.frame_bff_size elif self.stream.cfg_type == "full": - self.stream.trt_unet_batch_size = ( - 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size - ) + self.stream.trt_unet_batch_size = 2 * self.stream.denoising_steps_num * self.stream.frame_bff_size else: self.stream.trt_unet_batch_size = self.stream.denoising_steps_num * self.stream.frame_bff_size else: @@ -891,27 +867,33 @@ def _recalculate_timestep_dependent_params(self, t_index_list: List[int]) -> Non # Resize kvo_cache tensors if batch size changed if self.stream.kvo_cache and old_batch_size != self.stream.batch_size: - logger.info(f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}") + logger.info( + f"_recalculate_timestep_dependent_params: Resizing kvo_cache tensors from batch_size {old_batch_size} to {self.stream.batch_size}" + ) for i, cache_tensor in enumerate(self.stream.kvo_cache): # KVO cache shape: (2, cache_maxframes, batch_size, seq_length, hidden_dim) current_shape = cache_tensor.shape - new_shape = (current_shape[0], current_shape[1], self.stream.batch_size, current_shape[3], current_shape[4]) - new_cache_tensor = torch.zeros( - new_shape, - dtype=cache_tensor.dtype, - device=cache_tensor.device + new_shape = ( + current_shape[0], + current_shape[1], + self.stream.batch_size, + current_shape[3], + current_shape[4], ) - + new_cache_tensor = torch.zeros(new_shape, dtype=cache_tensor.dtype, device=cache_tensor.device) + # Copy over as much data as possible from old cache min_batch = min(old_batch_size, self.stream.batch_size) new_cache_tensor[:, :, :min_batch, :, :] = cache_tensor[:, :, :min_batch, :, :] - + self.stream.kvo_cache[i] = new_cache_tensor # Drop bucketed storage refs so update_kvo_cache falls back to # per-layer writes against the new tensors. self.stream._kvo_buckets = None self.stream._kvo_outputs_by_bucket = None - logger.info(f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}") + logger.info( + f"_recalculate_timestep_dependent_params: KVO cache tensors resized to new batch_size {self.stream.batch_size}" + ) # Update timestep-dependent calculations (shared with value-only path) self._update_timestep_calculations() @@ -930,10 +912,7 @@ def _recalculate_controlnet_inputs(self, width: int, height: int) -> None: @torch.inference_mode() def update_prompt_at_index( - self, - index: int, - new_prompt: str, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, index: int, new_prompt: str, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Update a single prompt at the specified index without re-encoding others.""" if not self._validate_index(index, self._current_prompt_list, "update_prompt_at_index"): @@ -947,11 +926,11 @@ def update_prompt_at_index( self._cache_prompt_embeddings([(new_prompt, weight)], self._current_negative_prompt) # Update cache index to point to the new prompt - if index in self._prompt_cache and self._prompt_cache[index]['text'] != new_prompt: + if index in self._prompt_cache and self._prompt_cache[index]["text"] != new_prompt: # Find if this prompt is already cached elsewhere existing_cache_key = None for cache_idx, cache_data in self._prompt_cache.items(): - if cache_data['text'] == new_prompt: + if cache_data["text"] == new_prompt: existing_cache_key = cache_idx break @@ -969,10 +948,7 @@ def update_prompt_at_index( do_classifier_free_guidance=False, negative_prompt=self._current_negative_prompt, ) - self._prompt_cache[index] = { - 'embed': encoder_output[0], - 'text': new_prompt - } + self._prompt_cache[index] = {"embed": encoder_output[0], "text": new_prompt} # Recompute blended embeddings with updated prompt self._apply_prompt_blending(prompt_interpolation_method) @@ -984,16 +960,12 @@ def get_current_prompts(self) -> List[Tuple[str, float]]: @torch.inference_mode() def add_prompt( - self, - prompt: str, - weight: float = 1.0, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, prompt: str, weight: float = 1.0, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Add a new prompt to the current list.""" new_index = len(self._current_prompt_list) self._current_prompt_list.append((prompt, weight)) - # Cache the new prompt encoder_output = self.stream.pipe.encode_prompt( prompt=prompt, @@ -1002,10 +974,7 @@ def add_prompt( do_classifier_free_guidance=False, negative_prompt=self._current_negative_prompt, ) - self._prompt_cache[new_index] = { - 'embed': encoder_output[0], - 'text': prompt - } + self._prompt_cache[new_index] = {"embed": encoder_output[0], "text": prompt} self._prompt_cache_stats.record_miss() # Recompute blended embeddings @@ -1013,9 +982,7 @@ def add_prompt( @torch.inference_mode() def remove_prompt_at_index( - self, - index: int, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + self, index: int, prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" ) -> None: """Remove a prompt at the specified index.""" if not self._validate_index(index, self._current_prompt_list, "remove_prompt_at_index"): @@ -1040,10 +1007,7 @@ def remove_prompt_at_index( @torch.inference_mode() def update_seed_at_index( - self, - index: int, - new_seed: int, - interpolation_method: Literal["linear", "slerp"] = "linear" + self, index: int, new_seed: int, interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Update a single seed at the specified index without regenerating others.""" if not self._validate_index(index, self._current_seed_list, "update_seed_at_index"): @@ -1053,16 +1017,15 @@ def update_seed_at_index( old_seed, weight = self._current_seed_list[index] self._current_seed_list[index] = (new_seed, weight) - # Cache the new seed noise self._cache_seed_noise([(new_seed, weight)]) # Update cache index to point to the new seed - if index in self._seed_cache and self._seed_cache[index]['seed'] != new_seed: + if index in self._seed_cache and self._seed_cache[index]["seed"] != new_seed: # Find if this seed is already cached elsewhere existing_cache_key = None for cache_idx, cache_data in self._seed_cache.items(): - if cache_data['seed'] == new_seed: + if cache_data["seed"] == new_seed: existing_cache_key = cache_idx break @@ -1080,13 +1043,10 @@ def update_seed_at_index( (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[index] = { - 'noise': noise, - 'seed': new_seed - } + self._seed_cache[index] = {"noise": noise, "seed": new_seed} # Recompute blended noise with updated seed self._apply_seed_blending(interpolation_method) @@ -1098,10 +1058,7 @@ def get_current_seeds(self) -> List[Tuple[int, float]]: @torch.inference_mode() def add_seed( - self, - seed: int, - weight: float = 1.0, - interpolation_method: Literal["linear", "slerp"] = "linear" + self, seed: int, weight: float = 1.0, interpolation_method: Literal["linear", "slerp"] = "linear" ) -> None: """Add a new seed to the current list.""" new_index = len(self._current_seed_list) @@ -1117,24 +1074,17 @@ def add_seed( (self.stream.batch_size, 4, self.stream.latent_height, self.stream.latent_width), generator=generator, device=self.stream.device, - dtype=self.stream.dtype + dtype=self.stream.dtype, ) - self._seed_cache[new_index] = { - 'noise': noise, - 'seed': seed - } + self._seed_cache[new_index] = {"noise": noise, "seed": seed} self._seed_cache_stats.record_miss() # Recompute blended noise self._apply_seed_blending(interpolation_method) @torch.inference_mode() - def remove_seed_at_index( - self, - index: int, - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: + def remove_seed_at_index(self, index: int, interpolation_method: Literal["linear", "slerp"] = "linear") -> None: """Remove a seed at the specified index.""" if not self._validate_index(index, self._current_seed_list, "remove_seed_at_index"): return @@ -1159,7 +1109,7 @@ def remove_seed_at_index( def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> None: """ Update ControlNet configuration by diffing current vs desired state. - + Args: desired_config: Complete ControlNet configuration list defining the desired state. Each dict contains: model_id, preprocessor, conditioning_scale, enabled, etc. @@ -1167,41 +1117,47 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non # Find the ControlNet pipeline/module (module-aware) controlnet_pipeline = self._get_controlnet_pipeline() if not controlnet_pipeline: - logger.debug("_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)") + logger.debug( + "_update_controlnet_config: No ControlNet pipeline found (expected when ControlNet not loaded)" + ) return - + current_config = self._get_current_controlnet_config() - + # Simple approach: detect what changed and apply minimal updates - current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} - desired_models = {cfg['model_id']: cfg for cfg in desired_config} - + current_models = { + i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets) + } + desired_models = {cfg["model_id"]: cfg for cfg in desired_config} + # Reorder to match desired order (module supports stable reordering) try: - desired_order = [cfg['model_id'] for cfg in desired_config if 'model_id' in cfg] - if hasattr(controlnet_pipeline, 'reorder_controlnets_by_model_ids'): + desired_order = [cfg["model_id"] for cfg in desired_config if "model_id" in cfg] + if hasattr(controlnet_pipeline, "reorder_controlnets_by_model_ids"): controlnet_pipeline.reorder_controlnets_by_model_ids(desired_order) except Exception: pass # Recompute current models after potential reorder - current_models = {i: getattr(cn, 'model_id', f'controlnet_{i}') for i, cn in enumerate(controlnet_pipeline.controlnets)} + current_models = { + i: getattr(cn, "model_id", f"controlnet_{i}") for i, cn in enumerate(controlnet_pipeline.controlnets) + } # Remove controlnets not in desired config for i in reversed(range(len(controlnet_pipeline.controlnets))): - model_id = current_models.get(i, f'controlnet_{i}') + model_id = current_models.get(i, f"controlnet_{i}") if model_id not in desired_models: logger.info(f"_update_controlnet_config: Removing ControlNet {model_id}") try: controlnet_pipeline.remove_controlnet(i) except Exception: raise - + # Add new controlnets and update existing ones for desired_cfg in desired_config: - model_id = desired_cfg['model_id'] + model_id = desired_cfg["model_id"] existing_index = next((i for i, mid in current_models.items() if mid == model_id), None) - + if existing_index is None: # Add new controlnet logger.info(f"_update_controlnet_config: Adding ControlNet {model_id}") @@ -1209,15 +1165,16 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non # Prefer module path: construct ControlNetConfig try: from .modules.controlnet_module import ControlNetConfig # type: ignore + cn_cfg = ControlNetConfig( - model_id=desired_cfg.get('model_id'), - preprocessor=desired_cfg.get('preprocessor'), - conditioning_scale=desired_cfg.get('conditioning_scale', 1.0), - enabled=desired_cfg.get('enabled', True), - conditioning_channels=desired_cfg.get('conditioning_channels'), - preprocessor_params=desired_cfg.get('preprocessor_params'), + model_id=desired_cfg.get("model_id"), + preprocessor=desired_cfg.get("preprocessor"), + conditioning_scale=desired_cfg.get("conditioning_scale", 1.0), + enabled=desired_cfg.get("enabled", True), + conditioning_channels=desired_cfg.get("conditioning_channels"), + preprocessor_params=desired_cfg.get("preprocessor_params"), ) - controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get('control_image')) + controlnet_pipeline.add_controlnet(cn_cfg, desired_cfg.get("control_image")) except Exception: # No fallback raise @@ -1225,114 +1182,136 @@ def _update_controlnet_config(self, desired_config: List[Dict[str, Any]]) -> Non logger.error(f"_update_controlnet_config: add_controlnet failed for {model_id}: {e}") else: # Update existing controlnet - if 'conditioning_scale' in desired_cfg: - current_scale = current_config[existing_index].get('conditioning_scale', 1.0) - desired_scale = desired_cfg['conditioning_scale'] - + if "conditioning_scale" in desired_cfg: + current_scale = current_config[existing_index].get("conditioning_scale", 1.0) + desired_scale = desired_cfg["conditioning_scale"] + if current_scale != desired_scale: - logger.info(f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}") - if hasattr(controlnet_pipeline, 'controlnet_scales') and 0 <= existing_index < len(controlnet_pipeline.controlnet_scales): + logger.info( + f"_update_controlnet_config: Updating {model_id} scale: {current_scale} → {desired_scale}" + ) + if hasattr(controlnet_pipeline, "controlnet_scales") and 0 <= existing_index < len( + controlnet_pipeline.controlnet_scales + ): controlnet_pipeline.controlnet_scales[existing_index] = float(desired_scale) - + # Enable/disable toggle - if 'enabled' in desired_cfg and hasattr(controlnet_pipeline, 'enabled_list'): + if "enabled" in desired_cfg and hasattr(controlnet_pipeline, "enabled_list"): if 0 <= existing_index < len(controlnet_pipeline.enabled_list): - controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg['enabled']) + controlnet_pipeline.enabled_list[existing_index] = bool(desired_cfg["enabled"]) - if 'preprocessor_params' in desired_cfg and hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[existing_index]: + if ( + "preprocessor_params" in desired_cfg + and hasattr(controlnet_pipeline, "preprocessors") + and controlnet_pipeline.preprocessors[existing_index] + ): preprocessor = controlnet_pipeline.preprocessors[existing_index] - preprocessor.params.update(desired_cfg['preprocessor_params']) - for param_name, param_value in desired_cfg['preprocessor_params'].items(): + preprocessor.params.update(desired_cfg["preprocessor_params"]) + for param_name, param_value in desired_cfg["preprocessor_params"].items(): if hasattr(preprocessor, param_name): setattr(preprocessor, param_name, param_value) - + # Pipeline references are now automatically managed during preprocessor creation # No need to manually re-establish pipeline references for pipeline-aware processors - def _get_controlnet_pipeline(self): """ Get the ControlNet module or legacy pipeline from the structure (module-aware). """ # Module-installed path - if hasattr(self.stream, '_controlnet_module'): + if hasattr(self.stream, "_controlnet_module"): return self.stream._controlnet_module # Legacy paths - if hasattr(self.stream, 'controlnets'): + if hasattr(self.stream, "controlnets"): return self.stream - if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'controlnets'): + if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "controlnets"): return self.stream.stream - if self.wrapper and hasattr(self.wrapper, 'stream'): - if hasattr(self.wrapper.stream, '_controlnet_module'): + if self.wrapper and hasattr(self.wrapper, "stream"): + if hasattr(self.wrapper.stream, "_controlnet_module"): return self.wrapper.stream._controlnet_module - if hasattr(self.wrapper.stream, 'controlnets'): + if hasattr(self.wrapper.stream, "controlnets"): return self.wrapper.stream - if hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'controlnets'): + if hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "controlnets"): return self.wrapper.stream.stream return None def _get_current_controlnet_config(self) -> List[Dict[str, Any]]: """ Get current ControlNet configuration state. - + Returns: List of current ControlNet configurations """ controlnet_pipeline = self._get_controlnet_pipeline() - if not controlnet_pipeline or not hasattr(controlnet_pipeline, 'controlnets') or not controlnet_pipeline.controlnets: + if ( + not controlnet_pipeline + or not hasattr(controlnet_pipeline, "controlnets") + or not controlnet_pipeline.controlnets + ): return [] - + current_config = [] for i, controlnet in enumerate(controlnet_pipeline.controlnets): - model_id = getattr(controlnet, 'model_id', f'controlnet_{i}') - scale = controlnet_pipeline.controlnet_scales[i] if hasattr(controlnet_pipeline, 'controlnet_scales') and i < len(controlnet_pipeline.controlnet_scales) else 1.0 + model_id = getattr(controlnet, "model_id", f"controlnet_{i}") + scale = ( + controlnet_pipeline.controlnet_scales[i] + if hasattr(controlnet_pipeline, "controlnet_scales") and i < len(controlnet_pipeline.controlnet_scales) + else 1.0 + ) enabled_val = True try: - if hasattr(controlnet_pipeline, 'enabled_list') and i < len(controlnet_pipeline.enabled_list): + if hasattr(controlnet_pipeline, "enabled_list") and i < len(controlnet_pipeline.enabled_list): enabled_val = bool(controlnet_pipeline.enabled_list[i]) except Exception: enabled_val = True config = { - 'model_id': model_id, - 'conditioning_scale': scale, - 'preprocessor_params': getattr(controlnet_pipeline.preprocessors[i], 'params', {}) if hasattr(controlnet_pipeline, 'preprocessors') and controlnet_pipeline.preprocessors[i] else {}, - 'enabled': enabled_val, + "model_id": model_id, + "conditioning_scale": scale, + "preprocessor_params": getattr(controlnet_pipeline.preprocessors[i], "params", {}) + if hasattr(controlnet_pipeline, "preprocessors") and controlnet_pipeline.preprocessors[i] + else {}, + "enabled": enabled_val, } current_config.append(config) - + return current_config def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: """ Update IPAdapter configuration. - + Args: - desired_config: IPAdapter configuration dict containing: + desired_config: IPAdapter configuration dict containing: ipadapter_model_path, image_encoder_path, style_image, scale, enabled, etc. """ # Find the IPAdapter pipeline ipadapter_pipeline = self._get_ipadapter_pipeline() - + if not ipadapter_pipeline: - logger.warning(f"_update_ipadapter_config: No IPAdapter pipeline found") + logger.warning("_update_ipadapter_config: No IPAdapter pipeline found") return - - if 'scale' in desired_config and desired_config['scale'] is not None: - desired_scale = float(desired_config['scale']) + + if "scale" in desired_config and desired_config["scale"] is not None: + desired_scale = float(desired_config["scale"]) # Get current scale from IPAdapter instance - current_scale = getattr(self.stream.ipadapter, 'scale', 1.0) if hasattr(self.stream, 'ipadapter') else 1.0 - + current_scale = getattr(self.stream.ipadapter, "scale", 1.0) if hasattr(self.stream, "ipadapter") else 1.0 + if current_scale != desired_scale: logger.info(f"_update_ipadapter_config: Updating scale: {current_scale} → {desired_scale}") - + # Get weight_type from IPAdapter instance - weight_type = getattr(self.stream.ipadapter, 'weight_type', None) if hasattr(self.stream, 'ipadapter') else None - + weight_type = ( + getattr(self.stream.ipadapter, "weight_type", None) if hasattr(self.stream, "ipadapter") else None + ) + # Apply scale with weight type consideration - if weight_type is not None and hasattr(self.stream, 'ipadapter'): + if weight_type is not None and hasattr(self.stream, "ipadapter"): try: from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights - ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + + ip_procs = [ + p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index") + ] num_layers = len(ip_procs) weights = build_layer_weights(num_layers, desired_scale, weight_type) if weights is not None: @@ -1340,47 +1319,51 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: else: self.stream.ipadapter.set_scale(desired_scale) # Update our tracking attribute - setattr(self.stream.ipadapter, 'scale', desired_scale) + setattr(self.stream.ipadapter, "scale", desired_scale) except Exception: # Do not add fallback mechanisms raise else: # Simple uniform scale - if hasattr(self.stream, 'ipadapter'): + if hasattr(self.stream, "ipadapter"): # Tell diffusers_ipadapter to set the scale self.stream.ipadapter.set_scale(desired_scale) # Update our tracking attribute - setattr(self.stream.ipadapter, 'scale', desired_scale) - + setattr(self.stream.ipadapter, "scale", desired_scale) # Update enabled state if provided - if 'enabled' in desired_config and desired_config['enabled'] is not None: - enabled_state = bool(desired_config['enabled']) + if "enabled" in desired_config and desired_config["enabled"] is not None: + enabled_state = bool(desired_config["enabled"]) # Update IPAdapter instance - if hasattr(self.stream, 'ipadapter'): - current_enabled = getattr(self.stream.ipadapter, 'enabled', True) + if hasattr(self.stream, "ipadapter"): + current_enabled = getattr(self.stream.ipadapter, "enabled", True) if current_enabled != enabled_state: - logger.info(f"_update_ipadapter_config: Updating enabled state: {current_enabled} → {enabled_state}") - setattr(self.stream.ipadapter, 'enabled', enabled_state) + logger.info( + f"_update_ipadapter_config: Updating enabled state: {current_enabled} → {enabled_state}" + ) + setattr(self.stream.ipadapter, "enabled", enabled_state) # Update weight type if provided (affects per-layer distribution and/or per-step factor) - if 'weight_type' in desired_config and desired_config['weight_type'] is not None: - weight_type = desired_config['weight_type'] + if "weight_type" in desired_config and desired_config["weight_type"] is not None: + weight_type = desired_config["weight_type"] # Update IPAdapter instance - if hasattr(self.stream, 'ipadapter'): - setattr(self.stream.ipadapter, 'weight_type', weight_type) - + if hasattr(self.stream, "ipadapter"): + setattr(self.stream.ipadapter, "weight_type", weight_type) + # For PyTorch UNet, immediately apply a per-layer scale vector so layers reflect selection types try: - is_tensorrt_engine = hasattr(self.stream.unet, 'engine') and hasattr(self.stream.unet, 'stream') + is_tensorrt_engine = hasattr(self.stream.unet, "engine") and hasattr(self.stream.unet, "stream") if not is_tensorrt_engine: # Compute per-layer vector using Diffusers_IPAdapter helper from diffusers_ipadapter.ip_adapter.attention_processor import build_layer_weights + # Count installed IP layers by scanning processors with _ip_layer_index - ip_procs = [p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index")] + ip_procs = [ + p for p in self.stream.pipe.unet.attn_processors.values() if hasattr(p, "_ip_layer_index") + ] num_layers = len(ip_procs) # Get base weight from IPAdapter instance - base_weight = float(getattr(self.stream.ipadapter, 'scale', 1.0)) + base_weight = float(getattr(self.stream.ipadapter, "scale", 1.0)) weights = build_layer_weights(num_layers, base_weight, weight_type) # If None, keep uniform base scale; else set per-layer vector if weights is not None: @@ -1388,7 +1371,7 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: else: self.stream.ipadapter.set_scale(base_weight) # Keep our tracking attribute in sync - setattr(self.stream.ipadapter, 'scale', base_weight) + setattr(self.stream.ipadapter, "scale", base_weight) except Exception: # Do not add fallback mechanisms raise @@ -1396,191 +1379,207 @@ def _update_ipadapter_config(self, desired_config: Dict[str, Any]) -> None: def _get_ipadapter_pipeline(self): """ Get the IPAdapter pipeline from the pipeline structure (following ControlNet pattern). - + Returns: IPAdapter pipeline object or None if not found """ # Check if stream is IPAdapter pipeline directly - if hasattr(self.stream, 'ipadapter'): + if hasattr(self.stream, "ipadapter"): return self.stream - + # Check if stream has nested stream (ControlNet wrapper) - if hasattr(self.stream, 'stream') and hasattr(self.stream.stream, 'ipadapter'): + if hasattr(self.stream, "stream") and hasattr(self.stream.stream, "ipadapter"): return self.stream.stream - + # Check if we have a wrapper reference and can access through it - if self.wrapper and hasattr(self.wrapper, 'stream'): - if hasattr(self.wrapper.stream, 'ipadapter'): + if self.wrapper and hasattr(self.wrapper, "stream"): + if hasattr(self.wrapper.stream, "ipadapter"): return self.wrapper.stream - elif hasattr(self.wrapper.stream, 'stream') and hasattr(self.wrapper.stream.stream, 'ipadapter'): + elif hasattr(self.wrapper.stream, "stream") and hasattr(self.wrapper.stream.stream, "ipadapter"): return self.wrapper.stream.stream - + return None def _get_current_ipadapter_config(self) -> Optional[Dict[str, Any]]: """ Get current IPAdapter configuration by introspecting the IPAdapter instance. - + Returns: Current IPAdapter configuration dict or None if no IPAdapter """ # Get config from IPAdapter instance - if hasattr(self.stream, 'ipadapter') and self.stream.ipadapter is not None: + if hasattr(self.stream, "ipadapter") and self.stream.ipadapter is not None: ipadapter = self.stream.ipadapter - + config = { - 'scale': getattr(ipadapter, 'scale', 1.0), - 'weight_type': getattr(ipadapter, 'weight_type', None), - 'enabled': getattr(ipadapter, 'enabled', True), # Check actual enabled state + "scale": getattr(ipadapter, "scale", 1.0), + "weight_type": getattr(ipadapter, "weight_type", None), + "enabled": getattr(ipadapter, "enabled", True), # Check actual enabled state } - + # Add static initialization fields - if hasattr(self.stream, '_ipadapter_module'): + if hasattr(self.stream, "_ipadapter_module"): module_config = self.stream._ipadapter_module.config - config.update({ - 'style_image_key': module_config.style_image_key, - 'num_image_tokens': module_config.num_image_tokens, - 'type': module_config.type.value, - }) - + config.update( + { + "style_image_key": module_config.style_image_key, + "num_image_tokens": module_config.num_image_tokens, + "type": module_config.type.value, + } + ) + # Check if style image is set ipadapter_pipeline = self._get_ipadapter_pipeline() - if ipadapter_pipeline and hasattr(ipadapter_pipeline, 'style_image') and ipadapter_pipeline.style_image: - config['has_style_image'] = True + if ipadapter_pipeline and hasattr(ipadapter_pipeline, "style_image") and ipadapter_pipeline.style_image: + config["has_style_image"] = True else: - config['has_style_image'] = False - + config["has_style_image"] = False + return config - + # No IPAdapter instance found return None def _get_current_hook_config(self, hook_type: str) -> List[Dict[str, Any]]: """ Get current hook configuration by introspecting the hook module state. - + Args: hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.) - + Returns: List of processor configurations or empty list if no module """ # Get the hook module module_attr_name = f"_{hook_type}_module" hook_module = getattr(self.stream, module_attr_name, None) - + if not hook_module: return [] - + # Get processors from the module - processors = getattr(hook_module, 'processors', []) - + processors = getattr(hook_module, "processors", []) + config = [] for i, processor in enumerate(processors): proc_config = { - 'type': getattr(processor, '__class__').__name__, - 'order': getattr(processor, 'order', i), - 'enabled': getattr(processor, 'enabled', True), + "type": getattr(processor, "__class__").__name__, + "order": getattr(processor, "order", i), + "enabled": getattr(processor, "enabled", True), } - + # Try to get processor parameters - if hasattr(processor, 'params'): - proc_config['params'] = dict(processor.params) - + if hasattr(processor, "params"): + proc_config["params"] = dict(processor.params) + config.append(proc_config) - + return config def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any]]) -> None: """ Update hook configuration by modifying existing processors in-place instead of recreating them. - + Args: hook_type: Type of hook (image_preprocessing, image_postprocessing, etc.) desired_config: List of processor configurations """ logger.info(f"_update_hook_config: Updating {hook_type} with {len(desired_config)} processors") - + # Get or create the hook module module_attr_name = f"_{hook_type}_module" hook_module = getattr(self.stream, module_attr_name, None) - + if not hook_module: logger.info(f"_update_hook_config: No existing {hook_type} module, creating new one") # Create the appropriate hook module try: if hook_type in ["image_preprocessing", "image_postprocessing"]: - from streamdiffusion.modules.image_processing_module import ImagePreprocessingModule, ImagePostprocessingModule + from streamdiffusion.modules.image_processing_module import ( + ImagePostprocessingModule, + ImagePreprocessingModule, + ) + if hook_type == "image_preprocessing": hook_module = ImagePreprocessingModule() else: hook_module = ImagePostprocessingModule() elif hook_type in ["latent_preprocessing", "latent_postprocessing"]: - from streamdiffusion.modules.latent_processing_module import LatentPreprocessingModule, LatentPostprocessingModule + from streamdiffusion.modules.latent_processing_module import ( + LatentPostprocessingModule, + LatentPreprocessingModule, + ) + if hook_type == "latent_preprocessing": hook_module = LatentPreprocessingModule() else: hook_module = LatentPostprocessingModule() else: raise ValueError(f"Unknown hook type: {hook_type}") - + # Install the module hook_module.install(self.stream) setattr(self.stream, module_attr_name, hook_module) logger.info(f"_update_hook_config: Created and installed {hook_type} module") - + except Exception as e: logger.error(f"_update_hook_config: Failed to create {hook_type} module: {e}") return - - logger.info(f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors") - + + logger.info( + f"_update_hook_config: Found existing {hook_type} module with {len(hook_module.processors)} processors" + ) + # Modify existing processors in-place instead of clearing and recreating for i, proc_config in enumerate(desired_config): - processor_type = proc_config.get('type', 'unknown') - enabled = proc_config.get('enabled', True) - params = proc_config.get('params', {}) - + processor_type = proc_config.get("type", "unknown") + enabled = proc_config.get("enabled", True) + params = proc_config.get("params", {}) + logger.info(f"_update_hook_config: Processing config {i}: type={processor_type}, enabled={enabled}") - + if i < len(hook_module.processors): # Modify existing processor existing_processor = hook_module.processors[i] - + # Get the current processor type from registry name if available, otherwise use class name - current_type = existing_processor.params.get('_registry_name') if hasattr(existing_processor, 'params') else None + current_type = ( + existing_processor.params.get("_registry_name") if hasattr(existing_processor, "params") else None + ) if not current_type: current_type = existing_processor.__class__.__name__ - - logger.info(f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}") - + + logger.info( + f"_update_hook_config: Modifying existing processor {i}: {current_type} -> {processor_type}" + ) + # If processor type changed, replace it if current_type.lower() != processor_type.lower(): logger.info(f"_update_hook_config: Type changed, replacing processor {i}") try: from streamdiffusion.preprocessing.processors import get_preprocessor - + # Determine normalization context from hook type - if 'latent' in hook_type: - normalization_context = 'latent' + if "latent" in hook_type: + normalization_context = "latent" else: # Image preprocessing/postprocessing uses 'pipeline' context - normalization_context = 'pipeline' - + normalization_context = "pipeline" + new_processor = get_preprocessor( - processor_type, - pipeline_ref=getattr(self, 'stream', None), - normalization_context=normalization_context + processor_type, + pipeline_ref=getattr(self, "stream", None), + normalization_context=normalization_context, ) - + # Copy attributes from old processor - setattr(new_processor, 'order', getattr(existing_processor, 'order', i)) - setattr(new_processor, 'enabled', enabled) - + setattr(new_processor, "order", getattr(existing_processor, "order", i)) + setattr(new_processor, "enabled", enabled) + # Set parameters - if hasattr(new_processor, 'params'): + if hasattr(new_processor, "params"): new_processor.params.update(params) - + hook_module.processors[i] = new_processor logger.info(f"_update_hook_config: Successfully replaced processor {i} with {processor_type}") except Exception as e: @@ -1588,15 +1587,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any else: # Same type, just update attributes logger.info(f"_update_hook_config: Same type, updating attributes for processor {i}") - setattr(existing_processor, 'enabled', enabled) - + setattr(existing_processor, "enabled", enabled) + # Update parameters - if hasattr(existing_processor, 'params'): + if hasattr(existing_processor, "params"): existing_processor.params.update(params) for param_name, param_value in params.items(): if hasattr(existing_processor, param_name): setattr(existing_processor, param_name, param_value) - + logger.info(f"_update_hook_config: Updated processor {i} enabled={enabled}, params={params}") else: # Add new processor @@ -1606,12 +1605,15 @@ def _update_hook_config(self, hook_type: str, desired_config: List[Dict[str, Any logger.info(f"_update_hook_config: Successfully added processor {i}: {processor_type}") except Exception as e: logger.error(f"_update_hook_config: Failed to add processor {i}: {e}") - + # Remove extra processors if config is shorter while len(hook_module.processors) > len(desired_config): removed_idx = len(hook_module.processors) - 1 removed_processor = hook_module.processors.pop() - logger.info(f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}") - - logger.info(f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors") + logger.info( + f"_update_hook_config: Removed extra processor {removed_idx}: {removed_processor.__class__.__name__}" + ) + logger.info( + f"_update_hook_config: Finished updating {hook_type}, now has {len(hook_module.processors)} processors" + ) diff --git a/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py b/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py index 355ad6fc1..2a7eee954 100644 --- a/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py +++ b/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py @@ -9,18 +9,19 @@ python -m streamdiffusion.tools.compile_depth_anything_tensorrt --model_size small --resolution 518 """ -import torch import logging -import os from pathlib import Path -from typing import Optional + import fire +import torch -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) try: import tensorrt as trt + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False @@ -28,6 +29,7 @@ try: import onnx + ONNX_AVAILABLE = True except ImportError: ONNX_AVAILABLE = False @@ -50,10 +52,7 @@ def export_depth_anything_to_onnx( - onnx_path: Path, - model_size: str = "small", - resolution: int = 518, - device: str = "cuda" + onnx_path: Path, model_size: str = "small", resolution: int = 518, device: str = "cuda" ) -> bool: """Export Depth Anything model to ONNX format""" try: @@ -102,6 +101,7 @@ def export_depth_anything_to_onnx( except Exception as e: logger.error(f"Failed to export ONNX: {e}") import traceback + traceback.print_exc() return False @@ -126,7 +126,7 @@ def build_tensorrt_engine( parser = trt.OnnxParser(network, trt_logger) # Parse ONNX - with open(onnx_path, 'rb') as f: + with open(onnx_path, "rb") as f: if not parser.parse(f.read()): for i in range(parser.num_errors): logger.error(f"ONNX parse error: {parser.get_error(i)}") @@ -142,10 +142,12 @@ def build_tensorrt_engine( # Set optimization profile for fixed resolution profile = builder.create_optimization_profile() - profile.set_shape("input", - (1, 3, resolution, resolution), # min - (1, 3, resolution, resolution), # opt - (1, 3, resolution, resolution)) # max + profile.set_shape( + "input", + (1, 3, resolution, resolution), # min + (1, 3, resolution, resolution), # opt + (1, 3, resolution, resolution), + ) # max config.add_optimization_profile(profile) # Build engine @@ -158,7 +160,7 @@ def build_tensorrt_engine( # Save engine engine_path.parent.mkdir(parents=True, exist_ok=True) - with open(engine_path, 'wb') as f: + with open(engine_path, "wb") as f: f.write(serialized_engine) logger.info(f"TensorRT engine saved: {engine_path}") @@ -167,6 +169,7 @@ def build_tensorrt_engine( except Exception as e: logger.error(f"Failed to build TensorRT engine: {e}") import traceback + traceback.print_exc() return False @@ -206,7 +209,7 @@ def compile_depth_anything( # Check if engine already exists if engine_path.exists(): logger.info(f"Engine already exists: {engine_path}") - overwrite = input("Overwrite? (y/N): ").lower().strip() == 'y' + overwrite = input("Overwrite? (y/N): ").lower().strip() == "y" if not overwrite: return @@ -228,9 +231,9 @@ def compile_depth_anything( logger.info("Removed intermediate ONNX file") logger.info(f"\nSuccess! Engine saved to: {engine_path}") - logger.info(f"\nTo use in config:") - logger.info(f' preprocessor: "depth_tensorrt"') - logger.info(f' preprocessor_params:') + logger.info("\nTo use in config:") + logger.info(' preprocessor: "depth_tensorrt"') + logger.info(" preprocessor_params:") logger.info(f' engine_path: "{engine_path}"') diff --git a/src/streamdiffusion/tools/compile_raft_tensorrt.py b/src/streamdiffusion/tools/compile_raft_tensorrt.py index 8734987ef..d3faef953 100644 --- a/src/streamdiffusion/tools/compile_raft_tensorrt.py +++ b/src/streamdiffusion/tools/compile_raft_tensorrt.py @@ -1,21 +1,24 @@ -import torch import logging from pathlib import Path -from typing import Optional + import fire +import torch + -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) try: import tensorrt as trt + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False logger.error("TensorRT not available. Please install it first.") try: - from torchvision.models.optical_flow import raft_small, Raft_Small_Weights + from torchvision.models.optical_flow import Raft_Small_Weights, raft_small + TORCHVISION_AVAILABLE = True except ImportError: TORCHVISION_AVAILABLE = False @@ -28,11 +31,11 @@ def export_raft_to_onnx( min_width: int = 512, max_height: int = 512, max_width: int = 512, - device: str = "cuda" + device: str = "cuda", ) -> bool: """ Export RAFT model to ONNX format - + Args: onnx_path: Path to save the ONNX model min_height: Minimum input height for the model @@ -40,41 +43,41 @@ def export_raft_to_onnx( max_height: Maximum input height for the model max_width: Maximum input width for the model device: Device to use for export - + Returns: True if successful, False otherwise """ if not TORCHVISION_AVAILABLE: logger.error("torchvision is required but not installed") return False - + logger.info(f"Exporting RAFT model to ONNX: {onnx_path}") logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}") - + try: # Load RAFT model logger.info("Loading RAFT Small model...") raft_model = raft_small(weights=Raft_Small_Weights.DEFAULT, progress=True) raft_model = raft_model.to(device=device) raft_model.eval() - + # Create dummy inputs using max resolution for export dummy_frame1 = torch.randn(1, 3, max_height, max_width).to(device) dummy_frame2 = torch.randn(1, 3, max_height, max_width).to(device) - + # Apply RAFT preprocessing if available weights = Raft_Small_Weights.DEFAULT - if hasattr(weights, 'transforms') and weights.transforms is not None: + if hasattr(weights, "transforms") and weights.transforms is not None: transforms = weights.transforms() dummy_frame1, dummy_frame2 = transforms(dummy_frame1, dummy_frame2) - + # Make batch, height, and width dimensions dynamic dynamic_axes = { "frame1": {0: "batch_size", 2: "height", 3: "width"}, "frame2": {0: "batch_size", 2: "height", 3: "width"}, "flow": {0: "batch_size", 2: "height", 3: "width"}, } - + logger.info("Exporting to ONNX...") with torch.no_grad(): torch.onnx.export( @@ -82,22 +85,23 @@ def export_raft_to_onnx( (dummy_frame1, dummy_frame2), str(onnx_path), verbose=False, - input_names=['frame1', 'frame2'], - output_names=['flow'], + input_names=["frame1", "frame2"], + output_names=["flow"], opset_version=17, export_params=True, dynamic_axes=dynamic_axes, ) - + del raft_model torch.cuda.empty_cache() - + logger.info(f"Successfully exported ONNX model to {onnx_path}") return True - + except Exception as e: logger.error(f"Failed to export ONNX model: {e}") import traceback + traceback.print_exc() return False @@ -110,11 +114,11 @@ def build_tensorrt_engine( max_height: int = 512, max_width: int = 512, fp16: bool = True, - workspace_size_gb: int = 4 + workspace_size_gb: int = 4, ) -> bool: """ Build TensorRT engine from ONNX model - + Args: onnx_path: Path to the ONNX model engine_path: Path to save the TensorRT engine @@ -124,74 +128,74 @@ def build_tensorrt_engine( max_width: Maximum input width for optimization fp16: Enable FP16 precision mode workspace_size_gb: Maximum workspace size in GB - + Returns: True if successful, False otherwise """ if not TENSORRT_AVAILABLE: logger.error("TensorRT is required but not installed") return False - + if not onnx_path.exists(): logger.error(f"ONNX model not found: {onnx_path}") return False - + logger.info(f"Building TensorRT engine from ONNX model: {onnx_path}") logger.info(f"Output path: {engine_path}") logger.info(f"Resolution range: {min_height}x{min_width} - {max_height}x{max_width}") logger.info(f"FP16 mode: {fp16}") logger.info("This may take several minutes...") - + try: builder = trt.Builder(trt.Logger(trt.Logger.INFO)) network = builder.create_network() # EXPLICIT_BATCH deprecated/ignored in TRT 10.x parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) - + logger.info("Parsing ONNX model...") - with open(onnx_path, 'rb') as model: + with open(onnx_path, "rb") as model: if not parser.parse(model.read()): logger.error("Failed to parse ONNX model") for error in range(parser.num_errors): logger.error(f"Parser error: {parser.get_error(error)}") return False - + logger.info("Configuring TensorRT builder...") config = builder.create_builder_config() - + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_size_gb * (1 << 30)) - + if fp16: config.set_flag(trt.BuilderFlag.FP16) logger.info("FP16 mode enabled") - + # Calculate optimal resolution (middle point) opt_height = (min_height + max_height) // 2 opt_width = (min_width + max_width) // 2 - + profile = builder.create_optimization_profile() min_shape = (1, 3, min_height, min_width) opt_shape = (1, 3, opt_height, opt_width) max_shape = (1, 3, max_height, max_width) - + profile.set_shape("frame1", min_shape, opt_shape, max_shape) profile.set_shape("frame2", min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) - + logger.info("Building TensorRT engine... (this will take a while)") engine = builder.build_serialized_network(network, config) - + if engine is None: logger.error("Failed to build TensorRT engine") return False - + logger.info(f"Saving engine to {engine_path}") engine_path.parent.mkdir(parents=True, exist_ok=True) - with open(engine_path, 'wb') as f: + with open(engine_path, "wb") as f: f.write(engine) - + logger.info(f"Successfully built and saved TensorRT engine: {engine_path}") - logger.info(f"Engine size: {engine_path.stat().st_size / (1024*1024):.2f} MB") - + logger.info(f"Engine size: {engine_path.stat().st_size / (1024 * 1024):.2f} MB") + # Delete ONNX file after successful engine creation try: if onnx_path.exists(): @@ -199,12 +203,13 @@ def build_tensorrt_engine( logger.info(f"Deleted ONNX file: {onnx_path}") except Exception as e: logger.warning(f"Failed to delete ONNX file: {e}") - + return True - + except Exception as e: logger.error(f"Failed to build TensorRT engine: {e}") import traceback + traceback.print_exc() return False @@ -216,11 +221,11 @@ def compile_raft( device: str = "cuda", fp16: bool = True, workspace_size_gb: int = 4, - force_rebuild: bool = False + force_rebuild: bool = False, ): """ Main function to compile RAFT model to TensorRT engine - + Args: min_resolution: Minimum input resolution as "HxW" (e.g., "512x512") (default: "512x512") max_resolution: Maximum input resolution as "HxW" (e.g., "1024x1024") (default: "512x512") @@ -234,46 +239,46 @@ def compile_raft( logger.error("TensorRT is not available. Please install it first using:") logger.error(" python -m streamdiffusion.tools.install-tensorrt") return - + if not TORCHVISION_AVAILABLE: logger.error("torchvision is not available. Please install it first using:") logger.error(" pip install torchvision") return - + # Parse resolution strings try: - min_height, min_width = map(int, min_resolution.split('x')) + min_height, min_width = map(int, min_resolution.split("x")) except: logger.error(f"Invalid min_resolution format: {min_resolution}. Expected format: HxW (e.g., 512x512)") return - + try: - max_height, max_width = map(int, max_resolution.split('x')) + max_height, max_width = map(int, max_resolution.split("x")) except: logger.error(f"Invalid max_resolution format: {max_resolution}. Expected format: HxW (e.g., 1024x1024)") return - + output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) - + # Add resolution suffix to filenames onnx_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.onnx" engine_path = output_path / f"raft_small_min_{min_resolution}_max_{max_resolution}.engine" - - logger.info("="*80) + + logger.info("=" * 80) logger.info("RAFT TensorRT Compilation") - logger.info("="*80) + logger.info("=" * 80) logger.info(f"Output directory: {output_path.absolute()}") logger.info(f"Resolution range: {min_resolution} - {max_resolution}") logger.info(f"ONNX path: {onnx_path}") logger.info(f"Engine path: {engine_path}") - logger.info("="*80) - + logger.info("=" * 80) + if engine_path.exists() and not force_rebuild: logger.info(f"TensorRT engine already exists: {engine_path}") logger.info("Use --force_rebuild to rebuild it") return - + if not onnx_path.exists() or force_rebuild: logger.info("\n[Step 1/2] Exporting RAFT to ONNX...") if not export_raft_to_onnx(onnx_path, min_height, min_width, max_height, max_width, device): @@ -281,21 +286,22 @@ def compile_raft( return else: logger.info(f"\n[Step 1/2] ONNX model already exists: {onnx_path}") - + logger.info("\n[Step 2/2] Building TensorRT engine...") - if not build_tensorrt_engine(onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb): + if not build_tensorrt_engine( + onnx_path, engine_path, min_height, min_width, max_height, max_width, fp16, workspace_size_gb + ): logger.error("Failed to build TensorRT engine") return - - logger.info("\n" + "="*80) + + logger.info("\n" + "=" * 80) logger.info("✓ Compilation completed successfully!") - logger.info("="*80) + logger.info("=" * 80) logger.info(f"Engine path: {engine_path.absolute()}") logger.info("\nYou can now use this engine in TemporalNetTensorRTPreprocessor:") logger.info(f' engine_path="{engine_path.absolute()}"') - logger.info("="*80) + logger.info("=" * 80) if __name__ == "__main__": fire.Fire(compile_raft) - diff --git a/src/streamdiffusion/utils/__init__.py b/src/streamdiffusion/utils/__init__.py index 00ff7cf7d..b40413d24 100644 --- a/src/streamdiffusion/utils/__init__.py +++ b/src/streamdiffusion/utils/__init__.py @@ -1,5 +1,6 @@ from .reporting import report_error + __all__ = [ "report_error", -] \ No newline at end of file +] diff --git a/src/streamdiffusion/utils/reporting.py b/src/streamdiffusion/utils/reporting.py index 44838d9c8..25e650c6f 100644 --- a/src/streamdiffusion/utils/reporting.py +++ b/src/streamdiffusion/utils/reporting.py @@ -25,5 +25,3 @@ def report_error( stacklevel=stacklevel, extra={"report_error": True}, ) - - diff --git a/tests/quality/README.md b/tests/quality/README.md new file mode 100644 index 000000000..2dd95422c --- /dev/null +++ b/tests/quality/README.md @@ -0,0 +1,77 @@ +# Quality Regression Harness + +Compares FP8-TRT output against FP16-TRT golden reference images using SSIM + LPIPS. + +## First-time setup + +```powershell +# Install dev deps (lpips, scikit-image) +pip install -r requirements-dev.txt + +# Build FP16-TRT engines for both fixture models (if not already cached) +python -m streamdiffusion.acceleration.tensorrt.build --model stabilityai/sd-turbo +python -m streamdiffusion.acceleration.tensorrt.build --model stabilityai/sdxl-turbo + +# Generate goldens + update manifest +python tests/quality/regenerate_golden.py --update-manifest + +# Build FP8-TRT engines +python -m streamdiffusion.acceleration.tensorrt.build --fp8 --model stabilityai/sd-turbo +python -m streamdiffusion.acceleration.tensorrt.build --fp8 --model stabilityai/sdxl-turbo + +# Seed thresholds from FP8 baseline (run once, check in thresholds.yaml) +python tests/quality/run_compare.py --baseline +``` + +## Running the harness + +```powershell +# Full comparison (both fixtures) +python tests/quality/run_compare.py + +# Single fixture +python tests/quality/run_compare.py --fixture sdxl_turbo_img2img_plain +``` + +## Exit codes + +| Code | Meaning | +|------|---------| +| 0 | All fixtures pass SSIM/LPIPS thresholds | +| 1 | One or more fixtures fail thresholds | +| 2 | Manifest version mismatch — results would be meaningless | +| 3 | Golden PNG missing — run `regenerate_golden.py --update-manifest` first | + +## Refreshing goldens after a dep bump + +```powershell +pip install -r requirements.txt -r requirements-dev.txt # upgrade deps +python tests/quality/regenerate_golden.py --update-manifest # regenerate + re-hash +python tests/quality/run_compare.py --baseline # re-seed thresholds +``` + +## File layout + +``` +tests/quality/ +├── run_compare.py # orchestrator — runs FP8, computes metrics, checks thresholds +├── regenerate_golden.py # generates FP16-TRT goldens and updates manifest hashes +├── manifest.json # pinned dep versions + golden sha256 (abort on mismatch) +├── thresholds.yaml # SSIM/LPIPS floors per fixture (seed with --baseline) +├── fixtures/ # fixture parameter files +│ ├── sd_turbo_img2img_plain.json +│ └── sdxl_turbo_img2img_plain.json +├── goldens/ # FP16-TRT reference PNGs (generated, then checked in) +│ ├── sd_turbo_img2img_plain.png +│ └── sdxl_turbo_img2img_plain.png +└── outputs/ # generated on run — gitignored + ├── *_fp8.png # FP8 output for each fixture + ├── *_comparison.png # side-by-side comparison + └── report.json # SSIM/LPIPS results +``` + +## Phase 3 gate + +This harness is the prerequisite for Phase 3 (feature-aware calibration + Q/DQ exclusion +narrowing). Do not narrow exclusions until both fixtures pass at stable thresholds across +at least two independent runs. diff --git a/tests/quality/fixtures/sd_turbo_img2img_plain.json b/tests/quality/fixtures/sd_turbo_img2img_plain.json new file mode 100644 index 000000000..a83b29d69 --- /dev/null +++ b/tests/quality/fixtures/sd_turbo_img2img_plain.json @@ -0,0 +1,15 @@ +{ + "model_id": "stabilityai/sd-turbo", + "prompt": "a portrait of a young woman with long brown hair, soft studio lighting, photorealistic", + "negative_prompt": "", + "seed": 2, + "t_index_list": [32, 45], + "width": 512, + "height": 512, + "cfg_type": "none", + "guidance_scale": 1.0, + "num_inference_steps": 50, + "delta": 1.0, + "warmup": 1, + "use_denoising_batch": true +} diff --git a/tests/quality/fixtures/sdxl_turbo_img2img_plain.json b/tests/quality/fixtures/sdxl_turbo_img2img_plain.json new file mode 100644 index 000000000..e4dc05c73 --- /dev/null +++ b/tests/quality/fixtures/sdxl_turbo_img2img_plain.json @@ -0,0 +1,15 @@ +{ + "model_id": "stabilityai/sdxl-turbo", + "prompt": "a portrait of a young woman with long brown hair, soft studio lighting, photorealistic", + "negative_prompt": "", + "seed": 2, + "t_index_list": [32, 45], + "width": 512, + "height": 512, + "cfg_type": "none", + "guidance_scale": 1.0, + "num_inference_steps": 50, + "delta": 1.0, + "warmup": 1, + "use_denoising_batch": true +} diff --git a/tests/quality/goldens/.gitkeep b/tests/quality/goldens/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/tests/quality/manifest.json b/tests/quality/manifest.json new file mode 100644 index 000000000..47ca14292 --- /dev/null +++ b/tests/quality/manifest.json @@ -0,0 +1,29 @@ +{ + "_note": "Version pins for the quality regression harness. Run regenerate_golden.py --update-manifest after any dep bump to refresh hashes. run_compare.py aborts if installed versions diverge.", + "versions": { + "torch": "2.8.0+cu128", + "tensorrt": "10.16.1.11", + "nvidia_modelopt": "0.43.0", + "diffusers": "0.38.0" + }, + "fixtures": { + "sd_turbo_img2img_plain": { + "model_id": "stabilityai/sd-turbo", + "seed": 2, + "t_index_list": [ + 32, + 45 + ], + "golden_sha256": "9056eadcbfa8637c9ad3aaaa09766ccf69926fc87ac7730219e82908e33c8d97" + }, + "sdxl_turbo_img2img_plain": { + "model_id": "stabilityai/sdxl-turbo", + "seed": 2, + "t_index_list": [ + 32, + 45 + ], + "golden_sha256": "8e6b61fde877ed4752e297c33364ff7bf616376fb3be68b55f4df311e58bdbc3" + } + } +} \ No newline at end of file diff --git a/tests/quality/regenerate_golden.py b/tests/quality/regenerate_golden.py new file mode 100644 index 000000000..7c86fd449 --- /dev/null +++ b/tests/quality/regenerate_golden.py @@ -0,0 +1,156 @@ +"""Generate FP16-TRT golden reference images for the quality regression harness. + +Usage: + python tests/quality/regenerate_golden.py + python tests/quality/regenerate_golden.py --update-manifest + python tests/quality/regenerate_golden.py --fixture sdxl_turbo_img2img_plain + +Goldens are FP16-TRT engine outputs (not bare PyTorch FP16) so they catch +TRT-specific regressions that matter in production. + +--update-manifest Also updates tests/quality/manifest.json with sha256 hashes + of the generated goldens and current installed dep versions. +""" + +import argparse +import hashlib +import json +import logging +import os +import sys + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from streamdiffusion import StreamDiffusionWrapper + + +logger = logging.getLogger("quality.regenerate") + +TESTS_QUALITY_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.join(TESTS_QUALITY_DIR, "..", "..") +INPUT_IMAGE = os.path.join(REPO_ROOT, "images", "inputs", "input.png") +GOLDENS_DIR = os.path.join(TESTS_QUALITY_DIR, "goldens") +FIXTURES_DIR = os.path.join(TESTS_QUALITY_DIR, "fixtures") +MANIFEST_PATH = os.path.join(TESTS_QUALITY_DIR, "manifest.json") +THRESHOLDS_PATH = os.path.join(TESTS_QUALITY_DIR, "thresholds.yaml") + + +def _sha256(path: str) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + +def _current_versions() -> dict: + import importlib.metadata + + versions = {} + for pkg, key in [ + ("torch", "torch"), + ("tensorrt", "tensorrt"), + ("nvidia-modelopt", "nvidia_modelopt"), + ("diffusers", "diffusers"), + ]: + try: + versions[key] = importlib.metadata.version(pkg) + except Exception: + # TRT on Windows is often installed via NVIDIA's custom mechanism + # with no .dist-info — fall back to the package's __version__. + if pkg == "tensorrt": + try: + import tensorrt as _trt + + versions[key] = _trt.__version__ + except Exception: + versions[key] = "unknown" + else: + versions[key] = "unknown" + return versions + + +def run_fixture(fixture_name: str, fixture: dict) -> str: + """Run FP16-TRT inference for one fixture and return the saved golden path.""" + golden_path = os.path.join(GOLDENS_DIR, f"{fixture_name}.png") + os.makedirs(GOLDENS_DIR, exist_ok=True) + + logger.info(f"[{fixture_name}] Running FP16-TRT inference → {golden_path}") + + stream = StreamDiffusionWrapper( + model_id_or_path=fixture["model_id"], + t_index_list=fixture["t_index_list"], + frame_buffer_size=1, + width=fixture["width"], + height=fixture["height"], + warmup=fixture.get("warmup", 1), + acceleration="tensorrt", + mode="img2img", + use_denoising_batch=fixture.get("use_denoising_batch", True), + cfg_type=fixture.get("cfg_type", "none"), + seed=fixture["seed"], + ) + stream.prepare( + prompt=fixture["prompt"], + negative_prompt=fixture.get("negative_prompt", ""), + num_inference_steps=fixture.get("num_inference_steps", 50), + guidance_scale=fixture.get("guidance_scale", 1.0), + delta=fixture.get("delta", 1.0), + ) + + image_tensor = stream.preprocess_image(INPUT_IMAGE) + for _ in range(stream.batch_size - 1): + stream(image=image_tensor) + output_image = stream(image=image_tensor) + output_image.save(golden_path) + + sha = _sha256(golden_path) + logger.info(f"[{fixture_name}] Golden saved: {golden_path} sha256={sha[:16]}...") + return sha + + +def main(): + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument( + "--update-manifest", action="store_true", help="Update manifest.json with current versions + golden hashes" + ) + parser.add_argument("--fixture", default=None, help="Run only this fixture (default: all)") + args = parser.parse_args() + + fixture_files = sorted(f for f in os.listdir(FIXTURES_DIR) if f.endswith(".json")) + if args.fixture: + fixture_files = [f"{args.fixture}.json"] + + results = {} + for fname in fixture_files: + name = fname[:-5] + with open(os.path.join(FIXTURES_DIR, fname)) as fp: + fixture = json.load(fp) + sha = run_fixture(name, fixture) + results[name] = sha + + if args.update_manifest: + with open(MANIFEST_PATH) as fp: + manifest = json.load(fp) + manifest["versions"] = _current_versions() + for name, sha in results.items(): + if name not in manifest["fixtures"]: + manifest["fixtures"][name] = {} + manifest["fixtures"][name]["golden_sha256"] = sha + manifest["fixtures"][name].pop("_note", None) + with open(MANIFEST_PATH, "w") as fp: + json.dump(manifest, fp, indent=4) + logger.info(f"Manifest updated: {MANIFEST_PATH}") + + # Seed thresholds from FP8 baseline if engines already exist, + # otherwise leave placeholders and print instructions. + print( + "\nThreshold seeding: run_compare.py --baseline to measure FP8 vs these goldens " + "and seed thresholds.yaml with baseline - 0.02 (SSIM) / + 0.05 (LPIPS)." + ) + + +if __name__ == "__main__": + main() diff --git a/tests/quality/run_compare.py b/tests/quality/run_compare.py new file mode 100644 index 000000000..cdaba96c8 --- /dev/null +++ b/tests/quality/run_compare.py @@ -0,0 +1,252 @@ +"""Quality regression harness: compare FP8-TRT output against FP16-TRT goldens. + +Exit codes: + 0 All fixtures pass SSIM/LPIPS thresholds + 1 One or more fixtures fail thresholds + 2 Manifest version mismatch (abort — results would be meaningless) + 3 Golden PNG missing — run regenerate_golden.py --update-manifest first + +Usage: + python tests/quality/run_compare.py + python tests/quality/run_compare.py --fixture sdxl_turbo_img2img_plain + python tests/quality/run_compare.py --baseline # seed thresholds.yaml + python tests/quality/run_compare.py --skip-manifest-check # bypass version pins +""" + +import argparse +import hashlib +import json +import logging +import os +import sys + +import yaml + + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +logger = logging.getLogger("quality.run_compare") + +TESTS_QUALITY_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.join(TESTS_QUALITY_DIR, "..", "..") +INPUT_IMAGE = os.path.join(REPO_ROOT, "images", "inputs", "input.png") +GOLDENS_DIR = os.path.join(TESTS_QUALITY_DIR, "goldens") +FIXTURES_DIR = os.path.join(TESTS_QUALITY_DIR, "fixtures") +MANIFEST_PATH = os.path.join(TESTS_QUALITY_DIR, "manifest.json") +THRESHOLDS_PATH = os.path.join(TESTS_QUALITY_DIR, "thresholds.yaml") +OUTPUTS_DIR = os.path.join(TESTS_QUALITY_DIR, "outputs") + + +def _sha256(path: str) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + h.update(chunk) + return h.hexdigest() + + +def _installed_version(pkg: str) -> str: + import importlib.metadata + + try: + return importlib.metadata.version(pkg) + except Exception: + if pkg == "tensorrt": + try: + import tensorrt as _trt + + return _trt.__version__ + except Exception: + pass + return "unknown" + + +def _check_manifest(manifest: dict) -> list[str]: + """Return list of version mismatch messages (empty = all match).""" + pkg_map = { + "torch": "torch", + "tensorrt": "tensorrt", + "nvidia_modelopt": "nvidia-modelopt", + "diffusers": "diffusers", + } + mismatches = [] + for key, pkg in pkg_map.items(): + pinned = manifest["versions"].get(key, "unknown") + installed = _installed_version(pkg) + if pinned != "unknown" and installed != "unknown" and pinned != installed: + mismatches.append(f" {key}: manifest={pinned} installed={installed}") + return mismatches + + +def _run_fp8_inference(fixture_name: str, fixture: dict, output_path: str) -> None: + from streamdiffusion import StreamDiffusionWrapper + + stream = StreamDiffusionWrapper( + model_id_or_path=fixture["model_id"], + t_index_list=fixture["t_index_list"], + frame_buffer_size=1, + width=fixture["width"], + height=fixture["height"], + warmup=fixture.get("warmup", 1), + acceleration="tensorrt", + mode="img2img", + use_denoising_batch=fixture.get("use_denoising_batch", True), + cfg_type=fixture.get("cfg_type", "none"), + seed=fixture["seed"], + fp8=True, + ) + stream.prepare( + prompt=fixture["prompt"], + negative_prompt=fixture.get("negative_prompt", ""), + num_inference_steps=fixture.get("num_inference_steps", 50), + guidance_scale=fixture.get("guidance_scale", 1.0), + delta=fixture.get("delta", 1.0), + ) + image_tensor = stream.preprocess_image(INPUT_IMAGE) + for _ in range(stream.batch_size - 1): + stream(image=image_tensor) + output_image = stream(image=image_tensor) + output_image.save(output_path) + + +def _compute_metrics(golden_path: str, output_path: str) -> dict: + import numpy as np + from PIL import Image + from skimage.metrics import structural_similarity as ssim_fn + + golden = np.array(Image.open(golden_path).convert("RGB"), dtype=np.float32) / 255.0 + output = np.array(Image.open(output_path).convert("RGB"), dtype=np.float32) / 255.0 + + ssim_val = float(ssim_fn(golden, output, data_range=1.0, channel_axis=2)) + + try: + import lpips + import torch + + loss_fn = lpips.LPIPS(net="alex", verbose=False) + g_t = torch.from_numpy(golden).permute(2, 0, 1).unsqueeze(0) * 2 - 1 + o_t = torch.from_numpy(output).permute(2, 0, 1).unsqueeze(0) * 2 - 1 + with torch.no_grad(): + lpips_val = float(loss_fn(g_t, o_t).item()) + except Exception as e: + logger.warning(f"LPIPS computation failed: {e}. Using placeholder 0.0.") + lpips_val = 0.0 + + return {"ssim": ssim_val, "lpips": lpips_val} + + +def _make_comparison_image(golden_path: str, output_path: str, comparison_path: str) -> None: + from PIL import Image, ImageDraw + + golden = Image.open(golden_path).convert("RGB") + output = Image.open(output_path).convert("RGB") + w, h = golden.width, golden.height + canvas = Image.new("RGB", (w * 2 + 10, h + 24), (30, 30, 30)) + canvas.paste(golden, (0, 24)) + canvas.paste(output, (w + 10, 24)) + draw = ImageDraw.Draw(canvas) + draw.text((4, 4), "FP16-TRT golden", fill=(200, 200, 200)) + draw.text((w + 14, 4), "FP8-TRT output", fill=(200, 200, 200)) + canvas.save(comparison_path) + + +def main(): + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s: %(message)s") + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("--fixture", default=None, help="Run only this fixture (default: all)") + parser.add_argument("--baseline", action="store_true", help="Seed thresholds.yaml from this run's SSIM/LPIPS") + parser.add_argument("--skip-manifest-check", action="store_true", help="Bypass version pin check") + args = parser.parse_args() + + with open(MANIFEST_PATH) as fp: + manifest = json.load(fp) + with open(THRESHOLDS_PATH) as fp: + thresholds = yaml.safe_load(fp) + + if not args.skip_manifest_check: + mismatches = _check_manifest(manifest) + if mismatches: + print("ERROR: Manifest version mismatch — results would be meaningless.\n") + print("\n".join(mismatches)) + print("\nRun regenerate_golden.py --update-manifest to refresh the manifest.") + sys.exit(2) + + fixture_files = sorted(f for f in os.listdir(FIXTURES_DIR) if f.endswith(".json")) + if args.fixture: + fixture_files = [f"{args.fixture}.json"] + + os.makedirs(OUTPUTS_DIR, exist_ok=True) + results = {} + all_pass = True + + for fname in fixture_files: + name = fname[:-5] + with open(os.path.join(FIXTURES_DIR, fname)) as fp: + fixture = json.load(fp) + + golden_path = os.path.join(GOLDENS_DIR, f"{name}.png") + if not os.path.exists(golden_path): + print(f"ERROR: Golden not found for {name}: {golden_path}") + print("Run: python tests/quality/regenerate_golden.py --update-manifest") + sys.exit(3) + + golden_sha = manifest["fixtures"].get(name, {}).get("golden_sha256") + if golden_sha is not None: + actual_sha = _sha256(golden_path) + if actual_sha != golden_sha: + print(f"WARNING: Golden sha256 mismatch for {name}. File may be corrupted or outdated.") + + output_path = os.path.join(OUTPUTS_DIR, f"{name}_fp8.png") + logger.info(f"[{name}] Running FP8-TRT inference...") + _run_fp8_inference(name, fixture, output_path) + + logger.info(f"[{name}] Computing metrics...") + metrics = _compute_metrics(golden_path, output_path) + results[name] = metrics + + comparison_path = os.path.join(OUTPUTS_DIR, f"{name}_comparison.png") + _make_comparison_image(golden_path, output_path, comparison_path) + + thresh = thresholds.get("fixtures", {}).get(name, {}) + ssim_min = thresh.get("ssim_min", 0.0) + lpips_max = thresh.get("lpips_max", 1.0) + passed = metrics["ssim"] >= ssim_min and metrics["lpips"] <= lpips_max + status = "PASS" if passed else "FAIL" + if not passed: + all_pass = False + + print( + f"[{name}] {status} SSIM={metrics['ssim']:.4f} (min={ssim_min}) " + f"LPIPS={metrics['lpips']:.4f} (max={lpips_max})" + ) + logger.info(f"[{name}] Comparison: {comparison_path}") + + report_path = os.path.join(OUTPUTS_DIR, "report.json") + with open(report_path, "w") as fp: + json.dump(results, fp, indent=2) + logger.info(f"Report: {report_path}") + + if args.baseline: + for name, metrics in results.items(): + if name not in thresholds.get("fixtures", {}): + thresholds.setdefault("fixtures", {})[name] = {} + thresholds["fixtures"][name]["ssim_min"] = round(metrics["ssim"] - 0.02, 4) + thresholds["fixtures"][name]["lpips_max"] = round(metrics["lpips"] + 0.05, 4) + thresholds["fixtures"][name].pop("_note", None) + with open(THRESHOLDS_PATH, "w") as fp: + yaml.dump(thresholds, fp, default_flow_style=False, sort_keys=False) + print(f"\nThresholds seeded → {THRESHOLDS_PATH}") + + n = len(results) + n_pass = sum( + 1 + for name, m in results.items() + if m["ssim"] >= thresholds.get("fixtures", {}).get(name, {}).get("ssim_min", 0.0) + and m["lpips"] <= thresholds.get("fixtures", {}).get(name, {}).get("lpips_max", 1.0) + ) + print(f"\n{n_pass}/{n} fixtures pass thresholds.") + sys.exit(0 if all_pass else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/quality/test_fp8_calib_tile.py b/tests/quality/test_fp8_calib_tile.py new file mode 100644 index 000000000..abe577cb3 --- /dev/null +++ b/tests/quality/test_fp8_calib_tile.py @@ -0,0 +1,81 @@ +"""Regression: per-input-aware calibration tiling for FP8 quantize. + +Reproduces the kvo_cache_in dim-0=2 vs synthesized dim-0=1 split mismatch +hit by SDXL-Turbo + use_cached_attn + cfg_type=self configs. +""" + +import math + +import numpy as np +import onnx +from onnx import TensorProto, helper + + +def _make_min_onnx(path): + """Minimal 2-input ONNX: symbolic-batch 'sample' + static-dim0=2 'kvo_cache_in_0'.""" + sample = helper.make_tensor_value_info("sample", TensorProto.FLOAT, ["2B", 4, 64, 64]) + kvo = helper.make_tensor_value_info("kvo_cache_in_0", TensorProto.FLOAT, [2, 4, "2B", 64, 64]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, ["2B", 4, 64, 64]) + ident = helper.make_node("Identity", inputs=["sample"], outputs=["out"]) + g = helper.make_graph([ident], "min", [sample, kvo], [out]) + onnx.save(helper.make_model(g, opset_imports=[helper.make_opsetid("", 17)]), path) + + +def test_per_input_tile_preserves_static_dim0(tmp_path): + """Per-input tile: sample stays at n_itr rows, kvo stays at 2×n_itr rows.""" + onnx_path = str(tmp_path / "min.onnx") + _make_min_onnx(onnx_path) + + calib = { + "sample": np.zeros((5, 4, 64, 64), dtype=np.float32), + "kvo_cache_in_0": np.zeros((10, 4, 5, 64, 64), dtype=np.float32), + } + + from streamdiffusion.acceleration.tensorrt.fp8_quantize import _read_onnx_input_specs + + specs = _read_onnx_input_specs(onnx_path) + + # Reproduce the fixed logic + resolved_dim0 = {name: max(1, (specs[name][1][0] or 1)) for name in calib} + n_itr = max(arr.shape[0] // resolved_dim0[name] for name, arr in calib.items()) + n_itr = max(1, n_itr) + out = {} + for k, arr in calib.items(): + target = n_itr * resolved_dim0[k] + if arr.shape[0] != target: + repeats = math.ceil(target / max(1, arr.shape[0])) + arr = np.tile(arr, (repeats,) + (1,) * (arr.ndim - 1))[:target] + out[k] = arr + + # n_itr=5, resolved_dim0(sample)=1 → 5 rows; resolved_dim0(kvo)=2 → 10 rows + assert out["sample"].shape == (5, 4, 64, 64) + assert out["kvo_cache_in_0"].shape == (10, 4, 5, 64, 64) + + # Verify modelopt split math: n_itr chunks each of shape (resolved_dim0, ...) + sample_chunks = np.array_split(out["sample"], n_itr, axis=0) + kvo_chunks = np.array_split(out["kvo_cache_in_0"], n_itr, axis=0) + assert sample_chunks[0].shape[0] == 1 + assert kvo_chunks[0].shape[0] == 2 # static dim 0 must be preserved + + +def test_naive_max_rows_tile_would_break(tmp_path): + """Confirms the OLD naïve tile produces the 'Got 1 Expected 2' symptom.""" + onnx_path = str(tmp_path / "min.onnx") + _make_min_onnx(onnx_path) + + calib = { + "sample": np.zeros((5, 4, 64, 64), dtype=np.float32), + "kvo_cache_in_0": np.zeros((10, 4, 5, 64, 64), dtype=np.float32), + } + + # Reproduce the buggy logic + _max_rows = max(a.shape[0] for a in calib.values()) + for k, a in list(calib.items()): + if a.shape[0] < _max_rows: + calib[k] = np.tile(a, (math.ceil(_max_rows / a.shape[0]),) + (1,) * (a.ndim - 1))[:_max_rows] + + # modelopt: n_itr = sample.shape[0] / symbolic_dim0(1) = 10 + # splits kvo into 10 chunks → each has shape[0]=1 → ORT rejects (expected 2) + n_itr_bad = calib["sample"].shape[0] # 10 (doubled by naïve tile) + kvo_chunk = np.array_split(calib["kvo_cache_in_0"], n_itr_bad, axis=0)[0] + assert kvo_chunk.shape[0] == 1 # this is the "Got 1 Expected 2" symptom diff --git a/tests/quality/thresholds.yaml b/tests/quality/thresholds.yaml new file mode 100644 index 000000000..d4889aced --- /dev/null +++ b/tests/quality/thresholds.yaml @@ -0,0 +1,7 @@ +fixtures: + sd_turbo_img2img_plain: + ssim_min: 0.9649 + lpips_max: 0.0586 + sdxl_turbo_img2img_plain: + ssim_min: 0.9604 + lpips_max: 0.0609 diff --git a/utils/viewer.py b/utils/viewer.py index dd6f6cad9..2bd90984b 100644 --- a/utils/viewer.py +++ b/utils/viewer.py @@ -3,11 +3,13 @@ import threading import time import tkinter as tk -from multiprocessing import Queue -from typing import List +from multiprocessing import Queue + from PIL import Image, ImageTk + from streamdiffusion.image_utils import postprocess_image + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) @@ -28,9 +30,8 @@ def update_image(image_data: Image.Image, label: tk.Label) -> None: label.configure(image=tk_image, width=width, height=height) label.image = tk_image # keep a reference -def _receive_images( - queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label -) -> None: + +def _receive_images(queue: Queue, fps_queue: Queue, label: tk.Label, fps_label: tk.Label) -> None: """ Continuously receive images from a queue and update the labels. @@ -85,9 +86,7 @@ def on_closing(): root.quit() # stop event loop return - thread = threading.Thread( - target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True - ) + thread = threading.Thread(target=_receive_images, args=(queue, fps_queue, label, fps_label), daemon=True) thread.start() try: @@ -95,4 +94,3 @@ def on_closing(): root.mainloop() except KeyboardInterrupt: return -