Skip to content

Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924

Open
ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
ksivaman:pad_weight_scale_inv
Open

Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
ksivaman:pad_weight_scale_inv

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Make sure scaling factor inverses are 128x4 padded per tensor.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR fixes grouped MXFP8 swizzle correctness when per-expert row counts are not multiples of 128 by (1) allocating the output scale buffer with the correct per-tensor 128×4-padded shape instead of reusing the input shape, and (2) introducing separate input/output strides so the swizzle kernel correctly reads from the compact quantize-kernel output layout and writes to the padded layout expected by cuDNN grouped GEMM. The CUDA kernels are also refactored to promote the padding predicates to compile-time IS_PADDED_K/IS_PADDED_M template parameters, eliminating per-iteration runtime branches in the inner load loops.

Confidence Score: 4/5

Safe to merge; fixes a real correctness bug with no new P0/P1 issues introduced.

Only P2 findings: a missing divisibility assertion on per_tensor_first_dim and an implicit alignment contract in the compact-buffer size detection. Core bug fix logic — split input/output strides, padded output allocation, and compile-time IS_PADDED_K/IS_PADDED_M dispatch — is correct.

The compact-buffer size detection block in swizzle.cu (~line 2077) warrants a second look for edge cases where the compact buffer lacks trailing alignment padding.

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Refactors padding detection into compile-time IS_PADDED_K/IS_PADDED_M template parameters, adds dispatch helpers to avoid per-iteration runtime checks, splits scale_stride_bytes into separate input/output strides to support compact-vs-padded grouped layouts, and adds per-byte K zeroing to prevent out-of-bounds reads in the compact input case.
transformer_engine/pytorch/csrc/extensions/swizzle.cpp Adds compute_padded_grouped_scale_shape lambda to allocate the output scale buffer with correct per-tensor 128x4-padded dimensions instead of reusing the input shape; updates post-swizzle set_rowwise/columnwise_scale_inv calls to use the actual allocated tensor shape. Minor: per_tensor_first_dim uses silent integer division without a divisibility assertion.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["maybe_swizzle_grouped_tensor()"] --> B["compute_padded_grouped_scale_shape()"]
    B --> C["allocateSpace(num_tensors × padded_m, padded_k)"]
    C --> D["nvte_swizzle_grouped_scaling_factors()"]
    D --> E["swizzle_grouped_scaling_factors() in swizzle.cu"]
    E --> F{"input_scale_numel matches?"}
    F -->|"== num_tensors × padded_scale_elems"| G["input_is_compact = false"]
    F -->|"== compact_total_scale_elems"| H["input_is_compact = true"]
    F -->|"no match"| I["NVTE_ERROR"]
    G --> J["dispatch_swizzle_*_kernel_impl()"]
    H --> J
    J --> K{"padding flags?"}
    K -->|"K+M"| L["impl<true,true>"]
    K -->|"K only"| M["impl<true,false>"]
    K -->|"M only"| N["impl<false,true>"]
    K -->|"neither"| O["impl<false,false>"]
    L & M & N & O --> P["output: padded layout\nnum_tensors × padded_m × padded_k"]
Loading

Reviews (1): Last reviewed commit: "Review suggestion from @Oleg-Goncharov" | Re-trigger Greptile

const auto logical_shape_nvte = input.logical_shape();
NVTE_CHECK(logical_shape_nvte.ndim >= 2,
"Grouped GEMM swizzle expects logical_shape with ndim >= 2.");
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent truncation when logical_shape_nvte.data[0] is not divisible by num_tensors

per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.

Suggested change
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0,
"Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors.");

Comment on lines +2077 to 2087
bool input_is_compact;
if (input_scale_numel == input->num_tensors * padded_scale_elems) {
input_is_compact = false;
} else if (input_scale_numel == compact_total_scale_elems) {
input_is_compact = true;
} else {
NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems,
"Grouped input columnwise_scale_inv size does not match expected packed size.");
NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems,
"Grouped output columnwise_scale_inv size does not match expected packed size.");
NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"),
" size does not match expected packed size (got ", input_scale_numel,
", expected either ", input->num_tensors * padded_scale_elems,
" (per-tensor padded) or ", compact_total_scale_elems, " (compact)).");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Implicit contract on compact-buffer alignment is not validated

The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.

Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 24, 2026

@ksivaman Could you add a test exercising the change?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants