From a8296fbac9fbed5dc8a7d53d9f0ac932659b5043 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 11:13:16 +0000 Subject: [PATCH 01/27] Add Cosmos3 action generation support --- examples/cosmos3/inference_cosmos3.py | 106 +++- .../transformers/transformer_cosmos3.py | 132 +++- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 584 ++++++++++++++++-- 3 files changed, 755 insertions(+), 67 deletions(-) diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index fd0d0537cb0e..675ead892c2f 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -23,13 +23,15 @@ """ import argparse +import json import pathlib +import urllib.request import torch from huggingface_hub import snapshot_download from diffusers import Cosmos3OmniPipeline -from diffusers.utils import encode_video, export_to_video, load_image +from diffusers.utils import encode_video, export_to_video, load_image, load_video HF_REPOS = { @@ -38,6 +40,22 @@ } +def _load_action(path: str | None): + if path is None: + raise ValueError("--action-path is required for forward_dynamics mode.") + if path.startswith(("http://", "https://")): + with urllib.request.urlopen(path) as response: + action = json.loads(response.read().decode("utf-8")) + else: + action = json.loads(pathlib.Path(path).read_text()) + tensor = torch.as_tensor(action, dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2: + raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.") + return tensor + + def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--prompt", required=True, help="Text prompt.") @@ -50,7 +68,7 @@ def main(): parser.add_argument( "--vision-path", default=None, - help="Optional URL or local path for an image-conditioning frame (enables image-to-video).", + help="Optional URL or local path for an image-conditioning frame, or an action conditioning video.", ) parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.") parser.add_argument("--height", type=int, default=720) @@ -62,12 +80,26 @@ def main(): help="Number of frames to generate. Use 1 for text-to-image; defaults to 189 for video (≈ 7.9s @ 24 FPS).", ) parser.add_argument("--fps", type=float, default=24.0) + parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.") + parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.") + parser.add_argument("--flow-shift", type=float, default=None, help="Scheduler flow shift.") + parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") parser.add_argument( "--enable-sound", action="store_true", default=False, help="Generate sound alongside video (requires a sound-capable checkpoint).", ) + parser.add_argument( + "--action-mode", + choices=["forward_dynamics", "inverse_dynamics", "policy"], + default=None, + help="Enable Cosmos3 action generation with a loaded conditioning video.", + ) + parser.add_argument("--action-path", default=None, help="JSON action path for forward_dynamics mode.") + parser.add_argument("--action-chunk-size", type=int, default=None, help="Number of action tokens to generate/use.") + parser.add_argument("--domain-name", default=None, help="Cosmos3 action embodiment domain name.") + parser.add_argument("--raw-action-dim", type=int, default=None, help="Slice predicted action output to this size.") parser.add_argument( "--no-duration-template", dest="add_duration_template", @@ -110,21 +142,54 @@ def main(): output_dir = pathlib.Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) - - image = load_image(args.vision_path) if args.vision_path is not None else None - - result = pipeline( - prompt=args.prompt, - image=image, - num_frames=args.num_frames, - height=args.height, - width=args.width, - fps=args.fps, - enable_sound=args.enable_sound, - add_resolution_template=args.add_resolution_template, - add_duration_template=args.add_duration_template, - enable_safety_check=not args.no_safety_check, - ) + generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None + + if args.action_mode is not None: + if args.vision_path is None: + raise ValueError("--vision-path must point to a video for action modes.") + if args.action_chunk_size is None: + raise ValueError("--action-chunk-size is required for action modes.") + video = load_video(args.vision_path) + action = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None + result = pipeline( + prompt=args.prompt, + video=video, + num_frames=args.action_chunk_size + 1, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + flow_shift=args.flow_shift, + action_mode=args.action_mode, + action=action, + action_chunk_size=args.action_chunk_size, + domain_name=args.domain_name, + raw_action_dim=args.raw_action_dim, + guidance_scale=args.guidance_scale, + generator=generator, + use_system_prompt=False, + add_resolution_template=args.add_resolution_template, + add_duration_template=args.add_duration_template, + enable_safety_check=not args.no_safety_check, + ) + else: + image = load_image(args.vision_path) if args.vision_path is not None else None + result = pipeline( + prompt=args.prompt, + image=image, + num_frames=args.num_frames, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + flow_shift=args.flow_shift, + enable_sound=args.enable_sound, + guidance_scale=args.guidance_scale, + generator=generator, + add_resolution_template=args.add_resolution_template, + add_duration_template=args.add_duration_template, + enable_safety_check=not args.no_safety_check, + ) if args.num_frames == 1: save_path = output_dir / "sample.jpg" @@ -145,6 +210,13 @@ def main(): export_to_video(result.video, str(save_path), fps=int(args.fps), quality=10, macro_block_size=1) print(f"Saved: {save_path}") + if result.action is not None: + for action in result.action: + action_path = output_dir / "sample_action.json" + with open(action_path, "w") as f: + json.dump(action.tolist(), f) + print(f"Saved: {action_path}") + if __name__ == "__main__": main() diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 822d4f279e28..54fbe066ac33 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -146,6 +146,39 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +class DomainAwareLinear(nn.Module): + """Linear projection with one weight/bias pair per embodiment domain.""" + + def __init__(self, input_size: int, output_size: int, num_domains: int) -> None: + super().__init__() + self.input_size = int(input_size) + self.output_size = int(output_size) + self.num_domains = int(num_domains) + self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size) + self.bias = nn.Embedding(self.num_domains, self.output_size) + nn.init.xavier_uniform_(self.fc.weight) + nn.init.zeros_(self.bias.weight) + + def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: + if domain_id.ndim == 0: + domain_id = domain_id.unsqueeze(0) + domain_id = domain_id.to(device=x.device, dtype=torch.long).reshape(-1) + if x.shape[0] != domain_id.shape[0]: + raise ValueError( + "Cosmos3 action domain_id batch size must match action tokens: " + f"tokens={x.shape[0]}, domain_id={domain_id.shape[0]}." + ) + if torch.any((domain_id < 0) | (domain_id >= self.num_domains)): + raise ValueError(f"Cosmos3 action domain_id must be in [0, {self.num_domains}), got {domain_id.tolist()}.") + weight = self.fc(domain_id).view(domain_id.shape[0], self.input_size, self.output_size) + bias = self.bias(domain_id).view(domain_id.shape[0], self.output_size) + if x.ndim == 2: + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias + if x.ndim == 3: + return torch.bmm(x, weight) + bias.unsqueeze(1) + raise ValueError(f"Cosmos3 DomainAwareLinear expected rank-2 or rank-3 input, got {tuple(x.shape)}.") + + class Cosmos3PackedMoTAttention(nn.Module, AttentionModuleMixin): """Dual-pathway packed attention for Qwen3VL MoT — separate projections for understanding (causal) and generation (full) token streams.""" @@ -291,6 +324,9 @@ def __init__( rms_norm_eps: float = 1e-6, rope_scaling: dict | None = None, rope_theta: float = 5000000.0, + action_dim: int | None = None, + action_gen: bool = False, + num_embodiment_domains: int = 32, sound_dim: int | None = None, sound_gen: bool = False, sound_latent_fps: float = 25.0, @@ -333,6 +369,13 @@ def __init__( self.proj_out = nn.Linear(hidden_size, patch_latent_dim, bias=True) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) + self.action_gen = action_gen + self.action_dim = int(32 if action_dim is None else action_dim) + self.num_embodiment_domains = int(num_embodiment_domains) + if action_gen: + self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) + self.action_proj_out = DomainAwareLinear(hidden_size, self.action_dim, self.num_embodiment_domains) + self.action_modality_embed = nn.Parameter(torch.zeros(hidden_size)) if sound_gen: if sound_dim is None: raise ValueError("`sound_dim` must be provided when `sound_gen=True`.") @@ -464,9 +507,43 @@ def _unpack_sound_latents( unpacked.append(output) return unpacked + def _pack_action_latents( + self, + tokens_action: list[torch.Tensor], + token_shapes_action: list[tuple[int, int, int]], + domain_ids_action: list[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """List of ``[T, D]`` tensors → packed ``[total_T, D]`` plus per-token domain ids.""" + packed: list[torch.Tensor] = [] + domain_ids: list[torch.Tensor] = [] + for action, shape, domain_id in zip(tokens_action, token_shapes_action, domain_ids_action): + token_count = shape[0] + packed.append(action[:token_count]) + domain_ids.append(domain_id.reshape(1).expand(token_count)) + return torch.cat(packed, dim=0), torch.cat(domain_ids, dim=0) + + def _unpack_action_latents( + self, + packed_preds: torch.Tensor, + token_shapes_action: list[tuple[int, int, int]], + noisy_frame_indexes_action: list[torch.Tensor], + ) -> list[torch.Tensor]: + """Packed ``[total_noisy_T, D]`` predictions → list of ``[T, D]`` tensors.""" + unpacked: list[torch.Tensor] = [] + start_idx = 0 + for shape, noisy_idxs in zip(token_shapes_action, noisy_frame_indexes_action): + T = shape[0] + output = torch.zeros((T, self.action_dim), device=packed_preds.device, dtype=packed_preds.dtype) + t_n = len(noisy_idxs) + if t_n > 0: + output[noisy_idxs] = packed_preds[start_idx : start_idx + t_n] + start_idx += t_n + unpacked.append(output) + return unpacked + # ------------------------------------------------------------------------- - # forward: full per-step pass — encode text/vision/sound → run layers → - # decode vision/sound. Pipeline calls this once per CFG pass. + # forward: full per-step pass — encode text/vision/sound/action → run layers → + # decode vision/sound/action. Pipeline calls this once per CFG pass. # ------------------------------------------------------------------------- def forward( @@ -488,7 +565,14 @@ def forward( sound_mse_loss_indexes: torch.Tensor | None = None, sound_timesteps: torch.Tensor | None = None, sound_noisy_frame_indexes: list[torch.Tensor] | None = None, - ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]: + action_tokens: list[torch.Tensor] | None = None, + action_token_shapes: list[tuple[int, int, int]] | None = None, + action_sequence_indexes: torch.Tensor | None = None, + action_mse_loss_indexes: torch.Tensor | None = None, + action_timesteps: torch.Tensor | None = None, + action_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_domain_ids: list[torch.Tensor] | None = None, + ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None]: """Run a full denoising-step forward pass. Args: @@ -511,10 +595,11 @@ def forward( sound_noisy_frame_indexes: Optional noisy frame indices per sound item. Returns: - ``(preds_vision, preds_sound)`` — list of per-modality latents (``preds_sound`` is ``None`` when the model - has no sound branch or sound inputs are omitted). + ``(preds_vision, preds_sound, preds_action)`` — lists of per-modality predictions. Optional modalities + return ``None`` when their inputs are omitted. """ has_sound = sound_tokens is not None and sound_sequence_indexes is not None + has_action = action_tokens is not None and action_sequence_indexes is not None # Embed text tokens into the joint hidden_states buffer at their sequence positions. packed_text_embedding = self.embed_tokens(input_ids) @@ -551,6 +636,27 @@ def forward( ) hidden_states[sound_sequence_indexes] = packed_tokens_sound + # Pack + project action latents (when present). Domain ids select the action head weights. + if has_action: + packed_tokens_action, per_token_domain_ids = self._pack_action_latents( + action_tokens, action_token_shapes, action_domain_ids + ) + packed_tokens_action = packed_tokens_action.to(target_dtype) + per_token_domain_ids = per_token_domain_ids.to(device=packed_tokens_action.device) + packed_tokens_action = self.action_proj_in(packed_tokens_action, per_token_domain_ids) + packed_tokens_action = packed_tokens_action + self.action_modality_embed + if action_mse_loss_indexes.numel() > 0: + timesteps_action = action_timesteps * self.config.timestep_scale + packed_timestep_embeds_action = self.time_embedder(self.time_proj(timesteps_action)) + packed_timestep_embeds_action = packed_timestep_embeds_action.to(target_dtype) + packed_tokens_action = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed_tokens_action, + packed_timestep_embeds=packed_timestep_embeds_action, + noisy_frame_indexes=action_noisy_frame_indexes, + token_shapes=action_token_shapes, + ) + hidden_states[action_sequence_indexes] = packed_tokens_action + # Compute rotary embeddings once for the joint sequence, then slice into und/gen halves. _meta_tensor = torch.tensor([], dtype=hidden_states.dtype, device=hidden_states.device) cos, sin = self.rotary_emb( @@ -590,4 +696,18 @@ def forward( preds_sound_packed = self.audio_proj_out(last_hidden_state[sound_mse_loss_indexes]) preds_sound = self._unpack_sound_latents(preds_sound_packed, sound_token_shapes, sound_noisy_frame_indexes) - return preds_vision, preds_sound + preds_action: list[torch.Tensor] | None = None + if has_action: + per_noisy_domain_ids = [ + domain_id.reshape(1).expand(len(noisy_idxs)) + for domain_id, noisy_idxs in zip(action_domain_ids, action_noisy_frame_indexes) + ] + per_noisy_domain_ids = torch.cat(per_noisy_domain_ids, dim=0).to(device=last_hidden_state.device) + preds_action_packed = self.action_proj_out( + last_hidden_state[action_mse_loss_indexes], per_noisy_domain_ids + ) + preds_action = self._unpack_action_latents( + preds_action_packed, action_token_shapes, action_noisy_frame_indexes + ) + + return preds_vision, preds_sound, preds_action diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 7225cce6ac9b..2bf831c7dd6d 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F from PIL import Image from transformers import AutoTokenizer, BatchEncoding @@ -130,6 +131,62 @@ def get_3d_mrope_ids_vae_tokens( _SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." _SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +_ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" +_ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" +_ACTION_MODE_POLICY = "policy" +_ACTION_MODES = {_ACTION_MODE_FORWARD_DYNAMICS, _ACTION_MODE_INVERSE_DYNAMICS, _ACTION_MODE_POLICY} + +_ACTION_RESOLUTION_BINS = { + "256": { + "1.0": (256, 256), + "0.8": (256, 320), + "1.25": (320, 256), + "0.6": (192, 320), + "1.6666666666666667": (320, 192), + }, + "480": { + "1.0": (640, 640), + "0.7391304347826086": (544, 736), + "1.3529411764705883": (736, 544), + "0.5769230769230769": (480, 832), + "1.7333333333333334": (832, 480), + }, + "704": { + "1.0": (960, 960), + "0.7647058823529411": (832, 1088), + "1.3076923076923077": (1088, 832), + "0.55": (704, 1280), + "1.8181818181818181": (1280, 704), + }, + "720": { + "1.0": (960, 960), + "0.7536231884057971": (832, 1104), + "1.3269230769230769": (1104, 832), + "0.5625": (720, 1280), + "1.7777777777777777": (1280, 720), + }, +} + +_EMBODIMENT_TO_DOMAIN_ID = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + @dataclass class Cosmos3OmniPipelineOutput(BaseOutput): @@ -142,10 +199,12 @@ class Cosmos3OmniPipelineOutput(BaseOutput): when ``output_type="latent"``. sound: Decoded audio waveform of shape ``[C, N]``. ``None`` when ``enable_sound=False``. + action: Predicted action tokens. ``None`` unless an action mode predicts actions. """ video: Any sound: torch.Tensor | None = None + action: list[torch.Tensor] | None = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents @@ -308,6 +367,7 @@ def _prepare_vision_segment( vision_fps: float | None, curr: int, device: torch.device | str, + condition_frame_indexes: list[int] | None = None, ) -> dict[str, Any]: """Build the static portion of the vision segment of the joint sequence. @@ -322,12 +382,16 @@ def _prepare_vision_segment( patch_w = math.ceil(latent_w / latent_patch_size) num_vision_tokens = latent_t * patch_h * patch_w - noisy_start = 1 if has_image_condition else 0 - noisy_frame_indexes = torch.arange(noisy_start, latent_t, device=device, dtype=torch.long) + if condition_frame_indexes is None: + condition_frame_indexes = [0] if has_image_condition else [] + cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < latent_t} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long + ) frame_token_stride = patch_h * patch_w mse_loss_indexes: list[int] = [] - for frame_idx in range(noisy_start, latent_t): + for frame_idx in noisy_frame_indexes.tolist(): frame_start = curr + frame_idx * frame_token_stride mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) @@ -352,7 +416,7 @@ def _prepare_vision_segment( # Assembly helpers (consumed inline before the transformer call). "vision_mrope_ids": vision_mrope_ids.to(device), "num_vision_tokens": num_vision_tokens, - "num_noisy_vision_tokens": (latent_t - noisy_start) * frame_token_stride, + "num_noisy_vision_tokens": len(noisy_frame_indexes) * frame_token_stride, } def _prepare_sound_segment( @@ -396,37 +460,163 @@ def _prepare_sound_segment( "sound_len": sound_len, } + def _pack_action_tokens( + self, + input_action_tokens: torch.Tensor, + condition_frame_indexes: list[int], + mrope_offset: int | float, + action_fps: float | None, + curr: int, + device: torch.device | str, + ) -> dict[str, Any]: + """Build the static action segment; per-step tokens/timesteps are spliced in the denoising loop.""" + config = self.transformer.config + action_len = input_action_tokens.shape[0] + cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < action_len} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(action_len) if idx not in cond_frames], device=device, dtype=torch.long + ) + + effective_fps = action_fps if config.enable_fps_modulation else None + action_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=action_len, + grid_h=1, + grid_w=1, + temporal_offset=mrope_offset, + reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, + fps=effective_fps, + base_fps=float(config.base_fps), + temporal_compression_factor=1, + start_frame_offset=1, + ) + + sequence_indexes = torch.arange(curr, curr + action_len, dtype=torch.long, device=device) + return { + "action_token_shapes": [(action_len, 1, 1)], + "action_sequence_indexes": sequence_indexes, + "action_mse_loss_indexes": sequence_indexes[noisy_frame_indexes], + "action_noisy_frame_indexes": [noisy_frame_indexes], + "action_mrope_ids": action_mrope_ids.to(device), + "action_len": action_len, + "num_noisy_action_tokens": len(noisy_frame_indexes), + } + + def _get_action_target_size( + self, + source_height: int, + source_width: int, + requested_height: int, + requested_width: int, + ) -> tuple[int, int]: + resolution_key = str(min(requested_height, requested_width)) + if resolution_key not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Cosmos3 action resolution binning only supports {sorted(_ACTION_RESOLUTION_BINS)}, " + f"got height={requested_height}, width={requested_width}." + ) + return self.video_processor.classify_height_width_bin( + source_height, + source_width, + ratios=_ACTION_RESOLUTION_BINS[resolution_key], + ) + + def _prepare_action_video_conditioning( + self, + video: Any, + height: int, + width: int, + num_frames: int, + device: torch.device | str, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, int, int]: + frames = self.video_processor.preprocess_video(video).to(device=device, dtype=dtype) + source_h, source_w = frames.shape[-2:] + target_h, target_w = self._get_action_target_size(source_h, source_w, height, width) + + if frames.shape[2] < num_frames: + frames = torch.cat([frames, frames[:, :, -1:].expand(-1, -1, num_frames - frames.shape[2], -1, -1)], dim=2) + else: + frames = frames[:, :, :num_frames] + + _, _, _, frame_h, frame_w = frames.shape + scale = min(target_w / frame_w, target_h / frame_h, 1.0) + content_h = max(1, int(scale * frame_h + 0.5)) + content_w = max(1, int(scale * frame_w + 0.5)) + + frames_t = frames.permute(0, 2, 1, 3, 4).reshape(-1, frames.shape[1], frame_h, frame_w) + if content_h != frame_h or content_w != frame_w: + frames_t = F.interpolate( + frames_t, + size=(content_h, content_w), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pad_right = target_w - content_w + pad_bottom = target_h - content_h + if pad_right or pad_bottom: + pad_mode = "replicate" if pad_right >= content_w or pad_bottom >= content_h else "reflect" + frames_t = F.pad(frames_t, (0, pad_right, 0, pad_bottom), mode=pad_mode) + frames = frames_t.reshape(frames.shape[0], num_frames, frames.shape[1], target_h, target_w).permute( + 0, 2, 1, 3, 4 + ) + image_size = torch.tensor([target_h, target_w, content_h, content_w], device=device, dtype=torch.float32) + return frames.to(dtype=dtype), image_size, target_h, target_w + + def _remove_action_video_padding_from_latent( + self, latents: torch.Tensor, image_size: torch.Tensor + ) -> torch.Tensor: + content_h = int(image_size[2].item()) + content_w = int(image_size[3].item()) + content_h_latent = max(content_h // self.vae_scale_factor_spatial, 1) + content_w_latent = max(content_w // self.vae_scale_factor_spatial, 1) + return latents[:, :, :, :content_h_latent, :content_w_latent].contiguous() + + def _remove_action_video_padding_from_video(self, video: torch.Tensor, image_size: torch.Tensor) -> torch.Tensor: + content_h = int(image_size[2].item()) + content_w = int(image_size[3].item()) + return video[:, :, :, :content_h, :content_w].contiguous() + def prepare_latents( self, image: torch.Tensor | None = None, + video: Any | None = None, num_frames: int = 189, height: int = 720, width: int = 1280, fps: float = 24.0, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, + action_latents: torch.Tensor | None = None, generator: torch.Generator | None = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, + action_mode: str | None = None, + action: torch.Tensor | None = None, + action_chunk_size: int | None = None, + domain_name: str | None = None, + raw_action_dim: int | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, + torch.Tensor | None, float, float | None, torch.Tensor, torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int | None, ]: """Build conditioning + initial noise for a single sample. Returns: - ``(vision_latents, sound_latents, fps_vision, fps_sound)``. ``vision_latents`` is the noisy vision tensor; - ``sound_latents`` is the noisy sound tensor (``None`` unless ``enable_sound`` was set). The FPS scalars - feed the per-step :meth:`_prepare_vision_segment` / :meth:`_prepare_sound_segment` calls in the denoising - loop. + Initial noisy tensors plus condition masks/metadata for vision, sound, and optional action modalities. """ is_image = num_frames == 1 - has_image_condition = image is not None and not is_image + has_image_condition = (image is not None and not is_image) or action_mode is not None # video_processor.preprocess handles PIL/np/tensor → [1, 3, H, W] in [-1, 1], resized to (height, width). conditioning_frame_2d: torch.Tensor | None = None @@ -435,8 +625,41 @@ def prepare_latents( device=device, dtype=dtype ) + action_domain_id: torch.Tensor | None = None + action_condition_mask: torch.Tensor | None = None + raw_action_dim_resolved: int | None = int(raw_action_dim) if raw_action_dim is not None else None + action_condition_frames: list[int] = [] + action_condition_frame_indexes: list[int] = [] + action_image_size: torch.Tensor | None = None + vision_condition_frames: list[int] | None = None + # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). - if is_image: + if action_mode is not None: + if action_chunk_size is None: + raise ValueError("action_mode requires action_chunk_size.") + if video is None: + raise ValueError(f"action_mode={action_mode!r} requires loaded video conditioning.") + target_frames = action_chunk_size + 1 + if num_frames != target_frames: + raise ValueError( + "Action runs require num_frames to equal action_chunk_size + 1; " + f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." + ) + vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( + video, height, width, target_frames, device=device, dtype=dtype + ) + if action_mode == _ACTION_MODE_FORWARD_DYNAMICS: + vision_condition_frames = [0] + action_condition_frames = list(range(action_chunk_size)) + elif action_mode == _ACTION_MODE_POLICY: + vision_condition_frames = [0] + elif action_mode == _ACTION_MODE_INVERSE_DYNAMICS: + latent_frames = (target_frames - 1) // self.vae.config.scale_factor_temporal + 1 + vision_condition_frames = list(range(latent_frames)) + else: + raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") + action_condition_frame_indexes = action_condition_frames + elif is_image: vision_tensor = ( conditioning_frame_2d.unsqueeze(2) # [1, 3, 1, H, W] if conditioning_frame_2d is not None @@ -451,6 +674,8 @@ def prepare_latents( vision_tensor[:, :, 1:] = conditioning_frame_2d.unsqueeze(2).expand(-1, -1, num_frames - 1, -1, -1) x0_tokens_vision = self._encode_video(vision_tensor).contiguous().float() + if action_image_size is not None: + x0_tokens_vision = self._remove_action_video_padding_from_latent(x0_tokens_vision, action_image_size) vision_shape = tuple(x0_tokens_vision.shape) x0_tokens_sound: torch.Tensor | None = None @@ -463,9 +688,60 @@ def prepare_latents( T_sound = (n_audio_samples + hop_size - 1) // hop_size x0_tokens_sound = torch.zeros(sound_dim, T_sound, device=device, dtype=dtype) + x0_tokens_action: torch.Tensor | None = None + if action_mode is not None: + assert action_chunk_size is not None + action_dim = self.transformer.action_dim + if action_mode == "forward_dynamics": + if action is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + action = action.to(device=device, dtype=dtype) + if action.shape[0] == 0: + raise ValueError("action_mode='forward_dynamics' requires at least one action token.") + + # Action chunks describe transitions, so action length must match action_chunk_size + # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. + if action.shape[0] < action_chunk_size: + action = torch.cat( + [action, action[-1:].expand(action_chunk_size - action.shape[0], -1)], + dim=0, + ) + action = action[:action_chunk_size] + + # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. + if action.shape[-1] > action_dim: + raise ValueError( + f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." + ) + if action.shape[-1] < action_dim: + action_padding = torch.zeros( + action.shape[0], + action_dim - action.shape[-1], + dtype=action.dtype, + device=action.device, + ) + action = torch.cat([action, action_padding], dim=-1) + x0_tokens_action = action + else: + x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) + if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + action_domain_id = torch.tensor( + [_EMBODIMENT_TO_DOMAIN_ID[domain_name]], + dtype=torch.long, + device=device, + ) + # Vision conditioning mask [latent_t, 1, 1]: frame 0 anchored when image-conditioning, rest noisy. vision_condition_mask = torch.zeros((x0_tokens_vision.shape[2], 1, 1), device=device, dtype=dtype) - if has_image_condition: + if vision_condition_frames is not None: + for frame_idx in vision_condition_frames: + if 0 <= frame_idx < vision_condition_mask.shape[0]: + vision_condition_mask[frame_idx, 0, 0] = 1.0 + elif has_image_condition: vision_condition_mask[0, 0, 0] = 1.0 if latents is None: @@ -491,17 +767,55 @@ def prepare_latents( else: sound_latents = sound_latents.to(device=device, dtype=dtype) - return latents, sound_latents, fps, fps_sound, vision_condition_mask, sound_condition_mask + if action_mode is not None and x0_tokens_action is not None: + action_condition_mask = torch.zeros((x0_tokens_action.shape[0], 1), device=device, dtype=dtype) + for frame_idx in action_condition_frames: + if 0 <= frame_idx < action_condition_mask.shape[0]: + action_condition_mask[frame_idx, 0] = 1.0 + if action_latents is None: + pure_noise_action = randn_tensor( + tuple(x0_tokens_action.shape), generator=generator, device=device, dtype=dtype + ) + action_latents = ( + action_condition_mask * x0_tokens_action + (1.0 - action_condition_mask) * pure_noise_action + ) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + else: + action_latents = action_latents.to(device=device, dtype=dtype) + + return ( + latents, + sound_latents, + action_latents, + fps, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) def check_inputs( self, prompt, negative_prompt, + image, + video, height: int, width: int, num_frames: int, + guidance_scale: float, enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], + action_mode: str | None, + action: torch.Tensor | None, + action_chunk_size: int | None, + domain_name: str | None, + raw_action_dim: int | None, ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) @@ -526,6 +840,31 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) + if action_mode is not None: + if not getattr(self.transformer.config, "action_gen", False): + raise ValueError("action_mode requires a transformer trained with action_gen=True.") + if image is not None: + raise ValueError("Use `video`, not `image`, for Cosmos3 action conditioning.") + if video is None: + raise ValueError(f"action_mode={action_mode!r} requires a loaded conditioning video.") + if action_chunk_size is None: + raise ValueError("action_mode requires action_chunk_size.") + if num_frames != action_chunk_size + 1: + raise ValueError( + "Action runs require num_frames to equal action_chunk_size + 1; " + f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." + ) + if domain_name is None: + raise ValueError("action_mode requires domain_name.") + if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + if action_mode == "forward_dynamics" and action is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: + raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") def tokenize_prompt( self, @@ -538,6 +877,7 @@ def tokenize_prompt( use_system_prompt: bool = True, add_resolution_template: bool = True, add_duration_template: bool = True, + action_mode: str | None = None, ) -> tuple[list[int], list[int]]: """Apply prompt-augmentation templates and tokenize cond/uncond prompts via the Qwen2 chat template. @@ -606,7 +946,10 @@ def _mask_velocity_predictions( preds_sound: list[torch.Tensor] | None, vision_condition_mask: list[torch.Tensor], sound_condition_mask: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + preds_action: list[torch.Tensor] | None = None, + action_condition_mask: list[torch.Tensor] | None = None, + raw_action_dim: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """Zero out conditioning positions in the transformer's velocity predictions. ``preds_vision`` / ``preds_sound`` are returned per-sample by the transformer; the pipeline runs batch=1, so we @@ -625,7 +968,16 @@ def _mask_velocity_predictions( noisy_mask_s = (1.0 - cond_mask_s).T.to(dtype=pred_s.dtype, device=pred_s.device) velocity_sound = pred_s * noisy_mask_s if noisy_mask_s.sum() > 0 else torch.zeros_like(pred_s) - return velocity_vision, velocity_sound + velocity_action: torch.Tensor | None = None + if preds_action is not None and action_condition_mask is not None: + pred_a = preds_action[0] + cond_mask_a = action_condition_mask[0] + noisy_mask_a = (1.0 - cond_mask_a).to(dtype=pred_a.dtype, device=pred_a.device) + velocity_action = pred_a * noisy_mask_a if noisy_mask_a.sum() > 0 else torch.zeros_like(pred_a) + if raw_action_dim is not None: + velocity_action[:, raw_action_dim:] = 0 + + return velocity_vision, velocity_sound, velocity_action def _apply_video_safety_check(self, video: Any, output_type: str, device: torch.device) -> Any: """Run the Cosmos video guardrail on a postprocessed video and return it in the same format. @@ -676,16 +1028,24 @@ def __call__( prompt: str | list[str], negative_prompt: str | list[str] | None = None, image: torch.Tensor | None = None, + video: Any | None = None, num_frames: int = 189, height: int = 720, width: int = 1280, fps: float = 24.0, num_inference_steps: int = 35, guidance_scale: float = 6.0, + flow_shift: float | None = None, enable_sound: bool = False, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, + action_latents: torch.Tensor | None = None, + action_mode: str | None = None, + action: torch.Tensor | None = None, + action_chunk_size: int | None = None, + domain_name: str | None = None, + raw_action_dim: int | None = None, output_type: str = "pil", return_dict: bool = True, use_system_prompt: bool = True, @@ -770,9 +1130,28 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if action_mode is not None and action_mode not in _ACTION_MODES: + raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") + if action_mode is not None and action_chunk_size is not None: + num_frames = action_chunk_size + 1 + # 1. Check inputs self.check_inputs( - prompt, negative_prompt, height, width, num_frames, enable_sound, callback_on_step_end_tensor_inputs + prompt, + negative_prompt, + image, + video, + height, + width, + num_frames, + guidance_scale, + enable_sound, + callback_on_step_end_tensor_inputs, + action_mode, + action, + action_chunk_size, + domain_name, + raw_action_dim, ) self._current_timestep = None @@ -809,6 +1188,7 @@ def __call__( use_system_prompt=use_system_prompt, add_resolution_template=add_resolution_template, add_duration_template=add_duration_template, + action_mode=action_mode, ) # 3. Pre-pack the text segment for each prompt — text packing is invariant @@ -817,22 +1197,42 @@ def __call__( uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) # 4. Prepare latents (initial noise per modality + pack metadata) - has_image_condition = image is not None and num_frames > 1 - latents, sound_latents, fps_vision, fps_sound, vision_condition_mask, sound_condition_mask = ( - self.prepare_latents( - image=image, - num_frames=num_frames, - height=height, - width=width, - fps=fps, - latents=latents, - sound_latents=sound_latents, - generator=generator, - device=device, - dtype=dtype, - enable_sound=enable_sound, - ) + ( + latents, + sound_latents, + action_latents, + fps_vision, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) = self.prepare_latents( + image=image, + video=video, + num_frames=num_frames, + height=height, + width=width, + fps=fps, + latents=latents, + sound_latents=sound_latents, + action_latents=action_latents, + generator=generator, + device=device, + dtype=dtype, + enable_sound=enable_sound, + action_mode=action_mode, + action=action, + action_chunk_size=action_chunk_size, + domain_name=domain_name, + raw_action_dim=raw_action_dim, ) + vision_condition_indexes_for_pack = torch.nonzero(vision_condition_mask[:, 0, 0] > 0, as_tuple=False).flatten() + vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes_for_pack] + has_image_condition = bool(vision_condition_indexes_for_pack) # 5. Pre-pack the static per-prompt vision / sound sequence segments. The only # fields that vary across denoising steps are the modality token tensors and the @@ -846,6 +1246,7 @@ def __call__( vision_fps=fps_vision, curr=cond_text_segment["und_len"], device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) cond_sound_segment: dict[str, Any] = {} if sound_latents is not None: @@ -856,17 +1257,33 @@ def __call__( curr=cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"], device=device, ) + cond_action_segment: dict[str, Any] = {} + if action_latents is not None: + cond_action_segment = self._pack_action_tokens( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=cond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0), + device=device, + ) cond_mrope_segments = [cond_text_segment["text_mrope_ids"], cond_vision_segment["vision_mrope_ids"]] if cond_sound_segment: cond_mrope_segments.append(cond_sound_segment["sound_mrope_ids"]) + if cond_action_segment: + cond_mrope_segments.append(cond_action_segment["action_mrope_ids"]) cond_packed_static = { **cond_text_segment, **cond_vision_segment, **cond_sound_segment, + **cond_action_segment, "position_ids": torch.cat(cond_mrope_segments, dim=1), "sequence_length": cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"] - + cond_sound_segment.get("sound_len", 0), + + cond_sound_segment.get("sound_len", 0) + + cond_action_segment.get("action_len", 0), } uncond_vision_segment = self._prepare_vision_segment( @@ -876,6 +1293,7 @@ def __call__( vision_fps=fps_vision, curr=uncond_text_segment["und_len"], device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) uncond_sound_segment: dict[str, Any] = {} if sound_latents is not None: @@ -886,29 +1304,58 @@ def __call__( curr=uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"], device=device, ) + uncond_action_segment: dict[str, Any] = {} + if action_latents is not None: + uncond_action_segment = self._pack_action_tokens( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=uncond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0), + device=device, + ) uncond_mrope_segments = [uncond_text_segment["text_mrope_ids"], uncond_vision_segment["vision_mrope_ids"]] if uncond_sound_segment: uncond_mrope_segments.append(uncond_sound_segment["sound_mrope_ids"]) + if uncond_action_segment: + uncond_mrope_segments.append(uncond_action_segment["action_mrope_ids"]) uncond_packed_static = { **uncond_text_segment, **uncond_vision_segment, **uncond_sound_segment, + **uncond_action_segment, "position_ids": torch.cat(uncond_mrope_segments, dim=1), "sequence_length": uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"] - + uncond_sound_segment.get("sound_len", 0), + + uncond_sound_segment.get("sound_len", 0) + + uncond_action_segment.get("action_len", 0), } num_noisy_vision_tokens = cond_vision_segment["num_noisy_vision_tokens"] sound_len = cond_sound_segment.get("sound_len") + action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, - # model_outputs history) on the instance, so audio gets its own copy. - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - sound_scheduler = copy.deepcopy(self.scheduler) if sound_latents is not None else None + # model_outputs history) on the instance, so audio/action gets its own copy. + inference_scheduler = copy.deepcopy(self.scheduler) + if flow_shift is not None: + inference_scheduler.register_to_config( + use_flow_sigmas=True, + use_karras_sigmas=False, + use_exponential_sigmas=False, + use_beta_sigmas=False, + flow_shift=flow_shift, + shift_terminal=None, + final_sigmas_type="zero", + ) + inference_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = inference_scheduler.timesteps + sound_scheduler = copy.deepcopy(inference_scheduler) if sound_latents is not None else None + action_scheduler = copy.deepcopy(inference_scheduler) if action_latents is not None else None # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * inference_scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -922,15 +1369,19 @@ def __call__( # noisy tokens before packing so the modality tokens enter the model in the right dtype. vision_tokens = latents.to(device=device, dtype=dtype) sound_tokens = sound_latents.to(device=device, dtype=dtype) if sound_latents is not None else None + action_tokens = action_latents.to(device=device, dtype=dtype) if action_latents is not None else None # The static packs both report the same num_noisy_vision_tokens / sound_len, so a # single per-step timestep tensor per modality is shared by the cond / uncond passes. vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) sound_timesteps = ( torch.full((sound_len,), timestep, device=device) if sound_tokens is not None else None ) + action_timesteps = ( + torch.full((action_noisy_len,), timestep, device=device) if action_tokens is not None else None + ) # --- Conditional pass --- - preds_vision, preds_sound = self.transformer( + preds_vision, preds_sound, preds_action = self.transformer( input_ids=cond_packed_static["input_ids"], text_indexes=cond_packed_static["text_indexes"], position_ids=cond_packed_static["position_ids"], @@ -948,17 +1399,28 @@ def __call__( sound_mse_loss_indexes=cond_packed_static.get("sound_mse_loss_indexes"), sound_timesteps=sound_timesteps, sound_noisy_frame_indexes=cond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=cond_packed_static.get("action_token_shapes"), + action_sequence_indexes=cond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=cond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=cond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, ) - cond_v_vision, cond_v_sound = self._mask_velocity_predictions( + cond_v_vision, cond_v_sound, cond_v_action = self._mask_velocity_predictions( preds_vision, preds_sound, vision_condition_mask=[vision_condition_mask], sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, + raw_action_dim=raw_action_dim_resolved, ) # --- Unconditional pass (Skip if not using CFG) --- + uncond_v_vision = uncond_v_sound = uncond_v_action = None if guidance_scale != 1.0: - preds_vision, preds_sound = self.transformer( + preds_vision, preds_sound, preds_action = self.transformer( input_ids=uncond_packed_static["input_ids"], text_indexes=uncond_packed_static["text_indexes"], position_ids=uncond_packed_static["position_ids"], @@ -976,12 +1438,22 @@ def __call__( sound_mse_loss_indexes=uncond_packed_static.get("sound_mse_loss_indexes"), sound_timesteps=sound_timesteps, sound_noisy_frame_indexes=uncond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=uncond_packed_static.get("action_token_shapes"), + action_sequence_indexes=uncond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=uncond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=uncond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, ) - uncond_v_vision, uncond_v_sound = self._mask_velocity_predictions( + uncond_v_vision, uncond_v_sound, uncond_v_action = self._mask_velocity_predictions( preds_vision, preds_sound, vision_condition_mask=[vision_condition_mask], sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, + raw_action_dim=raw_action_dim_resolved, ) # --- CFG combine + per-modality scheduler step --- @@ -994,7 +1466,7 @@ def __call__( else: velocity_vision = cond_v_vision - latents = self.scheduler.step( + latents = inference_scheduler.step( velocity_vision.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False )[0].squeeze(0) @@ -1008,18 +1480,40 @@ def __call__( velocity_sound.unsqueeze(0), t, sound_latents.unsqueeze(0), return_dict=False )[0].squeeze(0) + has_noisy_action = ( + action_condition_mask is not None and action_condition_mask.sum() < action_condition_mask.numel() + ) + if action_scheduler is not None and has_noisy_action and cond_v_action is not None: + if guidance_scale != 1.0: + velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) + else: + velocity_action = cond_v_action + action_latents = action_scheduler.step( + velocity_action.unsqueeze(0), t, action_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + if callback_on_step_end is not None: callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0 + ): progress_bar.update() self._current_timestep = None # 8. Postprocess + decode sound = self.decode_sound(sound_latents) if sound_latents is not None else None + action_output = None + if action_mode in {"inverse_dynamics", "policy"} and action_latents is not None: + action_output = action_latents + if raw_action_dim_resolved is not None: + action_output = action_output[:, :raw_action_dim_resolved] + action_output = [action_output.detach().cpu()] if output_type == "latent": video = latents else: @@ -1037,5 +1531,7 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: + if action_mode is not None: + return (video, sound, action_output) return (video, sound) - return Cosmos3OmniPipelineOutput(video=video, sound=sound) + return Cosmos3OmniPipelineOutput(video=video, sound=sound, action=action_output) From 2fcef5b47c01ef7569691f037604dd8b1b3966b9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 11:40:25 +0000 Subject: [PATCH 02/27] Add README action examples --- examples/cosmos3/README.md | 103 +++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/examples/cosmos3/README.md b/examples/cosmos3/README.md index 7a4cb277aa07..98cf30eac6d9 100644 --- a/examples/cosmos3/README.md +++ b/examples/cosmos3/README.md @@ -48,6 +48,104 @@ python examples/cosmos3/inference_cosmos3.py \ --enable-sound ``` +Action forward dynamics, robot domain (predict video from an observation video and a provided action chunk): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode forward_dynamics \ + --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.json" \ + --action-chunk-size 16 \ + --domain-name bridge_orig_lerobot \ + --height 480 --width 832 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_forward_dynamics_robot +``` + +Action forward dynamics, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode forward_dynamics \ + --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_action_25.json" \ + --action-chunk-size 60 \ + --domain-name av \ + --height 480 --width 832 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_forward_dynamics_av +``` + +Action inverse dynamics, robot domain (predict actions from an observed video): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode inverse_dynamics \ + --action-chunk-size 16 \ + --raw-action-dim 10 \ + --domain-name bridge_orig_lerobot \ + --height 480 --width 832 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_inverse_dynamics_robot +``` + +Action inverse dynamics, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode inverse_dynamics \ + --action-chunk-size 60 \ + --raw-action-dim 9 \ + --domain-name av \ + --height 480 --width 832 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_inverse_dynamics_av +``` + +Action policy, robot domain (predict both future video and actions from the first observation frame): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode policy \ + --action-chunk-size 16 \ + --raw-action-dim 10 \ + --domain-name bridge_orig_lerobot \ + --height 480 --width 832 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_policy_robot +``` + +Action policy, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. Please go backward. This video is captured from a first-person perspective looking at the scene." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode policy \ + --action-chunk-size 60 \ + --raw-action-dim 9 \ + --domain-name av \ + --height 480 --width 832 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ + --output results/cosmos3_policy_av +``` + +Action modes use `action_chunk_size + 1` video frames. `forward_dynamics` consumes `--action-path`; `inverse_dynamics` and `policy` write predicted actions to `sample-*_action.json` in model-normalized action space. The upstream camera-pose forward-dynamics sample uses a still image (`mountain_720.png`), while this wrapper currently expects `--vision-path` to load as video for action modes. + ### Useful flags | Flag | Default | Description | @@ -58,6 +156,11 @@ python examples/cosmos3/inference_cosmos3.py \ | `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). | | `--fps` | `24.0` | Frame rate of the generated video. | | `--enable-sound` | off | Generate a synchronized audio track. | +| `--action-mode` | `None` | Enable action conditioning/generation. One of `forward_dynamics`, `inverse_dynamics`, or `policy`. | +| `--action-path` | `None` | URL or local JSON action path for `forward_dynamics`. | +| `--action-chunk-size` | `None` | Number of action tokens. Action runs generate/use `action_chunk_size + 1` video frames. | +| `--domain-name` | `None` | Action embodiment domain, for example `bridge_orig_lerobot` or `av`. | +| `--raw-action-dim` | `None` | Slice predicted action output to the unpadded action dimension. Required for `inverse_dynamics` and `policy`. | | `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1`. | | `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. | | `--output` | `.` | Directory to write `sample.jpg` or `sample.mp4`. | From 40ea9732ce41b52f27b9a0f49ea04131c9f1d682 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 11:59:01 +0000 Subject: [PATCH 03/27] Use do_classifier_free_guidance property --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 2bf831c7dd6d..460b0786e4a1 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1022,6 +1022,10 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def do_classifier_free_guidance(self): + return self._guidance_scale != 1.0 + @torch.no_grad() def __call__( self, @@ -1156,6 +1160,7 @@ def __call__( self._current_timestep = None self._interrupt = False + self._guidance_scale = guidance_scale # Pipeline supports a single sample at a time; collapse list-style inputs to a single string. if isinstance(prompt, list): @@ -1419,7 +1424,7 @@ def __call__( # --- Unconditional pass (Skip if not using CFG) --- uncond_v_vision = uncond_v_sound = uncond_v_action = None - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: preds_vision, preds_sound, preds_action = self.transformer( input_ids=uncond_packed_static["input_ids"], text_indexes=uncond_packed_static["text_indexes"], @@ -1461,7 +1466,7 @@ def __call__( # to carry a batch dim; per-modality latents have no batch axis, so wrap for the step. # Skip CFG for 1.0 guidance scale - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_vision = uncond_v_vision + guidance_scale * (cond_v_vision - uncond_v_vision) else: velocity_vision = cond_v_vision @@ -1472,7 +1477,7 @@ def __call__( if sound_scheduler is not None and cond_v_sound is not None: # Skip CFG for 1.0 guidance scale - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_sound = uncond_v_sound + guidance_scale * (cond_v_sound - uncond_v_sound) else: velocity_sound = cond_v_sound @@ -1484,7 +1489,7 @@ def __call__( action_condition_mask is not None and action_condition_mask.sum() < action_condition_mask.numel() ) if action_scheduler is not None and has_noisy_action and cond_v_action is not None: - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) else: velocity_action = cond_v_action From 591cd4d062597331b982686ec446029a20f81c20 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 12:03:12 +0000 Subject: [PATCH 04/27] Remove unused method --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 460b0786e4a1..bf9edba8a265 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -572,11 +572,6 @@ def _remove_action_video_padding_from_latent( content_w_latent = max(content_w // self.vae_scale_factor_spatial, 1) return latents[:, :, :, :content_h_latent, :content_w_latent].contiguous() - def _remove_action_video_padding_from_video(self, video: torch.Tensor, image_size: torch.Tensor) -> torch.Tensor: - content_h = int(image_size[2].item()) - content_w = int(image_size[3].item()) - return video[:, :, :, :content_h, :content_w].contiguous() - def prepare_latents( self, image: torch.Tensor | None = None, From 04efd90e0f9b82ba92291907a9f8dc7ddd4d62ea Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Thu, 28 May 2026 16:24:14 +0000 Subject: [PATCH 05/27] Add action policy example to pipelines doc --- docs/source/en/api/pipelines/cosmos3.md | 48 +++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index ce26ee0c36ef..291c1dab75e4 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -459,6 +459,54 @@ encode_video( +## Action policy + +Action policy generation predicts future video and action tokens from the first observation frame, text prompt, and action domain metadata. The example below uses the Bridge robot domain and writes the predicted action chunk to JSON in model-normalized action space. + +```python +import json + +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.utils import export_to_video, load_video + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" +) + +prompt = ( + "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking " + "at the scene." +) +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" +) + +result = pipe( + prompt=prompt, + video=video, + num_frames=17, + height=480, + width=832, + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + flow_shift=5.0, + action_mode="policy", + action_chunk_size=16, + raw_action_dim=10, + domain_name="bridge_orig_lerobot", + use_system_prompt=False, +) + +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "sample.mp4", fps=5, macro_block_size=1) + +if result.action is not None: + with open("sample_action.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + ## Metadata templates `tokenize_prompt` appends short metadata sentences inside the user message so the LLM sees the conditioning the model was trained with. The positive prompt gets sentences like *"The video is 7.9 seconds long and is of 24 FPS."* and *"This video is of 720x1280 resolution."*; the negative prompt gets the inverse (*"… is not …"*). From 362b6ebcb5251e53826339474fd37f1156bcd320 Mon Sep 17 00:00:00 2001 From: Atharva Joshi Date: Thu, 28 May 2026 11:45:46 -0700 Subject: [PATCH 06/27] Adding model selection for action example doc. --- docs/source/en/api/pipelines/cosmos3.md | 53 +++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 291c1dab75e4..e58c7a95f796 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -463,6 +463,9 @@ encode_video( Action policy generation predicts future video and action tokens from the first observation frame, text prompt, and action domain metadata. The example below uses the Bridge robot domain and writes the predicted action chunk to JSON in model-normalized action space. + + + ```python import json @@ -507,6 +510,56 @@ if result.action is not None: json.dump(result.action[0].tolist(), f) ``` + + + +```python +import json + +import torch +from diffusers import Cosmos3OmniPipeline +from diffusers.utils import export_to_video, load_video + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" +) + +prompt = ( + "Put the pot to the left of the purple item. This video is captured from a first-person perspective looking " + "at the scene." +) +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" +) + +result = pipe( + prompt=prompt, + video=video, + num_frames=17, + height=480, + width=832, + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + flow_shift=5.0, + action_mode="policy", + action_chunk_size=16, + raw_action_dim=10, + domain_name="bridge_orig_lerobot", + use_system_prompt=False, +) + +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "sample.mp4", fps=5, macro_block_size=1) + +if result.action is not None: + with open("sample_action.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + + + + ## Metadata templates `tokenize_prompt` appends short metadata sentences inside the user message so the LLM sees the conditioning the model was trained with. The positive prompt gets sentences like *"The video is 7.9 seconds long and is of 24 FPS."* and *"This video is of 720x1280 resolution."*; the negative prompt gets the inverse (*"… is not …"*). From 5ff2ea92b495c3320b97516134bf4394d18aff22 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:29:00 +0000 Subject: [PATCH 07/27] Remove redundant casts --- .../models/transformers/transformer_cosmos3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 54fbe066ac33..29cfc127d253 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -151,9 +151,9 @@ class DomainAwareLinear(nn.Module): def __init__(self, input_size: int, output_size: int, num_domains: int) -> None: super().__init__() - self.input_size = int(input_size) - self.output_size = int(output_size) - self.num_domains = int(num_domains) + self.input_size = input_size + self.output_size = output_size + self.num_domains = num_domains self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size) self.bias = nn.Embedding(self.num_domains, self.output_size) nn.init.xavier_uniform_(self.fc.weight) @@ -370,8 +370,8 @@ def __init__( self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) self.action_gen = action_gen - self.action_dim = int(32 if action_dim is None else action_dim) - self.num_embodiment_domains = int(num_embodiment_domains) + self.action_dim = 32 if action_dim is None else action_dim + self.num_embodiment_domains = num_embodiment_domains if action_gen: self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) self.action_proj_out = DomainAwareLinear(hidden_size, self.action_dim, self.num_embodiment_domains) From a01c1c908fe9fa9d2094cbd7747aef6670550c92 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:30:53 +0000 Subject: [PATCH 08/27] Rename _pack_action_tokens to _prepare_action_segment --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index bf9edba8a265..6a8771a2c197 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -460,7 +460,7 @@ def _prepare_sound_segment( "sound_len": sound_len, } - def _pack_action_tokens( + def _prepare_action_segment( self, input_action_tokens: torch.Tensor, condition_frame_indexes: list[int], @@ -1259,7 +1259,7 @@ def __call__( ) cond_action_segment: dict[str, Any] = {} if action_latents is not None: - cond_action_segment = self._pack_action_tokens( + cond_action_segment = self._prepare_action_segment( input_action_tokens=action_latents, condition_frame_indexes=action_condition_frame_indexes, mrope_offset=cond_text_segment["vision_start_temporal_offset"], @@ -1306,7 +1306,7 @@ def __call__( ) uncond_action_segment: dict[str, Any] = {} if action_latents is not None: - uncond_action_segment = self._pack_action_tokens( + uncond_action_segment = self._prepare_action_segment( input_action_tokens=action_latents, condition_frame_indexes=action_condition_frame_indexes, mrope_offset=uncond_text_segment["vision_start_temporal_offset"], From c12e6b162b8542d7a4ef3f84822798a8f3c22ef0 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:35:39 +0000 Subject: [PATCH 09/27] Move validation checks to check_inputs --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 6a8771a2c197..f556a6810598 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -630,16 +630,8 @@ def prepare_latents( # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). if action_mode is not None: - if action_chunk_size is None: - raise ValueError("action_mode requires action_chunk_size.") - if video is None: - raise ValueError(f"action_mode={action_mode!r} requires loaded video conditioning.") + assert action_chunk_size is not None target_frames = action_chunk_size + 1 - if num_frames != target_frames: - raise ValueError( - "Action runs require num_frames to equal action_chunk_size + 1; " - f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." - ) vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( video, height, width, target_frames, device=device, dtype=dtype ) @@ -691,8 +683,6 @@ def prepare_latents( if action is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") action = action.to(device=device, dtype=dtype) - if action.shape[0] == 0: - raise ValueError("action_mode='forward_dynamics' requires at least one action token.") # Action chunks describe transitions, so action length must match action_chunk_size # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. @@ -704,10 +694,6 @@ def prepare_latents( action = action[:action_chunk_size] # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. - if action.shape[-1] > action_dim: - raise ValueError( - f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." - ) if action.shape[-1] < action_dim: action_padding = torch.zeros( action.shape[0], @@ -856,8 +842,16 @@ def check_inputs( f"Unknown Cosmos3 action domain_name={domain_name!r}; " f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) - if action_mode == "forward_dynamics" and action is None: - raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + if action_mode == "forward_dynamics": + if action is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + if action.shape[0] == 0: + raise ValueError("action_mode='forward_dynamics' requires at least one action token.") + action_dim = self.transformer.action_dim + if action.shape[-1] > action_dim: + raise ValueError( + f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." + ) if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") From fbcd0777f4118c8fec4f2be306c936334ab862f1 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:37:52 +0000 Subject: [PATCH 10/27] Add action arguments in the __call__ docstring --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index f556a6810598..b0b6ca8485c1 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1088,6 +1088,28 @@ def __call__( sound_latents (`torch.Tensor`, *optional*): Pre-generated sound latents to start denoising from. Only consulted when `enable_sound=True`; when `None`, fresh Gaussian noise is sampled. + action_latents (`torch.Tensor`, *optional*): + Pre-generated action latents to start the action stream's denoising from. Only consulted when an action + run is configured via `action_mode`; when `None`, fresh Gaussian noise is sampled for the action tokens. + action_mode (`str`, *optional*): + Selects the action-conditioned generation task and requires a transformer trained with + `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a + given `action` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), + or `"policy"` (jointly roll out future video and actions from the first frame). When set, conditioning + must be supplied via `video` (not `image`) and `num_frames` is forced to `action_chunk_size + 1`. + action (`torch.Tensor`, *optional*): + Raw action tokens of shape `[T, action_dim]` driving `action_mode="forward_dynamics"`. Sequences shorter + than `action_chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's + `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. + action_chunk_size (`int`, *optional*): + Number of action transition steps in the chunk. Required for every `action_mode`; the paired video has + `action_chunk_size + 1` frames and `num_frames` is overwritten accordingly. + domain_name (`str`, *optional*): + Embodiment domain that selects the domain-aware action projection weights. Required for action runs and + must be one of the registered Cosmos 3 embodiment domains. + raw_action_dim (`int`, *optional*): + Number of meaningful (unpadded) action channels to keep when slicing predicted actions. Required for + `action_mode="inverse_dynamics"` and `action_mode="policy"`. output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). From 7c4e2f488dac6de9a0aa1df98c160042c8193c13 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:40:18 +0000 Subject: [PATCH 11/27] Move action mode check to check_inputs --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index b0b6ca8485c1..09b5c52fee44 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -822,6 +822,8 @@ def check_inputs( f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if action_mode is not None: + if action_mode not in _ACTION_MODES: + raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") if not getattr(self.transformer.config, "action_gen", False): raise ValueError("action_mode requires a transformer trained with action_gen=True.") if image is not None: @@ -1145,8 +1147,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - if action_mode is not None and action_mode not in _ACTION_MODES: - raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") if action_mode is not None and action_chunk_size is not None: num_frames = action_chunk_size + 1 From 57d3d07d43d13ffbdea926db6ac598ddc0cb9bb5 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:42:45 +0000 Subject: [PATCH 12/27] Rename action to action_tokens --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 09b5c52fee44..1f31b378e5f8 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -588,7 +588,7 @@ def prepare_latents( dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, action_mode: str | None = None, - action: torch.Tensor | None = None, + action_tokens: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -680,29 +680,29 @@ def prepare_latents( assert action_chunk_size is not None action_dim = self.transformer.action_dim if action_mode == "forward_dynamics": - if action is None: + if action_tokens is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - action = action.to(device=device, dtype=dtype) + action_tokens = action_tokens.to(device=device, dtype=dtype) # Action chunks describe transitions, so action length must match action_chunk_size # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. - if action.shape[0] < action_chunk_size: - action = torch.cat( - [action, action[-1:].expand(action_chunk_size - action.shape[0], -1)], + if action_tokens.shape[0] < action_chunk_size: + action_tokens = torch.cat( + [action_tokens, action_tokens[-1:].expand(action_chunk_size - action_tokens.shape[0], -1)], dim=0, ) - action = action[:action_chunk_size] + action_tokens = action_tokens[:action_chunk_size] # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. - if action.shape[-1] < action_dim: + if action_tokens.shape[-1] < action_dim: action_padding = torch.zeros( - action.shape[0], - action_dim - action.shape[-1], - dtype=action.dtype, - device=action.device, + action_tokens.shape[0], + action_dim - action_tokens.shape[-1], + dtype=action_tokens.dtype, + device=action_tokens.device, ) - action = torch.cat([action, action_padding], dim=-1) - x0_tokens_action = action + action_tokens = torch.cat([action_tokens, action_padding], dim=-1) + x0_tokens_action = action_tokens else: x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: @@ -793,7 +793,7 @@ def check_inputs( enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], action_mode: str | None, - action: torch.Tensor | None, + action_tokens: torch.Tensor | None, action_chunk_size: int | None, domain_name: str | None, raw_action_dim: int | None, @@ -845,14 +845,14 @@ def check_inputs( f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) if action_mode == "forward_dynamics": - if action is None: + if action_tokens is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - if action.shape[0] == 0: + if action_tokens.shape[0] == 0: raise ValueError("action_mode='forward_dynamics' requires at least one action token.") action_dim = self.transformer.action_dim - if action.shape[-1] > action_dim: + if action_tokens.shape[-1] > action_dim: raise ValueError( - f"Cosmos3 action dimension {action.shape[-1]} exceeds model action_dim={action_dim}." + f"Cosmos3 action dimension {action_tokens.shape[-1]} exceeds model action_dim={action_dim}." ) if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") @@ -1037,7 +1037,7 @@ def __call__( sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, action_mode: str | None = None, - action: torch.Tensor | None = None, + action_tokens: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -1096,10 +1096,11 @@ def __call__( action_mode (`str`, *optional*): Selects the action-conditioned generation task and requires a transformer trained with `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a - given `action` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), - or `"policy"` (jointly roll out future video and actions from the first frame). When set, conditioning - must be supplied via `video` (not `image`) and `num_frames` is forced to `action_chunk_size + 1`. - action (`torch.Tensor`, *optional*): + given `action_tokens` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning + frames), or `"policy"` (jointly roll out future video and actions from the first frame). When set, + conditioning must be supplied via `video` (not `image`) and `num_frames` is forced to + `action_chunk_size + 1`. + action_tokens (`torch.Tensor`, *optional*): Raw action tokens of shape `[T, action_dim]` driving `action_mode="forward_dynamics"`. Sequences shorter than `action_chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. @@ -1163,7 +1164,7 @@ def __call__( enable_sound, callback_on_step_end_tensor_inputs, action_mode, - action, + action_tokens, action_chunk_size, domain_name, raw_action_dim, @@ -1241,7 +1242,7 @@ def __call__( dtype=dtype, enable_sound=enable_sound, action_mode=action_mode, - action=action, + action_tokens=action_tokens, action_chunk_size=action_chunk_size, domain_name=domain_name, raw_action_dim=raw_action_dim, From a6e204011fa1cd5e13871244611d7ccdf9407c6b Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 12:55:34 +0000 Subject: [PATCH 13/27] Add warning for num_frames ovewrite attempt --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 1f31b378e5f8..158e68642272 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -30,12 +30,15 @@ Cosmos3OmniTransformer, ) from ...schedulers import UniPCMultistepScheduler -from ...utils import BaseOutput, is_cosmos_guardrail_available +from ...utils import BaseOutput, is_cosmos_guardrail_available, logging from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + if is_cosmos_guardrail_available(): from cosmos_guardrail import CosmosSafetyChecker else: @@ -1149,7 +1152,13 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs if action_mode is not None and action_chunk_size is not None: - num_frames = action_chunk_size + 1 + target_num_frames = action_chunk_size + 1 + if num_frames != target_num_frames: + logger.warning( + f"`num_frames={num_frames}` is ignored for action runs and overwritten to " + f"`action_chunk_size + 1 = {target_num_frames}`." + ) + num_frames = target_num_frames # 1. Check inputs self.check_inputs( From 913c24f35d2f0af80d46405822856c3ff0f4c1f2 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 13:07:30 +0000 Subject: [PATCH 14/27] Rename action_tokens to raw_actions --- examples/cosmos3/inference_cosmos3.py | 2 +- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 54 +++++++++---------- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 675ead892c2f..737b57d30df5 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -161,7 +161,7 @@ def main(): num_inference_steps=args.num_inference_steps, flow_shift=args.flow_shift, action_mode=args.action_mode, - action=action, + raw_actions=action, action_chunk_size=args.action_chunk_size, domain_name=args.domain_name, raw_action_dim=args.raw_action_dim, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 158e68642272..b7d6e00226b5 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -591,7 +591,7 @@ def prepare_latents( dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, action_mode: str | None = None, - action_tokens: torch.Tensor | None = None, + raw_actions: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -683,29 +683,29 @@ def prepare_latents( assert action_chunk_size is not None action_dim = self.transformer.action_dim if action_mode == "forward_dynamics": - if action_tokens is None: + if raw_actions is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - action_tokens = action_tokens.to(device=device, dtype=dtype) + raw_actions = raw_actions.to(device=device, dtype=dtype) # Action chunks describe transitions, so action length must match action_chunk_size # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. - if action_tokens.shape[0] < action_chunk_size: - action_tokens = torch.cat( - [action_tokens, action_tokens[-1:].expand(action_chunk_size - action_tokens.shape[0], -1)], + if raw_actions.shape[0] < action_chunk_size: + raw_actions = torch.cat( + [raw_actions, raw_actions[-1:].expand(action_chunk_size - raw_actions.shape[0], -1)], dim=0, ) - action_tokens = action_tokens[:action_chunk_size] + raw_actions = raw_actions[:action_chunk_size] # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. - if action_tokens.shape[-1] < action_dim: + if raw_actions.shape[-1] < action_dim: action_padding = torch.zeros( - action_tokens.shape[0], - action_dim - action_tokens.shape[-1], - dtype=action_tokens.dtype, - device=action_tokens.device, + raw_actions.shape[0], + action_dim - raw_actions.shape[-1], + dtype=raw_actions.dtype, + device=raw_actions.device, ) - action_tokens = torch.cat([action_tokens, action_padding], dim=-1) - x0_tokens_action = action_tokens + raw_actions = torch.cat([raw_actions, action_padding], dim=-1) + x0_tokens_action = raw_actions else: x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: @@ -796,7 +796,7 @@ def check_inputs( enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], action_mode: str | None, - action_tokens: torch.Tensor | None, + raw_actions: torch.Tensor | None, action_chunk_size: int | None, domain_name: str | None, raw_action_dim: int | None, @@ -848,14 +848,14 @@ def check_inputs( f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) if action_mode == "forward_dynamics": - if action_tokens is None: + if raw_actions is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - if action_tokens.shape[0] == 0: + if raw_actions.shape[0] == 0: raise ValueError("action_mode='forward_dynamics' requires at least one action token.") action_dim = self.transformer.action_dim - if action_tokens.shape[-1] > action_dim: + if raw_actions.shape[-1] > action_dim: raise ValueError( - f"Cosmos3 action dimension {action_tokens.shape[-1]} exceeds model action_dim={action_dim}." + f"Cosmos3 action dimension {raw_actions.shape[-1]} exceeds model action_dim={action_dim}." ) if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") @@ -1040,7 +1040,7 @@ def __call__( sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, action_mode: str | None = None, - action_tokens: torch.Tensor | None = None, + raw_actions: torch.Tensor | None = None, action_chunk_size: int | None = None, domain_name: str | None = None, raw_action_dim: int | None = None, @@ -1099,14 +1099,14 @@ def __call__( action_mode (`str`, *optional*): Selects the action-conditioned generation task and requires a transformer trained with `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a - given `action_tokens` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning + given `raw_actions` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), or `"policy"` (jointly roll out future video and actions from the first frame). When set, conditioning must be supplied via `video` (not `image`) and `num_frames` is forced to `action_chunk_size + 1`. - action_tokens (`torch.Tensor`, *optional*): - Raw action tokens of shape `[T, action_dim]` driving `action_mode="forward_dynamics"`. Sequences shorter - than `action_chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's - `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. + raw_actions (`torch.Tensor`, *optional*): + Raw domain action vectors of shape `[T, raw_action_dim]` driving `action_mode="forward_dynamics"`. + Sequences shorter than `action_chunk_size` repeat the last action; longer ones are truncated. Channels + beyond the model's `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. action_chunk_size (`int`, *optional*): Number of action transition steps in the chunk. Required for every `action_mode`; the paired video has `action_chunk_size + 1` frames and `num_frames` is overwritten accordingly. @@ -1173,7 +1173,7 @@ def __call__( enable_sound, callback_on_step_end_tensor_inputs, action_mode, - action_tokens, + raw_actions, action_chunk_size, domain_name, raw_action_dim, @@ -1251,7 +1251,7 @@ def __call__( dtype=dtype, enable_sound=enable_sound, action_mode=action_mode, - action_tokens=action_tokens, + raw_actions=raw_actions, action_chunk_size=action_chunk_size, domain_name=domain_name, raw_action_dim=raw_action_dim, From 3bce946d550952798c13874a5ae5c91eec2616de Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Fri, 29 May 2026 13:33:36 +0000 Subject: [PATCH 15/27] Remove scheduler config override --- docs/source/en/api/pipelines/cosmos3.md | 2 -- examples/cosmos3/inference_cosmos3.py | 3 -- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 30 +++++-------------- 3 files changed, 8 insertions(+), 27 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index e58c7a95f796..301c78f58e62 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -494,7 +494,6 @@ result = pipe( fps=5, num_inference_steps=30, guidance_scale=1.0, - flow_shift=5.0, action_mode="policy", action_chunk_size=16, raw_action_dim=10, @@ -541,7 +540,6 @@ result = pipe( fps=5, num_inference_steps=30, guidance_scale=1.0, - flow_shift=5.0, action_mode="policy", action_chunk_size=16, raw_action_dim=10, diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 737b57d30df5..18297927e09d 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -82,7 +82,6 @@ def main(): parser.add_argument("--fps", type=float, default=24.0) parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.") parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.") - parser.add_argument("--flow-shift", type=float, default=None, help="Scheduler flow shift.") parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") parser.add_argument( "--enable-sound", @@ -159,7 +158,6 @@ def main(): width=args.width, fps=args.fps, num_inference_steps=args.num_inference_steps, - flow_shift=args.flow_shift, action_mode=args.action_mode, raw_actions=action, action_chunk_size=args.action_chunk_size, @@ -182,7 +180,6 @@ def main(): width=args.width, fps=args.fps, num_inference_steps=args.num_inference_steps, - flow_shift=args.flow_shift, enable_sound=args.enable_sound, guidance_scale=args.guidance_scale, generator=generator, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index b7d6e00226b5..e47062e4a0f4 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1033,7 +1033,6 @@ def __call__( fps: float = 24.0, num_inference_steps: int = 35, guidance_scale: float = 6.0, - flow_shift: float | None = None, enable_sound: bool = False, generator: torch.Generator | None = None, latents: torch.Tensor | None = None, @@ -1363,25 +1362,14 @@ def __call__( action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, - # model_outputs history) on the instance, so audio/action gets its own copy. - inference_scheduler = copy.deepcopy(self.scheduler) - if flow_shift is not None: - inference_scheduler.register_to_config( - use_flow_sigmas=True, - use_karras_sigmas=False, - use_exponential_sigmas=False, - use_beta_sigmas=False, - flow_shift=flow_shift, - shift_terminal=None, - final_sigmas_type="zero", - ) - inference_scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = inference_scheduler.timesteps - sound_scheduler = copy.deepcopy(inference_scheduler) if sound_latents is not None else None - action_scheduler = copy.deepcopy(inference_scheduler) if action_latents is not None else None + # model_outputs history) on the instance, so sound/action each get their own copy. + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + sound_scheduler = copy.deepcopy(self.scheduler) if sound_latents is not None else None + action_scheduler = copy.deepcopy(self.scheduler) if action_latents is not None else None # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * inference_scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1492,7 +1480,7 @@ def __call__( else: velocity_vision = cond_v_vision - latents = inference_scheduler.step( + latents = self.scheduler.step( velocity_vision.unsqueeze(0), t, latents.unsqueeze(0), return_dict=False )[0].squeeze(0) @@ -1525,9 +1513,7 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() self._current_timestep = None From efc4a3efda067bbb05f173d18a1265fa20a91ccb Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Mon, 1 Jun 2026 15:26:28 +0000 Subject: [PATCH 16/27] Refactor action to use CosmosActionCondition --- docs/source/en/api/pipelines/cosmos3.md | 46 ++- examples/cosmos3/README.md | 25 +- examples/cosmos3/inference_cosmos3.py | 48 ++- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/cosmos/__init__.py | 2 + .../pipelines/cosmos/pipeline_cosmos3_omni.py | 325 +++++++++++------- .../dummy_torch_and_transformers_objects.py | 15 + 8 files changed, 295 insertions(+), 170 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 301c78f58e62..2b5ca69043b5 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -459,7 +459,11 @@ encode_video( -## Action policy +## Action-conditioned generation + +Action runs group every action-specific input into a [`CosmosActionCondition`] passed via the `action` argument instead of the top-level `image` / `video` / `height` / `width` arguments. Set `resolution_tier` (`256`/`480`/`704`/`720`) close to the input video's native resolution; it selects the conditioning canvas. Cosmos 3 supports three action modes — `policy`, `forward_dynamics`, and `inverse_dynamics`. `policy` and `forward_dynamics` condition only on the first frame (so an `image` or a `video` both work), while `inverse_dynamics` requires a `video`. The conditioning video for an action run is set on `action.video` (or `action.image`), not on the pipeline's top-level `video` argument. + +### Action policy Action policy generation predicts future video and action tokens from the first observation frame, text prompt, and action domain metadata. The example below uses the Bridge robot domain and writes the predicted action chunk to JSON in model-normalized action space. @@ -470,7 +474,7 @@ Action policy generation predicts future video and action tokens from the first import json import torch -from diffusers import Cosmos3OmniPipeline +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition from diffusers.utils import export_to_video, load_video pipe = Cosmos3OmniPipeline.from_pretrained( @@ -487,17 +491,17 @@ video = load_video( result = pipe( prompt=prompt, - video=video, - num_frames=17, - height=480, - width=832, + action=CosmosActionCondition( + mode="policy", + chunk_size=16, + domain_name="bridge_orig_lerobot", + raw_action_dim=10, + resolution_tier=480, + video=video, + ), fps=5, num_inference_steps=30, guidance_scale=1.0, - action_mode="policy", - action_chunk_size=16, - raw_action_dim=10, - domain_name="bridge_orig_lerobot", use_system_prompt=False, ) @@ -516,7 +520,7 @@ if result.action is not None: import json import torch -from diffusers import Cosmos3OmniPipeline +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition from diffusers.utils import export_to_video, load_video pipe = Cosmos3OmniPipeline.from_pretrained( @@ -533,17 +537,17 @@ video = load_video( result = pipe( prompt=prompt, - video=video, - num_frames=17, - height=480, - width=832, + action=CosmosActionCondition( + mode="policy", + chunk_size=16, + domain_name="bridge_orig_lerobot", + raw_action_dim=10, + resolution_tier=480, + video=video, + ), fps=5, num_inference_steps=30, guidance_scale=1.0, - action_mode="policy", - action_chunk_size=16, - raw_action_dim=10, - domain_name="bridge_orig_lerobot", use_system_prompt=False, ) @@ -636,6 +640,10 @@ pipe = Cosmos3OmniPipeline.from_pretrained( - all - __call__ +## CosmosActionCondition + +[[autodoc]] CosmosActionCondition + ## Cosmos3OmniPipelineOutput [[autodoc]] pipelines.cosmos.pipeline_cosmos3_omni.Cosmos3OmniPipelineOutput \ No newline at end of file diff --git a/examples/cosmos3/README.md b/examples/cosmos3/README.md index 98cf30eac6d9..02f609e2ac0c 100644 --- a/examples/cosmos3/README.md +++ b/examples/cosmos3/README.md @@ -59,7 +59,7 @@ python examples/cosmos3/inference_cosmos3.py \ --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.json" \ --action-chunk-size 16 \ --domain-name bridge_orig_lerobot \ - --height 480 --width 832 --fps 5 \ + --resolution-tier 480 --fps 5 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ --output results/cosmos3_forward_dynamics_robot ``` @@ -75,7 +75,7 @@ python examples/cosmos3/inference_cosmos3.py \ --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_action_25.json" \ --action-chunk-size 60 \ --domain-name av \ - --height 480 --width 832 --fps 10 \ + --resolution-tier 480 --fps 10 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ --output results/cosmos3_forward_dynamics_av ``` @@ -91,7 +91,7 @@ python examples/cosmos3/inference_cosmos3.py \ --action-chunk-size 16 \ --raw-action-dim 10 \ --domain-name bridge_orig_lerobot \ - --height 480 --width 832 --fps 5 \ + --resolution-tier 480 --fps 5 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ --output results/cosmos3_inverse_dynamics_robot ``` @@ -107,7 +107,7 @@ python examples/cosmos3/inference_cosmos3.py \ --action-chunk-size 60 \ --raw-action-dim 9 \ --domain-name av \ - --height 480 --width 832 --fps 10 \ + --resolution-tier 480 --fps 10 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ --output results/cosmos3_inverse_dynamics_av ``` @@ -123,7 +123,7 @@ python examples/cosmos3/inference_cosmos3.py \ --action-chunk-size 16 \ --raw-action-dim 10 \ --domain-name bridge_orig_lerobot \ - --height 480 --width 832 --fps 5 \ + --resolution-tier 480 --fps 5 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ --output results/cosmos3_policy_robot ``` @@ -139,21 +139,26 @@ python examples/cosmos3/inference_cosmos3.py \ --action-chunk-size 60 \ --raw-action-dim 9 \ --domain-name av \ - --height 480 --width 832 --fps 10 \ + --resolution-tier 480 --fps 10 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ --output results/cosmos3_policy_av ``` -Action modes use `action_chunk_size + 1` video frames. `forward_dynamics` consumes `--action-path`; `inverse_dynamics` and `policy` write predicted actions to `sample-*_action.json` in model-normalized action space. The upstream camera-pose forward-dynamics sample uses a still image (`mountain_720.png`), while this wrapper currently expects `--vision-path` to load as video for action modes. +Action modes use `action_chunk_size + 1` conditioning frames. `forward_dynamics` consumes `--action-path`; `inverse_dynamics` and `policy` write predicted actions to `sample_action.json` in model-normalized action space. This script loads `--vision-path` as a video for all action modes; `policy` and `forward_dynamics` condition only on the first frame, while `inverse_dynamics` uses the whole video. + +`--resolution-tier` is a resolution *tier* (`256`/`480`/`704`/`720`). The tier keys a table of predefined aspect-ratio canvases; the one closest to the input aspect ratio becomes the padded conditioning canvas. It is not the output frame size: the input is downscaled (never upscaled) and padded to fill the canvas, then the padding is cropped from the latents so the decoded output follows the downscaled input content. `--height` / `--width` (and `--num-frames`) are ignored for action modes. + +Pick the tier that matches the native resolution of your conditioning input (`480` for ~480p, `720` for ~720p). A tier below your input downscales it and discards detail; a tier above your input gains no resolution (content is never upscaled), wastes compute on padding, and is a train/inference distribution mismatch that can degrade quality. ### Useful flags | Flag | Default | Description | |---|---|---| | `--prompt` | (required) | Text prompt. | -| `--vision-path` | `None` | URL or local path for an image-conditioning frame (image-to-video). | -| `--num-frames` | `189` | `1` = image, otherwise number of video frames (`189` ≈ 7.9 s @ 24 FPS). | -| `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). | +| `--vision-path` | `None` | URL or local path for an image-conditioning frame (image-to-video), or the image/video conditioning for action modes. | +| `--num-frames` | `189` | `1` = image, otherwise number of video frames (`189` ≈ 7.9 s @ 24 FPS). Ignored for action modes (derived from `--action-chunk-size`). | +| `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). Ignored for action modes; use `--resolution-tier`. | +| `--resolution-tier` | `480` | Action resolution tier (`256`/`480`/`704`/`720`): selects the aspect bin / padded conditioning canvas, not the output size. | | `--fps` | `24.0` | Frame rate of the generated video. | | `--enable-sound` | off | Generate a synchronized audio track. | | `--action-mode` | `None` | Enable action conditioning/generation. One of `forward_dynamics`, `inverse_dynamics`, or `policy`. | diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index 18297927e09d..d354de85165f 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -30,7 +30,7 @@ import torch from huggingface_hub import snapshot_download -from diffusers import Cosmos3OmniPipeline +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition from diffusers.utils import encode_video, export_to_video, load_image, load_video @@ -71,8 +71,18 @@ def main(): help="Optional URL or local path for an image-conditioning frame, or an action conditioning video.", ) parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.") - parser.add_argument("--height", type=int, default=720) - parser.add_argument("--width", type=int, default=1280) + parser.add_argument( + "--height", + type=int, + default=None, + help="Output height in pixels (default 720). Ignored for action modes; use --resolution-tier instead.", + ) + parser.add_argument( + "--width", + type=int, + default=None, + help="Output width in pixels (default 1280). Ignored for action modes; use --resolution-tier instead.", + ) parser.add_argument( "--num-frames", type=int, @@ -99,6 +109,16 @@ def main(): parser.add_argument("--action-chunk-size", type=int, default=None, help="Number of action tokens to generate/use.") parser.add_argument("--domain-name", default=None, help="Cosmos3 action embodiment domain name.") parser.add_argument("--raw-action-dim", type=int, default=None, help="Slice predicted action output to this size.") + parser.add_argument( + "--resolution-tier", + type=int, + default=480, + choices=[256, 480, 704, 720], + help=( + "Action resolution tier (256/480/704/720). Selects the aspect bin / padded conditioning canvas, " + "not the output frame size." + ), + ) parser.add_argument( "--no-duration-template", dest="add_duration_template", @@ -145,24 +165,24 @@ def main(): if args.action_mode is not None: if args.vision_path is None: - raise ValueError("--vision-path must point to a video for action modes.") + raise ValueError("--vision-path must point to a conditioning video for action modes.") if args.action_chunk_size is None: raise ValueError("--action-chunk-size is required for action modes.") video = load_video(args.vision_path) - action = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None + raw_actions = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None result = pipeline( prompt=args.prompt, - video=video, - num_frames=args.action_chunk_size + 1, - height=args.height, - width=args.width, + action=CosmosActionCondition( + mode=args.action_mode, + chunk_size=args.action_chunk_size, + domain_name=args.domain_name, + raw_action_dim=args.raw_action_dim, + resolution_tier=args.resolution_tier, + raw_actions=raw_actions, + video=video, + ), fps=args.fps, num_inference_steps=args.num_inference_steps, - action_mode=args.action_mode, - raw_actions=action, - action_chunk_size=args.action_chunk_size, - domain_name=args.domain_name, - raw_action_dim=args.raw_action_dim, guidance_scale=args.guidance_scale, generator=generator, use_system_prompt=False, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d75e2d9a5010..59fab3e689b3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -546,6 +546,7 @@ "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "Cosmos3OmniPipeline", + "CosmosActionCondition", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", @@ -1357,6 +1358,7 @@ Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, Cosmos3OmniPipeline, + CosmosActionCondition, CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 720548e38fd4..1bddb095ad3a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -214,6 +214,7 @@ "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "Cosmos3OmniPipeline", + "CosmosActionCondition", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", ] @@ -652,6 +653,7 @@ Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, Cosmos3OmniPipeline, + CosmosActionCondition, CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, ) diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py index 0f828933be09..54d841f5b998 100644 --- a/src/diffusers/pipelines/cosmos/__init__.py +++ b/src/diffusers/pipelines/cosmos/__init__.py @@ -34,6 +34,7 @@ _import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"] _import_structure["pipeline_cosmos3_omni"] = [ "Cosmos3OmniPipeline", + "CosmosActionCondition", ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -52,6 +53,7 @@ from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline from .pipeline_cosmos3_omni import ( Cosmos3OmniPipeline, + CosmosActionCondition, ) from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index e47062e4a0f4..64334b5e0381 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -15,7 +15,7 @@ import copy import math from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, Literal import numpy as np import torch @@ -210,6 +210,124 @@ class Cosmos3OmniPipelineOutput(BaseOutput): action: list[torch.Tensor] | None = None +@dataclass +class CosmosActionCondition: + """Groups every input required for a Cosmos 3 action-conditioned generation task. + + Pass this to [`Cosmos3OmniPipeline.__call__`] via the `action` argument instead of the top-level `image` / `video` + / `height` / `width` arguments, which are reserved for t2v, i2v runs. + + Attributes: + mode (`str`): + The action task. One of `"forward_dynamics"` (roll out future video from a first frame and a given + `raw_actions` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), or + `"policy"` (jointly roll out future video and actions from the first frame). + chunk_size (`int`): + Number of action transition steps in the chunk. The paired conditioning video spans `chunk_size + 1` + frames (see [`~CosmosActionCondition.num_frames`]). + domain_name (`str`): + Embodiment domain selecting the domain-aware action projection weights. Must be one of the registered + Cosmos 3 embodiment domains. + resolution_tier (`int`, defaults to `480`): + Action conditioning resolution *tier* (one of `256`, `480`, `704`, `720`). The tier picks a predefined + canvas whose aspect ratio is closest to the input; the input is downscaled (never upscaled) and padded + into it for conditioning. This is not the output frame size, which tracks the input content. Match the + tier to the input's native resolution: a lower tier discards detail, while a higher tier adds no + resolution (no upscaling), wastes compute on padding, and is a train/inference mismatch that can hurt + quality. + raw_action_dim (`int`, *optional*): + Number of meaningful (unpadded) action channels to keep when slicing predicted actions. Required for + `"policy"` and `"inverse_dynamics"`. + raw_actions (`torch.Tensor`, *optional*): + Raw domain action vectors of shape `[T, raw_action_dim]` driving `"forward_dynamics"`. Sequences shorter + than `chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's + `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. + image (`PIL.Image.Image`, `np.ndarray`, or `torch.Tensor`, *optional*): + Conditioning frame for `"policy"` / `"forward_dynamics"`. Mutually exclusive with `video`. + video (`list`, `np.ndarray`, or `torch.Tensor`, *optional*): + Conditioning video, required for `"inverse_dynamics"`. For `"policy"` / `"forward_dynamics"` only its + first frame is used. Mutually exclusive with `image`. + """ + + mode: Literal["policy", "forward_dynamics", "inverse_dynamics"] + chunk_size: int + domain_name: str + resolution_tier: int = 480 + raw_action_dim: int | None = None + raw_actions: torch.Tensor | None = None + image: Image.Image | np.ndarray | torch.Tensor | None = None + video: list | np.ndarray | torch.Tensor | None = None + + @property + def num_frames(self) -> int: + """Number of conditioning frames the paired video must provide (`chunk_size + 1`).""" + return self.chunk_size + 1 + + @property + def conditioning_clip(self) -> Any: + """Conditioning clip as a single source: `image` wrapped as a one-frame clip, else `video`.""" + return [self.image] if self.image is not None else self.video + + def target_size(self, source_height: int, source_width: int) -> tuple[int, int]: + """Padded conditioning canvas `(height, width)` for this `resolution_tier` and the source frame size. + + The tier selects a set of predefined aspect-ratio canvases; the one closest to the source aspect ratio is + returned. The source content is later downscaled (never upscaled) and padded into this canvas. + """ + resolution_key = str(self.resolution_tier) + if resolution_key not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Unsupported action resolution_tier={self.resolution_tier!r}; " + f"expected one of {sorted(int(k) for k in _ACTION_RESOLUTION_BINS)}." + ) + return VideoProcessor.classify_height_width_bin( + source_height, source_width, ratios=_ACTION_RESOLUTION_BINS[resolution_key] + ) + + def validate(self, *, action_dim: int | None = None) -> None: + """Validate every action-specific field. Called by [`Cosmos3OmniPipeline.check_inputs`]. + + Args: + action_dim (`int`, *optional*): + The model's action head width. When provided, `raw_actions` channels are checked against it. + """ + if self.mode not in _ACTION_MODES: + raise ValueError(f"Unsupported action mode={self.mode!r}; expected one of {sorted(_ACTION_MODES)}.") + if self.chunk_size is None or self.chunk_size < 1: + raise ValueError(f"action `chunk_size` must be >= 1, got {self.chunk_size}.") + if self.domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={self.domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + if str(self.resolution_tier) not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Unsupported action resolution_tier={self.resolution_tier!r}; " + f"expected one of {sorted(int(k) for k in _ACTION_RESOLUTION_BINS)}." + ) + if self.image is not None and self.video is not None: + raise ValueError("Provide either `image` or `video` for the action condition, not both.") + if self.mode == _ACTION_MODE_INVERSE_DYNAMICS: + if self.video is None: + raise ValueError("action mode='inverse_dynamics' requires `video` conditioning.") + else: + if self.image is None and self.video is None: + raise ValueError(f"action mode={self.mode!r} requires `image` or `video` conditioning.") + if self.mode in {_ACTION_MODE_POLICY, _ACTION_MODE_INVERSE_DYNAMICS} and self.raw_action_dim is None: + raise ValueError(f"action mode={self.mode!r} requires `raw_action_dim` for output slicing.") + if self.mode == _ACTION_MODE_FORWARD_DYNAMICS: + if self.raw_actions is None: + raise ValueError("action mode='forward_dynamics' requires `raw_actions`.") + if self.raw_actions.ndim != 2: + raise ValueError(f"`raw_actions` must have shape [T, D], got {tuple(self.raw_actions.shape)}.") + if self.raw_actions.shape[0] < 1: + raise ValueError("action mode='forward_dynamics' requires at least one action token.") + if action_dim is not None and self.raw_actions.shape[-1] > action_dim: + raise ValueError( + f"Cosmos3 action dimension {self.raw_actions.shape[-1]} exceeds model action_dim={action_dim}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" @@ -504,37 +622,16 @@ def _prepare_action_segment( "num_noisy_action_tokens": len(noisy_frame_indexes), } - def _get_action_target_size( - self, - source_height: int, - source_width: int, - requested_height: int, - requested_width: int, - ) -> tuple[int, int]: - resolution_key = str(min(requested_height, requested_width)) - if resolution_key not in _ACTION_RESOLUTION_BINS: - raise ValueError( - f"Cosmos3 action resolution binning only supports {sorted(_ACTION_RESOLUTION_BINS)}, " - f"got height={requested_height}, width={requested_width}." - ) - return self.video_processor.classify_height_width_bin( - source_height, - source_width, - ratios=_ACTION_RESOLUTION_BINS[resolution_key], - ) - def _prepare_action_video_conditioning( self, - video: Any, - height: int, - width: int, + action: "CosmosActionCondition", num_frames: int, device: torch.device | str, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor, int, int]: - frames = self.video_processor.preprocess_video(video).to(device=device, dtype=dtype) + frames = self.video_processor.preprocess_video(action.conditioning_clip).to(device=device, dtype=dtype) source_h, source_w = frames.shape[-2:] - target_h, target_w = self._get_action_target_size(source_h, source_w, height, width) + target_h, target_w = action.target_size(source_h, source_w) if frames.shape[2] < num_frames: frames = torch.cat([frames, frames[:, :, -1:].expand(-1, -1, num_frames - frames.shape[2], -1, -1)], dim=2) @@ -590,11 +687,7 @@ def prepare_latents( device: str = "cuda", dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, - action_mode: str | None = None, - raw_actions: torch.Tensor | None = None, - action_chunk_size: int | None = None, - domain_name: str | None = None, - raw_action_dim: int | None = None, + action: "CosmosActionCondition | None" = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -613,6 +706,7 @@ def prepare_latents( Returns: Initial noisy tensors plus condition masks/metadata for vision, sound, and optional action modalities. """ + action_mode = action.mode if action is not None else None is_image = num_frames == 1 has_image_condition = (image is not None and not is_image) or action_mode is not None @@ -625,22 +719,23 @@ def prepare_latents( action_domain_id: torch.Tensor | None = None action_condition_mask: torch.Tensor | None = None - raw_action_dim_resolved: int | None = int(raw_action_dim) if raw_action_dim is not None else None + raw_action_dim_resolved: int | None = ( + int(action.raw_action_dim) if action is not None and action.raw_action_dim is not None else None + ) action_condition_frames: list[int] = [] action_condition_frame_indexes: list[int] = [] action_image_size: torch.Tensor | None = None vision_condition_frames: list[int] | None = None # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). - if action_mode is not None: - assert action_chunk_size is not None - target_frames = action_chunk_size + 1 + if action is not None: + target_frames = action.chunk_size + 1 vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( - video, height, width, target_frames, device=device, dtype=dtype + action, target_frames, device=device, dtype=dtype ) if action_mode == _ACTION_MODE_FORWARD_DYNAMICS: vision_condition_frames = [0] - action_condition_frames = list(range(action_chunk_size)) + action_condition_frames = list(range(action.chunk_size)) elif action_mode == _ACTION_MODE_POLICY: vision_condition_frames = [0] elif action_mode == _ACTION_MODE_INVERSE_DYNAMICS: @@ -679,10 +774,11 @@ def prepare_latents( x0_tokens_sound = torch.zeros(sound_dim, T_sound, device=device, dtype=dtype) x0_tokens_action: torch.Tensor | None = None - if action_mode is not None: - assert action_chunk_size is not None + if action is not None: + action_chunk_size = action.chunk_size action_dim = self.transformer.action_dim if action_mode == "forward_dynamics": + raw_actions = action.raw_actions if raw_actions is None: raise ValueError("action_mode='forward_dynamics' requires an action tensor.") raw_actions = raw_actions.to(device=device, dtype=dtype) @@ -708,13 +804,13 @@ def prepare_latents( x0_tokens_action = raw_actions else: x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) - if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + if action.domain_name not in _EMBODIMENT_TO_DOMAIN_ID: raise ValueError( - f"Unknown Cosmos3 action domain_name={domain_name!r}; " + f"Unknown Cosmos3 action domain_name={action.domain_name!r}; " f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." ) action_domain_id = torch.tensor( - [_EMBODIMENT_TO_DOMAIN_ID[domain_name]], + [_EMBODIMENT_TO_DOMAIN_ID[action.domain_name]], dtype=torch.long, device=device, ) @@ -795,11 +891,7 @@ def check_inputs( guidance_scale: float, enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], - action_mode: str | None, - raw_actions: torch.Tensor | None, - action_chunk_size: int | None, - domain_name: str | None, - raw_action_dim: int | None, + action: "CosmosActionCondition | None" = None, ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) @@ -811,9 +903,6 @@ def check_inputs( ) if num_frames < 1: raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") - sf = int(self.vae.config.scale_factor_spatial) - if height % sf != 0 or width % sf != 0: - raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") if enable_sound: if self.sound_tokenizer is None: raise ValueError("`enable_sound=True` requires a sound-capable checkpoint with a `sound_tokenizer`.") @@ -824,41 +913,21 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if action_mode is not None: - if action_mode not in _ACTION_MODES: - raise ValueError(f"Unsupported action_mode={action_mode!r}; expected one of {sorted(_ACTION_MODES)}.") - if not getattr(self.transformer.config, "action_gen", False): - raise ValueError("action_mode requires a transformer trained with action_gen=True.") - if image is not None: - raise ValueError("Use `video`, not `image`, for Cosmos3 action conditioning.") - if video is None: - raise ValueError(f"action_mode={action_mode!r} requires a loaded conditioning video.") - if action_chunk_size is None: - raise ValueError("action_mode requires action_chunk_size.") - if num_frames != action_chunk_size + 1: - raise ValueError( - "Action runs require num_frames to equal action_chunk_size + 1; " - f"got num_frames={num_frames}, action_chunk_size={action_chunk_size}." - ) - if domain_name is None: - raise ValueError("action_mode requires domain_name.") - if domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + + if action is not None: + # API-conflict checks live here; all action-field validation is delegated to action.validate(). + if image is not None or video is not None: raise ValueError( - f"Unknown Cosmos3 action domain_name={domain_name!r}; " - f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + "Pass action conditioning via `action.image` / `action.video`, not the top-level " + "`image` / `video` arguments." ) - if action_mode == "forward_dynamics": - if raw_actions is None: - raise ValueError("action_mode='forward_dynamics' requires an action tensor.") - if raw_actions.shape[0] == 0: - raise ValueError("action_mode='forward_dynamics' requires at least one action token.") - action_dim = self.transformer.action_dim - if raw_actions.shape[-1] > action_dim: - raise ValueError( - f"Cosmos3 action dimension {raw_actions.shape[-1]} exceeds model action_dim={action_dim}." - ) - if action_mode in {"inverse_dynamics", "policy"} and raw_action_dim is None: - raise ValueError(f"action_mode={action_mode!r} requires raw_action_dim for output slicing.") + if not getattr(self.transformer.config, "action_gen", False): + raise ValueError("`action` requires a transformer trained with action_gen=True.") + action.validate(action_dim=self.transformer.action_dim) + else: + sf = int(self.vae.config.scale_factor_spatial) + if height % sf != 0 or width % sf != 0: + raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") def tokenize_prompt( self, @@ -1038,11 +1107,7 @@ def __call__( latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, action_latents: torch.Tensor | None = None, - action_mode: str | None = None, - raw_actions: torch.Tensor | None = None, - action_chunk_size: int | None = None, - domain_name: str | None = None, - raw_action_dim: int | None = None, + action: CosmosActionCondition | None = None, output_type: str = "pil", return_dict: bool = True, use_system_prompt: bool = True, @@ -1067,13 +1132,20 @@ def __call__( The negative prompt used for classifier-free guidance. When `None`, the empty string is used. image (`torch.Tensor` or `PIL.Image.Image`, *optional*): Optional conditioning frame for image-to-video. The pipeline anchors frame 0 to this image and denoises - the remaining frames. Ignored when `num_frames == 1`. + the remaining frames. Ignored when `num_frames == 1`. Not used for action runs (pass `action` instead). + video (`list`, `np.ndarray`, or `torch.Tensor`, *optional*): + Reserved for video-to-video conditioning. Video-to-video is not yet supported, so this argument is + currently accepted but unused. Action conditioning video is provided through `action` (see + [`CosmosActionCondition`]), not this argument. num_frames (`int`, *optional*, defaults to `189`): Number of frames to generate. Use `1` for text-to-image; the default produces ≈ 7.9 s at 24 FPS. + Ignored for action runs, where it is derived from `action.num_frames`. height (`int`, *optional*, defaults to `720`): - Output height in pixels. + Output height in pixels. Ignored for action runs, which size via + `action.resolution_tier`. width (`int`, *optional*, defaults to `1280`): - Output width in pixels. + Output width in pixels. Ignored for action runs, which size via + `action.resolution_tier`. fps (`float`, *optional*, defaults to `24.0`): Target frame rate, also injected into the mRoPE temporal modulation and into the duration metadata template. @@ -1094,27 +1166,13 @@ def __call__( `None`, fresh Gaussian noise is sampled. action_latents (`torch.Tensor`, *optional*): Pre-generated action latents to start the action stream's denoising from. Only consulted when an action - run is configured via `action_mode`; when `None`, fresh Gaussian noise is sampled for the action tokens. - action_mode (`str`, *optional*): - Selects the action-conditioned generation task and requires a transformer trained with - `action_gen=True`. One of `"forward_dynamics"` (predict the future video from an initial frame and a - given `raw_actions` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning - frames), or `"policy"` (jointly roll out future video and actions from the first frame). When set, - conditioning must be supplied via `video` (not `image`) and `num_frames` is forced to - `action_chunk_size + 1`. - raw_actions (`torch.Tensor`, *optional*): - Raw domain action vectors of shape `[T, raw_action_dim]` driving `action_mode="forward_dynamics"`. - Sequences shorter than `action_chunk_size` repeat the last action; longer ones are truncated. Channels - beyond the model's `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. - action_chunk_size (`int`, *optional*): - Number of action transition steps in the chunk. Required for every `action_mode`; the paired video has - `action_chunk_size + 1` frames and `num_frames` is overwritten accordingly. - domain_name (`str`, *optional*): - Embodiment domain that selects the domain-aware action projection weights. Required for action runs and - must be one of the registered Cosmos 3 embodiment domains. - raw_action_dim (`int`, *optional*): - Number of meaningful (unpadded) action channels to keep when slicing predicted actions. Required for - `action_mode="inverse_dynamics"` and `action_mode="policy"`. + run is configured via `action`; when `None`, fresh Gaussian noise is sampled for the action tokens. + action (`CosmosActionCondition`, *optional*): + Bundles every input for an action-conditioned run (mode, chunk size, embodiment domain, resolution + tier, raw actions, and the conditioning image/video), and requires a transformer trained with + `action_gen=True`. When set, passing the top-level `image` / `video` arguments raises; `height` / + `width` / `num_frames` are ignored (a warning is logged) since resolution comes from + `action.resolution_tier` and the frame count from `action.chunk_size`. See [`CosmosActionCondition`]. output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). @@ -1150,14 +1208,24 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - if action_mode is not None and action_chunk_size is not None: - target_num_frames = action_chunk_size + 1 - if num_frames != target_num_frames: + # Action runs size via `action.resolution_tier` and derive their frame count from `action.chunk_size`, so + # `height` / `width` / `num_frames` are ignored. Warn (rather than fail) when the caller passed non-default + # values so a shared config can still be reused across action and non-action runs. + if action is not None: + ignored = [ + name + for name, value, default in ( + ("num_frames", num_frames, 189), + ("height", height, 720), + ("width", width, 1280), + ) + if value != default + ] + if ignored: logger.warning( - f"`num_frames={num_frames}` is ignored for action runs and overwritten to " - f"`action_chunk_size + 1 = {target_num_frames}`." + "Action runs derive resolution from `action.resolution_tier` and frame count from " + f"`action.chunk_size`; ignoring {', '.join(ignored)}." ) - num_frames = target_num_frames # 1. Check inputs self.check_inputs( @@ -1171,13 +1239,20 @@ def __call__( guidance_scale, enable_sound, callback_on_step_end_tensor_inputs, - action_mode, - raw_actions, - action_chunk_size, - domain_name, - raw_action_dim, + action, ) + # `action_mode` is the only action field consumed directly in __call__ (prompt template + output slicing); + # all other action fields are read from `action` at their point of use (e.g. in prepare_latents). + action_mode = action.mode if action is not None else None + + if action is not None: + num_frames = action.num_frames + # Resolve the padded conditioning canvas from the tier + input aspect *before* tokenization, so the + # resolution prompt template matches the canvas the model is actually conditioned on. + probe = self.video_processor.preprocess_video(action.conditioning_clip) + height, width = action.target_size(int(probe.shape[-2]), int(probe.shape[-1])) + self._current_timestep = None self._interrupt = False self._guidance_scale = guidance_scale @@ -1249,11 +1324,7 @@ def __call__( device=device, dtype=dtype, enable_sound=enable_sound, - action_mode=action_mode, - raw_actions=raw_actions, - action_chunk_size=action_chunk_size, - domain_name=domain_name, - raw_action_dim=raw_action_dim, + action=action, ) vision_condition_indexes_for_pack = torch.nonzero(vision_condition_mask[:, 0, 0] > 0, as_tuple=False).flatten() vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes_for_pack] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 934dcc5ebf2d..569fa932d88c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1337,6 +1337,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CosmosActionCondition(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CosmosTextToWorldPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From b0aa0269b55355e9e3253a78db7fdceea97cdfb9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 07:55:45 +0000 Subject: [PATCH 17/27] Fix examples script to support flow_shift arg --- examples/cosmos3/inference_cosmos3.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index d354de85165f..ad0e5affa3ae 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -30,7 +30,7 @@ import torch from huggingface_hub import snapshot_download -from diffusers import Cosmos3OmniPipeline, CosmosActionCondition +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition, UniPCMultistepScheduler from diffusers.utils import encode_video, export_to_video, load_image, load_video @@ -92,6 +92,12 @@ def main(): parser.add_argument("--fps", type=float, default=24.0) parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.") parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.") + parser.add_argument( + "--flow-shift", + type=float, + default=None, + help="Override the scheduler's flow-matching shift (UniPCMultistepScheduler.flow_shift).", + ) parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") parser.add_argument( "--enable-sound", @@ -159,6 +165,10 @@ def main(): ) print("Pipeline loaded successfully.") + if args.flow_shift is not None: + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=args.flow_shift) + print(f"Scheduler flow_shift set to {args.flow_shift}.") + output_dir = pathlib.Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None From 9aadc2649cffb853e9d6db8472867854fddf586d Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 08:53:25 +0000 Subject: [PATCH 18/27] Apply styling fixes --- src/diffusers/models/transformers/transformer_cosmos3.py | 6 ++---- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 29cfc127d253..9c1d5e1c17c7 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -156,8 +156,6 @@ def __init__(self, input_size: int, output_size: int, num_domains: int) -> None: self.num_domains = num_domains self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size) self.bias = nn.Embedding(self.num_domains, self.output_size) - nn.init.xavier_uniform_(self.fc.weight) - nn.init.zeros_(self.bias.weight) def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: if domain_id.ndim == 0: @@ -324,7 +322,7 @@ def __init__( rms_norm_eps: float = 1e-6, rope_scaling: dict | None = None, rope_theta: float = 5000000.0, - action_dim: int | None = None, + action_dim: int = 32, action_gen: bool = False, num_embodiment_domains: int = 32, sound_dim: int | None = None, @@ -370,7 +368,7 @@ def __init__( self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) self.action_gen = action_gen - self.action_dim = 32 if action_dim is None else action_dim + self.action_dim = action_dim self.num_embodiment_domains = num_embodiment_domains if action_gen: self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 64334b5e0381..45c8ca99ebcf 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -137,7 +137,7 @@ def get_3d_mrope_ids_vae_tokens( _ACTION_MODE_FORWARD_DYNAMICS = "forward_dynamics" _ACTION_MODE_INVERSE_DYNAMICS = "inverse_dynamics" _ACTION_MODE_POLICY = "policy" -_ACTION_MODES = {_ACTION_MODE_FORWARD_DYNAMICS, _ACTION_MODE_INVERSE_DYNAMICS, _ACTION_MODE_POLICY} +_ACTION_MODES = {"forward_dynamics", "inverse_dynamics", "policy"} _ACTION_RESOLUTION_BINS = { "256": { @@ -291,7 +291,7 @@ def validate(self, *, action_dim: int | None = None) -> None: action_dim (`int`, *optional*): The model's action head width. When provided, `raw_actions` channels are checked against it. """ - if self.mode not in _ACTION_MODES: + if self.mode not in ["policy", "forward_dynamics", "inverse_dynamics"]: raise ValueError(f"Unsupported action mode={self.mode!r}; expected one of {sorted(_ACTION_MODES)}.") if self.chunk_size is None or self.chunk_size < 1: raise ValueError(f"action `chunk_size` must be >= 1, got {self.chunk_size}.") @@ -307,7 +307,7 @@ def validate(self, *, action_dim: int | None = None) -> None: ) if self.image is not None and self.video is not None: raise ValueError("Provide either `image` or `video` for the action condition, not both.") - if self.mode == _ACTION_MODE_INVERSE_DYNAMICS: + if self.mode == "inverse_dynamics": if self.video is None: raise ValueError("action mode='inverse_dynamics' requires `video` conditioning.") else: @@ -315,7 +315,7 @@ def validate(self, *, action_dim: int | None = None) -> None: raise ValueError(f"action mode={self.mode!r} requires `image` or `video` conditioning.") if self.mode in {_ACTION_MODE_POLICY, _ACTION_MODE_INVERSE_DYNAMICS} and self.raw_action_dim is None: raise ValueError(f"action mode={self.mode!r} requires `raw_action_dim` for output slicing.") - if self.mode == _ACTION_MODE_FORWARD_DYNAMICS: + if self.mode == "forward_dynamics": if self.raw_actions is None: raise ValueError("action mode='forward_dynamics' requires `raw_actions`.") if self.raw_actions.ndim != 2: From 76acf3920c9c888731a85b4810a22e36d2796922 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 09:07:22 +0000 Subject: [PATCH 19/27] Remove CosmosActionCondition properties, move to pipeline --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 73 ++++++++----------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 45c8ca99ebcf..a882b84a643c 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -224,7 +224,7 @@ class CosmosActionCondition: `"policy"` (jointly roll out future video and actions from the first frame). chunk_size (`int`): Number of action transition steps in the chunk. The paired conditioning video spans `chunk_size + 1` - frames (see [`~CosmosActionCondition.num_frames`]). + frames. domain_name (`str`): Embodiment domain selecting the domain-aware action projection weights. Must be one of the registered Cosmos 3 embodiment domains. @@ -258,32 +258,6 @@ class CosmosActionCondition: image: Image.Image | np.ndarray | torch.Tensor | None = None video: list | np.ndarray | torch.Tensor | None = None - @property - def num_frames(self) -> int: - """Number of conditioning frames the paired video must provide (`chunk_size + 1`).""" - return self.chunk_size + 1 - - @property - def conditioning_clip(self) -> Any: - """Conditioning clip as a single source: `image` wrapped as a one-frame clip, else `video`.""" - return [self.image] if self.image is not None else self.video - - def target_size(self, source_height: int, source_width: int) -> tuple[int, int]: - """Padded conditioning canvas `(height, width)` for this `resolution_tier` and the source frame size. - - The tier selects a set of predefined aspect-ratio canvases; the one closest to the source aspect ratio is - returned. The source content is later downscaled (never upscaled) and padded into this canvas. - """ - resolution_key = str(self.resolution_tier) - if resolution_key not in _ACTION_RESOLUTION_BINS: - raise ValueError( - f"Unsupported action resolution_tier={self.resolution_tier!r}; " - f"expected one of {sorted(int(k) for k in _ACTION_RESOLUTION_BINS)}." - ) - return VideoProcessor.classify_height_width_bin( - source_height, source_width, ratios=_ACTION_RESOLUTION_BINS[resolution_key] - ) - def validate(self, *, action_dim: int | None = None) -> None: """Validate every action-specific field. Called by [`Cosmos3OmniPipeline.check_inputs`]. @@ -293,7 +267,7 @@ def validate(self, *, action_dim: int | None = None) -> None: """ if self.mode not in ["policy", "forward_dynamics", "inverse_dynamics"]: raise ValueError(f"Unsupported action mode={self.mode!r}; expected one of {sorted(_ACTION_MODES)}.") - if self.chunk_size is None or self.chunk_size < 1: + if self.chunk_size < 1: raise ValueError(f"action `chunk_size` must be >= 1, got {self.chunk_size}.") if self.domain_name not in _EMBODIMENT_TO_DOMAIN_ID: raise ValueError( @@ -307,13 +281,11 @@ def validate(self, *, action_dim: int | None = None) -> None: ) if self.image is not None and self.video is not None: raise ValueError("Provide either `image` or `video` for the action condition, not both.") - if self.mode == "inverse_dynamics": - if self.video is None: - raise ValueError("action mode='inverse_dynamics' requires `video` conditioning.") - else: - if self.image is None and self.video is None: - raise ValueError(f"action mode={self.mode!r} requires `image` or `video` conditioning.") - if self.mode in {_ACTION_MODE_POLICY, _ACTION_MODE_INVERSE_DYNAMICS} and self.raw_action_dim is None: + elif self.image is None and self.video is None: + raise ValueError("`image` and `video` cannot both be None") + if self.mode == "inverse_dynamics" and self.video is None: + raise ValueError("action mode='inverse_dynamics' requires `video` conditioning.") + if self.mode in {"policy", "inverse_dynamics"} and self.raw_action_dim is None: raise ValueError(f"action mode={self.mode!r} requires `raw_action_dim` for output slicing.") if self.mode == "forward_dynamics": if self.raw_actions is None: @@ -624,14 +596,23 @@ def _prepare_action_segment( def _prepare_action_video_conditioning( self, - action: "CosmosActionCondition", + conditioning_clip: Any, + resolution_tier: int, num_frames: int, device: torch.device | str, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor, int, int]: - frames = self.video_processor.preprocess_video(action.conditioning_clip).to(device=device, dtype=dtype) + frames = self.video_processor.preprocess_video(conditioning_clip).to(device=device, dtype=dtype) source_h, source_w = frames.shape[-2:] - target_h, target_w = action.target_size(source_h, source_w) + resolution_key = str(resolution_tier) + if resolution_key not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Unsupported action resolution_tier={resolution_tier!r}; " + f"expected one of {sorted(int(k) for k in _ACTION_RESOLUTION_BINS)}." + ) + target_h, target_w = VideoProcessor.classify_height_width_bin( + source_h, source_w, ratios=_ACTION_RESOLUTION_BINS[resolution_key] + ) if frames.shape[2] < num_frames: frames = torch.cat([frames, frames[:, :, -1:].expand(-1, -1, num_frames - frames.shape[2], -1, -1)], dim=2) @@ -730,8 +711,9 @@ def prepare_latents( # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). if action is not None: target_frames = action.chunk_size + 1 + conditioning_clip = [action.image] if action.image is not None else action.video vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( - action, target_frames, device=device, dtype=dtype + conditioning_clip, action.resolution_tier, target_frames, device=device, dtype=dtype ) if action_mode == _ACTION_MODE_FORWARD_DYNAMICS: vision_condition_frames = [0] @@ -1139,7 +1121,7 @@ def __call__( [`CosmosActionCondition`]), not this argument. num_frames (`int`, *optional*, defaults to `189`): Number of frames to generate. Use `1` for text-to-image; the default produces ≈ 7.9 s at 24 FPS. - Ignored for action runs, where it is derived from `action.num_frames`. + Ignored for action runs, where it is derived from `action.chunk_size + 1`. height (`int`, *optional*, defaults to `720`): Output height in pixels. Ignored for action runs, which size via `action.resolution_tier`. @@ -1247,11 +1229,16 @@ def __call__( action_mode = action.mode if action is not None else None if action is not None: - num_frames = action.num_frames + num_frames = action.chunk_size + 1 # Resolve the padded conditioning canvas from the tier + input aspect *before* tokenization, so the # resolution prompt template matches the canvas the model is actually conditioned on. - probe = self.video_processor.preprocess_video(action.conditioning_clip) - height, width = action.target_size(int(probe.shape[-2]), int(probe.shape[-1])) + conditioning_clip = [action.image] if action.image is not None else action.video + probe = self.video_processor.preprocess_video(conditioning_clip) + source_h, source_w = int(probe.shape[-2]), int(probe.shape[-1]) + resolution_key = str(action.resolution_tier) + height, width = VideoProcessor.classify_height_width_bin( + source_h, source_w, ratios=_ACTION_RESOLUTION_BINS[resolution_key] + ) self._current_timestep = None self._interrupt = False From df5b239e847d21ebb5ed139adb6a5a7139f91db9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 09:18:38 +0000 Subject: [PATCH 20/27] Replace validate wiht post init --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index a882b84a643c..e881fe8d1a1b 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -258,13 +258,8 @@ class CosmosActionCondition: image: Image.Image | np.ndarray | torch.Tensor | None = None video: list | np.ndarray | torch.Tensor | None = None - def validate(self, *, action_dim: int | None = None) -> None: - """Validate every action-specific field. Called by [`Cosmos3OmniPipeline.check_inputs`]. - - Args: - action_dim (`int`, *optional*): - The model's action head width. When provided, `raw_actions` channels are checked against it. - """ + def __post_init__(self) -> None: + """Validate self-contained action fields at construction time.""" if self.mode not in ["policy", "forward_dynamics", "inverse_dynamics"]: raise ValueError(f"Unsupported action mode={self.mode!r}; expected one of {sorted(_ACTION_MODES)}.") if self.chunk_size < 1: @@ -294,10 +289,6 @@ def validate(self, *, action_dim: int | None = None) -> None: raise ValueError(f"`raw_actions` must have shape [T, D], got {tuple(self.raw_actions.shape)}.") if self.raw_actions.shape[0] < 1: raise ValueError("action mode='forward_dynamics' requires at least one action token.") - if action_dim is not None and self.raw_actions.shape[-1] > action_dim: - raise ValueError( - f"Cosmos3 action dimension {self.raw_actions.shape[-1]} exceeds model action_dim={action_dim}." - ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents @@ -897,7 +888,7 @@ def check_inputs( ) if action is not None: - # API-conflict checks live here; all action-field validation is delegated to action.validate(). + # API-conflict + model-dependent checks live here. if image is not None or video is not None: raise ValueError( "Pass action conditioning via `action.image` / `action.video`, not the top-level " @@ -905,7 +896,12 @@ def check_inputs( ) if not getattr(self.transformer.config, "action_gen", False): raise ValueError("`action` requires a transformer trained with action_gen=True.") - action.validate(action_dim=self.transformer.action_dim) + if action.mode == "forward_dynamics" and action.raw_actions is not None: + if action.raw_actions.shape[-1] > self.transformer.config.action_dim: + raise ValueError( + f"Cosmos3 action dimension {action.raw_actions.shape[-1]} exceeds model action_dim=" + f"{self.transformer.config.action_dim}." + ) else: sf = int(self.vae.config.scale_factor_spatial) if height % sf != 0 or width % sf != 0: From ad562935fb0ad7049b4ad51eaafdfd77adacbbf1 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 09:41:04 +0000 Subject: [PATCH 21/27] Set height/width/num_frames to None, raise error if set ofr action --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 72 +++++++++---------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index e881fe8d1a1b..2e8d07edf19f 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -648,9 +648,9 @@ def prepare_latents( self, image: torch.Tensor | None = None, video: Any | None = None, - num_frames: int = 189, - height: int = 720, - width: int = 1280, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, fps: float = 24.0, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, @@ -858,9 +858,9 @@ def check_inputs( negative_prompt, image, video, - height: int, - width: int, - num_frames: int, + height: int | None, + width: int | None, + num_frames: int | None, guidance_scale: float, enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], @@ -874,8 +874,6 @@ def check_inputs( raise ValueError( f"`negative_prompt` must be a str, list of str, or None, got {type(negative_prompt).__name__}." ) - if num_frames < 1: - raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") if enable_sound: if self.sound_tokenizer is None: raise ValueError("`enable_sound=True` requires a sound-capable checkpoint with a `sound_tokenizer`.") @@ -889,6 +887,10 @@ def check_inputs( if action is not None: # API-conflict + model-dependent checks live here. + if num_frames is not None: + raise ValueError("`num_frames` has to be None if action is not None") + if height is not None or width is not None: + raise ValueError("`height` and `width` have to be None if action is not None") if image is not None or video is not None: raise ValueError( "Pass action conditioning via `action.image` / `action.video`, not the top-level " @@ -903,6 +905,12 @@ def check_inputs( f"{self.transformer.config.action_dim}." ) else: + if num_frames is None: + raise ValueError("`num_frames` must be provided when `action` is None.") + if height is None or width is None: + raise ValueError("`height` and `width` must be provided when `action` is None.") + if num_frames < 1: + raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") sf = int(self.vae.config.scale_factor_spatial) if height % sf != 0 or width % sf != 0: raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") @@ -1115,15 +1123,16 @@ def __call__( Reserved for video-to-video conditioning. Video-to-video is not yet supported, so this argument is currently accepted but unused. Action conditioning video is provided through `action` (see [`CosmosActionCondition`]), not this argument. - num_frames (`int`, *optional*, defaults to `189`): - Number of frames to generate. Use `1` for text-to-image; the default produces ≈ 7.9 s at 24 FPS. - Ignored for action runs, where it is derived from `action.chunk_size + 1`. - height (`int`, *optional*, defaults to `720`): - Output height in pixels. Ignored for action runs, which size via - `action.resolution_tier`. - width (`int`, *optional*, defaults to `1280`): - Output width in pixels. Ignored for action runs, which size via - `action.resolution_tier`. + num_frames (`int`, *optional*, defaults to `None`): + Number of frames to generate. Use `1` for text-to-image. Defaults to `189` (≈ 7.9 s at 24 FPS) for + non-action modes when omitted (`None`). Must be `None` for action runs, where frame count is derived + from `action.chunk_size + 1`. + height (`int`, *optional*, defaults to `None`): + Output height in pixels. Defaults to `720` for non-action modes when omitted (`None`). Must be `None` + for action runs, which size via `action.resolution_tier`. + width (`int`, *optional*, defaults to `None`): + Output width in pixels. Defaults to `1280` for non-action modes when omitted (`None`). Must be + `None` for action runs, which size via `action.resolution_tier`. fps (`float`, *optional*, defaults to `24.0`): Target frame rate, also injected into the mRoPE temporal modulation and into the duration metadata template. @@ -1149,8 +1158,8 @@ def __call__( Bundles every input for an action-conditioned run (mode, chunk size, embodiment domain, resolution tier, raw actions, and the conditioning image/video), and requires a transformer trained with `action_gen=True`. When set, passing the top-level `image` / `video` arguments raises; `height` / - `width` / `num_frames` are ignored (a warning is logged) since resolution comes from - `action.resolution_tier` and the frame count from `action.chunk_size`. See [`CosmosActionCondition`]. + `width` / `num_frames` must be `None`, since resolution comes from `action.resolution_tier` and + frame count from `action.chunk_size`. See [`CosmosActionCondition`]. output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). @@ -1186,24 +1195,13 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # Action runs size via `action.resolution_tier` and derive their frame count from `action.chunk_size`, so - # `height` / `width` / `num_frames` are ignored. Warn (rather than fail) when the caller passed non-default - # values so a shared config can still be reused across action and non-action runs. - if action is not None: - ignored = [ - name - for name, value, default in ( - ("num_frames", num_frames, 189), - ("height", height, 720), - ("width", width, 1280), - ) - if value != default - ] - if ignored: - logger.warning( - "Action runs derive resolution from `action.resolution_tier` and frame count from " - f"`action.chunk_size`; ignoring {', '.join(ignored)}." - ) + if action is None: + if num_frames is None: + num_frames = 189 + if height is None: + height = 720 + if width is None: + width = 1280 # 1. Check inputs self.check_inputs( From 4fefb133b9d51ead41b251baba247ff556bebcb9 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 09:42:17 +0000 Subject: [PATCH 22/27] Fix action_dim default setting --- src/diffusers/models/transformers/transformer_cosmos3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 9c1d5e1c17c7..67b3a18576ec 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -322,7 +322,7 @@ def __init__( rms_norm_eps: float = 1e-6, rope_scaling: dict | None = None, rope_theta: float = 5000000.0, - action_dim: int = 32, + action_dim: int | None = None, action_gen: bool = False, num_embodiment_domains: int = 32, sound_dim: int | None = None, @@ -371,6 +371,8 @@ def __init__( self.action_dim = action_dim self.num_embodiment_domains = num_embodiment_domains if action_gen: + if self.action_dim is None: + raise ValueError("`action_dim` must be provided when `action_gen=True`.") self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) self.action_proj_out = DomainAwareLinear(hidden_size, self.action_dim, self.num_embodiment_domains) self.action_modality_embed = nn.Parameter(torch.zeros(hidden_size)) From 01a8c46ac15c082cace4ff6a6a3cd9a4920c4fdd Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 09:46:20 +0000 Subject: [PATCH 23/27] Remove video argument before v2v is added --- .../pipelines/cosmos/pipeline_cosmos3_omni.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 2e8d07edf19f..bfcb4a7e7084 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -214,8 +214,8 @@ class Cosmos3OmniPipelineOutput(BaseOutput): class CosmosActionCondition: """Groups every input required for a Cosmos 3 action-conditioned generation task. - Pass this to [`Cosmos3OmniPipeline.__call__`] via the `action` argument instead of the top-level `image` / `video` - / `height` / `width` arguments, which are reserved for t2v, i2v runs. + Pass this to [`Cosmos3OmniPipeline.__call__`] via the `action` argument instead of the top-level `image` / + `height` / `width` arguments, which are reserved for t2v, i2v runs. Attributes: mode (`str`): @@ -647,7 +647,6 @@ def _remove_action_video_padding_from_latent( def prepare_latents( self, image: torch.Tensor | None = None, - video: Any | None = None, num_frames: int | None = None, height: int | None = None, width: int | None = None, @@ -857,7 +856,6 @@ def check_inputs( prompt, negative_prompt, image, - video, height: int | None, width: int | None, num_frames: int | None, @@ -891,10 +889,9 @@ def check_inputs( raise ValueError("`num_frames` has to be None if action is not None") if height is not None or width is not None: raise ValueError("`height` and `width` have to be None if action is not None") - if image is not None or video is not None: + if image is not None: raise ValueError( - "Pass action conditioning via `action.image` / `action.video`, not the top-level " - "`image` / `video` arguments." + "Pass action conditioning via `action.image` / `action.video`, not the top-level `image` argument." ) if not getattr(self.transformer.config, "action_gen", False): raise ValueError("`action` requires a transformer trained with action_gen=True.") @@ -1081,7 +1078,6 @@ def __call__( prompt: str | list[str], negative_prompt: str | list[str] | None = None, image: torch.Tensor | None = None, - video: Any | None = None, num_frames: int = 189, height: int = 720, width: int = 1280, @@ -1119,10 +1115,6 @@ def __call__( image (`torch.Tensor` or `PIL.Image.Image`, *optional*): Optional conditioning frame for image-to-video. The pipeline anchors frame 0 to this image and denoises the remaining frames. Ignored when `num_frames == 1`. Not used for action runs (pass `action` instead). - video (`list`, `np.ndarray`, or `torch.Tensor`, *optional*): - Reserved for video-to-video conditioning. Video-to-video is not yet supported, so this argument is - currently accepted but unused. Action conditioning video is provided through `action` (see - [`CosmosActionCondition`]), not this argument. num_frames (`int`, *optional*, defaults to `None`): Number of frames to generate. Use `1` for text-to-image. Defaults to `189` (≈ 7.9 s at 24 FPS) for non-action modes when omitted (`None`). Must be `None` for action runs, where frame count is derived @@ -1157,7 +1149,7 @@ def __call__( action (`CosmosActionCondition`, *optional*): Bundles every input for an action-conditioned run (mode, chunk size, embodiment domain, resolution tier, raw actions, and the conditioning image/video), and requires a transformer trained with - `action_gen=True`. When set, passing the top-level `image` / `video` arguments raises; `height` / + `action_gen=True`. When set, passing the top-level `image` argument raises; `height` / `width` / `num_frames` must be `None`, since resolution comes from `action.resolution_tier` and frame count from `action.chunk_size`. See [`CosmosActionCondition`]. output_type (`str`, *optional*, defaults to `"pil"`): @@ -1208,7 +1200,6 @@ def __call__( prompt, negative_prompt, image, - video, height, width, num_frames, @@ -1293,7 +1284,6 @@ def __call__( action_condition_frame_indexes, ) = self.prepare_latents( image=image, - video=video, num_frames=num_frames, height=height, width=width, From ef8208a686f13eb95aab2d75a3989c191f1c07b3 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 09:48:27 +0000 Subject: [PATCH 24/27] Fix None args --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index bfcb4a7e7084..850675fc9e2c 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -1078,9 +1078,9 @@ def __call__( prompt: str | list[str], negative_prompt: str | list[str] | None = None, image: torch.Tensor | None = None, - num_frames: int = 189, - height: int = 720, - width: int = 1280, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, fps: float = 24.0, num_inference_steps: int = 35, guidance_scale: float = 6.0, From b5fa4f779a8b19b355ed9d5308e332abcae46ebb Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 15:41:42 +0000 Subject: [PATCH 25/27] Add _EMBODIMENT_TO_RAW_ACTION_DIM mapping --- docs/source/en/api/pipelines/cosmos3.md | 2 - examples/cosmos3/inference_cosmos3.py | 2 - .../pipelines/cosmos/pipeline_cosmos3_omni.py | 62 ++++++++++++++++--- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 2b5ca69043b5..ab8cd34a9b3d 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -495,7 +495,6 @@ result = pipe( mode="policy", chunk_size=16, domain_name="bridge_orig_lerobot", - raw_action_dim=10, resolution_tier=480, video=video, ), @@ -541,7 +540,6 @@ result = pipe( mode="policy", chunk_size=16, domain_name="bridge_orig_lerobot", - raw_action_dim=10, resolution_tier=480, video=video, ), diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index ad0e5affa3ae..9f9da0a5ea64 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -114,7 +114,6 @@ def main(): parser.add_argument("--action-path", default=None, help="JSON action path for forward_dynamics mode.") parser.add_argument("--action-chunk-size", type=int, default=None, help="Number of action tokens to generate/use.") parser.add_argument("--domain-name", default=None, help="Cosmos3 action embodiment domain name.") - parser.add_argument("--raw-action-dim", type=int, default=None, help="Slice predicted action output to this size.") parser.add_argument( "--resolution-tier", type=int, @@ -186,7 +185,6 @@ def main(): mode=args.action_mode, chunk_size=args.action_chunk_size, domain_name=args.domain_name, - raw_action_dim=args.raw_action_dim, resolution_tier=args.resolution_tier, raw_actions=raw_actions, video=video, diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 850675fc9e2c..8f25fa7a3352 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -190,6 +190,38 @@ def get_3d_mrope_ids_vae_tokens( "fractal": 20, } +# Canonical (unpadded) action width per embodiment. The width is fixed per embodiment and resolved from +# `domain_name` via this table. +# +# Widths come from the Cosmos 3 unified action representation (paper Fig. 3), which composes a few shared geometric +# building blocks: a 9D pose (3D translation + 6D rotation, the over-parameterized rotation of Zhou et al. 2019), a +# 1D grasp state (gripper open/close), and a 15D grasp state (fingertip positions, 3D x 5 fingers). Each embodiment +# concatenates these blocks, so its width is just their sum. For example: +# * av / camera_pose -> 9 : a single ego/effector 9D pose. +# * bridge / droid / fractal / umi -> 10 : one arm = 9D effector pose + 1D gripper. +# * robomind-franka-dual -> 20 : two arms = 2 x (9D + 1D). +# * agibotworld / agibot_gear_gripper -> 29 : humanoid = 9D ego + 2 x (9D arm + 1D gripper). +# * galbot -> 30 : humanoid-style stack with an extra pose block. +# +# TODO: support the configuration-dependent domains (`libero`, `hand_pose`), whose width is not fixed per embodiment +# (it depends on the dataset's rotation/keypoint configuration) and so is absent here. +_EMBODIMENT_TO_RAW_ACTION_DIM = { + "av": 9, + "camera_pose": 9, + "pusht": 2, + "umi": 10, + "bridge_orig_lerobot": 10, + "droid_lerobot": 10, + "robomind-franka": 10, + "robomind-franka-dual": 20, + "robomind-ur": 10, + "galbot": 30, + "agibotworld": 29, + "agibot_gear_gripper": 29, + "agibot_gear_gripper_ext": 29, + "fractal": 10, +} + @dataclass class Cosmos3OmniPipelineOutput(BaseOutput): @@ -227,7 +259,8 @@ class CosmosActionCondition: frames. domain_name (`str`): Embodiment domain selecting the domain-aware action projection weights. Must be one of the registered - Cosmos 3 embodiment domains. + Cosmos 3 embodiment domains. It also fixes the unpadded action width used to slice predicted actions, + resolved internally from this name (see `_EMBODIMENT_TO_RAW_ACTION_DIM`). resolution_tier (`int`, defaults to `480`): Action conditioning resolution *tier* (one of `256`, `480`, `704`, `720`). The tier picks a predefined canvas whose aspect ratio is closest to the input; the input is downscaled (never upscaled) and padded @@ -235,9 +268,6 @@ class CosmosActionCondition: tier to the input's native resolution: a lower tier discards detail, while a higher tier adds no resolution (no upscaling), wastes compute on padding, and is a train/inference mismatch that can hurt quality. - raw_action_dim (`int`, *optional*): - Number of meaningful (unpadded) action channels to keep when slicing predicted actions. Required for - `"policy"` and `"inverse_dynamics"`. raw_actions (`torch.Tensor`, *optional*): Raw domain action vectors of shape `[T, raw_action_dim]` driving `"forward_dynamics"`. Sequences shorter than `chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's @@ -253,11 +283,11 @@ class CosmosActionCondition: chunk_size: int domain_name: str resolution_tier: int = 480 - raw_action_dim: int | None = None raw_actions: torch.Tensor | None = None image: Image.Image | np.ndarray | torch.Tensor | None = None video: list | np.ndarray | torch.Tensor | None = None + def __post_init__(self) -> None: """Validate self-contained action fields at construction time.""" if self.mode not in ["policy", "forward_dynamics", "inverse_dynamics"]: @@ -280,8 +310,15 @@ def __post_init__(self) -> None: raise ValueError("`image` and `video` cannot both be None") if self.mode == "inverse_dynamics" and self.video is None: raise ValueError("action mode='inverse_dynamics' requires `video` conditioning.") - if self.mode in {"policy", "inverse_dynamics"} and self.raw_action_dim is None: - raise ValueError(f"action mode={self.mode!r} requires `raw_action_dim` for output slicing.") + # Resolve the unpadded action width from the embodiment: the width is fixed per embodiment and looked up from + # the table. Domains absent from the table are unsupported for action inference in all modes. + # TODO: support the configuration-dependent domains (libero, hand_pose), whose width is set per-dataset. + if self.domain_name not in _EMBODIMENT_TO_RAW_ACTION_DIM: + raise ValueError( + f"domain_name={self.domain_name!r} is not supported for action inference: it has no canonical action " + f"width. Supported domains: {sorted(_EMBODIMENT_TO_RAW_ACTION_DIM)}." + ) + self.raw_action_dim = _EMBODIMENT_TO_RAW_ACTION_DIM[self.domain_name] if self.mode == "forward_dynamics": if self.raw_actions is None: raise ValueError("action mode='forward_dynamics' requires `raw_actions`.") @@ -289,6 +326,12 @@ def __post_init__(self) -> None: raise ValueError(f"`raw_actions` must have shape [T, D], got {tuple(self.raw_actions.shape)}.") if self.raw_actions.shape[0] < 1: raise ValueError("action mode='forward_dynamics' requires at least one action token.") + # The supplied action width must match the embodiment's expected width. + if self.raw_actions.shape[1] != self.raw_action_dim: + raise ValueError( + f"`raw_actions` width ({self.raw_actions.shape[1]}) does not match the expected action width " + f"({self.raw_action_dim}) for domain_name={self.domain_name!r}." + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents @@ -693,6 +736,11 @@ def prepare_latents( raw_action_dim_resolved: int | None = ( int(action.raw_action_dim) if action is not None and action.raw_action_dim is not None else None ) + if raw_action_dim_resolved is not None and raw_action_dim_resolved > self.transformer.config.action_dim: + raise ValueError( + f"raw_action_dim={raw_action_dim_resolved} exceeds the model's trained action_dim=" + f"{self.transformer.config.action_dim}; this checkpoint cannot represent that action width." + ) action_condition_frames: list[int] = [] action_condition_frame_indexes: list[int] = [] action_image_size: torch.Tensor | None = None From 35530b5bd5521d7029d5ca568da4feddc5447743 Mon Sep 17 00:00:00 2001 From: Yuliya Zhautouskaya Date: Tue, 2 Jun 2026 15:51:51 +0000 Subject: [PATCH 26/27] Remove --raw-action-dim from README.md --- examples/cosmos3/README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/cosmos3/README.md b/examples/cosmos3/README.md index 02f609e2ac0c..96d346791eec 100644 --- a/examples/cosmos3/README.md +++ b/examples/cosmos3/README.md @@ -89,7 +89,6 @@ python examples/cosmos3/inference_cosmos3.py \ --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ --action-mode inverse_dynamics \ --action-chunk-size 16 \ - --raw-action-dim 10 \ --domain-name bridge_orig_lerobot \ --resolution-tier 480 --fps 5 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ @@ -105,7 +104,6 @@ python examples/cosmos3/inference_cosmos3.py \ --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ --action-mode inverse_dynamics \ --action-chunk-size 60 \ - --raw-action-dim 9 \ --domain-name av \ --resolution-tier 480 --fps 10 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ @@ -121,7 +119,6 @@ python examples/cosmos3/inference_cosmos3.py \ --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ --action-mode policy \ --action-chunk-size 16 \ - --raw-action-dim 10 \ --domain-name bridge_orig_lerobot \ --resolution-tier 480 --fps 5 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ @@ -137,7 +134,6 @@ python examples/cosmos3/inference_cosmos3.py \ --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ --action-mode policy \ --action-chunk-size 60 \ - --raw-action-dim 9 \ --domain-name av \ --resolution-tier 480 --fps 10 \ --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 5.0 --seed 0 \ @@ -165,7 +161,6 @@ Pick the tier that matches the native resolution of your conditioning input (`48 | `--action-path` | `None` | URL or local JSON action path for `forward_dynamics`. | | `--action-chunk-size` | `None` | Number of action tokens. Action runs generate/use `action_chunk_size + 1` video frames. | | `--domain-name` | `None` | Action embodiment domain, for example `bridge_orig_lerobot` or `av`. | -| `--raw-action-dim` | `None` | Slice predicted action output to the unpadded action dimension. Required for `inverse_dynamics` and `policy`. | | `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1`. | | `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. | | `--output` | `.` | Directory to write `sample.jpg` or `sample.mp4`. | From 6375b8db678d3c01e79d4a3c98d7d5cad2de767e Mon Sep 17 00:00:00 2001 From: Atharva Joshi Date: Tue, 2 Jun 2026 13:00:53 -0700 Subject: [PATCH 27/27] Added prompt upsampler docs and examples --- docs/source/en/api/pipelines/cosmos3.md | 344 ++++++++---------------- 1 file changed, 119 insertions(+), 225 deletions(-) diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index ab8cd34a9b3d..0d086e878f75 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -43,161 +43,155 @@ Two checkpoints are released on the Hub — [`nvidia/Cosmos3-Nano`](https://hugg > [!TIP] > Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. -## Text-to-image +## Prompt upsampling + +Cosmos 3 was trained on long, highly descriptive captions. For optimal quality, short text prompts should be **upsampled into a specific JSON structure** before they are passed to the pipeline. The upsampler lives in the [cosmos-framework](https://github.com/NVIDIA/cosmos-framework) package. + +Start from a short, plain-text prompt and save it to `assets/prompt.txt`. For the text-to-video example below, the original prompt is *"A robotic arm is cleaning a plate in a kitchen"*: + +```bash +mkdir -p assets +echo "A robotic arm is cleaning a plate in a kitchen" > assets/prompt.txt +``` + +Then install the framework and run the upsampler. The example below upsamples for text-to-video using Opus-4.6: + +```bash +git clone https://github.com/NVIDIA/cosmos-framework.git packages/cosmos-framework +pip install -e packages/cosmos-framework + +export PROMPT_UPSAMPLER_ENDPOINT_URL="https://api.anthropic.com/v1/" +export PROMPT_UPSAMPLER_MODEL_NAME="claude-opus-4-6" +export PROMPT_UPSAMPLER_API_TOKEN="" + +python -m cosmos_framework.inference.prompt_upsampling \ + --input assets/prompt.txt \ + --output assets/example_t2v_prompt.json \ + --mode text2video \ + --endpoint-url "${PROMPT_UPSAMPLER_ENDPOINT_URL}" \ + --model "${PROMPT_UPSAMPLER_MODEL_NAME}" \ + --api-token "${PROMPT_UPSAMPLER_API_TOKEN}" \ + --resolution 720 \ + --aspect-ratio "16,9" +``` -Single-frame generation. The model is conditioned only on the text prompt; pass `num_frames=1`. +Switch `--mode` to match the workflow you are targeting (`text2image`, `text2video`, `image2video`). The command writes the upsampled prompt(s) to the `--output` file as a JSON array (one object per non-empty line in `--input`); pass a `.jsonl` path instead to get one JSON object per line. For `image2video`, you must also supply the conditioning image via `--image-url` (a URL or local path) or `--image-list` (one image per prompt). + +A pre-upsampled positive prompt (`assets/example_t2v_prompt.json`) and negative prompt (`assets/negative_prompt.json`) are provided for convenience, and are used by the generation examples below. The examples load these JSON files and pass them to the pipeline as JSON strings via `json.dumps(...)`. + +## Text-to-video + +Multi-frame generation conditioned on text alone. Pick `num_frames` based on the target duration — the default `num_frames=189` produces ≈ 7.9 s at 24 FPS. The prompt and negative prompt are read from the JSON-upsampled files described in [Prompt upsampling](#prompt-upsampling). ```python +import json import torch from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video + +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=10.0) -prompt = ( - "A medium shot of a modern robotics research laboratory with white walls and a gray floor. " - "A robotic arm with a metallic finish is mounted on a clean white workbench, its gripper positioned " - "above a row of small colored objects. A laptop and neatly arranged tools sit beside the robot. " - "A large monitor on the wall behind displays a software interface. The scene is brightly lit by " - "overhead fluorescent lights." +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + num_frames=189, + height=720, + width=1280, + num_inference_steps=35, + guidance_scale=6.0, + fps=24.0, ) - -result = pipe(prompt=prompt, num_frames=1, height=720, width=1280) -result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) ``` ```python +import json import torch from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video + +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=10.0) -prompt = ( - "A medium shot of a modern robotics research laboratory with white walls and a gray floor. " - "A robotic arm with a metallic finish is mounted on a clean white workbench, its gripper positioned " - "above a row of small colored objects. A laptop and neatly arranged tools sit beside the robot. " - "A large monitor on the wall behind displays a software interface. The scene is brightly lit by " - "overhead fluorescent lights." +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + num_frames=189, + height=720, + width=1280, + num_inference_steps=35, + guidance_scale=6.0, + fps=24.0, ) - -result = pipe(prompt=prompt, num_frames=1, height=720, width=1280) -result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) ``` -## Text-to-video +## Text-to-image -Multi-frame generation conditioned on text alone. Pick `num_frames` based on the target duration — the default `num_frames=189` produces ≈ 7.9 s at 24 FPS. +Single-frame generation. The model is conditioned only on the text prompt; pass `num_frames=1`. Upsample with `--mode text2image` to produce the JSON prompt. ```python +import json import torch from diffusers import Cosmos3OmniPipeline -from diffusers.utils import export_to_video + +# JSON-upsampled prompt (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2i_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting." -) - -# Recommended quality-control negative prompt for text-to-video. -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - -result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=189, - height=720, - width=1280, - fps=24.0, -) -# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). -export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) +result = pipe(prompt=json.dumps(json_prompt), num_frames=1, height=720, width=1280) +result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) ``` ```python +import json import torch from diffusers import Cosmos3OmniPipeline -from diffusers.utils import export_to_video + +# JSON-upsampled prompt (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2i_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting." -) - -# Recommended quality-control negative prompt for text-to-video. -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - -result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=189, - height=720, - width=1280, - fps=24.0, -) -# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). -export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) +result = pipe(prompt=json.dumps(json_prompt), num_frames=1, height=720, width=1280) +result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) ``` @@ -205,16 +199,21 @@ export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) ## Image-to-video -Pass a conditioning image via `image=`. The pipeline anchors frame 0 to the supplied image and denoises the rest. +Pass a conditioning image via `image=`. The pipeline anchors frame 0 to the supplied image and denoises the rest. Upsample with `--mode image2video` to produce the JSON prompt. ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import export_to_video, load_image +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_i2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt_i2v.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -222,42 +221,10 @@ pipe = Cosmos3OmniPipeline.from_pretrained( image = load_image( "https://github.com/nvidia-cosmos/cosmos-dependencies/releases/download/assets/robot_153.jpg" ) -prompt = ( - "The video opens with a view of a testing environment, characterized by a large wooden table at the " - "center. On this table, two robot arms are positioned at opposite ends, with the left arm closer to " - "the camera and the right arm further away. Between the hands lies a dark wooden shelf with a red " - "spherical object on its top rack, likely serving as a platform or obstacle. In the background, " - "various pieces of equipment, including a tripod, a chair, are visible. A person wearing a blue " - "jacket and black pants stands near the center of the room, observing the experiment, with a static " - "hand position throughout. The floor is tiled with a patterned design, and additional items like a " - "small robot figure and some cables can be seen scattered around the space. As the video progresses, " - "the right robotic hand extends outward, moving from its initial position towards the red spherical " - "object on the shelf. The hand then picks up the object and places it on the lowest rack of the " - "shelf, completing a smooth, deliberate manipulation. The left robotic hand remains stationary " - "throughout the sequence. No new objects appear in the video; all existing elements maintain their " - "positions except for the movement of the right robotic hand. The scene concludes with the right " - "robotic hand returning to its initial position, while the left hand continues to rest on the table. " - "The overall environment remains unchanged, with the focus remaining on the interaction between the " - "robotic hands and the wooden block, highlighting precise control during the demonstration." -) - -# Recommended quality-control negative prompt for image-to-video. -negative_prompt = ( - "The video captures a series of frames showing macroblocking artifacts, chromatic aberration, " - "high-frequency noise, and rolling shutter distortion. It includes static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, bit-depth compression artifacts, color banding, unnatural transitions, " - "outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual " - "noise, and flickering. Avoid moiré patterns, edge halos, and temporal aliasing. Furthermore, the content " - "defies common sense, generating illogical scenarios, nonsensical entities, absurd character behaviors, " - "and conceptual paradoxes that violate basic human reasoning and everyday reality. The video looks like a " - "surreal or glitchy hallucination. Overall, the video is of poor quality." -) result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), image=image, num_frames=189, height=720, @@ -272,10 +239,15 @@ export_to_video(result.video, "cosmos3_i2v.mp4", fps=24, macro_block_size=1) ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import export_to_video, load_image +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_i2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt_i2v.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -283,42 +255,10 @@ pipe = Cosmos3OmniPipeline.from_pretrained( image = load_image( "https://github.com/nvidia-cosmos/cosmos-dependencies/releases/download/assets/robot_153.jpg" ) -prompt = ( - "The video opens with a view of a testing environment, characterized by a large wooden table at the " - "center. On this table, two robot arms are positioned at opposite ends, with the left arm closer to " - "the camera and the right arm further away. Between the hands lies a dark wooden shelf with a red " - "spherical object on its top rack, likely serving as a platform or obstacle. In the background, " - "various pieces of equipment, including a tripod, a chair, are visible. A person wearing a blue " - "jacket and black pants stands near the center of the room, observing the experiment, with a static " - "hand position throughout. The floor is tiled with a patterned design, and additional items like a " - "small robot figure and some cables can be seen scattered around the space. As the video progresses, " - "the right robotic hand extends outward, moving from its initial position towards the red spherical " - "object on the shelf. The hand then picks up the object and places it on the lowest rack of the " - "shelf, completing a smooth, deliberate manipulation. The left robotic hand remains stationary " - "throughout the sequence. No new objects appear in the video; all existing elements maintain their " - "positions except for the movement of the right robotic hand. The scene concludes with the right " - "robotic hand returning to its initial position, while the left hand continues to rest on the table. " - "The overall environment remains unchanged, with the focus remaining on the interaction between the " - "robotic hands and the wooden block, highlighting precise control during the demonstration." -) - -# Recommended quality-control negative prompt for image-to-video. -negative_prompt = ( - "The video captures a series of frames showing macroblocking artifacts, chromatic aberration, " - "high-frequency noise, and rolling shutter distortion. It includes static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, bit-depth compression artifacts, color banding, unnatural transitions, " - "outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual " - "noise, and flickering. Avoid moiré patterns, edge halos, and temporal aliasing. Furthermore, the content " - "defies common sense, generating illogical scenarios, nonsensical entities, absurd character behaviors, " - "and conceptual paradoxes that violate basic human reasoning and everyday reality. The video looks like a " - "surreal or glitchy hallucination. Overall, the video is of poor quality." -) result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), image=image, num_frames=189, height=720, @@ -342,45 +282,22 @@ This is the same call as the text-to-video example above with `enable_sound=True ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import encode_video +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_sound_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting. Audio description: the soft whir of servo motors, gentle " - "thuds as fruits land in the plastic bag, the rustle of the bag settling in the shopping cart, and " - "a faint refrigeration hum in the background." -) - -# Recommended quality-control negative prompt (same as text-to-video). -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), num_frames=189, height=720, width=1280, @@ -401,45 +318,22 @@ encode_video( ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import encode_video +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_sound_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting. Audio description: the soft whir of servo motors, gentle " - "thuds as fruits land in the plastic bag, the rustle of the bag settling in the shopping cart, and " - "a faint refrigeration hum in the background." -) - -# Recommended quality-control negative prompt (same as text-to-video). -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), num_frames=189, height=720, width=1280,