Skip to content

Optimizations for MXFP8/NVFP4 dequantize kernels#2865

Open
YigongQin wants to merge 11 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim
Open

Optimizations for MXFP8/NVFP4 dequantize kernels#2865
YigongQin wants to merge 11 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim

Conversation

@YigongQin
Copy link
Copy Markdown

@YigongQin YigongQin commented Apr 10, 2026

Description

  • Handle empty tensors in dequantize for CUDA graph compatibility
  • Add swizzled scale support to the NVFP4 dequantize kernel, reusing the existing MXFP8 swizzle index computation
  • Add C++ unit tests for both NVFP4 and MXFP8 dequantization (including swizzled scale variants)
  • Fix to_cpu() and set_scale() in test infrastructure to correctly sync amax/scale for NVTE_NVFP4_1D_SCALING mode

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Handle empty tensors in dequantize for CUDA graph compatibility — Early return when input has zero elements, avoiding kernel launches on empty tensors.
  • Add GEMM-swizzled scale support to NVFP4 dequantize kernel — Template the kernel with WITH_GEMM_SWIZZLED_SCALES to support reading scales from swizzled layout, reusing the MXFP8 swizzle index computation.
  • Add GEMM-swizzled scale support to MXFP8 dequantize kernel — Extend the MXFP8 dequantize kernel to handle swizzled scale inputs.
  • Add C++ unit tests for NVFP4 dequantization — 21 tests for compact scales + 21 tests for swizzled scales, covering multiple sizes and output dtypes (fp32, bf16, fp16).
  • Add C++ unit tests for MXFP8 dequantization with swizzled scales — New swizzled test suite for MXFP8.
  • Fix to_cpu() to sync amax/scale for NVFP4 tensors — Previously only synced for NVTE_DELAYED_TENSOR_SCALING, causing the CPU reference to use stale amax=0.
  • Fix set_scale() to work for NVFP4 tensors — Same condition fix, enabling the scale to be properly uploaded to GPU before quantization.
  • Fix swizzled test ordering — Move from_cpu() before the FP4 data copy to prevent from_cpu() from overwriting the copied data with zeros.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from f5e7375 to 39c0fb1 Compare April 10, 2026 22:04
@zianglih
Copy link
Copy Markdown
Contributor

zianglih commented Apr 14, 2026

The following relevant unit tests passed on SM100 (with the drop optimize_for_gemm = False changes):

python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

@zianglih zianglih force-pushed the yigongq/bwd-dequantize-optim branch from ddab15d to 3a4afdd Compare April 14, 2026 18:46
@zianglih
Copy link
Copy Markdown
Contributor

After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE:

# With the optimization
NVTE_BACKWARD_OVERRIDE=dequantized python benchmarks/linear/benchmark_grouped_linear.py --recipe mxfp8 --fwd-only
       m     k     n recipe  num_gemms  grouped_fwd_time_ms
0  16384  7168  2048  mxfp8          4             0.272829
1  32768  7168  2048  mxfp8          4             0.509788
2  65536  7168  2048  mxfp8          4             0.948633
3  98304  7168  2048  mxfp8          4             1.391146
0  16384  7168  2048  mxfp8          8             0.303238
1  32768  7168  2048  mxfp8          8             0.533896
2  65536  7168  2048  mxfp8          8             1.003446
3  98304  7168  2048  mxfp8          8             1.470030

# Without the optimization
git restore --source 77b8681de5cf -- transformer_engine/pytorch/module
NVTE_BACKWARD_OVERRIDE=dequantized python benchmarks/linear/benchmark_grouped_linear.py --recipe mxfp8 --fwd-only
       m     k     n recipe  num_gemms  grouped_fwd_time_ms
0  16384  7168  2048  mxfp8          4             0.282720
1  32768  7168  2048  mxfp8          4             0.526736
2  65536  7168  2048  mxfp8          4             0.982166
3  98304  7168  2048  mxfp8          4             1.451485
0  16384  7168  2048  mxfp8          8             0.313753
1  32768  7168  2048  mxfp8          8             0.551043
2  65536  7168  2048  mxfp8          8             1.040773
3  98304  7168  2048  mxfp8          8             1.527951

@YigongQin YigongQin marked this pull request as ready for review April 15, 2026 16:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 15, 2026

Greptile Summary

This PR adds GEMM-swizzled scale support to NVFP4 and MXFP8 dequantize kernels, enables CUDA-graph compatibility via early return on empty tensors, removes now-unnecessary optimize_for_gemm = False overrides from several Python module backward paths, and adds C++ unit tests for both formats.

  • P1 — set_scale() not updated for NVFP4: The PR description explicitly states "Fix set_scale() to work for NVFP4 tensors," but the condition in Tensor::set_scale() still only handles NVTE_DELAYED_TENSOR_SCALING. For NVTE_NVFP4_1D_SCALING tensors the call silently no-ops, leaving the GPU scalar scale unchanged.

Confidence Score: 4/5

Safe to merge with minor fix — the missing set_scale() update doesn't break the new tests but contradicts the PR description and will silently fail for future callers.

One P1 finding (set_scale() not updated for NVFP4 as claimed), one P2 finding (missing null guard in from_cpu()). Kernel logic, swizzle index math, and Python module changes are correct.

tests/cpp/test_common.cu — set_scale() condition and from_cpu() null guard.

Important Files Changed

Filename Overview
tests/cpp/test_common.cu Fixes to_cpu() and Tensor constructor for NVFP4, but set_scale() is not updated as claimed in the PR description, and from_cpu() lacks the null guard added to to_cpu().
tests/cpp/test_common.h Adds set_amax() helper and GPU memory cleanup for amax/scale in destructor; changes look correct.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Templates kernel with WITH_GEMM_SWIZZLED_SCALES, reuses mxfp8 swizzle index computation; num_scale_tiles_X computation and index addressing look correct.
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh Adds WITH_GEMM_SWIZZLED_SCALES template parameter; swizzle index computation correctly accounts for rowwise vs colwise axis transposition and num_scale_tiles_X.
transformer_engine/common/cast/dispatch/dequantize.cuh Adds early return for empty tensors (numel == 0) enabling CUDA-graph compatibility; straightforward and correct.
tests/cpp/operator/test_dequantize_nvfp4.cu New NVFP4 dequantize test suite with compact and swizzled scale variants; tests skip on non-Blackwell devices and handle empty-tensor cases.
tests/cpp/operator/test_dequantize_mxfp8.cu Extends MXFP8 test suite with swizzled-scale variant and zero-row test dimensions; logic mirrors the existing non-swizzled test correctly.
transformer_engine/pytorch/module/grouped_linear.py Removes empty-tensor special-case in backward dequantize and drops optimize_for_gemm=False override for MXFP8/NVFP4 dequantized backward.
transformer_engine/pytorch/module/linear.py Removes the optimize_for_gemm=False override block for MXFP8/NVFP4 backward; consistent with grouped_linear and layernorm_linear changes.
transformer_engine/pytorch/module/layernorm_linear.py Removes optimize_for_gemm=False override for MXFP8/NVFP4 dequantized backward path; change is symmetric with linear.py.
transformer_engine/pytorch/ops/basic/basic_linear.py Removes optimize_for_gemm=False override when backward_override is set for MXFP8/NVFP4; now safe because dequantize supports swizzled scales.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_dequantize] --> B{input.numel == 0?}
    B -- yes --> C[Early return
 CUDA graph safe]
    B -- no --> D{scaling_mode}
    D -- NVTE_NVFP4_1D_SCALING --> E{with_gemm_swizzled_scales?}
    D -- NVTE_MXFP8_1D_SCALING --> F{with_gemm_swizzled_scales?}
    D -- NVTE_DELAYED_TENSOR_SCALING --> G[FP8 delayed kernel]
    E -- false --> H[dequantize_fp4_kernel
 compact scales
 x + y*scale_stride]
    E -- true --> I[dequantize_fp4_kernel
 swizzled scales
 gemm_swizzled_scale_idx]
    F -- false --> J[dequantize_mxfp8_kernel
 compact scales]
    F -- true --> K[dequantize_mxfp8_kernel
 swizzled scales
 gemm_swizzled_scale_idx]
    I --> L[mxfp8::swizzle::
gemm_swizzled_scale_idx]
    K --> L
Loading

Comments Outside Diff (1)

  1. tests/cpp/test_common.cu, line 517-525 (link)

    P1 set_scale() not updated for NVTE_NVFP4_1D_SCALING

    The PR description says "Fix set_scale() to work for NVFP4 tensors — Same condition fix, enabling the scale to be properly uploaded to GPU before quantization," but set_scale() still only handles NVTE_DELAYED_TENSOR_SCALING. For NVFP4 tensors the function silently does nothing — *scale_cpu_data_ is never updated and from_cpu() is never called — so calling set_scale() on an NVFP4 tensor leaves the GPU scalar scale unchanged.

Reviews (9): Last reviewed commit: "remove redundant set scale" | Re-trigger Greptile

Comment thread tests/cpp/operator/test_dequantize_nvfp4.cu Outdated
}
}

std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one edge case:

For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered

// Zero out swizzled scales if padding is needed
and the dequantize numerics won't be affected.

For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?

NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format.");
So there shouldn't be any unswizzle logic needed here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both MXFP8 and NVFP4, the unit test logic is: 1. generate compact scales (or from quantization); 2. call nvte_swizzle_scaling_factors to swizzle compact scales; 3. compare results of nvte_dequantize with compact scales and swizzled scales. Quantize with swizzle fusion is never enabled for both MXFP8 and NVFP4

Comment on lines -1713 to -1719
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.backward_override == "dequantized" and (
fp8_recipe.mxfp8() or fp8_recipe.nvfp4()
):
input_quantizer.optimize_for_gemm = False
if grad_output_quantizer is not None:
grad_output_quantizer.optimize_for_gemm = False
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm of two minds about this:

  • Logically, GEMM-optimized data is not guaranteed to support anything except GEMMs. Even if MXFP8 and NVFP4 dequant happens to support them, these are custom optimizations. Future recipes can not be expected to support dequantizing GEMM-optimzied data by default.
  • It's a little pedantic to have edge-case logic that won't be triggered by any of our existing use-cases. Given how subtle this is, I worry about it becoming stale and distracting.

I think for now, this change is fine. However, if we encounter problems in a future recipe, we should reimplement it properly:

# LOGICALLY WRONG!
# Fails if we add a new recipe
if recipe.backward_override == "dequantized" and recipe.future_recipe():
    input_quantizer.optimize_for_gemm = False

# LOGICALLY RIGHT!
# Automatically handles new recipes
if recipe.backward_override == "dequantized" and not (
    recipe.float8_per_tensor_scaling()
    or recipe.float8_block_scaling()
    or recipe.mxfp8()
    or recipe.nvfp4()
):
    input_quantizer.optimize_for_gemm = False

CC @ptrendx @ksivaman @zhongbozhu

Comment thread tests/cpp/test_common.h Outdated
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from 0eda58a to 1bf24be Compare April 23, 2026 21:47
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from 666c496 to 80484a9 Compare April 23, 2026 21:53
@zhongbozhu
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Comment thread tests/cpp/test_common.h
Comment thread tests/cpp/test_common.cu Outdated
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci core

timmoon10
timmoon10 previously approved these changes Apr 24, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI. These kernels will be very useful.

Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants