Optimizations for MXFP8/NVFP4 dequantize kernels#2865
Optimizations for MXFP8/NVFP4 dequantize kernels#2865YigongQin wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
f5e7375 to
39c0fb1
Compare
|
The following relevant unit tests passed on SM100 (with the drop |
ddab15d to
3a4afdd
Compare
|
After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE: |
Greptile SummaryThis PR adds GEMM-swizzled scale support to NVFP4 and MXFP8 dequantize kernels, enables CUDA-graph compatibility via early return on empty tensors, removes now-unnecessary
Confidence Score: 4/5Safe to merge with minor fix — the missing set_scale() update doesn't break the new tests but contradicts the PR description and will silently fail for future callers. One P1 finding (set_scale() not updated for NVFP4 as claimed), one P2 finding (missing null guard in from_cpu()). Kernel logic, swizzle index math, and Python module changes are correct. tests/cpp/test_common.cu — set_scale() condition and from_cpu() null guard. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_dequantize] --> B{input.numel == 0?}
B -- yes --> C[Early return
CUDA graph safe]
B -- no --> D{scaling_mode}
D -- NVTE_NVFP4_1D_SCALING --> E{with_gemm_swizzled_scales?}
D -- NVTE_MXFP8_1D_SCALING --> F{with_gemm_swizzled_scales?}
D -- NVTE_DELAYED_TENSOR_SCALING --> G[FP8 delayed kernel]
E -- false --> H[dequantize_fp4_kernel
compact scales
x + y*scale_stride]
E -- true --> I[dequantize_fp4_kernel
swizzled scales
gemm_swizzled_scale_idx]
F -- false --> J[dequantize_mxfp8_kernel
compact scales]
F -- true --> K[dequantize_mxfp8_kernel
swizzled scales
gemm_swizzled_scale_idx]
I --> L[mxfp8::swizzle::
gemm_swizzled_scale_idx]
K --> L
|
| } | ||
| } | ||
|
|
||
| std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = { |
There was a problem hiding this comment.
There is one edge case:
For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered
and the dequantize numerics won't be affected.For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?
So there shouldn't be any unswizzle logic needed here?There was a problem hiding this comment.
For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.
There was a problem hiding this comment.
For both MXFP8 and NVFP4, the unit test logic is: 1. generate compact scales (or from quantization); 2. call nvte_swizzle_scaling_factors to swizzle compact scales; 3. compare results of nvte_dequantize with compact scales and swizzled scales. Quantize with swizzle fusion is never enabled for both MXFP8 and NVFP4
e6f2a6c to
0eccfb1
Compare
0eccfb1 to
2c479b0
Compare
| fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
| if fp8_recipe.backward_override == "dequantized" and ( | ||
| fp8_recipe.mxfp8() or fp8_recipe.nvfp4() | ||
| ): | ||
| input_quantizer.optimize_for_gemm = False | ||
| if grad_output_quantizer is not None: | ||
| grad_output_quantizer.optimize_for_gemm = False |
There was a problem hiding this comment.
I'm of two minds about this:
- Logically, GEMM-optimized data is not guaranteed to support anything except GEMMs. Even if MXFP8 and NVFP4 dequant happens to support them, these are custom optimizations. Future recipes can not be expected to support dequantizing GEMM-optimzied data by default.
- It's a little pedantic to have edge-case logic that won't be triggered by any of our existing use-cases. Given how subtle this is, I worry about it becoming stale and distracting.
I think for now, this change is fine. However, if we encounter problems in a future recipe, we should reimplement it properly:
# LOGICALLY WRONG!
# Fails if we add a new recipe
if recipe.backward_override == "dequantized" and recipe.future_recipe():
input_quantizer.optimize_for_gemm = False
# LOGICALLY RIGHT!
# Automatically handles new recipes
if recipe.backward_override == "dequantized" and not (
recipe.float8_per_tensor_scaling()
or recipe.float8_block_scaling()
or recipe.mxfp8()
or recipe.nvfp4()
):
input_quantizer.optimize_for_gemm = FalseSigned-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
0eda58a to
1bf24be
Compare
Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
666c496 to
80484a9
Compare
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
/te-ci core |
timmoon10
left a comment
There was a problem hiding this comment.
LGTM, pending CI. These kernels will be very useful.
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Description
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: