diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index ce26ee0c36ef..7ce1ff4f58cf 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -43,161 +43,159 @@ Two checkpoints are released on the Hub — [`nvidia/Cosmos3-Nano`](https://hugg > [!TIP] > Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. -## Text-to-image +## Prompt upsampling + +Cosmos 3 was trained on long, highly descriptive captions. For optimal quality, short text prompts should be **upsampled into a specific JSON structure** before they are passed to the pipeline. The upsampler lives in the [cosmos-framework](https://github.com/NVIDIA/cosmos-framework) package. + +Start from a short, plain-text prompt and save it to `assets/prompt.txt`. For the text-to-video example below, the original prompt is *"A robotic arm is cleaning a plate in a kitchen"*: + +```bash +mkdir -p assets +echo "A robotic arm is cleaning a plate in a kitchen" > assets/prompt.txt +``` + +Then install the framework and run the upsampler. The example below upsamples for text-to-video using Opus-4.6: + +```bash +git clone https://github.com/NVIDIA/cosmos-framework.git packages/cosmos-framework +pip install -e packages/cosmos-framework + +export PROMPT_UPSAMPLER_ENDPOINT_URL="https://api.anthropic.com/v1/" +export PROMPT_UPSAMPLER_MODEL_NAME="claude-opus-4-6" +export PROMPT_UPSAMPLER_API_TOKEN="" + +python -m cosmos_framework.inference.prompt_upsampling \ + --input assets/prompt.txt \ + --output assets/example_t2v_prompt.json \ + --mode text2video \ + --endpoint-url "${PROMPT_UPSAMPLER_ENDPOINT_URL}" \ + --model "${PROMPT_UPSAMPLER_MODEL_NAME}" \ + --api-token "${PROMPT_UPSAMPLER_API_TOKEN}" \ + --resolution 720 \ + --aspect-ratio "16,9" +``` + +Switch `--mode` to match the workflow you are targeting (`text2image`, `text2video`, `image2video`). The command writes the upsampled prompt(s) to the `--output` file as a JSON array (one object per non-empty line in `--input`); pass a `.jsonl` path instead to get one JSON object per line. For `image2video`, you must also supply the conditioning image via `--image-url` (a URL or local path) or `--image-list` (one image per prompt). + +A pre-upsampled positive prompt (`assets/example_t2v_prompt.json`) and negative prompt (`assets/negative_prompt.json`) are provided for convenience, and are used by the generation examples below. The examples load these JSON files and pass them to the pipeline as JSON strings via `json.dumps(...)`. -Single-frame generation. The model is conditioned only on the text prompt; pass `num_frames=1`. +## Text-to-video + +Multi-frame generation conditioned on text alone. Pick `num_frames` based on the target duration — the default `num_frames=189` produces ≈ 7.9 s at 24 FPS. The prompt and negative prompt are read from the JSON-upsampled files described in [Prompt upsampling](#prompt-upsampling). ```python +import json import torch from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video + +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) - -prompt = ( - "A medium shot of a modern robotics research laboratory with white walls and a gray floor. " - "A robotic arm with a metallic finish is mounted on a clean white workbench, its gripper positioned " - "above a row of small colored objects. A laptop and neatly arranged tools sit beside the robot. " - "A large monitor on the wall behind displays a software interface. The scene is brightly lit by " - "overhead fluorescent lights." +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False ) -result = pipe(prompt=prompt, num_frames=1, height=720, width=1280) -result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + num_frames=189, + height=720, + width=1280, + num_inference_steps=35, + guidance_scale=6.0, + fps=24.0, +) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) ``` ```python +import json import torch from diffusers import Cosmos3OmniPipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video + +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) - -prompt = ( - "A medium shot of a modern robotics research laboratory with white walls and a gray floor. " - "A robotic arm with a metallic finish is mounted on a clean white workbench, its gripper positioned " - "above a row of small colored objects. A laptop and neatly arranged tools sit beside the robot. " - "A large monitor on the wall behind displays a software interface. The scene is brightly lit by " - "overhead fluorescent lights." +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False ) -result = pipe(prompt=prompt, num_frames=1, height=720, width=1280) -result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) +result = pipe( + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), + num_frames=189, + height=720, + width=1280, + num_inference_steps=35, + guidance_scale=6.0, + fps=24.0, +) +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) ``` -## Text-to-video +## Text-to-image -Multi-frame generation conditioned on text alone. Pick `num_frames` based on the target duration — the default `num_frames=189` produces ≈ 7.9 s at 24 FPS. +Single-frame generation. The model is conditioned only on the text prompt; pass `num_frames=1`. Upsample with `--mode text2image` to produce the JSON prompt. ```python +import json import torch from diffusers import Cosmos3OmniPipeline -from diffusers.utils import export_to_video + +# JSON-upsampled prompt (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2i_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting." -) - -# Recommended quality-control negative prompt for text-to-video. -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - -result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=189, - height=720, - width=1280, - fps=24.0, -) -# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). -export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) +result = pipe(prompt=json.dumps(json_prompt), num_frames=1, height=720, width=1280) +result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) ``` ```python +import json import torch from diffusers import Cosmos3OmniPipeline -from diffusers.utils import export_to_video + +# JSON-upsampled prompt (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2i_prompt.json")) pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting." -) - -# Recommended quality-control negative prompt for text-to-video. -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - -result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=189, - height=720, - width=1280, - fps=24.0, -) -# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). -export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) +result = pipe(prompt=json.dumps(json_prompt), num_frames=1, height=720, width=1280) +result.video[0].save("cosmos3_t2i.jpg", format="JPEG", quality=85) ``` @@ -205,16 +203,21 @@ export_to_video(result.video, "cosmos3_t2v.mp4", fps=24, macro_block_size=1) ## Image-to-video -Pass a conditioning image via `image=`. The pipeline anchors frame 0 to the supplied image and denoises the rest. +Pass a conditioning image via `image=`. The pipeline anchors frame 0 to the supplied image and denoises the rest. Upsample with `--mode image2video` to produce the JSON prompt. ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import export_to_video, load_image +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_i2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt_i2v.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -222,42 +225,10 @@ pipe = Cosmos3OmniPipeline.from_pretrained( image = load_image( "https://github.com/nvidia-cosmos/cosmos-dependencies/releases/download/assets/robot_153.jpg" ) -prompt = ( - "The video opens with a view of a testing environment, characterized by a large wooden table at the " - "center. On this table, two robot arms are positioned at opposite ends, with the left arm closer to " - "the camera and the right arm further away. Between the hands lies a dark wooden shelf with a red " - "spherical object on its top rack, likely serving as a platform or obstacle. In the background, " - "various pieces of equipment, including a tripod, a chair, are visible. A person wearing a blue " - "jacket and black pants stands near the center of the room, observing the experiment, with a static " - "hand position throughout. The floor is tiled with a patterned design, and additional items like a " - "small robot figure and some cables can be seen scattered around the space. As the video progresses, " - "the right robotic hand extends outward, moving from its initial position towards the red spherical " - "object on the shelf. The hand then picks up the object and places it on the lowest rack of the " - "shelf, completing a smooth, deliberate manipulation. The left robotic hand remains stationary " - "throughout the sequence. No new objects appear in the video; all existing elements maintain their " - "positions except for the movement of the right robotic hand. The scene concludes with the right " - "robotic hand returning to its initial position, while the left hand continues to rest on the table. " - "The overall environment remains unchanged, with the focus remaining on the interaction between the " - "robotic hands and the wooden block, highlighting precise control during the demonstration." -) - -# Recommended quality-control negative prompt for image-to-video. -negative_prompt = ( - "The video captures a series of frames showing macroblocking artifacts, chromatic aberration, " - "high-frequency noise, and rolling shutter distortion. It includes static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, bit-depth compression artifacts, color banding, unnatural transitions, " - "outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual " - "noise, and flickering. Avoid moiré patterns, edge halos, and temporal aliasing. Furthermore, the content " - "defies common sense, generating illogical scenarios, nonsensical entities, absurd character behaviors, " - "and conceptual paradoxes that violate basic human reasoning and everyday reality. The video looks like a " - "surreal or glitchy hallucination. Overall, the video is of poor quality." -) result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), image=image, num_frames=189, height=720, @@ -272,10 +243,15 @@ export_to_video(result.video, "cosmos3_i2v.mp4", fps=24, macro_block_size=1) ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import export_to_video, load_image +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_i2v_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt_i2v.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) @@ -283,42 +259,10 @@ pipe = Cosmos3OmniPipeline.from_pretrained( image = load_image( "https://github.com/nvidia-cosmos/cosmos-dependencies/releases/download/assets/robot_153.jpg" ) -prompt = ( - "The video opens with a view of a testing environment, characterized by a large wooden table at the " - "center. On this table, two robot arms are positioned at opposite ends, with the left arm closer to " - "the camera and the right arm further away. Between the hands lies a dark wooden shelf with a red " - "spherical object on its top rack, likely serving as a platform or obstacle. In the background, " - "various pieces of equipment, including a tripod, a chair, are visible. A person wearing a blue " - "jacket and black pants stands near the center of the room, observing the experiment, with a static " - "hand position throughout. The floor is tiled with a patterned design, and additional items like a " - "small robot figure and some cables can be seen scattered around the space. As the video progresses, " - "the right robotic hand extends outward, moving from its initial position towards the red spherical " - "object on the shelf. The hand then picks up the object and places it on the lowest rack of the " - "shelf, completing a smooth, deliberate manipulation. The left robotic hand remains stationary " - "throughout the sequence. No new objects appear in the video; all existing elements maintain their " - "positions except for the movement of the right robotic hand. The scene concludes with the right " - "robotic hand returning to its initial position, while the left hand continues to rest on the table. " - "The overall environment remains unchanged, with the focus remaining on the interaction between the " - "robotic hands and the wooden block, highlighting precise control during the demonstration." -) - -# Recommended quality-control negative prompt for image-to-video. -negative_prompt = ( - "The video captures a series of frames showing macroblocking artifacts, chromatic aberration, " - "high-frequency noise, and rolling shutter distortion. It includes static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, bit-depth compression artifacts, color banding, unnatural transitions, " - "outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual " - "noise, and flickering. Avoid moiré patterns, edge halos, and temporal aliasing. Furthermore, the content " - "defies common sense, generating illogical scenarios, nonsensical entities, absurd character behaviors, " - "and conceptual paradoxes that violate basic human reasoning and everyday reality. The video looks like a " - "surreal or glitchy hallucination. Overall, the video is of poor quality." -) result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), image=image, num_frames=189, height=720, @@ -342,45 +286,22 @@ This is the same call as the text-to-video example above with `enable_sound=True ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import encode_video +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_sound_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting. Audio description: the soft whir of servo motors, gentle " - "thuds as fruits land in the plastic bag, the rustle of the bag settling in the shopping cart, and " - "a faint refrigeration hum in the background." -) - -# Recommended quality-control negative prompt (same as text-to-video). -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), num_frames=189, height=720, width=1280, @@ -401,45 +322,22 @@ encode_video( ```python +import json import torch from diffusers import Cosmos3OmniPipeline from diffusers.utils import encode_video +# JSON-upsampled positive and negative prompts (see "Prompt upsampling" above). +json_prompt = json.load(open("assets/example_t2v_sound_prompt.json")) +negative_prompt = json.load(open("assets/negative_prompt.json")) + pipe = Cosmos3OmniPipeline.from_pretrained( "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" ) -prompt = ( - "The video opens with a view of a well-lit indoor space featuring a wooden display case with " - "compartments filled with various fruits, including bananas, apples, pears, oranges, and carambolas. " - "The bananas are neatly arranged in the middle compartment, while apples are in the left and a mix " - "of pears, oranges, and carambolas are in the right. Two robotic arms with grippers are positioned " - "at the bottom of the frame, with the one on the left remaining stationary, partially obscuring the " - "apples. The robotic arm on the right begins its action, extending towards the right side of the " - "display case. It carefully picks up a pear from the fruit section, placing it into a plastic bag " - "in the shopping cart nearby, which has red handles. After securing the pear, the arm retracts back " - "to its original position. The process repeats as the robotic arm picks up an orange and places it " - "in the bag, followed by a carambola. The final frame captures the robotic arm returning to its " - "initial position, leaving the display case and surrounding area unchanged. The video showcases a " - "seamless and efficient automated fruit-picking process, highlighting the precision and efficiency " - "of modern robotics in a retail setting. Audio description: the soft whir of servo motors, gentle " - "thuds as fruits land in the plastic bag, the rustle of the bag settling in the shopping cart, and " - "a faint refrigeration hum in the background." -) - -# Recommended quality-control negative prompt (same as text-to-video). -negative_prompt = ( - "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, " - "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, " - "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky " - "movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, " - "fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. " - "Overall, the video is of poor quality." -) - result = pipe( - prompt=prompt, - negative_prompt=negative_prompt, + prompt=json.dumps(json_prompt), + negative_prompt=json.dumps(negative_prompt), num_frames=189, height=720, width=1280, @@ -459,6 +357,113 @@ encode_video( +## Action-conditioned generation + +Action runs group every action-specific input into a [`CosmosActionCondition`] passed via the `action` argument instead of the top-level `image` / `video` / `height` / `width` arguments. Set `resolution_tier` (`256`/`480`/`704`/`720`) close to the input video's native resolution; it selects the conditioning canvas. Cosmos 3 supports three action modes — `policy`, `forward_dynamics`, and `inverse_dynamics`. `policy` and `forward_dynamics` condition only on the first frame (so an `image` or a `video` both work), while `inverse_dynamics` requires a `video`. The conditioning video for an action run is set on `action.video` (or `action.image`), not on the pipeline's top-level `video` argument. + +Pass a plain task description as `prompt` and pick the camera with `action.view_point` (default `"ego_view"`; also `"third_person_view"`, `"wrist_view"`, `"concat_view"`). The pipeline turns these into the structured JSON caption the model was trained on, so action prompts should not be LLM-upsampled. + +### Action policy + +Action policy generation predicts future video and action tokens from the first observation frame, text prompt, and action domain metadata. The example below uses the Bridge robot domain and writes the predicted action chunk to JSON in model-normalized action space. + + + + +```python +import json + +import torch +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_video + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Nano", torch_dtype=torch.bfloat16, device_map="cuda" +) +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +prompt = "Put the pot to the left of the purple item." +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_20260501_0.mp4" +) + +result = pipe( + prompt=prompt, + action=CosmosActionCondition( + mode="policy", + chunk_size=16, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + video=video, + view_point="ego_view", + ), + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + use_system_prompt=False, +) + +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "sample.mp4", fps=5, macro_block_size=1) + +if result.action is not None: + with open("sample_action.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + + + + +```python +import json + +import torch +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_video + +pipe = Cosmos3OmniPipeline.from_pretrained( + "nvidia/Cosmos3-Super", torch_dtype=torch.bfloat16, device_map="cuda" +) +pipe.scheduler = UniPCMultistepScheduler.from_config( + pipe.scheduler.config, flow_shift=10.0, use_karras_sigmas=False +) + +prompt = "Put the pot to the left of the purple item." +video = load_video( + "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_20260501_0.mp4" +) + +result = pipe( + prompt=prompt, + action=CosmosActionCondition( + mode="policy", + chunk_size=16, + domain_name="bridge_orig_lerobot", + resolution_tier=480, + video=video, + view_point="ego_view", + ), + fps=5, + num_inference_steps=30, + guidance_scale=1.0, + use_system_prompt=False, +) + +# macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). +export_to_video(result.video, "sample.mp4", fps=5, macro_block_size=1) + +if result.action is not None: + with open("sample_action.json", "w") as f: + json.dump(result.action[0].tolist(), f) +``` + + + + ## Metadata templates `tokenize_prompt` appends short metadata sentences inside the user message so the LLM sees the conditioning the model was trained with. The positive prompt gets sentences like *"The video is 7.9 seconds long and is of 24 FPS."* and *"This video is of 720x1280 resolution."*; the negative prompt gets the inverse (*"… is not …"*). @@ -537,6 +542,10 @@ pipe = Cosmos3OmniPipeline.from_pretrained( - all - __call__ +## CosmosActionCondition + +[[autodoc]] CosmosActionCondition + ## Cosmos3OmniPipelineOutput [[autodoc]] pipelines.cosmos.pipeline_cosmos3_omni.Cosmos3OmniPipelineOutput \ No newline at end of file diff --git a/examples/cosmos3/README.md b/examples/cosmos3/README.md index 7a4cb277aa07..dd4be5dc286f 100644 --- a/examples/cosmos3/README.md +++ b/examples/cosmos3/README.md @@ -48,16 +48,123 @@ python examples/cosmos3/inference_cosmos3.py \ --enable-sound ``` +Action forward dynamics, robot domain (predict video from an observation video and a provided action chunk): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode forward_dynamics \ + --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.json" \ + --action-chunk-size 16 \ + --domain-name bridge_orig_lerobot \ + --resolution-tier 480 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \ + --output results/cosmos3_forward_dynamics_robot +``` + +Action forward dynamics, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode forward_dynamics \ + --action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_action_25.json" \ + --action-chunk-size 60 \ + --domain-name av \ + --resolution-tier 480 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \ + --output results/cosmos3_forward_dynamics_av +``` + +Action inverse dynamics, robot domain (predict actions from an observed video): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode inverse_dynamics \ + --action-chunk-size 16 \ + --domain-name bridge_orig_lerobot \ + --resolution-tier 480 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \ + --output results/cosmos3_inverse_dynamics_robot +``` + +Action inverse dynamics, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode inverse_dynamics \ + --action-chunk-size 60 \ + --domain-name av \ + --resolution-tier 480 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \ + --output results/cosmos3_inverse_dynamics_av +``` + +Action policy, robot domain (predict both future video and actions from the first observation frame): + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "Put the pot to the left of the purple item." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \ + --action-mode policy \ + --action-chunk-size 16 \ + --domain-name bridge_orig_lerobot \ + --resolution-tier 480 --fps 5 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \ + --output results/cosmos3_policy_robot +``` + +Action policy, autonomous-vehicle domain: + +```bash +python examples/cosmos3/inference_cosmos3.py \ + --model nano \ + --prompt "You are an autonomous vehicle planning system. Please go backward." \ + --vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \ + --action-mode policy \ + --action-chunk-size 60 \ + --domain-name av \ + --resolution-tier 480 --fps 10 \ + --num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \ + --output results/cosmos3_policy_av +``` + +Action modes use `action_chunk_size + 1` conditioning frames. `forward_dynamics` consumes `--action-path`; `inverse_dynamics` and `policy` write predicted actions to `sample_action.json` in model-normalized action space. This script loads `--vision-path` as a video for all action modes; `policy` and `forward_dynamics` condition only on the first frame, while `inverse_dynamics` uses the whole video. + +Pass `--prompt` as a plain task description and select the camera perspective with `--view-point` (default `ego_view`); the pipeline builds the structured action caption (task, viewpoint, duration, FPS, resolution) the model was trained on. Do not hand-write the viewpoint sentence into `--prompt`. + +`--resolution-tier` is a resolution *tier* (`256`/`480`/`704`/`720`). The tier keys a table of predefined aspect-ratio canvases; the one closest to the input aspect ratio becomes the padded conditioning canvas. It is not the output frame size: the input is downscaled (never upscaled) and padded to fill the canvas, then the padding is cropped from the latents so the decoded output follows the downscaled input content. `--height` / `--width` (and `--num-frames`) are ignored for action modes. + +Pick the tier that matches the native resolution of your conditioning input (`480` for ~480p, `720` for ~720p). A tier below your input downscales it and discards detail; a tier above your input gains no resolution (content is never upscaled), wastes compute on padding, and is a train/inference distribution mismatch that can degrade quality. + ### Useful flags | Flag | Default | Description | |---|---|---| | `--prompt` | (required) | Text prompt. | -| `--vision-path` | `None` | URL or local path for an image-conditioning frame (image-to-video). | -| `--num-frames` | `189` | `1` = image, otherwise number of video frames (`189` ≈ 7.9 s @ 24 FPS). | -| `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). | +| `--vision-path` | `None` | URL or local path for an image-conditioning frame (image-to-video), or the image/video conditioning for action modes. | +| `--num-frames` | `189` | `1` = image, otherwise number of video frames (`189` ≈ 7.9 s @ 24 FPS). Ignored for action modes (derived from `--action-chunk-size`). | +| `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). Ignored for action modes; use `--resolution-tier`. | +| `--resolution-tier` | `480` | Action resolution tier (`256`/`480`/`704`/`720`): selects the aspect bin / padded conditioning canvas, not the output size. | | `--fps` | `24.0` | Frame rate of the generated video. | +| `--flow-shift` | `None` | Override `UniPCMultistepScheduler.flow_shift` (and force `use_karras_sigmas=False`); left at the checkpoint default when unset. Cosmos3 runs use `10.0`. | | `--enable-sound` | off | Generate a synchronized audio track. | -| `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1`. | -| `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. | +| `--action-mode` | `None` | Enable action conditioning/generation. One of `forward_dynamics`, `inverse_dynamics`, or `policy`. | +| `--action-path` | `None` | URL or local JSON action path for `forward_dynamics`. | +| `--action-chunk-size` | `None` | Number of action tokens. Action runs generate/use `action_chunk_size + 1` video frames. | +| `--domain-name` | `None` | Action embodiment domain, for example `bridge_orig_lerobot` or `av`. | +| `--view-point` | `ego_view` | Camera perspective for the action caption's framing (`ego_view`, `third_person_view`, `wrist_view`, `concat_view`). Action only. | +| `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1` and for action modes (which build a structured caption instead). | +| `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. Ignored for action modes. | | `--output` | `.` | Directory to write `sample.jpg` or `sample.mp4`. | diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index fd0d0537cb0e..e9a5f5f369bb 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -23,13 +23,15 @@ """ import argparse +import json import pathlib +import urllib.request import torch from huggingface_hub import snapshot_download -from diffusers import Cosmos3OmniPipeline -from diffusers.utils import encode_video, export_to_video, load_image +from diffusers import Cosmos3OmniPipeline, CosmosActionCondition, UniPCMultistepScheduler +from diffusers.utils import encode_video, export_to_video, load_image, load_video HF_REPOS = { @@ -38,6 +40,22 @@ } +def _load_action(path: str | None): + if path is None: + raise ValueError("--action-path is required for forward_dynamics mode.") + if path.startswith(("http://", "https://")): + with urllib.request.urlopen(path) as response: + action = json.loads(response.read().decode("utf-8")) + else: + action = json.loads(pathlib.Path(path).read_text()) + tensor = torch.as_tensor(action, dtype=torch.float32) + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + if tensor.ndim != 2: + raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.") + return tensor + + def main(): parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument("--prompt", required=True, help="Text prompt.") @@ -50,11 +68,21 @@ def main(): parser.add_argument( "--vision-path", default=None, - help="Optional URL or local path for an image-conditioning frame (enables image-to-video).", + help="Optional URL or local path for an image-conditioning frame, or an action conditioning video.", ) parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.") - parser.add_argument("--height", type=int, default=720) - parser.add_argument("--width", type=int, default=1280) + parser.add_argument( + "--height", + type=int, + default=None, + help="Output height in pixels (default 720). Ignored for action modes; use --resolution-tier instead.", + ) + parser.add_argument( + "--width", + type=int, + default=None, + help="Output width in pixels (default 1280). Ignored for action modes; use --resolution-tier instead.", + ) parser.add_argument( "--num-frames", type=int, @@ -62,12 +90,46 @@ def main(): help="Number of frames to generate. Use 1 for text-to-image; defaults to 189 for video (≈ 7.9s @ 24 FPS).", ) parser.add_argument("--fps", type=float, default=24.0) + parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.") + parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.") + parser.add_argument( + "--flow-shift", + type=float, + default=None, + help="Override the scheduler's flow-matching shift (UniPCMultistepScheduler.flow_shift).", + ) + parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") parser.add_argument( "--enable-sound", action="store_true", default=False, help="Generate sound alongside video (requires a sound-capable checkpoint).", ) + parser.add_argument( + "--action-mode", + choices=["forward_dynamics", "inverse_dynamics", "policy"], + default=None, + help="Enable Cosmos3 action generation with a loaded conditioning video.", + ) + parser.add_argument("--action-path", default=None, help="JSON action path for forward_dynamics mode.") + parser.add_argument("--action-chunk-size", type=int, default=None, help="Number of action tokens to generate/use.") + parser.add_argument("--domain-name", default=None, help="Cosmos3 action embodiment domain name.") + parser.add_argument( + "--view-point", + choices=["ego_view", "third_person_view", "wrist_view", "concat_view"], + default="ego_view", + help="Camera perspective for the action caption's cinematography.framing field (default: ego_view).", + ) + parser.add_argument( + "--resolution-tier", + type=int, + default=480, + choices=[256, 480, 704, 720], + help=( + "Action resolution tier (256/480/704/720). Selects the aspect bin / padded conditioning canvas, " + "not the output frame size." + ), + ) parser.add_argument( "--no-duration-template", dest="add_duration_template", @@ -108,23 +170,59 @@ def main(): ) print("Pipeline loaded successfully.") + if args.flow_shift is not None: + pipeline.scheduler = UniPCMultistepScheduler.from_config( + pipeline.scheduler.config, flow_shift=args.flow_shift, use_karras_sigmas=False + ) + output_dir = pathlib.Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) - - image = load_image(args.vision_path) if args.vision_path is not None else None - - result = pipeline( - prompt=args.prompt, - image=image, - num_frames=args.num_frames, - height=args.height, - width=args.width, - fps=args.fps, - enable_sound=args.enable_sound, - add_resolution_template=args.add_resolution_template, - add_duration_template=args.add_duration_template, - enable_safety_check=not args.no_safety_check, - ) + generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None + + if args.action_mode is not None: + if args.vision_path is None: + raise ValueError("--vision-path must point to a conditioning video for action modes.") + if args.action_chunk_size is None: + raise ValueError("--action-chunk-size is required for action modes.") + video = load_video(args.vision_path) + raw_actions = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None + result = pipeline( + prompt=args.prompt, + action=CosmosActionCondition( + mode=args.action_mode, + chunk_size=args.action_chunk_size, + domain_name=args.domain_name, + resolution_tier=args.resolution_tier, + raw_actions=raw_actions, + video=video, + view_point=args.view_point, + ), + fps=args.fps, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + generator=generator, + use_system_prompt=False, + add_resolution_template=args.add_resolution_template, + add_duration_template=args.add_duration_template, + enable_safety_check=not args.no_safety_check, + ) + else: + image = load_image(args.vision_path) if args.vision_path is not None else None + result = pipeline( + prompt=args.prompt, + image=image, + num_frames=args.num_frames, + height=args.height, + width=args.width, + fps=args.fps, + num_inference_steps=args.num_inference_steps, + enable_sound=args.enable_sound, + guidance_scale=args.guidance_scale, + generator=generator, + add_resolution_template=args.add_resolution_template, + add_duration_template=args.add_duration_template, + enable_safety_check=not args.no_safety_check, + ) if args.num_frames == 1: save_path = output_dir / "sample.jpg" @@ -145,6 +243,13 @@ def main(): export_to_video(result.video, str(save_path), fps=int(args.fps), quality=10, macro_block_size=1) print(f"Saved: {save_path}") + if result.action is not None: + for action in result.action: + action_path = output_dir / "sample_action.json" + with open(action_path, "w") as f: + json.dump(action.tolist(), f) + print(f"Saved: {action_path}") + if __name__ == "__main__": main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a8957183ef99..01d126d4b560 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -553,6 +553,7 @@ "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "Cosmos3OmniPipeline", + "CosmosActionCondition", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", "CycleDiffusionPipeline", @@ -1372,6 +1373,7 @@ Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, Cosmos3OmniPipeline, + CosmosActionCondition, CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, CycleDiffusionPipeline, diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index 822d4f279e28..67b3a18576ec 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -146,6 +146,37 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +class DomainAwareLinear(nn.Module): + """Linear projection with one weight/bias pair per embodiment domain.""" + + def __init__(self, input_size: int, output_size: int, num_domains: int) -> None: + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.num_domains = num_domains + self.fc = nn.Embedding(self.num_domains, self.output_size * self.input_size) + self.bias = nn.Embedding(self.num_domains, self.output_size) + + def forward(self, x: torch.Tensor, domain_id: torch.Tensor) -> torch.Tensor: + if domain_id.ndim == 0: + domain_id = domain_id.unsqueeze(0) + domain_id = domain_id.to(device=x.device, dtype=torch.long).reshape(-1) + if x.shape[0] != domain_id.shape[0]: + raise ValueError( + "Cosmos3 action domain_id batch size must match action tokens: " + f"tokens={x.shape[0]}, domain_id={domain_id.shape[0]}." + ) + if torch.any((domain_id < 0) | (domain_id >= self.num_domains)): + raise ValueError(f"Cosmos3 action domain_id must be in [0, {self.num_domains}), got {domain_id.tolist()}.") + weight = self.fc(domain_id).view(domain_id.shape[0], self.input_size, self.output_size) + bias = self.bias(domain_id).view(domain_id.shape[0], self.output_size) + if x.ndim == 2: + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias + if x.ndim == 3: + return torch.bmm(x, weight) + bias.unsqueeze(1) + raise ValueError(f"Cosmos3 DomainAwareLinear expected rank-2 or rank-3 input, got {tuple(x.shape)}.") + + class Cosmos3PackedMoTAttention(nn.Module, AttentionModuleMixin): """Dual-pathway packed attention for Qwen3VL MoT — separate projections for understanding (causal) and generation (full) token streams.""" @@ -291,6 +322,9 @@ def __init__( rms_norm_eps: float = 1e-6, rope_scaling: dict | None = None, rope_theta: float = 5000000.0, + action_dim: int | None = None, + action_gen: bool = False, + num_embodiment_domains: int = 32, sound_dim: int | None = None, sound_gen: bool = False, sound_latent_fps: float = 25.0, @@ -333,6 +367,15 @@ def __init__( self.proj_out = nn.Linear(hidden_size, patch_latent_dim, bias=True) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=hidden_size) + self.action_gen = action_gen + self.action_dim = action_dim + self.num_embodiment_domains = num_embodiment_domains + if action_gen: + if self.action_dim is None: + raise ValueError("`action_dim` must be provided when `action_gen=True`.") + self.action_proj_in = DomainAwareLinear(self.action_dim, hidden_size, self.num_embodiment_domains) + self.action_proj_out = DomainAwareLinear(hidden_size, self.action_dim, self.num_embodiment_domains) + self.action_modality_embed = nn.Parameter(torch.zeros(hidden_size)) if sound_gen: if sound_dim is None: raise ValueError("`sound_dim` must be provided when `sound_gen=True`.") @@ -464,9 +507,43 @@ def _unpack_sound_latents( unpacked.append(output) return unpacked + def _pack_action_latents( + self, + tokens_action: list[torch.Tensor], + token_shapes_action: list[tuple[int, int, int]], + domain_ids_action: list[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + """List of ``[T, D]`` tensors → packed ``[total_T, D]`` plus per-token domain ids.""" + packed: list[torch.Tensor] = [] + domain_ids: list[torch.Tensor] = [] + for action, shape, domain_id in zip(tokens_action, token_shapes_action, domain_ids_action): + token_count = shape[0] + packed.append(action[:token_count]) + domain_ids.append(domain_id.reshape(1).expand(token_count)) + return torch.cat(packed, dim=0), torch.cat(domain_ids, dim=0) + + def _unpack_action_latents( + self, + packed_preds: torch.Tensor, + token_shapes_action: list[tuple[int, int, int]], + noisy_frame_indexes_action: list[torch.Tensor], + ) -> list[torch.Tensor]: + """Packed ``[total_noisy_T, D]`` predictions → list of ``[T, D]`` tensors.""" + unpacked: list[torch.Tensor] = [] + start_idx = 0 + for shape, noisy_idxs in zip(token_shapes_action, noisy_frame_indexes_action): + T = shape[0] + output = torch.zeros((T, self.action_dim), device=packed_preds.device, dtype=packed_preds.dtype) + t_n = len(noisy_idxs) + if t_n > 0: + output[noisy_idxs] = packed_preds[start_idx : start_idx + t_n] + start_idx += t_n + unpacked.append(output) + return unpacked + # ------------------------------------------------------------------------- - # forward: full per-step pass — encode text/vision/sound → run layers → - # decode vision/sound. Pipeline calls this once per CFG pass. + # forward: full per-step pass — encode text/vision/sound/action → run layers → + # decode vision/sound/action. Pipeline calls this once per CFG pass. # ------------------------------------------------------------------------- def forward( @@ -488,7 +565,14 @@ def forward( sound_mse_loss_indexes: torch.Tensor | None = None, sound_timesteps: torch.Tensor | None = None, sound_noisy_frame_indexes: list[torch.Tensor] | None = None, - ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None]: + action_tokens: list[torch.Tensor] | None = None, + action_token_shapes: list[tuple[int, int, int]] | None = None, + action_sequence_indexes: torch.Tensor | None = None, + action_mse_loss_indexes: torch.Tensor | None = None, + action_timesteps: torch.Tensor | None = None, + action_noisy_frame_indexes: list[torch.Tensor] | None = None, + action_domain_ids: list[torch.Tensor] | None = None, + ) -> tuple[list[torch.Tensor], list[torch.Tensor] | None, list[torch.Tensor] | None]: """Run a full denoising-step forward pass. Args: @@ -511,10 +595,11 @@ def forward( sound_noisy_frame_indexes: Optional noisy frame indices per sound item. Returns: - ``(preds_vision, preds_sound)`` — list of per-modality latents (``preds_sound`` is ``None`` when the model - has no sound branch or sound inputs are omitted). + ``(preds_vision, preds_sound, preds_action)`` — lists of per-modality predictions. Optional modalities + return ``None`` when their inputs are omitted. """ has_sound = sound_tokens is not None and sound_sequence_indexes is not None + has_action = action_tokens is not None and action_sequence_indexes is not None # Embed text tokens into the joint hidden_states buffer at their sequence positions. packed_text_embedding = self.embed_tokens(input_ids) @@ -551,6 +636,27 @@ def forward( ) hidden_states[sound_sequence_indexes] = packed_tokens_sound + # Pack + project action latents (when present). Domain ids select the action head weights. + if has_action: + packed_tokens_action, per_token_domain_ids = self._pack_action_latents( + action_tokens, action_token_shapes, action_domain_ids + ) + packed_tokens_action = packed_tokens_action.to(target_dtype) + per_token_domain_ids = per_token_domain_ids.to(device=packed_tokens_action.device) + packed_tokens_action = self.action_proj_in(packed_tokens_action, per_token_domain_ids) + packed_tokens_action = packed_tokens_action + self.action_modality_embed + if action_mse_loss_indexes.numel() > 0: + timesteps_action = action_timesteps * self.config.timestep_scale + packed_timestep_embeds_action = self.time_embedder(self.time_proj(timesteps_action)) + packed_timestep_embeds_action = packed_timestep_embeds_action.to(target_dtype) + packed_tokens_action = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens=packed_tokens_action, + packed_timestep_embeds=packed_timestep_embeds_action, + noisy_frame_indexes=action_noisy_frame_indexes, + token_shapes=action_token_shapes, + ) + hidden_states[action_sequence_indexes] = packed_tokens_action + # Compute rotary embeddings once for the joint sequence, then slice into und/gen halves. _meta_tensor = torch.tensor([], dtype=hidden_states.dtype, device=hidden_states.device) cos, sin = self.rotary_emb( @@ -590,4 +696,18 @@ def forward( preds_sound_packed = self.audio_proj_out(last_hidden_state[sound_mse_loss_indexes]) preds_sound = self._unpack_sound_latents(preds_sound_packed, sound_token_shapes, sound_noisy_frame_indexes) - return preds_vision, preds_sound + preds_action: list[torch.Tensor] | None = None + if has_action: + per_noisy_domain_ids = [ + domain_id.reshape(1).expand(len(noisy_idxs)) + for domain_id, noisy_idxs in zip(action_domain_ids, action_noisy_frame_indexes) + ] + per_noisy_domain_ids = torch.cat(per_noisy_domain_ids, dim=0).to(device=last_hidden_state.device) + preds_action_packed = self.action_proj_out( + last_hidden_state[action_mse_loss_indexes], per_noisy_domain_ids + ) + preds_action = self._unpack_action_latents( + preds_action_packed, action_token_shapes, action_noisy_frame_indexes + ) + + return preds_vision, preds_sound, preds_action diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7c38837f308d..89ebd8e186d8 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -215,6 +215,7 @@ "Cosmos2TextToImagePipeline", "Cosmos2VideoToWorldPipeline", "Cosmos3OmniPipeline", + "CosmosActionCondition", "CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline", ] @@ -655,6 +656,7 @@ Cosmos2TextToImagePipeline, Cosmos2VideoToWorldPipeline, Cosmos3OmniPipeline, + CosmosActionCondition, CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline, ) diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py index 0f828933be09..54d841f5b998 100644 --- a/src/diffusers/pipelines/cosmos/__init__.py +++ b/src/diffusers/pipelines/cosmos/__init__.py @@ -34,6 +34,7 @@ _import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"] _import_structure["pipeline_cosmos3_omni"] = [ "Cosmos3OmniPipeline", + "CosmosActionCondition", ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -52,6 +53,7 @@ from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline from .pipeline_cosmos3_omni import ( Cosmos3OmniPipeline, + CosmosActionCondition, ) from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 7225cce6ac9b..5425b7b575eb 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -13,12 +13,14 @@ # limitations under the License. import copy +import json import math from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, Literal import numpy as np import torch +import torch.nn.functional as F from PIL import Image from transformers import AutoTokenizer, BatchEncoding @@ -29,12 +31,15 @@ Cosmos3OmniTransformer, ) from ...schedulers import UniPCMultistepScheduler -from ...utils import BaseOutput, is_cosmos_guardrail_available +from ...utils import BaseOutput, is_cosmos_guardrail_available, logging from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + if is_cosmos_guardrail_available(): from cosmos_guardrail import CosmosSafetyChecker else: @@ -130,6 +135,100 @@ def get_3d_mrope_ids_vae_tokens( _SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." _SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +_ACTION_RESOLUTION_BINS = { + "256": { + "1.0": (256, 256), + "0.8": (256, 320), + "1.25": (320, 256), + "0.6": (192, 320), + "1.6666666666666667": (320, 192), + }, + "480": { + "1.0": (640, 640), + "0.7391304347826086": (544, 736), + "1.3529411764705883": (736, 544), + "0.5769230769230769": (480, 832), + "1.7333333333333334": (832, 480), + }, + "704": { + "1.0": (960, 960), + "0.7647058823529411": (832, 1088), + "1.3076923076923077": (1088, 832), + "0.55": (704, 1280), + "1.8181818181818181": (1280, 704), + }, + "720": { + "1.0": (960, 960), + "0.7536231884057971": (832, 1104), + "1.3269230769230769": (1104, 832), + "0.5625": (720, 1280), + "1.7777777777777777": (1280, 720), + }, +} + +# Viewpoint -> framing sentence, used to fill the action JSON `cinematography.framing` field. The action model was +# trained with these exact sentences; `"ego_view"` is the default when no viewpoint is supplied. +_ACTION_VIEWPOINT_TEMPLATES = { + "ego_view": "This video is captured from a first-person perspective looking at the scene.", + "third_person_view": "This video is captured from a third-person perspective looking towards the agent from the front.", + "wrist_view": "This video is captured from a wrist-mounted camera.", + "concat_view": "This video contains concatenated views from multiple camera perspectives.", +} + +_EMBODIMENT_TO_DOMAIN_ID = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + +# Canonical (unpadded) action width per embodiment. The width is fixed per embodiment and resolved from +# `domain_name` via this table. +# +# Widths come from the Cosmos 3 unified action representation (paper Fig. 3), which composes a few shared geometric +# building blocks: a 9D pose (3D translation + 6D rotation, the over-parameterized rotation of Zhou et al. 2019), a +# 1D grasp state (gripper open/close), and a 15D grasp state (fingertip positions, 3D x 5 fingers). Each embodiment +# concatenates these blocks, so its width is just their sum. For example: +# * av / camera_pose -> 9 : a single ego/effector 9D pose. +# * bridge / droid / fractal / umi -> 10 : one arm = 9D effector pose + 1D gripper. +# * robomind-franka-dual -> 20 : two arms = 2 x (9D + 1D). +# * agibotworld / agibot_gear_gripper -> 29 : humanoid = 9D ego + 2 x (9D arm + 1D gripper). +# * galbot -> 30 : humanoid-style stack with an extra pose block. +# * hand_pose -> 57 : egocentric two-hand motion = 9D ego + 2 x (9D wrist + 15D fingertips). +# +# TODO: support the configuration-dependent domains `libero`, whose width is not fixed per embodiment +# (it depends on the dataset's rotation/keypoint configuration) and so is absent here. +_EMBODIMENT_TO_RAW_ACTION_DIM = { + "av": 9, + "camera_pose": 9, + "pusht": 2, + "umi": 10, + "bridge_orig_lerobot": 10, + "droid_lerobot": 10, + "robomind-franka": 10, + "robomind-franka-dual": 20, + "robomind-ur": 10, + "galbot": 30, + "agibotworld": 29, + "agibot_gear_gripper": 29, + "agibot_gear_gripper_ext": 29, + "fractal": 10, + "hand_pose": 57, +} + @dataclass class Cosmos3OmniPipelineOutput(BaseOutput): @@ -142,10 +241,110 @@ class Cosmos3OmniPipelineOutput(BaseOutput): when ``output_type="latent"``. sound: Decoded audio waveform of shape ``[C, N]``. ``None`` when ``enable_sound=False``. + action: Predicted action tokens. ``None`` unless an action mode predicts actions. """ video: Any sound: torch.Tensor | None = None + action: list[torch.Tensor] | None = None + + +@dataclass +class CosmosActionCondition: + """Groups every input required for a Cosmos 3 action-conditioned generation task. + + Pass this to [`Cosmos3OmniPipeline.__call__`] via the `action` argument instead of the top-level `image` / `height` + / `width` arguments, which are reserved for t2v, i2v runs. + + Attributes: + mode (`str`): + The action task. One of `"forward_dynamics"` (roll out future video from a first frame and a given + `raw_actions` sequence), `"inverse_dynamics"` (infer the actions connecting the conditioning frames), or + `"policy"` (jointly roll out future video and actions from the first frame). + chunk_size (`int`): + Number of action transition steps in the chunk. The paired conditioning video spans `chunk_size + 1` + frames. + domain_name (`str`): + Embodiment domain selecting the domain-aware action projection weights. Must be one of the registered + Cosmos 3 embodiment domains. It also fixes the unpadded action width used to slice predicted actions, + resolved internally from this name (see `_EMBODIMENT_TO_RAW_ACTION_DIM`). + resolution_tier (`int`, defaults to `480`): + Action conditioning resolution *tier* (one of `256`, `480`, `704`, `720`). The tier picks a predefined + canvas whose aspect ratio is closest to the input; the input is downscaled (never upscaled) and padded into + it for conditioning. This is not the output frame size, which tracks the input content. Match the tier to + the input's native resolution: a lower tier discards detail, while a higher tier adds no resolution (no + upscaling), wastes compute on padding, and is a train/inference mismatch that can hurt quality. + raw_actions (`torch.Tensor`, *optional*): + Raw domain action vectors of shape `[T, raw_action_dim]` driving `"forward_dynamics"`. Sequences shorter + than `chunk_size` repeat the last action; longer ones are truncated. Channels beyond the model's + `action_dim` are rejected, and narrower inputs are zero-padded up to `action_dim`. + image (`PIL.Image.Image`, `np.ndarray`, or `torch.Tensor`, *optional*): + Conditioning frame for `"policy"` / `"forward_dynamics"`. Mutually exclusive with `video`. + video (`list`, `np.ndarray`, or `torch.Tensor`, *optional*): + Conditioning video, required for `"inverse_dynamics"`. For `"policy"` / `"forward_dynamics"` only its first + frame is used. Mutually exclusive with `image`. + view_point (`str`, defaults to `"ego_view"`): + Camera perspective label used to populate the action caption's `cinematography.framing` field. One of + `"ego_view"`, `"third_person_view"`, `"wrist_view"`, or `"concat_view"`. The action model was trained on + structured JSON captions that carry this viewpoint sentence; an unrecognized label drops the framing field + (with a warning). + """ + + mode: Literal["policy", "forward_dynamics", "inverse_dynamics"] + chunk_size: int + domain_name: str + resolution_tier: int = 480 + raw_actions: torch.Tensor | None = None + image: Image.Image | np.ndarray | torch.Tensor | None = None + video: list | np.ndarray | torch.Tensor | None = None + view_point: str = "ego_view" + + def __post_init__(self) -> None: + """Validate self-contained action fields at construction time.""" + if self.mode not in ["policy", "forward_dynamics", "inverse_dynamics"]: + raise ValueError( + f"Unsupported action mode={self.mode!r}; expected one of ['forward_dynamics', 'inverse_dynamics', 'policy']." + ) + if self.chunk_size < 1: + raise ValueError(f"action `chunk_size` must be >= 1, got {self.chunk_size}.") + if self.domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={self.domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + if str(self.resolution_tier) not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Unsupported action resolution_tier={self.resolution_tier!r}; " + f"expected one of {sorted(int(k) for k in _ACTION_RESOLUTION_BINS)}." + ) + if self.image is not None and self.video is not None: + raise ValueError("Provide either `image` or `video` for the action condition, not both.") + elif self.image is None and self.video is None: + raise ValueError("`image` and `video` cannot both be None") + if self.mode == "inverse_dynamics" and self.video is None: + raise ValueError("action mode='inverse_dynamics' requires `video` conditioning.") + # Resolve the unpadded action width from the embodiment: the width is fixed per embodiment and looked up from + # the table. Domains absent from the table are unsupported for action inference in all modes. + # TODO: support the configuration-dependent domains (libero, hand_pose), whose width is set per-dataset. + if self.domain_name not in _EMBODIMENT_TO_RAW_ACTION_DIM: + raise ValueError( + f"domain_name={self.domain_name!r} is not supported for action inference: it has no canonical action " + f"width. Supported domains: {sorted(_EMBODIMENT_TO_RAW_ACTION_DIM)}." + ) + self.raw_action_dim = _EMBODIMENT_TO_RAW_ACTION_DIM[self.domain_name] + if self.mode == "forward_dynamics": + if self.raw_actions is None: + raise ValueError("action mode='forward_dynamics' requires `raw_actions`.") + if self.raw_actions.ndim != 2: + raise ValueError(f"`raw_actions` must have shape [T, D], got {tuple(self.raw_actions.shape)}.") + if self.raw_actions.shape[0] < 1: + raise ValueError("action mode='forward_dynamics' requires at least one action token.") + # The supplied action width must match the embodiment's expected width. + if self.raw_actions.shape[1] != self.raw_action_dim: + raise ValueError( + f"`raw_actions` width ({self.raw_actions.shape[1]}) does not match the expected action width " + f"({self.raw_action_dim}) for domain_name={self.domain_name!r}." + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents @@ -308,6 +507,7 @@ def _prepare_vision_segment( vision_fps: float | None, curr: int, device: torch.device | str, + condition_frame_indexes: list[int] | None = None, ) -> dict[str, Any]: """Build the static portion of the vision segment of the joint sequence. @@ -322,12 +522,16 @@ def _prepare_vision_segment( patch_w = math.ceil(latent_w / latent_patch_size) num_vision_tokens = latent_t * patch_h * patch_w - noisy_start = 1 if has_image_condition else 0 - noisy_frame_indexes = torch.arange(noisy_start, latent_t, device=device, dtype=torch.long) + if condition_frame_indexes is None: + condition_frame_indexes = [0] if has_image_condition else [] + cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < latent_t} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in cond_frames], device=device, dtype=torch.long + ) frame_token_stride = patch_h * patch_w mse_loss_indexes: list[int] = [] - for frame_idx in range(noisy_start, latent_t): + for frame_idx in noisy_frame_indexes.tolist(): frame_start = curr + frame_idx * frame_token_stride mse_loss_indexes.extend(range(frame_start, frame_start + frame_token_stride)) @@ -352,7 +556,7 @@ def _prepare_vision_segment( # Assembly helpers (consumed inline before the transformer call). "vision_mrope_ids": vision_mrope_ids.to(device), "num_vision_tokens": num_vision_tokens, - "num_noisy_vision_tokens": (latent_t - noisy_start) * frame_token_stride, + "num_noisy_vision_tokens": len(noisy_frame_indexes) * frame_token_stride, } def _prepare_sound_segment( @@ -396,37 +600,143 @@ def _prepare_sound_segment( "sound_len": sound_len, } + def _prepare_action_segment( + self, + input_action_tokens: torch.Tensor, + condition_frame_indexes: list[int], + mrope_offset: int | float, + action_fps: float | None, + curr: int, + device: torch.device | str, + ) -> dict[str, Any]: + """Build the static action segment; per-step tokens/timesteps are spliced in the denoising loop.""" + config = self.transformer.config + action_len = input_action_tokens.shape[0] + cond_frames = {idx for idx in condition_frame_indexes if 0 <= idx < action_len} + noisy_frame_indexes = torch.tensor( + [idx for idx in range(action_len) if idx not in cond_frames], device=device, dtype=torch.long + ) + + effective_fps = action_fps if config.enable_fps_modulation else None + action_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=action_len, + grid_h=1, + grid_w=1, + temporal_offset=mrope_offset, + reset_spatial_indices=config.unified_3d_mrope_reset_spatial_ids, + fps=effective_fps, + base_fps=float(config.base_fps), + temporal_compression_factor=1, + base_temporal_compression_factor=self.vae.config.scale_factor_temporal, + start_frame_offset=1, + ) + + sequence_indexes = torch.arange(curr, curr + action_len, dtype=torch.long, device=device) + return { + "action_token_shapes": [(action_len, 1, 1)], + "action_sequence_indexes": sequence_indexes, + "action_mse_loss_indexes": sequence_indexes[noisy_frame_indexes], + "action_noisy_frame_indexes": [noisy_frame_indexes], + "action_mrope_ids": action_mrope_ids.to(device), + "action_len": action_len, + "num_noisy_action_tokens": len(noisy_frame_indexes), + } + + def _prepare_action_video_conditioning( + self, + conditioning_clip: Any, + resolution_tier: int, + num_frames: int, + device: torch.device | str, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, int, int]: + frames = self.video_processor.preprocess_video(conditioning_clip).to(device=device, dtype=dtype) + source_h, source_w = frames.shape[-2:] + resolution_key = str(resolution_tier) + if resolution_key not in _ACTION_RESOLUTION_BINS: + raise ValueError( + f"Unsupported action resolution_tier={resolution_tier!r}; " + f"expected one of {sorted(int(k) for k in _ACTION_RESOLUTION_BINS)}." + ) + target_h, target_w = VideoProcessor.classify_height_width_bin( + source_h, source_w, ratios=_ACTION_RESOLUTION_BINS[resolution_key] + ) + + if frames.shape[2] < num_frames: + frames = torch.cat([frames, frames[:, :, -1:].expand(-1, -1, num_frames - frames.shape[2], -1, -1)], dim=2) + else: + frames = frames[:, :, :num_frames] + + _, _, _, frame_h, frame_w = frames.shape + scale = min(target_w / frame_w, target_h / frame_h, 1.0) + content_h = max(1, int(scale * frame_h + 0.5)) + content_w = max(1, int(scale * frame_w + 0.5)) + + frames_t = frames.permute(0, 2, 1, 3, 4).reshape(-1, frames.shape[1], frame_h, frame_w) + if content_h != frame_h or content_w != frame_w: + frames_t = F.interpolate( + frames_t, + size=(content_h, content_w), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pad_right = target_w - content_w + pad_bottom = target_h - content_h + if pad_right or pad_bottom: + pad_mode = "replicate" if pad_right >= content_w or pad_bottom >= content_h else "reflect" + frames_t = F.pad(frames_t, (0, pad_right, 0, pad_bottom), mode=pad_mode) + frames = frames_t.reshape(frames.shape[0], num_frames, frames.shape[1], target_h, target_w).permute( + 0, 2, 1, 3, 4 + ) + image_size = torch.tensor([target_h, target_w, content_h, content_w], device=device, dtype=torch.float32) + return frames.to(dtype=dtype), image_size, target_h, target_w + + def _remove_action_video_padding_from_latent( + self, latents: torch.Tensor, image_size: torch.Tensor + ) -> torch.Tensor: + content_h = int(image_size[2].item()) + content_w = int(image_size[3].item()) + content_h_latent = max(content_h // self.vae_scale_factor_spatial, 1) + content_w_latent = max(content_w // self.vae_scale_factor_spatial, 1) + return latents[:, :, :, :content_h_latent, :content_w_latent].contiguous() + def prepare_latents( self, image: torch.Tensor | None = None, - num_frames: int = 189, - height: int = 720, - width: int = 1280, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, fps: float = 24.0, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, + action_latents: torch.Tensor | None = None, generator: torch.Generator | None = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, enable_sound: bool = False, + action: "CosmosActionCondition | None" = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, + torch.Tensor | None, float, float | None, torch.Tensor, torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + int | None, ]: """Build conditioning + initial noise for a single sample. Returns: - ``(vision_latents, sound_latents, fps_vision, fps_sound)``. ``vision_latents`` is the noisy vision tensor; - ``sound_latents`` is the noisy sound tensor (``None`` unless ``enable_sound`` was set). The FPS scalars - feed the per-step :meth:`_prepare_vision_segment` / :meth:`_prepare_sound_segment` calls in the denoising - loop. + Initial noisy tensors plus condition masks/metadata for vision, sound, and optional action modalities. """ + action_mode = action.mode if action is not None else None is_image = num_frames == 1 - has_image_condition = image is not None and not is_image + has_image_condition = (image is not None and not is_image) or action_mode is not None # video_processor.preprocess handles PIL/np/tensor → [1, 3, H, W] in [-1, 1], resized to (height, width). conditioning_frame_2d: torch.Tensor | None = None @@ -435,8 +745,43 @@ def prepare_latents( device=device, dtype=dtype ) + action_domain_id: torch.Tensor | None = None + action_condition_mask: torch.Tensor | None = None + raw_action_dim_resolved: int | None = ( + int(action.raw_action_dim) if action is not None and action.raw_action_dim is not None else None + ) + if raw_action_dim_resolved is not None and raw_action_dim_resolved > self.transformer.config.action_dim: + raise ValueError( + f"raw_action_dim={raw_action_dim_resolved} exceeds the model's trained action_dim=" + f"{self.transformer.config.action_dim}; this checkpoint cannot represent that action width." + ) + action_condition_frames: list[int] = [] + action_condition_frame_indexes: list[int] = [] + action_image_size: torch.Tensor | None = None + vision_condition_frames: list[int] | None = None + # Build the vision conditioning tensor (always [1, 3, T, H, W], in [-1, 1], on device). - if is_image: + if action is not None: + target_frames = action.chunk_size + 1 + conditioning_clip = [action.image] if action.image is not None else action.video + vision_tensor, action_image_size, height, width = self._prepare_action_video_conditioning( + conditioning_clip, action.resolution_tier, target_frames, device=device, dtype=dtype + ) + if action_mode == "forward_dynamics": + vision_condition_frames = [0] + action_condition_frames = list(range(action.chunk_size)) + elif action_mode == "policy": + vision_condition_frames = [0] + elif action_mode == "inverse_dynamics": + latent_frames = (target_frames - 1) // self.vae.config.scale_factor_temporal + 1 + vision_condition_frames = list(range(latent_frames)) + else: + raise ValueError( + f"Unsupported action_mode={action_mode!r}; expected one of " + "['forward_dynamics', 'inverse_dynamics', 'policy']." + ) + action_condition_frame_indexes = action_condition_frames + elif is_image: vision_tensor = ( conditioning_frame_2d.unsqueeze(2) # [1, 3, 1, H, W] if conditioning_frame_2d is not None @@ -451,6 +796,8 @@ def prepare_latents( vision_tensor[:, :, 1:] = conditioning_frame_2d.unsqueeze(2).expand(-1, -1, num_frames - 1, -1, -1) x0_tokens_vision = self._encode_video(vision_tensor).contiguous().float() + if action_image_size is not None: + x0_tokens_vision = self._remove_action_video_padding_from_latent(x0_tokens_vision, action_image_size) vision_shape = tuple(x0_tokens_vision.shape) x0_tokens_sound: torch.Tensor | None = None @@ -463,9 +810,55 @@ def prepare_latents( T_sound = (n_audio_samples + hop_size - 1) // hop_size x0_tokens_sound = torch.zeros(sound_dim, T_sound, device=device, dtype=dtype) + x0_tokens_action: torch.Tensor | None = None + if action is not None: + action_chunk_size = action.chunk_size + action_dim = self.transformer.action_dim + if action_mode == "forward_dynamics": + raw_actions = action.raw_actions + if raw_actions is None: + raise ValueError("action_mode='forward_dynamics' requires an action tensor.") + raw_actions = raw_actions.to(device=device, dtype=dtype) + + # Action chunks describe transitions, so action length must match action_chunk_size + # while the paired video has action_chunk_size + 1 frames. Short inputs repeat the last action. + if raw_actions.shape[0] < action_chunk_size: + raw_actions = torch.cat( + [raw_actions, raw_actions[-1:].expand(action_chunk_size - raw_actions.shape[0], -1)], + dim=0, + ) + raw_actions = raw_actions[:action_chunk_size] + + # The model action head has a fixed action_dim; pad raw domain actions with zeros on the channel axis. + if raw_actions.shape[-1] < action_dim: + action_padding = torch.zeros( + raw_actions.shape[0], + action_dim - raw_actions.shape[-1], + dtype=raw_actions.dtype, + device=raw_actions.device, + ) + raw_actions = torch.cat([raw_actions, action_padding], dim=-1) + x0_tokens_action = raw_actions + else: + x0_tokens_action = torch.zeros(action_chunk_size, action_dim, device=device, dtype=dtype) + if action.domain_name not in _EMBODIMENT_TO_DOMAIN_ID: + raise ValueError( + f"Unknown Cosmos3 action domain_name={action.domain_name!r}; " + f"expected one of {sorted(_EMBODIMENT_TO_DOMAIN_ID)}." + ) + action_domain_id = torch.tensor( + [_EMBODIMENT_TO_DOMAIN_ID[action.domain_name]], + dtype=torch.long, + device=device, + ) + # Vision conditioning mask [latent_t, 1, 1]: frame 0 anchored when image-conditioning, rest noisy. vision_condition_mask = torch.zeros((x0_tokens_vision.shape[2], 1, 1), device=device, dtype=dtype) - if has_image_condition: + if vision_condition_frames is not None: + for frame_idx in vision_condition_frames: + if 0 <= frame_idx < vision_condition_mask.shape[0]: + vision_condition_mask[frame_idx, 0, 0] = 1.0 + elif has_image_condition: vision_condition_mask[0, 0, 0] = 1.0 if latents is None: @@ -491,17 +884,50 @@ def prepare_latents( else: sound_latents = sound_latents.to(device=device, dtype=dtype) - return latents, sound_latents, fps, fps_sound, vision_condition_mask, sound_condition_mask + if action_mode is not None and x0_tokens_action is not None: + action_condition_mask = torch.zeros((x0_tokens_action.shape[0], 1), device=device, dtype=dtype) + for frame_idx in action_condition_frames: + if 0 <= frame_idx < action_condition_mask.shape[0]: + action_condition_mask[frame_idx, 0] = 1.0 + if action_latents is None: + pure_noise_action = randn_tensor( + tuple(x0_tokens_action.shape), generator=generator, device=device, dtype=dtype + ) + action_latents = ( + action_condition_mask * x0_tokens_action + (1.0 - action_condition_mask) * pure_noise_action + ) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + else: + action_latents = action_latents.to(device=device, dtype=dtype) + + return ( + latents, + sound_latents, + action_latents, + fps, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) def check_inputs( self, prompt, negative_prompt, - height: int, - width: int, - num_frames: int, + image, + height: int | None, + width: int | None, + num_frames: int | None, + guidance_scale: float, enable_sound: bool, callback_on_step_end_tensor_inputs: list[str], + action: "CosmosActionCondition | None" = None, ) -> None: if not isinstance(prompt, (str, list)) or ( isinstance(prompt, list) and not all(isinstance(p, str) for p in prompt) @@ -511,11 +937,6 @@ def check_inputs( raise ValueError( f"`negative_prompt` must be a str, list of str, or None, got {type(negative_prompt).__name__}." ) - if num_frames < 1: - raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") - sf = int(self.vae.config.scale_factor_spatial) - if height % sf != 0 or width % sf != 0: - raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") if enable_sound: if self.sound_tokenizer is None: raise ValueError("`enable_sound=True` requires a sound-capable checkpoint with a `sound_tokenizer`.") @@ -527,6 +948,76 @@ def check_inputs( f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) + if action is not None: + # API-conflict + model-dependent checks live here. + if num_frames is not None: + raise ValueError("`num_frames` has to be None if action is not None") + if height is not None or width is not None: + raise ValueError("`height` and `width` have to be None if action is not None") + if image is not None: + raise ValueError( + "Pass action conditioning via `action.image` / `action.video`, not the top-level `image` argument." + ) + if not getattr(self.transformer.config, "action_gen", False): + raise ValueError("`action` requires a transformer trained with action_gen=True.") + if action.mode == "forward_dynamics" and action.raw_actions is not None: + if action.raw_actions.shape[-1] > self.transformer.config.action_dim: + raise ValueError( + f"Cosmos3 action dimension {action.raw_actions.shape[-1]} exceeds model action_dim=" + f"{self.transformer.config.action_dim}." + ) + else: + if num_frames is None: + raise ValueError("`num_frames` must be provided when `action` is None.") + if height is None or width is None: + raise ValueError("`height` and `width` must be provided when `action` is None.") + if num_frames < 1: + raise ValueError(f"`num_frames` must be >= 1, got {num_frames}.") + sf = int(self.vae.config.scale_factor_spatial) + if height % sf != 0 or width % sf != 0: + raise ValueError(f"`height` and `width` must be multiples of {sf}, got ({height}, {width}).") + + @staticmethod + def _build_action_json_prompt( + description: str, + *, + view_point: str | None, + num_frames: int, + fps: float, + height: int, + width: int, + ) -> str: + """Build the structured action caption the model was trained on, then serialize it to a JSON string.""" + duration_seconds = num_frames / fps if fps > 0 else 0.0 + duration = int(duration_seconds) if duration_seconds >= 0 and math.isfinite(duration_seconds) else 0 + action_end = round(duration_seconds) if duration_seconds >= 0 and math.isfinite(duration_seconds) else 0 + minutes, seconds = divmod(action_end, 60) + + desc = description.strip() + if desc and not desc.endswith((".", "!", "?")): + desc = f"{desc}." + + prompt: dict[str, Any] = {} + framing = _ACTION_VIEWPOINT_TEMPLATES.get(view_point) if view_point is not None else None + if view_point is not None and framing is None: + logger.warning( + f"Unrecognized action view_point={view_point!r}; known viewpoints: " + f"{sorted(_ACTION_VIEWPOINT_TEMPLATES)}. Dropping the cinematography.framing field." + ) + if framing: + prompt["cinematography"] = {"framing": framing} + ratio = width / height if height > 0 else 1.0 + aspect_ratio = min( + ("1,1", "4,3", "3,4", "16,9", "9,16"), + key=lambda r: abs(int(r.split(",")[0]) / int(r.split(",")[1]) - ratio), + ) + prompt["actions"] = [{"time": f"0:00-{minutes}:{seconds:02d}", "description": desc}] + prompt["duration"] = f"{duration}s" + prompt["fps"] = float(fps) + prompt["resolution"] = {"H": int(height), "W": int(width)} + prompt["aspect_ratio"] = aspect_ratio + return json.dumps(prompt) + def tokenize_prompt( self, prompt: str, @@ -538,6 +1029,8 @@ def tokenize_prompt( use_system_prompt: bool = True, add_resolution_template: bool = True, add_duration_template: bool = True, + action_mode: str | None = None, + action_view_point: str | None = None, ) -> tuple[list[int], list[int]]: """Apply prompt-augmentation templates and tokenize cond/uncond prompts via the Qwen2 chat template. @@ -548,6 +1041,10 @@ def tokenize_prompt( quality-control negative prompts to pass explicitly for text2video / image2video. The duration and resolution templates are appended to the prompt, and inverse templates are appended to the negative prompt, when enabled. + When ``action_mode`` is set, the prompt is instead converted to the structured action JSON caption the model + was trained on (see :meth:`_build_action_json_prompt`), using ``action_view_point`` for the framing field; the + flat metadata templates are skipped because the JSON already carries duration/fps/resolution/aspect_ratio. + Returns: ``(cond_input_ids, uncond_input_ids)`` — token-id lists for this sample. """ @@ -594,9 +1091,18 @@ def _add_special_tokens(input_ids: list[int]) -> list[int]: self.llm_special_tokens["start_of_generation"], ] - cond_encodings = _tokenize(_apply_templates(prompt)) + if action_mode is not None: + cond_text = self._build_action_json_prompt( + prompt, view_point=action_view_point, num_frames=num_frames, fps=fps, height=height, width=width + ) + uncond_text = negative_prompt + else: + cond_text = _apply_templates(prompt) + uncond_text = _apply_templates(negative_prompt, is_negative=True) + + cond_encodings = _tokenize(cond_text) cond_input_ids = _add_special_tokens(cond_encodings.input_ids) - uncond_encodings = _tokenize(_apply_templates(negative_prompt, is_negative=True)) + uncond_encodings = _tokenize(uncond_text) uncond_input_ids = _add_special_tokens(uncond_encodings.input_ids) return cond_input_ids, uncond_input_ids @@ -606,7 +1112,10 @@ def _mask_velocity_predictions( preds_sound: list[torch.Tensor] | None, vision_condition_mask: list[torch.Tensor], sound_condition_mask: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + preds_action: list[torch.Tensor] | None = None, + action_condition_mask: list[torch.Tensor] | None = None, + raw_action_dim: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """Zero out conditioning positions in the transformer's velocity predictions. ``preds_vision`` / ``preds_sound`` are returned per-sample by the transformer; the pipeline runs batch=1, so we @@ -625,7 +1134,16 @@ def _mask_velocity_predictions( noisy_mask_s = (1.0 - cond_mask_s).T.to(dtype=pred_s.dtype, device=pred_s.device) velocity_sound = pred_s * noisy_mask_s if noisy_mask_s.sum() > 0 else torch.zeros_like(pred_s) - return velocity_vision, velocity_sound + velocity_action: torch.Tensor | None = None + if preds_action is not None and action_condition_mask is not None: + pred_a = preds_action[0] + cond_mask_a = action_condition_mask[0] + noisy_mask_a = (1.0 - cond_mask_a).to(dtype=pred_a.dtype, device=pred_a.device) + velocity_action = pred_a * noisy_mask_a if noisy_mask_a.sum() > 0 else torch.zeros_like(pred_a) + if raw_action_dim is not None: + velocity_action[:, raw_action_dim:] = 0 + + return velocity_vision, velocity_sound, velocity_action def _apply_video_safety_check(self, video: Any, output_type: str, device: torch.device) -> Any: """Run the Cosmos video guardrail on a postprocessed video and return it in the same format. @@ -670,15 +1188,19 @@ def current_timestep(self): def interrupt(self): return self._interrupt + @property + def do_classifier_free_guidance(self): + return self._guidance_scale != 1.0 + @torch.no_grad() def __call__( self, prompt: str | list[str], negative_prompt: str | list[str] | None = None, image: torch.Tensor | None = None, - num_frames: int = 189, - height: int = 720, - width: int = 1280, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, fps: float = 24.0, num_inference_steps: int = 35, guidance_scale: float = 6.0, @@ -686,6 +1208,8 @@ def __call__( generator: torch.Generator | None = None, latents: torch.Tensor | None = None, sound_latents: torch.Tensor | None = None, + action_latents: torch.Tensor | None = None, + action: CosmosActionCondition | None = None, output_type: str = "pil", return_dict: bool = True, use_system_prompt: bool = True, @@ -710,13 +1234,17 @@ def __call__( The negative prompt used for classifier-free guidance. When `None`, the empty string is used. image (`torch.Tensor` or `PIL.Image.Image`, *optional*): Optional conditioning frame for image-to-video. The pipeline anchors frame 0 to this image and denoises - the remaining frames. Ignored when `num_frames == 1`. - num_frames (`int`, *optional*, defaults to `189`): - Number of frames to generate. Use `1` for text-to-image; the default produces ≈ 7.9 s at 24 FPS. - height (`int`, *optional*, defaults to `720`): - Output height in pixels. - width (`int`, *optional*, defaults to `1280`): - Output width in pixels. + the remaining frames. Ignored when `num_frames == 1`. Not used for action runs (pass `action` instead). + num_frames (`int`, *optional*, defaults to `None`): + Number of frames to generate. Use `1` for text-to-image. Defaults to `189` (≈ 7.9 s at 24 FPS) for + non-action modes when omitted (`None`). Must be `None` for action runs, where frame count is derived + from `action.chunk_size + 1`. + height (`int`, *optional*, defaults to `None`): + Output height in pixels. Defaults to `720` for non-action modes when omitted (`None`). Must be `None` + for action runs, which size via `action.resolution_tier`. + width (`int`, *optional*, defaults to `None`): + Output width in pixels. Defaults to `1280` for non-action modes when omitted (`None`). Must be `None` + for action runs, which size via `action.resolution_tier`. fps (`float`, *optional*, defaults to `24.0`): Target frame rate, also injected into the mRoPE temporal modulation and into the duration metadata template. @@ -735,6 +1263,15 @@ def __call__( sound_latents (`torch.Tensor`, *optional*): Pre-generated sound latents to start denoising from. Only consulted when `enable_sound=True`; when `None`, fresh Gaussian noise is sampled. + action_latents (`torch.Tensor`, *optional*): + Pre-generated action latents to start the action stream's denoising from. Only consulted when an action + run is configured via `action`; when `None`, fresh Gaussian noise is sampled for the action tokens. + action (`CosmosActionCondition`, *optional*): + Bundles every input for an action-conditioned run (mode, chunk size, embodiment domain, resolution + tier, raw actions, and the conditioning image/video), and requires a transformer trained with + `action_gen=True`. When set, passing the top-level `image` argument raises; `height` / `width` / + `num_frames` must be `None`, since resolution comes from `action.resolution_tier` and frame count from + `action.chunk_size`. See [`CosmosActionCondition`]. output_type (`str`, *optional*, defaults to `"pil"`): Output format for the video. One of `"pil"` (list of `PIL.Image.Image`), `"np"` (`np.ndarray`, `[T, H, W, C]`), `"pt"` (`torch.Tensor`, `[T, C, H, W]`), or `"latent"` (raw vision latents). @@ -770,13 +1307,47 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if action is None: + if num_frames is None: + num_frames = 189 + if height is None: + height = 720 + if width is None: + width = 1280 + # 1. Check inputs self.check_inputs( - prompt, negative_prompt, height, width, num_frames, enable_sound, callback_on_step_end_tensor_inputs + prompt, + negative_prompt, + image, + height, + width, + num_frames, + guidance_scale, + enable_sound, + callback_on_step_end_tensor_inputs, + action, ) + # `action_mode` is the only action field consumed directly in __call__ (prompt template + output slicing); + # all other action fields are read from `action` at their point of use (e.g. in prepare_latents). + action_mode = action.mode if action is not None else None + + if action is not None: + num_frames = action.chunk_size + 1 + # Resolve the padded conditioning canvas from the tier + input aspect *before* tokenization, so the + # resolution prompt template matches the canvas the model is actually conditioned on. + conditioning_clip = [action.image] if action.image is not None else action.video + probe = self.video_processor.preprocess_video(conditioning_clip) + source_h, source_w = int(probe.shape[-2]), int(probe.shape[-1]) + resolution_key = str(action.resolution_tier) + height, width = VideoProcessor.classify_height_width_bin( + source_h, source_w, ratios=_ACTION_RESOLUTION_BINS[resolution_key] + ) + self._current_timestep = None self._interrupt = False + self._guidance_scale = guidance_scale # Pipeline supports a single sample at a time; collapse list-style inputs to a single string. if isinstance(prompt, list): @@ -809,6 +1380,8 @@ def __call__( use_system_prompt=use_system_prompt, add_resolution_template=add_resolution_template, add_duration_template=add_duration_template, + action_mode=action_mode, + action_view_point=action.view_point if action is not None else None, ) # 3. Pre-pack the text segment for each prompt — text packing is invariant @@ -817,22 +1390,37 @@ def __call__( uncond_text_segment = self._prepare_text_segment(uncond_input_ids, device=device) # 4. Prepare latents (initial noise per modality + pack metadata) - has_image_condition = image is not None and num_frames > 1 - latents, sound_latents, fps_vision, fps_sound, vision_condition_mask, sound_condition_mask = ( - self.prepare_latents( - image=image, - num_frames=num_frames, - height=height, - width=width, - fps=fps, - latents=latents, - sound_latents=sound_latents, - generator=generator, - device=device, - dtype=dtype, - enable_sound=enable_sound, - ) + ( + latents, + sound_latents, + action_latents, + fps_vision, + fps_sound, + vision_condition_mask, + sound_condition_mask, + action_condition_mask, + action_domain_id, + action_image_size, + raw_action_dim_resolved, + action_condition_frame_indexes, + ) = self.prepare_latents( + image=image, + num_frames=num_frames, + height=height, + width=width, + fps=fps, + latents=latents, + sound_latents=sound_latents, + action_latents=action_latents, + generator=generator, + device=device, + dtype=dtype, + enable_sound=enable_sound, + action=action, ) + vision_condition_indexes_for_pack = torch.nonzero(vision_condition_mask[:, 0, 0] > 0, as_tuple=False).flatten() + vision_condition_indexes_for_pack = [int(idx.item()) for idx in vision_condition_indexes_for_pack] + has_image_condition = bool(vision_condition_indexes_for_pack) # 5. Pre-pack the static per-prompt vision / sound sequence segments. The only # fields that vary across denoising steps are the modality token tensors and the @@ -846,6 +1434,7 @@ def __call__( vision_fps=fps_vision, curr=cond_text_segment["und_len"], device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) cond_sound_segment: dict[str, Any] = {} if sound_latents is not None: @@ -856,17 +1445,33 @@ def __call__( curr=cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"], device=device, ) + cond_action_segment: dict[str, Any] = {} + if action_latents is not None: + cond_action_segment = self._prepare_action_segment( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=cond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=cond_text_segment["und_len"] + + cond_vision_segment["num_vision_tokens"] + + cond_sound_segment.get("sound_len", 0), + device=device, + ) cond_mrope_segments = [cond_text_segment["text_mrope_ids"], cond_vision_segment["vision_mrope_ids"]] if cond_sound_segment: cond_mrope_segments.append(cond_sound_segment["sound_mrope_ids"]) + if cond_action_segment: + cond_mrope_segments.append(cond_action_segment["action_mrope_ids"]) cond_packed_static = { **cond_text_segment, **cond_vision_segment, **cond_sound_segment, + **cond_action_segment, "position_ids": torch.cat(cond_mrope_segments, dim=1), "sequence_length": cond_text_segment["und_len"] + cond_vision_segment["num_vision_tokens"] - + cond_sound_segment.get("sound_len", 0), + + cond_sound_segment.get("sound_len", 0) + + cond_action_segment.get("action_len", 0), } uncond_vision_segment = self._prepare_vision_segment( @@ -876,6 +1481,7 @@ def __call__( vision_fps=fps_vision, curr=uncond_text_segment["und_len"], device=device, + condition_frame_indexes=vision_condition_indexes_for_pack, ) uncond_sound_segment: dict[str, Any] = {} if sound_latents is not None: @@ -886,26 +1492,44 @@ def __call__( curr=uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"], device=device, ) + uncond_action_segment: dict[str, Any] = {} + if action_latents is not None: + uncond_action_segment = self._prepare_action_segment( + input_action_tokens=action_latents, + condition_frame_indexes=action_condition_frame_indexes, + mrope_offset=uncond_text_segment["vision_start_temporal_offset"], + action_fps=fps_vision, + curr=uncond_text_segment["und_len"] + + uncond_vision_segment["num_vision_tokens"] + + uncond_sound_segment.get("sound_len", 0), + device=device, + ) uncond_mrope_segments = [uncond_text_segment["text_mrope_ids"], uncond_vision_segment["vision_mrope_ids"]] if uncond_sound_segment: uncond_mrope_segments.append(uncond_sound_segment["sound_mrope_ids"]) + if uncond_action_segment: + uncond_mrope_segments.append(uncond_action_segment["action_mrope_ids"]) uncond_packed_static = { **uncond_text_segment, **uncond_vision_segment, **uncond_sound_segment, + **uncond_action_segment, "position_ids": torch.cat(uncond_mrope_segments, dim=1), "sequence_length": uncond_text_segment["und_len"] + uncond_vision_segment["num_vision_tokens"] - + uncond_sound_segment.get("sound_len", 0), + + uncond_sound_segment.get("sound_len", 0) + + uncond_action_segment.get("action_len", 0), } num_noisy_vision_tokens = cond_vision_segment["num_noisy_vision_tokens"] sound_len = cond_sound_segment.get("sound_len") + action_noisy_len = cond_action_segment.get("num_noisy_action_tokens") # 6. Set timesteps. UniPCMultistepScheduler keeps per-step state (_step_index, - # model_outputs history) on the instance, so audio gets its own copy. + # model_outputs history) on the instance, so sound/action each get their own copy. self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps sound_scheduler = copy.deepcopy(self.scheduler) if sound_latents is not None else None + action_scheduler = copy.deepcopy(self.scheduler) if action_latents is not None else None # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -922,15 +1546,19 @@ def __call__( # noisy tokens before packing so the modality tokens enter the model in the right dtype. vision_tokens = latents.to(device=device, dtype=dtype) sound_tokens = sound_latents.to(device=device, dtype=dtype) if sound_latents is not None else None + action_tokens = action_latents.to(device=device, dtype=dtype) if action_latents is not None else None # The static packs both report the same num_noisy_vision_tokens / sound_len, so a # single per-step timestep tensor per modality is shared by the cond / uncond passes. vision_timesteps = torch.full((num_noisy_vision_tokens,), timestep, device=device) sound_timesteps = ( torch.full((sound_len,), timestep, device=device) if sound_tokens is not None else None ) + action_timesteps = ( + torch.full((action_noisy_len,), timestep, device=device) if action_tokens is not None else None + ) # --- Conditional pass --- - preds_vision, preds_sound = self.transformer( + preds_vision, preds_sound, preds_action = self.transformer( input_ids=cond_packed_static["input_ids"], text_indexes=cond_packed_static["text_indexes"], position_ids=cond_packed_static["position_ids"], @@ -948,17 +1576,28 @@ def __call__( sound_mse_loss_indexes=cond_packed_static.get("sound_mse_loss_indexes"), sound_timesteps=sound_timesteps, sound_noisy_frame_indexes=cond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=cond_packed_static.get("action_token_shapes"), + action_sequence_indexes=cond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=cond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=cond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, ) - cond_v_vision, cond_v_sound = self._mask_velocity_predictions( + cond_v_vision, cond_v_sound, cond_v_action = self._mask_velocity_predictions( preds_vision, preds_sound, vision_condition_mask=[vision_condition_mask], sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, + raw_action_dim=raw_action_dim_resolved, ) # --- Unconditional pass (Skip if not using CFG) --- - if guidance_scale != 1.0: - preds_vision, preds_sound = self.transformer( + uncond_v_vision = uncond_v_sound = uncond_v_action = None + if self.do_classifier_free_guidance: + preds_vision, preds_sound, preds_action = self.transformer( input_ids=uncond_packed_static["input_ids"], text_indexes=uncond_packed_static["text_indexes"], position_ids=uncond_packed_static["position_ids"], @@ -976,12 +1615,22 @@ def __call__( sound_mse_loss_indexes=uncond_packed_static.get("sound_mse_loss_indexes"), sound_timesteps=sound_timesteps, sound_noisy_frame_indexes=uncond_packed_static.get("sound_noisy_frame_indexes"), + action_tokens=[action_tokens] if action_tokens is not None else None, + action_token_shapes=uncond_packed_static.get("action_token_shapes"), + action_sequence_indexes=uncond_packed_static.get("action_sequence_indexes"), + action_mse_loss_indexes=uncond_packed_static.get("action_mse_loss_indexes"), + action_timesteps=action_timesteps, + action_noisy_frame_indexes=uncond_packed_static.get("action_noisy_frame_indexes"), + action_domain_ids=[action_domain_id] if action_domain_id is not None else None, ) - uncond_v_vision, uncond_v_sound = self._mask_velocity_predictions( + uncond_v_vision, uncond_v_sound, uncond_v_action = self._mask_velocity_predictions( preds_vision, preds_sound, vision_condition_mask=[vision_condition_mask], sound_condition_mask=[sound_condition_mask] if sound_condition_mask is not None else None, + preds_action=preds_action, + action_condition_mask=[action_condition_mask] if action_condition_mask is not None else None, + raw_action_dim=raw_action_dim_resolved, ) # --- CFG combine + per-modality scheduler step --- @@ -989,7 +1638,7 @@ def __call__( # to carry a batch dim; per-modality latents have no batch axis, so wrap for the step. # Skip CFG for 1.0 guidance scale - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_vision = uncond_v_vision + guidance_scale * (cond_v_vision - uncond_v_vision) else: velocity_vision = cond_v_vision @@ -1000,7 +1649,7 @@ def __call__( if sound_scheduler is not None and cond_v_sound is not None: # Skip CFG for 1.0 guidance scale - if guidance_scale != 1.0: + if self.do_classifier_free_guidance: velocity_sound = uncond_v_sound + guidance_scale * (cond_v_sound - uncond_v_sound) else: velocity_sound = cond_v_sound @@ -1008,6 +1657,20 @@ def __call__( velocity_sound.unsqueeze(0), t, sound_latents.unsqueeze(0), return_dict=False )[0].squeeze(0) + has_noisy_action = ( + action_condition_mask is not None and action_condition_mask.sum() < action_condition_mask.numel() + ) + if action_scheduler is not None and has_noisy_action and cond_v_action is not None: + if self.do_classifier_free_guidance: + velocity_action = uncond_v_action + guidance_scale * (cond_v_action - uncond_v_action) + else: + velocity_action = cond_v_action + action_latents = action_scheduler.step( + velocity_action.unsqueeze(0), t, action_latents.unsqueeze(0), return_dict=False + )[0].squeeze(0) + if raw_action_dim_resolved is not None: + action_latents[:, raw_action_dim_resolved:] = 0 + if callback_on_step_end is not None: callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs} callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) @@ -1020,6 +1683,12 @@ def __call__( # 8. Postprocess + decode sound = self.decode_sound(sound_latents) if sound_latents is not None else None + action_output = None + if action_mode in {"inverse_dynamics", "policy"} and action_latents is not None: + action_output = action_latents + if raw_action_dim_resolved is not None: + action_output = action_output[:, :raw_action_dim_resolved] + action_output = [action_output.detach().cpu()] if output_type == "latent": video = latents else: @@ -1037,5 +1706,7 @@ def __call__( self.maybe_free_model_hooks() if not return_dict: + if action_mode is not None: + return (video, sound, action_output) return (video, sound) - return Cosmos3OmniPipelineOutput(video=video, sound=sound) + return Cosmos3OmniPipelineOutput(video=video, sound=sound, action=action_output) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 3b8004110802..e4a08776f143 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1397,6 +1397,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CosmosActionCondition(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CosmosTextToWorldPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]