Skip to content

Improve the dimension checks for the FP8 recipes#2894

Open
ptrendx wants to merge 19 commits intoNVIDIA:mainfrom
ptrendx:pr_relax_dim_checks
Open

Improve the dimension checks for the FP8 recipes#2894
ptrendx wants to merge 19 commits intoNVIDIA:mainfrom
ptrendx:pr_relax_dim_checks

Conversation

@ptrendx
Copy link
Copy Markdown
Member

@ptrendx ptrendx commented Apr 16, 2026

Description

Improved the dimension checks for the FP8 recipes. The old one was too restrictive and did not have any recipe-aware logic.

Dimension check changes

Layer Recipe Before After
Python assert_dim_for_fp8_exec (utils.py) all FP8/FP4 M % 8 == 0last_dim % 16 == 0, run before the recipe is known Removed — guard was recipe-agnostic and stricter than any underlying kernel required
C++ MXFP8Quantizer::create_tensor / get_scale_shape MXFP8 first_dim % 32last_dim % 32 last_dim % 16 only (TMA 16 B alignment). Scale tensor uses DIVUP for partial trailing blocks
C++ NVFP4Quantizer::create_tensor / get_scale_shape NVFP4 first_dim % 32last_dim % 32 last_dim % 32 always; first_dim % 16 only when RHT is enabled (Hadamard block size)
C++ swizzle_block_scaling.cu kernel Float8 1×128 block data_rows % 4 == 0 Unchanged — real kernel requirement; error message clarified with actual data_rows
C++ cublaslt_gemm.cu MXFP8 path MXFP8 No explicit check — cuBLAS returns opaque CUBLAS_STATUS_NOT_SUPPORTED` / "no algo found" lda % 16 == 0ldb % 16 == 0 via min_alignment_elements(dtype) using typeToNumBits (16 B alignment for FP8)
C++ cublaslt_gemm.cu NVFP4 path NVFP4 No explicit check — opaque cuBLAS error lda % 32 == 0ldb % 32 == 0 (16 B alignment for FP4 = 32 elements)
C++ cublaslt_gemm.cu 1D block-scaled FP8 path Float8BlockScaling 1×128 No explicit check n % 8 == 0 for the native 1D block-scaling path
Python is_quantizablesupports_quantized_allgather (mxfp8_tensor.py) MXFP8 Both-dim % 32 Renamed. last_dim % 16; also first_dim % 32 when columnwise usage (block boundary across ranks)
Python is_quantizablesupports_quantized_allgather (nvfp4_tensor.py) NVFP4 Both-dim % 32 Renamed. last_dim % 32; also first_dim % 16 when columnwise usage

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

ptrendx added 12 commits April 14, 2026 17:38
This one-size-fits-all Python guard checked dimensions before the
recipe was known, rejecting valid shapes. Dimension validation is
handled per-recipe in the C++ quantizer where requirements are known.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
- MXFP8: only require last dim divisible by 16 (was: both dims by 32).
  The kernel handles partial blocks via bounds checks and TMA zero-padding.
  Scale tensors are over-allocated with roundup alignment.
- NVFP4: only require last dim divisible by 32 (was: both dims by 16).
  4-bit data needs 32 elements for 16-byte alignment.
- Float8BlockScaling swizzle: remove data_rows%4 assertion. The kernel
  already handles non-aligned rows via DIVUP and OOB zero-fill.
- Fix integer truncation in MXFP8/NVFP4 get_scale_shape to use ceildiv
  instead of plain division, ensuring correct scale allocation for
  non-block-aligned dimensions.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
- MXFP8: only check last dim divisible by 16 (removed first-dim check)
- NVFP4: only check last dim divisible by 32 (removed first-dim check)

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
NVFP4 backward pass uses the Hadamard transform which requires
num_rows % 16 == 0. Add a clear error message at the quantizer level
when columnwise_usage is enabled, instead of letting the user hit
the raw Hadamard kernel assertion.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
The old name was misleading — it doesn't check whether a tensor can
be quantized in general, but whether a local shard's shape supports
quantized all-gather without scaling factor blocks spanning across
GPU boundaries. The new name reflects the actual purpose.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Include actual dimension values (m, n, k, lda, ldb) in error messages
so users can trace which tensor dimension is misaligned. Reference
the cuBLAS documentation for FP8 alignment requirements.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Remove "Ensure all tensor dimensions..." sentence and parentheses
from cuBLAS documentation links for consistent error messages.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Replace hardcoded alignment values (16, 32) with a helper that
computes the minimum element alignment from the data type's bit
width and the 16-byte cuBLAS alignment requirement. This ensures
alignment checks stay correct if new data types are added.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Verifies that the relaxed dimension checks allow small M values
for recipes that support them on the forward pass.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Each recipe gets its tightest viable (M, N, K) for both inference and
training. The baseline pre-quantizes+dequantizes inputs and weights via
the recipe's quantizers so the comparison holds at BF16 tolerance.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 April 16, 2026 20:26
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 16, 2026

Greptile Summary

This PR refines FP8/FP4 dimension checks to be recipe-aware: removes the over-restrictive Python-level assert_dim_for_fp8_exec, relaxes MXFP8 to last_dim % 16 with ceildiv scale allocation (matching the C++ DIVUP), scopes NVFP4 C++ checks to quantization-only requirements, and adds explicit leading-dimension alignment checks in the cuBLAS GEMM paths. The previously flagged P0 (flat_first_dim NameError in NVFP4) and P1 (make_empty/get_scale_shape floor-div and FSDP2 columnwise guard) issues appear addressed in recent commits.

  • The aten.new_zeros dispatch in mxfp8_tensor.py (lines 576–583) computes scale shapes without round_up_to_nearest_multiple padding, while make_empty and get_scale_shape both apply this padding. This creates an undersized allocation for the newly-allowed last_dim = 16 case.

Confidence Score: 4/5

Safe to merge with minor concern about missing alignment padding in the new_zeros dispatch for newly-relaxed MXFP8 shapes.

The core dimension-check relaxation is correct and well-reasoned. Previously flagged P0/P1 issues (flat_first_dim NameError, ceildiv inconsistency, fsdp_pre_all_gather floor-div) are resolved. One P2 concern remains: the new_zeros dispatch uses unpadded scale shapes inconsistent with make_empty, which could produce undersized scale buffers for shapes newly enabled by this PR (last_dim = 16).

transformer_engine/pytorch/tensor/mxfp8_tensor.py — new_zeros dispatch scale shape computation (lines 576–583)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/mxfp8_tensor.py Relaxes MXFP8 constraints to last_dim % 16; fixes ceildiv in make_empty/get_scale_shape; relaxes new_zeros guard; but new_zeros does not apply round_up_to_nearest_multiple padding to computed scale shapes.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds supports_quantized_allgather with corrected NVFP4 constraints; make_empty now uses shape_2d for columnwise avoiding the previous flat_first_dim NameError; looks correct.
transformer_engine/common/gemm/cublaslt_gemm.cu Adds explicit lda/ldb alignment checks for MXFP8, NVFP4, and 1D block-scaled FP8 GEMMs using min_alignment_elements(); also adds m % 8 and n % 8 checks for block-scaled paths. Looks correct.
transformer_engine/pytorch/csrc/quantizer.cpp MXFP8 get_scale_shape now uses ceildiv (matching Python); NVFP4 create_tensor checks flat_last_dim % 2 (minimal quantization requirement) and flat_first_dim % 16 when RHT enabled; PR description claims last_dim % 32 for NVFP4 but code only enforces % 2.
transformer_engine/pytorch/utils.py assert_dim_for_fp8_exec removed; rest of utils.py unchanged and unaffected.
transformer_engine/common/swizzle/swizzle_block_scaling.cu data_rows % 4 == 0 check unchanged; error message clarified with actual data_rows value. No correctness concerns.
transformer_engine/pytorch/distributed.py Uses supports_quantized_allgather (renamed from is_quantizable) for MXFP8 and NVFP4 all-gather fallback logic; correct guard semantics preserved.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Tensor shape] --> B{last_dim % 16 == 0?}
    B -- No --> C[Reject / fallback to BF16]
    B -- Yes --> D{Recipe?}

    D -- MXFP8 --> E[C++: last_dim % 16 check\nPython make_empty: last_dim % 16\nScale via ceildiv+roundup]
    D -- NVFP4 --> F[C++: last_dim % 2, first_dim % 16 if RHT\nPython make_empty: last_dim % 32\nScale via ceildiv+roundup]
    D -- Float8 1x128 --> G[swizzle_block_scaling: rows % 4\ncublaslt: lda % 16, m % 8]

    E --> H{Distributed?}
    F --> H
    H -- MXFP8 AllGather --> I[supports_quantized_allgather:\nlast_dim % 16\nfirst_dim % 32 if columnwise]
    H -- NVFP4 AllGather --> J[supports_quantized_allgather:\nlast_dim % 32\nfirst_dim % 16 if columnwise]
    I -- passes --> K[Quantized AllGather]
    J -- passes --> K
    I -- fails --> L[Fallback: BF16 AllGather + re-quantize]
    J -- fails --> L

    E --> M[cuBLAS GEMM:\nlda % 16 for FP8]
    F --> N[cuBLAS GEMM:\nlda % 32 for FP4]
Loading

Reviews (5): Last reviewed commit: "Scope quantizer.cpp dim checks to quanti..." | Re-trigger Greptile

Comment on lines +1278 to +1285
Constraints:
- cuBLAS needs 16B-aligned leading dims: 16 elts for FP8, 32 for FP4.
- Linear(K, N) fprop has lda=ldb=K, ldc=N (no M alignment needed).
- Training adds wgrad/dgrad with lda/ldb/ldc spanning M, N, K -> all must be aligned.
- Float8BlockScaling swizzle requires data_rows (first dim) % 4.
- NVFP4 RHT requires input first dim % 16 (subsumed by 32-elt alignment for training).
"""
if recipe_obj.delayed() or recipe_obj.float8_current_scaling():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 MXFP8 tight-dim test doesn't exercise the newly-allowed K = 16

The test uses K = 32 for MXFP8 inference, so last_dim = 32 is divisible by 32 and the Python //-based scale-shape bug is never triggered. Adding a separate assertion or updating the inference tuple to (1, 16, 16, ...) would catch the ceildiv inconsistency before it reaches production.

ptrendx and others added 5 commits April 16, 2026 15:05
The C++ quantizer uses DIVUP for partial trailing blocks, so a last_dim of
16 (now allowed by the relaxed C++ guard) produces one scale. The Python
helpers were still floor-dividing, so the corresponding scale tensors
collapsed to zero size. Mirrors the C++ behavior across make_empty,
get_scale_shape, and the new_zeros torch_dispatch path; the new_zeros path
also relaxes its fall-back guard so shapes the quantizer now accepts aren't
silently downgraded to a generic tensor.

Adds a Python-level test covering last_dim=16 since the end-to-end GEMM
test can't exercise it (the MXFP8 GEMM kernel still requires K >= 32).

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
- nvfp4_tensor.make_empty: flat_first_dim was never defined; use math.prod(shape[:-1])
  (which is what the reader would expect from the surrounding code).
- quantized_tensor.supports_quantized_allgather: move the unused-argument
  pragma to disable-next so pylint actually picks it up.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
…ble by 32.

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member Author

ptrendx commented Apr 16, 2026

/te-ci

Comment on lines +367 to 375
def supports_quantized_allgather(self, inp: torch.Tensor) -> bool:
"""Whether tensor shape supports quantized all-gather.
When False, the distributed all-gather falls back to gathering
in high precision and quantizing afterward. This is needed when
the local shard's shape would cause scaling factor blocks to
span across GPU boundaries.
"""
return True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • We shouldn't assume that quantized tensors use a block-scale format.
  • While we're touching this function, changing the default to False is more general. It will always work, even without a custom all-gather impl.
Suggested change
def supports_quantized_allgather(self, inp: torch.Tensor) -> bool:
"""Whether tensor shape supports quantized all-gather.
When False, the distributed all-gather falls back to gathering
in high precision and quantizing afterward. This is needed when
the local shard's shape would cause scaling factor blocks to
span across GPU boundaries.
"""
return True
def supports_quantized_allgather(self, inp: torch.Tensor) -> bool:
"""Whether tensor shape supports quantized all-gather.
When False, the distributed all-gather falls back to gathering
in high precision and quantizing afterward.
"""
return False

Comment on lines -311 to -315
flat_first_dim = math.prod(shape[:-1])
assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, (
f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by"
f" {NVFP4_BLOCK_SCALING_SIZE}"
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be safer to keep this check if allocating column-wise data.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, not sure if this check is needed at all (even the 32 elements on the rowwise side) that was left there. This check comes from the GEMM and so it is checked there. There is also the check in the allgather support function. Here we should just check if the quantization functions support this shape - ultimately we don't know what is the goal of the person invoking this function, maybe they just need this as storage for their custom operations and so they don't care about the GEMM requirements?

Comment on lines +98 to 99
if inp.shape[-1] % 16 != 0:
return False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't partial blocks break all-gather? If we wanted to be as permissive as possible, we could reintroduce the check if row-wise usage is requested:

Suggested change
if inp.shape[-1] % 16 != 0:
return False
if inp.shape[-1] % 16 != 0:
return False
if self.rowwise_usage and inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0:
return False

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so - in the end the MXFP8 tensors are always rowwise in memory so the last dimension does not matter as it is not the dimension that is being allgathered. The dimension that is checked is the first dimension when the columnwise usage is true.

scale_inv = torch.empty(
round_up_to_nearest_multiple(math.prod(shape[:-1]), 128),
round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4),
round_up_to_nearest_multiple(math.ceil(shape[-1] / MXFP8_BLOCK_SCALING_SIZE), 4),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd prefer going through a ceildiv utility function. It's not like FP64's 52 mantissa bits aren't enough, but I have residual superstition about floating point error.

Comment on lines +95 to +102
// Minimum number of elements for 16-byte alignment, given a data type.
// cuBLAS requires (dim * typeSize) % 16 == 0 for FP8 tensor core usage,
// i.e. dim % (128 / typeBits) == 0.
constexpr size_t kAlignmentBytes = 16;

size_t min_alignment_elements(transformer_engine::DType dtype) {
return kAlignmentBytes * 8 / transformer_engine::typeToNumBits(dtype);
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is misplaced. Also, the min_ is kind of redundant.

Suggested change
// Minimum number of elements for 16-byte alignment, given a data type.
// cuBLAS requires (dim * typeSize) % 16 == 0 for FP8 tensor core usage,
// i.e. dim % (128 / typeBits) == 0.
constexpr size_t kAlignmentBytes = 16;
size_t min_alignment_elements(transformer_engine::DType dtype) {
return kAlignmentBytes * 8 / transformer_engine::typeToNumBits(dtype);
}
constexpr size_t kAlignmentBytes = 16;
// Number of elements for 16-byte alignment, given a data type.
// cuBLAS requires (dim * typeSize) % 16 == 0 for FP8 tensor core usage,
// i.e. dim % (128 / typeBits) == 0.
size_t alignment_elements(transformer_engine::DType dtype) {
return kAlignmentBytes * 8 / transformer_engine::typeToNumBits(dtype);
}

def test_linear_tight_dims(recipe, inference, dtype):
"""te.Linear with the tightest M/N/K per recipe, vs a pytorch baseline.

Previously the Python assert_dim_for_fp8_exec rejected any M not divisible
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These historical comments are good for checking what Claude did, but they're useless and distracting if they end up in the final code. assert_dim_for_fp8_exec is gone and mentioning it doesn't help you understand the code as it is.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that during reviewing of the generated code. Will clean.

torch.backends.cudnn.allow_tf32 = False
try:
te_linear = Linear(K, N, bias=False, params_dtype=dtype, device=device)
torch_linear = torch.nn.Linear(K, N, bias=False, device=device, dtype=dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The real baseline isn't PyTorch, it's mathematical ground truth. FP64 compute helps us get as close as possible:

Suggested change
torch_linear = torch.nn.Linear(K, N, bias=False, device=device, dtype=dtype)
torch_linear = torch.nn.Linear(K, N, bias=False, device="cpu", dtype=torch.float64)

This also lets us get rid of the TF32 logic.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you like that approach, but I generally disagree - while you are correct that the FP64 would give us the best estimate of the real values that "should" be there, it doesn't necessarily make the best reference. This is because in the case of the quantized execution we know that we deviate from the ground truth (due to various factors: quantization, TF32, etc.) and using the reference that is correct but closer to the actual computation being done helps with tightening the tolerances. And those ultimately provide better confidence in the result being correct.

Comment on lines +1348 to +1353
# Share weights: TE gets the raw weight (it quantizes internally); the
# baseline gets dequantize(quantize(W)) so both do the same matmul.
W = torch.randn(N, K, dtype=dtype, device=device)
with torch.no_grad():
te_linear.weight.copy_(W)
torch_linear.weight.copy_(weight_quantizer(W).dequantize())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It's a bit odd that the TE and reference impls don't actually share the same weight values.

Suggested change
# Share weights: TE gets the raw weight (it quantizes internally); the
# baseline gets dequantize(quantize(W)) so both do the same matmul.
W = torch.randn(N, K, dtype=dtype, device=device)
with torch.no_grad():
te_linear.weight.copy_(W)
torch_linear.weight.copy_(weight_quantizer(W).dequantize())
# Share weights
W = torch.randn(N, K, dtype=dtype, device=device)
W = weight_quantizers(W).dequantize()
with torch.no_grad():
te_linear.weight.copy_(W)
torch_linear.weight.copy_(W)

torch.backends.cudnn.allow_tf32 = prev_tf32_cudnn


def test_mxfp8_scale_shape_partial_block():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't test numerics, and generally test_quantized_tensor.py is a more natural place for basic quantization tests.

Remove leading-dim alignment checks that duplicate the cuBLAS FP8
alignment requirement already enforced (with clearer error messages)
in cublaslt_gemm.cu via min_alignment_elements(): MXFP8 last_dim%16
and NVFP4 last_dim%32 in both create_tensor and get_scale_shape.

Keep quantization-specific checks: NVFP4 last_dim%2 for the 4-bit
byte-packed storage, and first_dim%16 for NVFP4 with RHT (Hadamard
transform).

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants