Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions tests/pytorch/test_backward_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
id="NVFP4BlockScaling",
),
pytest.param(
"nvfp4_pertoken",
marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
id="NVFP4PerTokenBlockScaling",
),
]


Expand Down Expand Up @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype(
) -> None:
if dtype == torch.bfloat16 and not bf16_available:
pytest.skip(reason_for_no_bf16)
if recipe_name == "nvfp4":
if recipe_name in ("nvfp4", "nvfp4_pertoken"):
if module_type in ("linear", "layernorm_linear") and dtype not in (
torch.bfloat16,
torch.float32,
Expand Down Expand Up @@ -195,7 +200,9 @@ def _maybe_skip_unsupported_recipe_shape(
" by 32."
)
return
if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0):
if recipe_name in ("nvfp4", "nvfp4_pertoken") and (
flat_first_dim % 16 != 0 or last_dim % 16 != 0
):
pytest.skip(
"Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible"
" by 16."
Expand All @@ -220,7 +227,9 @@ def _maybe_skip_unsupported_recipe_shape(
pytest.skip(
"te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32."
)
if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0):
if recipe_name in ("nvfp4", "nvfp4_pertoken") and (
flat_first_dim % 16 != 0 or last_dim % 16 != 0
):
pytest.skip(
"te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16."
)
Expand All @@ -239,9 +248,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]
)
if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits):
pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.")
if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits):
if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits):
pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.")
if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits):
if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 64 != 0 for m in non_empty_splits):
pytest.skip(
"GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty "
"m_split divisible by 64 due to grouped amax kernel constraints."
Expand Down
Loading