feat(wan22): add WAN 2.2 text-to-video adapter and dataset for MLPerf inference #293
feat(wan22): add WAN 2.2 text-to-video adapter and dataset for MLPerf inference #293wu6u3tw wants to merge 6 commits intomlcommons:mainfrom
Conversation
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
There was a problem hiding this comment.
Code Review
This pull request introduces support for the WAN2.2 MLPerf text-to-video benchmark, including a new adapter, dataset loader, and associated Pydantic models. It also adds comprehensive documentation and an example configuration for running benchmarks on Lyris. The review feedback identifies several critical omissions and inconsistencies: the VideoPathRequest and Wan22Dataset are missing the latent_path field required for MLPerf reproducibility, and there is a mismatch between the adapter implementation and unit tests regarding the response_format and handling of VideoPayloadResponse. Additionally, the feedback suggests using None as a default for negative_prompt to allow server-side defaults and injecting the canonical MLPerf negative prompt into the dataset.
| negative_prompt: str = "" | ||
| size: str = Field(default="720x1280", description="Frame size in 'WxH' format.") | ||
| seconds: float = Field( | ||
| default=5.0, | ||
| description="Video duration. 81 frames @ ~16.2 fps = 5 s (MLPerf standard).", | ||
| ) | ||
| fps: int = Field(default=16, description="Frames per second (MLPerf: 16).") | ||
| num_inference_steps: int = Field( | ||
| default=20, description="Denoising steps (MLPerf: 20)." | ||
| ) | ||
| guidance_scale: float = Field( | ||
| default=4.0, description="CFG guidance scale (MLPerf: 4.0)." | ||
| ) | ||
| guidance_scale_2: float = Field( | ||
| default=3.0, description="Secondary guidance scale for null-text CFG (MLPerf: 3.0)." | ||
| ) | ||
| seed: int = Field(default=42, description="Random seed (MLPerf: 42).") | ||
| output_format: Literal["mp4", "avi", "auto"] = "auto" | ||
| response_format: Literal["video_bytes", "video_path"] = "video_path" |
There was a problem hiding this comment.
The VideoPathRequest model is missing the latent_path field, which is required for MLPerf reproducibility as mentioned in the design plan. Additionally, negative_prompt should default to None to allow the server to use its own default when the field is omitted from the JSON payload (using exclude_none=True in the adapter).
negative_prompt: str | None = Field(
default=None,
description="Text describing what to avoid. None = let server default.",
)
size: str = Field(default="720x1280", description="Frame size in 'WxH' format.")
seconds: float = Field(
default=5.0,
description="Video duration. 81 frames @ ~16.2 fps = 5 s (MLPerf standard).",
)
fps: int = Field(default=16, description="Frames per second (MLPerf: 16).")
num_inference_steps: int = Field(
default=20, description="Denoising steps (MLPerf: 20)."
)
guidance_scale: float = Field(
default=4.0, description="CFG guidance scale (MLPerf: 4.0)."
)
guidance_scale_2: float = Field(
default=3.0, description="Secondary guidance scale for null-text CFG (MLPerf: 3.0)."
)
seed: int = Field(default=42, description="Random seed (MLPerf: 42)."
latent_path: str | None = Field(
default=None,
description="Absolute path to a pre-computed latent tensor (.pt file) on shared storage.",
)
output_format: Literal["mp4", "avi", "auto"] = "auto"
response_format: Literal["video_bytes", "video_path"] = "video_path"| req = VideoPathRequest( | ||
| prompt=data["prompt"], | ||
| negative_prompt=data.get("negative_prompt", ""), | ||
| size=data.get("size", "720x1280"), | ||
| seconds=data.get("seconds", 5.0), | ||
| fps=data.get("fps", 16), | ||
| num_inference_steps=data.get("num_inference_steps", 20), | ||
| guidance_scale=data.get("guidance_scale", 4.0), | ||
| guidance_scale_2=data.get("guidance_scale_2", 3.0), | ||
| seed=data.get("seed", 42), | ||
| output_format=data.get("output_format", "auto"), | ||
| response_format="video_path", | ||
| ) | ||
| return req.model_dump_json().encode() |
There was a problem hiding this comment.
There is a mismatch between the encode_query implementation and the unit tests. The code hardcodes response_format="video_path", but TestWan22Adapter.test_encode_query_always_requests_video_bytes (line 41) asserts it is video_bytes. Additionally, latent_path is not being passed from query.data to the request model, and exclude_none=True should be used to allow server-side defaults for optional fields like negative_prompt.
| req = VideoPathRequest( | |
| prompt=data["prompt"], | |
| negative_prompt=data.get("negative_prompt", ""), | |
| size=data.get("size", "720x1280"), | |
| seconds=data.get("seconds", 5.0), | |
| fps=data.get("fps", 16), | |
| num_inference_steps=data.get("num_inference_steps", 20), | |
| guidance_scale=data.get("guidance_scale", 4.0), | |
| guidance_scale_2=data.get("guidance_scale_2", 3.0), | |
| seed=data.get("seed", 42), | |
| output_format=data.get("output_format", "auto"), | |
| response_format="video_path", | |
| ) | |
| return req.model_dump_json().encode() | |
| req = VideoPathRequest( | |
| prompt=data["prompt"], | |
| negative_prompt=data.get("negative_prompt"), | |
| size=data.get("size", "720x1280"), | |
| seconds=data.get("seconds", 5.0), | |
| fps=data.get("fps", 16), | |
| num_inference_steps=data.get("num_inference_steps", 20), | |
| guidance_scale=data.get("guidance_scale", 4.0), | |
| guidance_scale_2=data.get("guidance_scale_2", 3.0), | |
| seed=data.get("seed", 42), | |
| output_format=data.get("output_format", "auto"), | |
| response_format=data.get("response_format", "video_bytes"), | |
| latent_path=data.get("latent_path"), | |
| ) | |
| return req.model_dump_json(exclude_none=True).encode() |
| def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult: | ||
| """Deserialise trtllm-serve VideoPathResponse JSON bytes to QueryResult. | ||
|
|
||
| metadata["video_path"] carries the Lustre path to the encoded video for | ||
| the accuracy evaluator. | ||
| """ | ||
| resp = VideoPathResponse.model_validate_json(response_bytes) | ||
| return QueryResult( | ||
| id=query_id, | ||
| response_output=TextModelOutput(output=resp.video_id), | ||
| metadata={"video_path": resp.video_path}, | ||
| ) |
There was a problem hiding this comment.
The decode_response method currently only handles VideoPathResponse. However, the unit tests (e.g., test_decode_response_returns_video_bytes_in_metadata at line 74) use VideoPayloadResponse and expect video_bytes in the metadata. The adapter should be updated to handle both response formats to support both performance (path-based) and accuracy (payload-based) modes.
@classmethod
def decode_response(cls, response_bytes: bytes, query_id: str) -> QueryResult:
"""Deserialise trtllm-serve response JSON bytes to QueryResult.
Supports both video_path (perf) and video_bytes (accuracy) formats.
"""
try:
# Try parsing as path response first (perf mode)
resp = VideoPathResponse.model_validate_json(response_bytes)
metadata = {"video_path": resp.video_path}
video_id = resp.video_id
except Exception:
# Fallback to payload response (accuracy mode)
from .types import VideoPayloadResponse
resp = VideoPayloadResponse.model_validate_json(response_bytes)
metadata = {"video_bytes": resp.video_bytes}
video_id = resp.video_id
return QueryResult(
id=query_id,
response_output=TextModelOutput(output=video_id),
metadata=metadata,
)| def get_dataloader( # type: ignore[override] | ||
| cls, | ||
| path: Path | str | None = None, | ||
| negative_prompt: str = "", | ||
| **kwargs: Any, | ||
| ) -> "Wan22Dataset": | ||
| """Create a Wan22Dataset from a prompts file path. | ||
|
|
||
| Called by DataLoaderFactory when ``--dataset <path>`` is used with | ||
| ``name=wan22_mlperf``. The ``path`` argument maps directly to | ||
| ``prompts_path``. | ||
| """ | ||
| if path is None: | ||
| raise ValueError( | ||
| "Wan22Dataset requires a prompts file path. " | ||
| "Pass --dataset <path/to/prompts.txt> or set path= in the dataset config." | ||
| ) | ||
| return cls(prompts_path=path, negative_prompt=negative_prompt) | ||
|
|
||
| def __init__( | ||
| self, | ||
| prompts_path: Path | str, | ||
| negative_prompt: str = "", | ||
| ) -> None: | ||
| prompts = [ | ||
| line.strip() | ||
| for line in Path(prompts_path).read_text().splitlines() | ||
| if line.strip() | ||
| ] | ||
| super().__init__(dataframe=pd.DataFrame({"prompt": prompts})) | ||
| self.negative_prompt = negative_prompt | ||
|
|
||
| def load(self, **kwargs: Any) -> None: # type: ignore[override] | ||
| """Build self.data from the loaded dataframe. No transforms needed.""" | ||
| assert self.dataframe is not None | ||
| self.data = [ | ||
| { | ||
| "prompt": row["prompt"], | ||
| "negative_prompt": self.negative_prompt, | ||
| "sample_id": str(i), | ||
| "sample_index": i, | ||
| } | ||
| for i, row in self.dataframe.iterrows() | ||
| ] |
There was a problem hiding this comment.
The Wan22Dataset implementation is missing the latent_path parameter and the canonical MLPerf negative prompt defined in the design plan. These are necessary for the MLPerf T2V workload. The load method should also be updated to only include negative_prompt and latent_path in the sample data if they are provided.
_MLPERF_NEGATIVE_PROMPT = (
"vivid colors, overexposed, static, blurry details, subtitles, style, "
"work of art, painting, picture, still, overall grayish, worst quality, "
"low quality, JPEG artifacts, ugly, deformed, extra fingers, poorly drawn hands, "
"poorly drawn face, deformed, disfigured, deformed limbs, fused fingers, "
"static image, cluttered background, three legs, many people in the background, "
"walking backwards"
)
@classmethod
def get_dataloader( # type: ignore[override]
cls,
path: Path | str | None = None,
negative_prompt: str | None = _MLPERF_NEGATIVE_PROMPT,
latent_path: Path | str | None = None,
**kwargs: Any,
) -> "Wan22Dataset":
if path is None:
raise ValueError("Wan22Dataset requires a prompts file path.")
return cls(prompts_path=path, negative_prompt=negative_prompt, latent_path=latent_path)
def __init__(
self,
prompts_path: Path | str,
negative_prompt: str | None = _MLPERF_NEGATIVE_PROMPT,
latent_path: Path | str | None = None,
) -> None:
prompts = [
line.strip()
for line in Path(prompts_path).read_text().splitlines()
if line.strip()
]
super().__init__(dataframe=pd.DataFrame({"prompt": prompts}))
self.negative_prompt = negative_prompt
self.latent_path = latent_path
def load(self, **kwargs: Any) -> None: # type: ignore[override]
assert self.dataframe is not None
self.data = [
{
"prompt": row["prompt"],
**({"negative_prompt": self.negative_prompt} if self.negative_prompt is not None else {}),
**({"latent_path": str(self.latent_path)} if self.latent_path is not None else {}),
"sample_id": str(i),
"sample_index": i,
}
for i, row in self.dataframe.iterrows()
]…er, Wan22Dataset→VideoGenDataset
- Rename src/inference_endpoint/wan22/ → videogen/
- Rename tests/unit/wan22/ → tests/unit/videogen/
- Rename tests/integration/wan22/ → tests/integration/videogen/
- APIType.WAN22 → APIType.VIDEOGEN ("wan22" → "videogen")
- Wan22Adapter → VideoGenAdapter
- Wan22Dataset → VideoGenDataset
- Wan22Accumulator → VideoGenAccumulator
- Update all imports, maps, __all__, tests, docs, and example yaml
- Keep dataset_id="wan22_mlperf" and model_params.name="wan22" (MLPerf identifiers)
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…bytes) encode_query: response_format defaults to "video_bytes" but can be overridden via query.data["response_format"] = "video_path" for Lustre-path mode. decode_response: dispatches on response shape — "video_bytes" key → VideoPayloadResponse, otherwise → VideoPathResponse. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
wan22module withWan22Adapter,Wan22Accumulator,Wan22Dataset, and Pydantic wire types for the trtllm-servePOST /v1/videos/generationsendpoint.Wan22Adapterusesresponse_format=video_path: the server saves the encoded video to shared storage (Lustre) and returns only the file path, avoiding 3–5 MB of base64 video bytes per request overHTTP and ZMQ transport.
Wan22Datasetloads MLPerf WAN2.2 prompt text files (one prompt per line); dataset IDwan22_mlperfis registered withDataLoaderFactoryfor--datasetCLI use.APIType.WAN22and wiresWan22Adapter/Wan22AccumulatorintoHTTPClientConfig.with_updates()to reset adapter and accumulator whenapi_typechanges.Test plan
pytest -m unit tests/unit/wan22/— adapter, dataset, factory, types, init, registration unit testspytest -m integration tests/integration/wan22/— adapter round-trip with mock serverpre-commit run --all-filespasses cleanWhat does this PR do?
Type of change
Related issues
Testing
Checklist