Skip to content

feat(wan22): add WAN 2.2 text-to-video adapter and dataset for MLPerf inference #293

Draft
wu6u3tw wants to merge 6 commits intomlcommons:mainfrom
wu6u3tw:feat/wan22
Draft

feat(wan22): add WAN 2.2 text-to-video adapter and dataset for MLPerf inference #293
wu6u3tw wants to merge 6 commits intomlcommons:mainfrom
wu6u3tw:feat/wan22

Conversation

@wu6u3tw
Copy link
Copy Markdown
Collaborator

@wu6u3tw wu6u3tw commented Apr 22, 2026

Summary

  • Adds a new wan22 module with Wan22Adapter, Wan22Accumulator, Wan22Dataset, and Pydantic wire types for the trtllm-serve POST /v1/videos/generations endpoint.
  • Wan22Adapter uses response_format=video_path: the server saves the encoded video to shared storage (Lustre) and returns only the file path, avoiding 3–5 MB of base64 video bytes per request over
    HTTP and ZMQ transport.
  • Wan22Dataset loads MLPerf WAN2.2 prompt text files (one prompt per line); dataset ID wan22_mlperf is registered with DataLoaderFactory for --dataset CLI use.
  • Registers APIType.WAN22 and wires Wan22Adapter/Wan22Accumulator into HTTPClientConfig.
  • Adds an example offline benchmark YAML for Lyris (GB200/GB300) targeting a local trtllm-serve instance with MLPerf-standard params (720×1280, 5 s, 20 steps, guidance 4.0/3.0, seed 42, 248 prompts).
  • Fixes with_updates() to reset adapter and accumulator when api_type changes.

Test plan

  • pytest -m unit tests/unit/wan22/ — adapter, dataset, factory, types, init, registration unit tests
  • pytest -m integration tests/integration/wan22/ — adapter round-trip with mock server
  • pre-commit run --all-files passes clean
  • Offline benchmark runs end-to-end against a live trtllm-serve instance on Lyris (manual)

What does this PR do?

Type of change

  • Bug fix
  • New feature
  • Documentation update
  • Refactor/cleanup

Related issues

Testing

  • Tests added/updated
  • All tests pass locally
  • Manual testing completed

Checklist

  • Code follows project style
  • Pre-commit hooks pass
  • Documentation updated (if needed)

@wu6u3tw wu6u3tw requested a review from a team April 22, 2026 22:25
@wu6u3tw wu6u3tw self-assigned this Apr 22, 2026
@wu6u3tw wu6u3tw marked this pull request as draft April 22, 2026 22:25
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the WAN2.2 MLPerf text-to-video benchmark, including a new adapter, dataset loader, and associated Pydantic models. It also adds comprehensive documentation and an example configuration for running benchmarks on Lyris. The review feedback identifies several critical omissions and inconsistencies: the VideoPathRequest and Wan22Dataset are missing the latent_path field required for MLPerf reproducibility, and there is a mismatch between the adapter implementation and unit tests regarding the response_format and handling of VideoPayloadResponse. Additionally, the feedback suggests using None as a default for negative_prompt to allow server-side defaults and injecting the canonical MLPerf negative prompt into the dataset.

Comment on lines +35 to +53
negative_prompt: str = ""
size: str = Field(default="720x1280", description="Frame size in 'WxH' format.")
seconds: float = Field(
default=5.0,
description="Video duration. 81 frames @ ~16.2 fps = 5 s (MLPerf standard).",
)
fps: int = Field(default=16, description="Frames per second (MLPerf: 16).")
num_inference_steps: int = Field(
default=20, description="Denoising steps (MLPerf: 20)."
)
guidance_scale: float = Field(
default=4.0, description="CFG guidance scale (MLPerf: 4.0)."
)
guidance_scale_2: float = Field(
default=3.0, description="Secondary guidance scale for null-text CFG (MLPerf: 3.0)."
)
seed: int = Field(default=42, description="Random seed (MLPerf: 42).")
output_format: Literal["mp4", "avi", "auto"] = "auto"
response_format: Literal["video_bytes", "video_path"] = "video_path"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The VideoPathRequest model is missing the latent_path field, which is required for MLPerf reproducibility as mentioned in the design plan. Additionally, negative_prompt should default to None to allow the server to use its own default when the field is omitted from the JSON payload (using exclude_none=True in the adapter).

    negative_prompt: str | None = Field(
        default=None,
        description="Text describing what to avoid. None = let server default.",
    )
    size: str = Field(default="720x1280", description="Frame size in 'WxH' format.")
    seconds: float = Field(
        default=5.0,
        description="Video duration. 81 frames @ ~16.2 fps = 5 s (MLPerf standard).",
    )
    fps: int = Field(default=16, description="Frames per second (MLPerf: 16).")
    num_inference_steps: int = Field(
        default=20, description="Denoising steps (MLPerf: 20)."
    )
    guidance_scale: float = Field(
        default=4.0, description="CFG guidance scale (MLPerf: 4.0)."
    )
    guidance_scale_2: float = Field(
        default=3.0, description="Secondary guidance scale for null-text CFG (MLPerf: 3.0)."
    )
    seed: int = Field(default=42, description="Random seed (MLPerf: 42)."
    latent_path: str | None = Field(
        default=None,
        description="Absolute path to a pre-computed latent tensor (.pt file) on shared storage.",
    )
    output_format: Literal["mp4", "avi", "auto"] = "auto"
    response_format: Literal["video_bytes", "video_path"] = "video_path"

Comment on lines +63 to +76
req = VideoPathRequest(
prompt=data["prompt"],
negative_prompt=data.get("negative_prompt", ""),
size=data.get("size", "720x1280"),
seconds=data.get("seconds", 5.0),
fps=data.get("fps", 16),
num_inference_steps=data.get("num_inference_steps", 20),
guidance_scale=data.get("guidance_scale", 4.0),
guidance_scale_2=data.get("guidance_scale_2", 3.0),
seed=data.get("seed", 42),
output_format=data.get("output_format", "auto"),
response_format="video_path",
)
return req.model_dump_json().encode()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a mismatch between the encode_query implementation and the unit tests. The code hardcodes response_format="video_path", but TestWan22Adapter.test_encode_query_always_requests_video_bytes (line 41) asserts it is video_bytes. Additionally, latent_path is not being passed from query.data to the request model, and exclude_none=True should be used to allow server-side defaults for optional fields like negative_prompt.

Suggested change
req = VideoPathRequest(
prompt=data["prompt"],
negative_prompt=data.get("negative_prompt", ""),
size=data.get("size", "720x1280"),
seconds=data.get("seconds", 5.0),
fps=data.get("fps", 16),
num_inference_steps=data.get("num_inference_steps", 20),
guidance_scale=data.get("guidance_scale", 4.0),
guidance_scale_2=data.get("guidance_scale_2", 3.0),
seed=data.get("seed", 42),
output_format=data.get("output_format", "auto"),
response_format="video_path",
)
return req.model_dump_json().encode()
req = VideoPathRequest(
prompt=data["prompt"],
negative_prompt=data.get("negative_prompt"),
size=data.get("size", "720x1280"),
seconds=data.get("seconds", 5.0),
fps=data.get("fps", 16),
num_inference_steps=data.get("num_inference_steps", 20),
guidance_scale=data.get("guidance_scale", 4.0),
guidance_scale_2=data.get("guidance_scale_2", 3.0),
seed=data.get("seed", 42),
output_format=data.get("output_format", "auto"),
response_format=data.get("response_format", "video_bytes"),
latent_path=data.get("latent_path"),
)
return req.model_dump_json(exclude_none=True).encode()

Comment on lines +79 to +90
def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult:
"""Deserialise trtllm-serve VideoPathResponse JSON bytes to QueryResult.

metadata["video_path"] carries the Lustre path to the encoded video for
the accuracy evaluator.
"""
resp = VideoPathResponse.model_validate_json(response_bytes)
return QueryResult(
id=query_id,
response_output=TextModelOutput(output=resp.video_id),
metadata={"video_path": resp.video_path},
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The decode_response method currently only handles VideoPathResponse. However, the unit tests (e.g., test_decode_response_returns_video_bytes_in_metadata at line 74) use VideoPayloadResponse and expect video_bytes in the metadata. The adapter should be updated to handle both response formats to support both performance (path-based) and accuracy (payload-based) modes.

    @classmethod
    def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult:
        """Deserialise trtllm-serve response JSON bytes to QueryResult.

        Supports both video_path (perf) and video_bytes (accuracy) formats.
        """
        try:
            # Try parsing as path response first (perf mode)
            resp = VideoPathResponse.model_validate_json(response_bytes)
            metadata = {"video_path": resp.video_path}
            video_id = resp.video_id
        except Exception:
            # Fallback to payload response (accuracy mode)
            from .types import VideoPayloadResponse
            resp = VideoPayloadResponse.model_validate_json(response_bytes)
            metadata = {"video_bytes": resp.video_bytes}
            video_id = resp.video_id

        return QueryResult(
            id=query_id,
            response_output=TextModelOutput(output=video_id),
            metadata=metadata,
        )

Comment on lines +36 to +79
def get_dataloader( # type: ignore[override]
cls,
path: Path | str | None = None,
negative_prompt: str = "",
**kwargs: Any,
) -> "Wan22Dataset":
"""Create a Wan22Dataset from a prompts file path.

Called by DataLoaderFactory when ``--dataset <path>`` is used with
``name=wan22_mlperf``. The ``path`` argument maps directly to
``prompts_path``.
"""
if path is None:
raise ValueError(
"Wan22Dataset requires a prompts file path. "
"Pass --dataset <path/to/prompts.txt> or set path= in the dataset config."
)
return cls(prompts_path=path, negative_prompt=negative_prompt)

def __init__(
self,
prompts_path: Path | str,
negative_prompt: str = "",
) -> None:
prompts = [
line.strip()
for line in Path(prompts_path).read_text().splitlines()
if line.strip()
]
super().__init__(dataframe=pd.DataFrame({"prompt": prompts}))
self.negative_prompt = negative_prompt

def load(self, **kwargs: Any) -> None: # type: ignore[override]
"""Build self.data from the loaded dataframe. No transforms needed."""
assert self.dataframe is not None
self.data = [
{
"prompt": row["prompt"],
"negative_prompt": self.negative_prompt,
"sample_id": str(i),
"sample_index": i,
}
for i, row in self.dataframe.iterrows()
]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Wan22Dataset implementation is missing the latent_path parameter and the canonical MLPerf negative prompt defined in the design plan. These are necessary for the MLPerf T2V workload. The load method should also be updated to only include negative_prompt and latent_path in the sample data if they are provided.

    _MLPERF_NEGATIVE_PROMPT = (
        "vivid colors, overexposed, static, blurry details, subtitles, style, "
        "work of art, painting, picture, still, overall grayish, worst quality, "
        "low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, "
        "poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, "
        "static image, cluttered background, three legs, many people in the background, "
        "walking backwards"
    )

    @classmethod
    def get_dataloader(  # type: ignore[override]
        cls,
        path: Path | str | None = None,
        negative_prompt: str | None = _MLPERF_NEGATIVE_PROMPT,
        latent_path: Path | str | None = None,
        **kwargs: Any,
    ) -> "Wan22Dataset":
        if path is None:
            raise ValueError("Wan22Dataset requires a prompts file path.")
        return cls(prompts_path=path, negative_prompt=negative_prompt, latent_path=latent_path)

    def __init__(
        self,
        prompts_path: Path | str,
        negative_prompt: str | None = _MLPERF_NEGATIVE_PROMPT,
        latent_path: Path | str | None = None,
    ) -> None:
        prompts = [
            line.strip()
            for line in Path(prompts_path).read_text().splitlines()
            if line.strip()
        ]
        super().__init__(dataframe=pd.DataFrame({"prompt": prompts}))
        self.negative_prompt = negative_prompt
        self.latent_path = latent_path

    def load(self, **kwargs: Any) -> None:  # type: ignore[override]
        assert self.dataframe is not None
        self.data = [
            {
                "prompt": row["prompt"],
                **({"negative_prompt": self.negative_prompt} if self.negative_prompt is not None else {}),
                **({"latent_path": str(self.latent_path)} if self.latent_path is not None else {}),
                "sample_id": str(i),
                "sample_index": i,
            }
            for i, row in self.dataframe.iterrows()
        ]

wu6u3tw and others added 3 commits April 23, 2026 16:36
…er, Wan22Dataset→VideoGenDataset

- Rename src/inference_endpoint/wan22/ → videogen/
- Rename tests/unit/wan22/ → tests/unit/videogen/
- Rename tests/integration/wan22/ → tests/integration/videogen/
- APIType.WAN22 → APIType.VIDEOGEN ("wan22" → "videogen")
- Wan22Adapter → VideoGenAdapter
- Wan22Dataset → VideoGenDataset
- Wan22Accumulator → VideoGenAccumulator
- Update all imports, maps, __all__, tests, docs, and example yaml
- Keep dataset_id="wan22_mlperf" and model_params.name="wan22" (MLPerf identifiers)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…bytes)

encode_query: response_format defaults to "video_bytes" but can be overridden
via query.data["response_format"] = "video_path" for Lustre-path mode.

decode_response: dispatches on response shape — "video_bytes" key → VideoPayloadResponse,
otherwise → VideoPathResponse.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant