Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924ksivaman wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Greptile SummaryThis PR fixes grouped MXFP8 swizzle correctness when per-expert row counts are not multiples of 128 by (1) allocating the output scale buffer with the correct per-tensor 128×4-padded shape instead of reusing the input shape, and (2) introducing separate input/output strides so the swizzle kernel correctly reads from the compact quantize-kernel output layout and writes to the padded layout expected by cuDNN grouped GEMM. The CUDA kernels are also refactored to promote the padding predicates to compile-time Confidence Score: 4/5Safe to merge; fixes a real correctness bug with no new P0/P1 issues introduced. Only P2 findings: a missing divisibility assertion on per_tensor_first_dim and an implicit alignment contract in the compact-buffer size detection. Core bug fix logic — split input/output strides, padded output allocation, and compile-time IS_PADDED_K/IS_PADDED_M dispatch — is correct. The compact-buffer size detection block in swizzle.cu (~line 2077) warrants a second look for edge cases where the compact buffer lacks trailing alignment padding. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["maybe_swizzle_grouped_tensor()"] --> B["compute_padded_grouped_scale_shape()"]
B --> C["allocateSpace(num_tensors × padded_m, padded_k)"]
C --> D["nvte_swizzle_grouped_scaling_factors()"]
D --> E["swizzle_grouped_scaling_factors() in swizzle.cu"]
E --> F{"input_scale_numel matches?"}
F -->|"== num_tensors × padded_scale_elems"| G["input_is_compact = false"]
F -->|"== compact_total_scale_elems"| H["input_is_compact = true"]
F -->|"no match"| I["NVTE_ERROR"]
G --> J["dispatch_swizzle_*_kernel_impl()"]
H --> J
J --> K{"padding flags?"}
K -->|"K+M"| L["impl<true,true>"]
K -->|"K only"| M["impl<true,false>"]
K -->|"M only"| N["impl<false,true>"]
K -->|"neither"| O["impl<false,false>"]
L & M & N & O --> P["output: padded layout\nnum_tensors × padded_m × padded_k"]
Reviews (1): Last reviewed commit: "Review suggestion from @Oleg-Goncharov" | Re-trigger Greptile |
| const auto logical_shape_nvte = input.logical_shape(); | ||
| NVTE_CHECK(logical_shape_nvte.ndim >= 2, | ||
| "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); | ||
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; |
There was a problem hiding this comment.
Silent truncation when
logical_shape_nvte.data[0] is not divisible by num_tensors
per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0, | |
| "Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors."); |
| bool input_is_compact; | ||
| if (input_scale_numel == input->num_tensors * padded_scale_elems) { | ||
| input_is_compact = false; | ||
| } else if (input_scale_numel == compact_total_scale_elems) { | ||
| input_is_compact = true; | ||
| } else { | ||
| NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, | ||
| "Grouped input columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, | ||
| "Grouped output columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), | ||
| " size does not match expected packed size (got ", input_scale_numel, | ||
| ", expected either ", input->num_tensors * padded_scale_elems, | ||
| " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); | ||
| } |
There was a problem hiding this comment.
Implicit contract on compact-buffer alignment is not validated
The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.
Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.
|
@ksivaman Could you add a test exercising the change? |
Description
Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).
Type of change
Changes
Checklist: