diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..8ae502a1a1 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_pertoken", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4PerTokenBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -195,7 +200,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +227,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,9 +248,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py new file mode 100644 index 0000000000..93a4376eb2 --- /dev/null +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -0,0 +1,505 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for per-token NVFP4 quantization kernel (tex.quantize_nvfp4_pertoken). + +These tests validate the CUDA kernel in quantize_pertoken_nvfp4.cuh, which +performs per-row amax reduction and NVFP4 quantization in a single kernel. + +Tests require SM100+ (Blackwell) for FP4 hardware support. +""" + +import math +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +# Check hardware support +_, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) +nvfp4_available = te.is_nvfp4_available() + +pytestmark = pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4) + +FP4_MAX = 6.0 +FP8_E4M3_MAX = 448.0 + +# FP4 E2M1 look-up table: 4-bit index -> float value +# Lower nibble = first element, upper nibble = second element +_FP4_E2M1_LUT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] + + +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + """Unpack uint8 packed FP4 data to two columns per byte. + + Each byte contains 2 FP4 values: lower nibble = first, upper nibble = second. + Returns a uint8 tensor with 2x the columns. + """ + repeated = packed.repeat_interleave(2, dim=1) + repeated[:, 0::2] = repeated[:, 0::2] & 0x0F # Lower 4 bits + repeated[:, 1::2] = repeated[:, 1::2] >> 4 # Upper 4 bits + return repeated + + +def fp4_to_fp32(unpacked: torch.Tensor) -> torch.Tensor: + """Convert unpacked FP4 indices to float32 values using E2M1 LUT.""" + lut = torch.tensor(_FP4_E2M1_LUT, dtype=torch.float32, device=unpacked.device) + return lut[unpacked.long()] + + +def dequantize_pertoken_fp4( + data: torch.Tensor, scales: torch.Tensor, per_token_scales: torch.Tensor +) -> torch.Tensor: + """Dequantize per-token NVFP4: result = fp4_val * block_scale * per_token_scale. + + Args: + data: (M, K/2) uint8 packed FP4 + scales: (M, K/16) uint8 block scales (FP8 E4M3) + per_token_scales: (M,) FP32 per-token global scales + + Returns: + (M, K) float32 dequantized tensor + """ + num_rows = data.shape[0] + num_cols = data.shape[1] * 2 # 2 FP4 values per byte + + # Unpack FP4 -> float32 + fp4_vals = fp4_to_fp32(unpack_fp4(data)) # (M, K) + + # Expand block scales: each scale covers 16 elements + block_scales_f32 = scales.view(torch.float8_e4m3fn).float() # (M, K/16) + block_scales_expanded = block_scales_f32.repeat_interleave(16, dim=1) # (M, K) + block_scales_expanded = block_scales_expanded[:, :num_cols] + + # Expand per-token scales: one per row + token_scales_expanded = per_token_scales.unsqueeze(1) # (M, 1) + + return fp4_vals * block_scales_expanded * token_scales_expanded + + +def _has_pertoken_kernel(): + """Check if the per-token kernel binding is available.""" + return hasattr(tex, "quantize_nvfp4_pertoken") + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- + + +def nvfp4_pertoken_quantize_ref(input_tensor: torch.Tensor): + """Pure PyTorch reference for per-token NVFP4 quantization. + + Reproduces the exact logic of quantize_pertoken_nvfp4_kernel: + Pass 1: per-row amax → S_enc → per_token_scale + Pass 2: per-block(16) amax → S_dec_b (E4M3) → scale + quantize to FP4 + + Returns: + data: (M, K/2) uint8 packed FP4 + scales: (M, K/16) uint8 (FP8 E4M3 block scales) + per_token_scales: (M,) FP32 + """ + from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import cast_to_fp4x2 + + assert input_tensor.dim() == 2 + num_rows, num_cols = input_tensor.shape + assert num_cols % 16 == 0 + + x = input_tensor.float() + + # --- Pass 1: Per-row amax → S_enc → per_token_scale --- + row_amax = x.abs().amax(dim=1) # (M,) + + # compute_global_encode_scaling_factor_FP4: S_enc = fp8_max * fp4_max / amax + S_enc = FP8_E4M3_MAX * FP4_MAX / row_amax + S_enc = torch.clamp(S_enc, max=torch.finfo(torch.float32).max) + S_enc = torch.where((row_amax == 0) | (S_enc == 0), torch.ones_like(S_enc), S_enc) + + per_token_scales = 1.0 / S_enc # global_scale = 1 / S_enc + per_token_scales = torch.where( + row_amax == 0, torch.ones_like(per_token_scales), per_token_scales + ) + + # --- Pass 2: Per-block quantization --- + num_blocks = num_cols // 16 + x_blocks = x.view(num_rows, num_blocks, 16) # (M, K/16, 16) + + # Per-block amax + block_amax = x_blocks.abs().amax(dim=-1) # (M, K/16) + + # compute_decoding_scaling_factor: S_dec_b = block_amax * S_enc / fp4_max + # Then cast to FP8 E4M3 + S_enc_expanded = S_enc.unsqueeze(1) # (M, 1) + S_dec_b = block_amax * S_enc_expanded / FP4_MAX + S_dec_b = torch.clamp(S_dec_b, max=FP8_E4M3_MAX) + S_dec_b_fp8 = S_dec_b.to(torch.float8_e4m3fn) + S_dec_b_f = S_dec_b_fp8.float() + + # Block encode scale = S_enc / S_dec_b_f (inverse for quantization) + block_encode_scale = torch.where( + S_dec_b_f != 0, + S_enc_expanded / S_dec_b_f, + torch.zeros_like(S_dec_b_f), + ) # (M, K/16) + + # Scale input and clamp to FP4 range [-6, 6] + block_encode_expanded = block_encode_scale.unsqueeze(-1) # (M, K/16, 1) + scaled_x = x_blocks * block_encode_expanded # (M, K/16, 16) + scaled_x = scaled_x.reshape(num_rows, num_cols) + clamped_x = torch.clamp(scaled_x, -FP4_MAX, FP4_MAX) + + # Pack to FP4 using TE's reference cast_to_fp4x2 + data = cast_to_fp4x2(clamped_x) + + # Block scales as uint8 (FP8 E4M3 raw bytes) + scales = S_dec_b_fp8.view(torch.uint8) + + return data, scales, per_token_scales + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _has_pertoken_kernel(), reason="tex.quantize_nvfp4_pertoken not available") +class TestQuantizeNvfp4Pertoken: + """Test suite for per-token NVFP4 quantization kernel.""" + + @pytest.mark.parametrize( + "num_rows,num_cols", + [ + (1, 16), + (1, 256), + (4, 256), + (32, 256), + (64, 4096), + (128, 4096), + (256, 4096), + (512, 14336), + ], + ids=lambda x: f"{x}", + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_output_shapes(self, num_rows, num_cols, dtype): + """Verify output tensor shapes are correct.""" + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + data, scales, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + assert data.shape == (num_rows, num_cols // 2), f"data shape: {data.shape}" + assert scales.shape == (num_rows, num_cols // 16), f"scales shape: {scales.shape}" + assert per_token_scales.shape == ( + num_rows, + ), f"per_token_scales shape: {per_token_scales.shape}" + assert data.dtype == torch.uint8 + assert scales.dtype == torch.uint8 + assert per_token_scales.dtype == torch.float32 + + @pytest.mark.parametrize( + "num_rows,num_cols", + [ + (1, 256), + (32, 256), + (64, 4096), + (256, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_per_token_scales_match_reference(self, num_rows, num_cols, dtype): + """Verify per-token scales match pure PyTorch reference.""" + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + _, _, ref_scales = nvfp4_pertoken_quantize_ref(x) + + torch.testing.assert_close( + per_token_scales, + ref_scales.to(device="cuda"), + atol=1e-5, + rtol=1e-3, + msg="Per-token scales should match reference", + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_zero_input(self, dtype): + """Zero input: S_enc = 1.0 (fallback), so global_scale = 1/1 = 1.0.""" + x = torch.zeros(16, 256, dtype=dtype, device="cuda") + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # When amax=0, compute_global_encode_scaling_factor_FP4 returns 1.0 + # so global_scale = 1/S_enc = 1/1 = 1.0 + assert ( + per_token_scales == 1.0 + ).all(), f"Zero input should give global_scale=1.0 (S_enc fallback), got {per_token_scales}" + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_uniform_rows_same_scale(self, dtype): + """Rows with the same magnitude should produce the same per-token scale.""" + num_rows = 8 + num_cols = 256 + x = torch.randn(1, num_cols, dtype=dtype, device="cuda").expand(num_rows, -1).contiguous() + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # All rows identical → all scales identical + assert torch.allclose( + per_token_scales, per_token_scales[0].expand_as(per_token_scales) + ), "Identical rows should produce identical per-token scales" + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_different_rows_different_scales(self, dtype): + """Rows with different magnitudes should produce different per-token scales.""" + num_cols = 256 + # Row 0: small values, Row 1: large values + x = torch.zeros(2, num_cols, dtype=dtype, device="cuda") + x[0] = torch.randn(num_cols, dtype=dtype, device="cuda") * 0.01 + x[1] = torch.randn(num_cols, dtype=dtype, device="cuda") * 100.0 + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # Scale for large row should be much larger + assert per_token_scales[1] > per_token_scales[0] * 10, ( + f"Large row scale ({per_token_scales[1].item():.6f}) should be >> " + f"small row scale ({per_token_scales[0].item():.6f})" + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_scale_formula(self, dtype): + """Verify scale = row_amax / (fp8_max * fp4_max).""" + num_rows = 4 + num_cols = 256 + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # Compute expected scales + row_amax = x.float().abs().amax(dim=1) + expected_scales = row_amax / (FP8_E4M3_MAX * FP4_MAX) + + torch.testing.assert_close( + per_token_scales, + expected_scales, + atol=1e-5, + rtol=1e-3, + msg="Scale should equal row_amax / (fp8_max * fp4_max)", + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_block_scales_are_valid_fp8(self, dtype): + """Block scales should be valid FP8 E4M3 values (non-NaN, non-Inf).""" + x = torch.randn(32, 4096, dtype=dtype, device="cuda") + _, scales, _ = tex.quantize_nvfp4_pertoken(x) + + # Reinterpret uint8 as FP8 E4M3 and check for validity + scales_f32 = scales.to(torch.float8_e4m3fn).float() + assert not torch.isnan(scales_f32).any(), "Block scales contain NaN" + assert not torch.isinf(scales_f32).any(), "Block scales contain Inf" + assert (scales_f32 >= 0).all(), "Block scales should be non-negative" + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_packed_fp4_data_shape(self, dtype): + """Packed FP4 output should have exactly half the columns (2 elements per byte).""" + for num_cols in [16, 32, 256, 4096]: + x = torch.randn(4, num_cols, dtype=dtype, device="cuda") + data, _, _ = tex.quantize_nvfp4_pertoken(x) + assert data.shape[1] == num_cols // 2 + + @pytest.mark.parametrize( + "num_rows,num_cols", + [ + (4, 256), + (32, 256), + (64, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_dequantized_data_close_to_input(self, num_rows, num_cols, dtype): + """Dequantized FP4 data should be close to the original input. + + Quantize -> dequantize round-trip should preserve values within FP4 precision. + FP4 E2M1 has ~1 bit mantissa, so expect ~25% relative error for non-tiny values. + """ + torch.manual_seed(42) + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + data, scales, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + dequant = dequantize_pertoken_fp4(data, scales, per_token_scales) + + # Compare against original (allow FP4 quantization error) + x_f32 = x.float() + nonzero = x_f32.abs() > 0.1 # skip very small values where relative error is meaningless + if nonzero.any(): + rel_error = ((dequant[nonzero] - x_f32[nonzero]).abs() / x_f32[nonzero].abs()).mean() + assert rel_error < 0.5, ( + f"Mean relative error {rel_error:.3f} too high for FP4 round-trip " + f"(shape={num_rows}x{num_cols}, dtype={dtype})" + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_fp4_values_in_valid_range(self, dtype): + """Unpacked FP4 indices should be in [0, 15] (valid 4-bit range).""" + x = torch.randn(16, 256, dtype=dtype, device="cuda") + data, _, _ = tex.quantize_nvfp4_pertoken(x) + + unpacked = unpack_fp4(data) + assert (unpacked >= 0).all() and ( + unpacked <= 15 + ).all(), f"FP4 indices out of range: min={unpacked.min()}, max={unpacked.max()}" + + def test_input_validation_not_2d(self): + """Should reject non-2D input.""" + x = torch.randn(2, 3, 256, dtype=torch.bfloat16, device="cuda") + with pytest.raises(RuntimeError): + tex.quantize_nvfp4_pertoken(x) + + def test_input_validation_not_multiple_of_16(self): + """Should reject num_cols not divisible by 16.""" + x = torch.randn(4, 100, dtype=torch.bfloat16, device="cuda") + with pytest.raises(RuntimeError): + tex.quantize_nvfp4_pertoken(x) + + def test_input_validation_wrong_dtype(self): + """Should reject non-BF16/FP16 input.""" + x = torch.randn(4, 256, dtype=torch.float32, device="cuda") + with pytest.raises(RuntimeError): + tex.quantize_nvfp4_pertoken(x) + + # ----------------------------------------------------------------------- + # Exact byte-match tests (following test_nvfp4_quantize_exact.py pattern) + # ----------------------------------------------------------------------- + + @pytest.mark.parametrize( + "M, N", + [ + (4, 256), + (16, 256), + (32, 1024), + (128, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_fp4_data_exact_match(self, M, N, dtype): + """FP4 packed data must exactly match Python reference (byte-for-byte).""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + data, scales, pts = tex.quantize_nvfp4_pertoken(x) + ref_data, ref_scales, ref_pts = nvfp4_pertoken_quantize_ref(x) + + # Unpack both to 4-bit indices for comparison + kernel_unpacked = unpack_fp4(data) + ref_unpacked = unpack_fp4(ref_data.to(device="cuda")) + + torch.testing.assert_close( + kernel_unpacked, + ref_unpacked, + atol=0.0, + rtol=0.0, + msg=f"FP4 data mismatch for shape ({M}, {N}), dtype={dtype}", + ) + + @pytest.mark.parametrize( + "M, N", + [ + (4, 256), + (16, 256), + (32, 1024), + (128, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_block_scales_exact_match(self, M, N, dtype): + """Block scales must exactly match Python reference (byte-for-byte).""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + _, scales, _ = tex.quantize_nvfp4_pertoken(x) + _, ref_scales, _ = nvfp4_pertoken_quantize_ref(x) + + torch.testing.assert_close( + scales, + ref_scales.to(device="cuda"), + atol=0.0, + rtol=0.0, + msg=f"Block scales mismatch for shape ({M}, {N}), dtype={dtype}", + ) + + @pytest.mark.parametrize( + "M, N", + [ + (4, 256), + (16, 256), + (32, 1024), + (128, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_per_token_scales_exact_match(self, M, N, dtype): + """Per-token scales must exactly match Python reference.""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + _, _, pts = tex.quantize_nvfp4_pertoken(x) + _, _, ref_pts = nvfp4_pertoken_quantize_ref(x) + + torch.testing.assert_close( + pts, + ref_pts.to(device="cuda"), + atol=0.0, + rtol=0.0, + msg=f"Per-token scales mismatch for shape ({M}, {N}), dtype={dtype}", + ) + + +# --------------------------------------------------------------------------- +# Standalone test (can run without tex binding for reference validation) +# --------------------------------------------------------------------------- + + +class TestPertokenScaleReference: + """Test the pure PyTorch reference implementation (no CUDA kernel needed).""" + + def test_reference_basic(self): + """Basic reference test on CPU.""" + x = torch.tensor([[1.0, 2.0, 3.0, 4.0] * 4], dtype=torch.float32) + _, _, pts = nvfp4_pertoken_quantize_ref(x) + expected = torch.tensor([4.0 / (FP8_E4M3_MAX * FP4_MAX)]) + torch.testing.assert_close(pts, expected) + + def test_reference_multi_row(self): + """Multi-row reference test.""" + x = torch.zeros(3, 16, dtype=torch.float32) + x[0] = 1.0 + x[1] = 10.0 + x[2] = 0.1 + _, _, pts = nvfp4_pertoken_quantize_ref(x) + + assert pts[1] > pts[0] > pts[2] + torch.testing.assert_close(pts[0], torch.tensor(1.0 / (FP8_E4M3_MAX * FP4_MAX))) + torch.testing.assert_close(pts[1], torch.tensor(10.0 / (FP8_E4M3_MAX * FP4_MAX))) + + def test_reference_zero_row(self): + """Zero row: S_enc=1.0 fallback, so global_scale=1.0.""" + x = torch.zeros(2, 16, dtype=torch.float32) + x[0] = 5.0 + _, _, pts = nvfp4_pertoken_quantize_ref(x) + assert pts[1] == 1.0 diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index fd9a6416ec..f5077ee294 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -149,6 +149,13 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_pertoken": + return transformer_engine.common.recipe.NVFP4PerTokenBlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 61cfacd334..ab519b98d4 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -16,6 +16,7 @@ #include "../utils.cuh" #include "dispatch/dequantize.cuh" #include "dispatch/quantize.cuh" +#include "nvfp4/quantize_pertoken_nvfp4.cuh" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -146,3 +147,36 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out dispatch::group_quantize_fwd_host_aware_helper( input, outputs, split_sections, num_tensors, quant_config, stream); } + +void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data, + NVTETensor output_scales, NVTETensor output_per_token_scales, + size_t num_rows, size_t num_cols, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_nvfp4_pertoken); + + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + const void *input_ptr = nvte_tensor_data(input); + void *data_ptr = nvte_tensor_data(output_data); + void *scales_ptr = nvte_tensor_data(output_scales); + void *pertoken_ptr = nvte_tensor_data(output_per_token_scales); + const NVTEDType itype = nvte_tensor_type(input); + + using namespace transformer_engine; + + if (itype == NVTEDType::kNVTEBFloat16) { + dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + num_rows, num_cols, reinterpret_cast(input_ptr), nullptr, + reinterpret_cast(data_ptr), reinterpret_cast(scales_ptr), + reinterpret_cast(pertoken_ptr), stream); + } else if (itype == NVTEDType::kNVTEFloat16) { + dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + num_rows, num_cols, reinterpret_cast(input_ptr), nullptr, + reinterpret_cast(data_ptr), reinterpret_cast(scales_ptr), + reinterpret_cast(pertoken_ptr), stream); + } else { + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16 or Float16."); + } +} diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh new file mode 100644 index 0000000000..659835e51a --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -0,0 +1,231 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_pertoken_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. + * + * Unlike standard NVFP4 quantization which uses a single per-tensor global scale, + * per-token NVFP4 computes a separate global scale for each row. This preserves + * more dynamic range per token, improving accuracy for MoE workloads. + * + * Scaling hierarchy: + * global_scale[row] = row_amax / (fp8_max * fp4_max) + * block_scale[row, block] = block_amax / (fp4_max * global_scale[row]) + * x_fp4 = quantize_to_fp4(x / (global_scale[row] * block_scale[row, block])) + * + * Based on the approach from FlashInfer (flashinfer-ai/flashinfer#3027): + * two-pass design with one CUDA block per row. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ + +#include +#include + +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_pertoken_kernel { + +using namespace core; + +constexpr int PERTOKEN_BLOCK_SIZE = 256; +constexpr int PERTOKEN_SF_VEC_SIZE = 16; + +/* + * Per-token NVFP4 quantization kernel. + * + * One CUDA block per row. Two passes: + * Pass 1: Vectorized load + per-row amax reduction via cub::BlockReduce + * Pass 2: Reload data, compute per-block E4M3 scale, quantize to FP4 + * + * Template parameters: + * IType - Input type (half, __nv_bfloat16) + * BLOCK_SIZE - Threads per block + * + * Parameters: + * num_rows - Number of rows (tokens) + * num_cols - Number of columns (hidden dim), must be multiple of 16 + * input - Input tensor (num_rows, num_cols), IType + * row_offsets - Optional row index remapping (for MoE expert routing), or nullptr + * output_data - Output packed FP4 data (num_rows, num_cols/2), uint8 + * output_scales - Output block scales, fp8e4m3 + * output_per_token_scales - Output per-row global scales (num_rows,), FP32 + * scale_stride - Stride of scale factor output (number of SF vectors per row) + */ +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_pertoken_nvfp4_kernel( + const int num_rows, const int num_cols, const IType *__restrict__ input, + const int *__restrict__ row_offsets, // optional: nullptr for identity mapping + uint8_t *__restrict__ output_data, fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_scales, const int scale_stride) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f + constexpr float fp4_max = TypeExtrema::max; // 6.0f + constexpr float fp4_max_inv = 1.0f / fp4_max; + + // Packed type: 4 elements per float2 pair for FP4 conversion + using IType2 = + typename std::conditional::value, half2, __nv_bfloat162>::type; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + // Optional row remapping (for MoE routing) + const int actual_row = (row_offsets != nullptr) ? row_offsets[row_idx] : row_idx; + if (actual_row < 0) return; + + const int num_vec2 = num_cols / 2; // number of IType2 elements per row + const IType2 *input_row = reinterpret_cast(input + actual_row * num_cols); + + // ========================================================================= + // Pass 1: Per-row amax reduction + // ========================================================================= + float thread_max = 0.0f; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + IType2 val = input_row[i]; + float2 fval; + if constexpr (std::is_same_v) { + fval = __half22float2(val); + } else { + fval = __bfloat1622float2(val); + } + thread_max = fmaxf(thread_max, fabsf(fval.x)); + thread_max = fmaxf(thread_max, fabsf(fval.y)); + } + + // Block-wide max reduction + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); + + // Compute and store per-token global scale + // global_scale = row_amax / (fp8_max * fp4_max) + // S_enc = fp8_max * fp4_max / row_amax (encoding scale, inverse of global_scale) + __shared__ float shared_s_enc; + if (threadIdx.x == 0) { + float s_enc = compute_global_encode_scaling_factor_FP4(row_amax); + float global_scale = (s_enc > 0.0f) ? (1.0f / s_enc) : 0.0f; + output_per_token_scales[row_idx] = global_scale; + shared_s_enc = s_enc; + } + __syncthreads(); + const float S_enc = shared_s_enc; + + // ========================================================================= + // Pass 2: Compute block scales and quantize to FP4 + // ========================================================================= + // Each thread processes one 16-element block: computes block amax, + // derives E4M3 block scale, then quantizes 16 elements to 8 packed FP4 bytes. + const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; + + for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { + const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE; + + // Load 16 elements and find block amax + float block_max = 0.0f; + float vals[PERTOKEN_SF_VEC_SIZE]; + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j++) { + if constexpr (std::is_same_v) { + vals[j] = __half2float(input[actual_row * num_cols + col_start + j]); + } else { + vals[j] = __bfloat162float(input[actual_row * num_cols + col_start + j]); + } + block_max = fmaxf(block_max, fabsf(vals[j])); + } + + // Compute per-block E4M3 scale factor: S_dec_b = block_max / (fp4_max / S_enc) + fp8e4m3 S_dec_b = quantization_SF::compute_decoding_scaling_factor(block_max, S_enc); + float S_dec_b_f = static_cast(S_dec_b); + + // Store block scale (LINEAR layout: row-major) + output_scales[row_idx * scale_stride + sf_idx] = S_dec_b; + + // Compute encoding scale for this block: maps input range to [-6, 6] (FP4 range) + float block_encode_scale = (S_dec_b_f != 0.0f) ? __fdividef(S_enc, S_dec_b_f) : 0.0f; + + // Scale values and pack to FP4 using PTX cvt.rn.satfinite.e2m1x2 + // Process 8 elements (4 pairs) at a time -> 4 bytes -> 1 uint32_t + // Matching FlashInfer's fp32_vec_to_e2m1 pattern. + uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 8) { + float s0 = vals[j] * block_encode_scale; + float s1 = vals[j + 1] * block_encode_scale; + float s2 = vals[j + 2] * block_encode_scale; + float s3 = vals[j + 3] * block_encode_scale; + float s4 = vals[j + 4] * block_encode_scale; + float s5 = vals[j + 5] * block_encode_scale; + float s6 = vals[j + 6] * block_encode_scale; + float s7 = vals[j + 7] * block_encode_scale; + uint32_t packed; + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}\n" + : "=r"(packed) + : "f"(s0), "f"(s1), "f"(s2), "f"(s3), "f"(s4), "f"(s5), "f"(s6), "f"(s7)); + reinterpret_cast(out_ptr)[j / 8] = packed; + } + // Handle remaining 8 elements (PERTOKEN_SF_VEC_SIZE=16, so exactly 2 iterations of 8) + // The loop above covers j=0..7 and j=8..15, so all 16 elements are handled. + } +#endif // __CUDA_ARCH__ >= 1000 +} + +/* + * Host-side launcher for per-token NVFP4 quantization. + */ +template +void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_scales, + cudaStream_t stream) { + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols); + + const int scale_stride = num_cols / PERTOKEN_SF_VEC_SIZE; + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_pertoken_nvfp4_kernel + <<>>(num_rows, num_cols, input, row_offsets, output_data, + output_scales, output_per_token_scales, scale_stride); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace quantize_pertoken_kernel +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 554d8c1ac9..0661fc5454 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -453,6 +453,23 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! \brief Per-token NVFP4 quantization. + * + * Quantizes an input tensor to NVFP4 with per-row (per-token) global scaling. + * Each row gets its own FP32 global scale derived from its row-wise amax. + * + * \param[in] input Input tensor (num_rows, num_cols). + * \param[out] output_data Packed FP4 data (num_rows, num_cols/2), uint8. + * \param[out] output_scales Block scales (num_rows, num_cols/16), FP8 E4M3. + * \param[out] output_per_token_scales Per-row global scales (num_rows,), FP32. + * \param[in] num_rows Number of rows. + * \param[in] num_cols Number of columns (must be multiple of 16). + * \param[in] stream CUDA stream. + */ +void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data, + NVTETensor output_scales, NVTETensor output_per_token_scales, + size_t num_rows, size_t num_cols, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..033e9b82ac 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -96,6 +96,11 @@ def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" return issubclass(cls, NVFP4BlockScaling) + @classmethod + def nvfp4_pertoken(cls): + """Whether the given recipe is NVFP4 per-token block scaling.""" + return issubclass(cls, NVFP4PerTokenBlockScaling) + @classmethod def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" @@ -540,6 +545,47 @@ def __repr__(self) -> str: ) +@dataclass() +class NVFP4PerTokenBlockScaling(NVFP4BlockScaling): + """ + NVFP4 with per-token (per-row) global scaling. + + Extends NVFP4BlockScaling by computing a separate FP32 global scale factor + for each token row, rather than a single per-tensor global scale. This + preserves more dynamic range information per token, improving accuracy + for MoE grouped GEMM workloads. + + The forward pass uses cuDNN Frontend's grouped GEMM kernels with the + ``global_scale_tensor`` parameter to apply per-token scales. The backward + pass is controlled by ``backward_override``: + + - ``None``: Use standard NVFP4 backward (default) + - ``'high_precision'``: Keep original high-precision operands for backward + - ``'dequantized'``: Dequantize saved operands to BF16/FP32 for backward + + Parameters + ---------- + fp4_format : {Format.E2M1}, default = Format.E2M1 + FP4 data type. + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. Inherited from NVFP4BlockScaling. + disable_rht : bool, default = False + If set to `True`, random Hadamard transforms are not applied. + disable_stochastic_rounding : bool, default = False + If set to `True`, stochastic rounding is disabled. + disable_2d_quantization : bool, default = False + If set to `True`, 1D block scaling with block size 16 is used for all tensors. + + Notes + ----- + The per-token quantization kernel is a placeholder. Currently, the per-tensor + amax is broadcast to all tokens as an approximation. A true per-token kernel + (``quantize_pertoken_nvfp4.cuh``) will compute row-wise amax for optimal accuracy. + """ + + pass + + @dataclass() class CustomRecipe(Recipe): """ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fb5783dfcb..eefbc2fdc7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -314,6 +314,8 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); +std::tuple quantize_nvfp4_pertoken(at::Tensor input); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5fb162c72d..44a9eeaf3a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1556,5 +1556,47 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } +std::tuple quantize_nvfp4_pertoken(at::Tensor input) { + // Input validation + NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); + NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); + NVTE_CHECK(input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Half, + "Input must be BFloat16 or Half"); + + const int num_rows = input.size(0); + const int num_cols = input.size(1); + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + if (num_rows == 0) { + auto options = input.options(); + return {at::empty({0, num_cols / 2}, options.dtype(at::kByte)), + at::empty({0, num_cols / 16}, options.dtype(at::kByte)), + at::empty({0}, options.dtype(at::kFloat))}; + } + + auto input_contig = input.contiguous(); + auto options = input_contig.options(); + + // Allocate outputs + auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); + auto output_scales = at::empty({num_rows, num_cols / 16}, options.dtype(at::kByte)); + auto output_per_token_scales = at::empty({num_rows}, options.dtype(at::kFloat)); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // Call C API + auto te_input = makeTransformerEngineTensor(input_contig); + auto te_data = makeTransformerEngineTensor(output_data); + auto te_scales = makeTransformerEngineTensor(output_scales); + auto te_pertoken = makeTransformerEngineTensor(output_per_token_scales); + + nvte_quantize_nvfp4_pertoken(te_input.data(), te_data.data(), te_scales.data(), + te_pertoken.data(), num_rows, num_cols, stream); + + return {output_data, output_scales, output_per_token_scales}; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 27d26d3dab..a7ca590478 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,6 +145,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + m.def("quantize_nvfp4_pertoken", transformer_engine::pytorch::quantize_nvfp4_pertoken, + "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index e21915a5a6..cd9b68c1e8 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -156,7 +156,7 @@ def fuse_grouped_mlp_ops( if not fused_op_cls.is_supported(): return ops - if recipe is None or not recipe.mxfp8(): + if recipe is None or not (recipe.mxfp8() or recipe.nvfp4_pertoken()): return ops fc1_bias_ok = ( diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19a090f121..06197db66f 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -33,6 +33,7 @@ # Note: Registration logic is non-trivial, so submodule handles it internally. from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, ) from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 90c4204f06..17d66f906b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -19,8 +19,9 @@ from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer -from ...constants import MXFP8_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU +from ...tensor.nvfp4_tensor import NVFP4Quantizer +from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -543,6 +544,476 @@ def fuser_forward( return fc2_out, [(), (), ()] +class ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4(FusedOperation): + """Fused op for NVFP4 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end with NVFP4 + (FP4 E2M1 data + FP8 E4M3 block scales + FP32 per-token global scale). + + Forward pass only. Backward falls back to the unfused path. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_glu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, GLU activation, and post-multiplication.""" + from cudnn import grouped_gemm_glu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_glu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_quant_kernel(cls) -> Callable: + """Grouped GEMM quant kernel for block-scaled inputs.""" + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_quant_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if int(os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4", "0")) <= 0: + return False + if get_device_compute_capability()[0] != 10: + return False + try: + cls.grouped_gemm_glu_kernel() + cls.grouped_gemm_quant_kernel() + except ImportError: + return False + return True + + def __init__(self, ops: tuple[FusibleOperation, ...]) -> None: + super().__init__(ops) + fc1, swiglu, fc2 = ops + if not isinstance(fc1, GroupedLinear): + raise TypeError(f"Expected GroupedLinear for FC1, got {type(fc1).__name__}") + if not isinstance(swiglu, ScaledSwiGLU): + raise TypeError(f"Expected ScaledSwiGLU, got {type(swiglu).__name__}") + if not isinstance(fc2, GroupedLinear): + raise TypeError(f"Expected GroupedLinear for FC2, got {type(fc2).__name__}") + validate_grouped_mlp_dims(fc1, swiglu, fc2) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + input_ = input_.reshape(-1, fc1_weight_shape[1]) + in_shape = list(input_.size()) + + num_groups = fc1_op.num_groups + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + device = fc1_weight_param.device + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_weight_param.dtype + + # Check which grads are required + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + weight_requires_grad = requires_grad and ( + fc1_weight_param.requires_grad or fc2_weight_param.requires_grad + ) + + # Quantizers + fc1_input_quantizer = fc1_op.get_quantizer("forward", 0) + fc1_weight_quantizer = fc1_op.get_quantizer("forward", 1) + fc1_grad_output_quantizer = fc1_op.get_quantizer("backward", 0) + fc2_input_quantizer = fc2_op.get_quantizer("forward", 0) + fc2_weight_quantizer = fc2_op.get_quantizer("forward", 1) + fc2_grad_output_quantizer = fc2_op.get_quantizer("backward", 0) + + # Extract split sizes from extra input + fc1_split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + fc1_split_sizes.size() != fc2_split_sizes.size() + or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError( + f"{self.__class__.__name__} got different split points for FC1 and FC2." + ) + split_sizes = fc1_split_sizes + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + split_sizes = split_sizes.to(dtype=torch.int64, device=device) + split_points = torch.cumsum(split_sizes, 0, dtype=torch.int) + split_points_offsets = torch.cumsum(split_sizes, 0) + base_offsets = torch.cat( + [ + torch.zeros(1, device=split_sizes.device, dtype=split_sizes.dtype), + split_points_offsets, + ] + ) + fc1_x_tensor_offsets = base_offsets * fc1_weight_shape[1] + fc2_x_tensor_offsets = base_offsets * fc2_weight_shape[1] + + # Extract post-scales from extra input + scales = basic_op_extra_inputs[1][0] + + # Prepare FC1 grouped weight tensor for fused kernels. + if fc1_op.single_grouped_weight: + if not isinstance(fc1_op.weight, GroupedTensor): + raise RuntimeError( + "FC1 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc1_op.weight.quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc1_op.weight.quantizer = fc1_weight_quantizer + grouped_fc1_weight = fc1_op.weight + else: + if fc1_op.weight.rowwise_data is None: + raise RuntimeError("FC1 grouped weight has no rowwise_data to quantize.") + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc1_weight = tex.group_quantize( + fc1_op.weight.rowwise_data.view(fc1_op.weight.logical_shape), + fc1_weight_quantizer, + num_groups, + None, + ) + else: + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc1_weights = [] + for idx, weight in enumerate(fc1_weights): + quantizer = fc1_op.get_quantizer("forward", 2 * idx + 1) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc1_weights.append(quantizer(weight)) + else: + quantized_fc1_weights.append(weight) + grouped_fc1_weight = quantized_fc1_weights + + # Prepare FC2 grouped weight tensor for fused kernels. + if fc2_op.single_grouped_weight: + if not isinstance(fc2_op.weight, GroupedTensor): + raise RuntimeError( + "FC2 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc2_op.weight.quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc2_op.weight.quantizer = fc2_weight_quantizer + grouped_fc2_weight = fc2_op.weight + else: + if fc2_op.weight.rowwise_data is None: + raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc2_weight = tex.group_quantize( + fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), + fc2_weight_quantizer, + num_groups, + None, + ) + else: + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc2_weights = [] + for idx, weight in enumerate(fc2_weights): + quantizer = fc2_op.get_quantizer("forward", 2 * idx + 1) + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc2_weights.append(quantizer(weight)) + else: + quantized_fc2_weights.append(weight) + grouped_fc2_weight = quantized_fc2_weights + + # Enforce default swizzle metadata + if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc1_weight, GroupedTensor + ): + grouped_fc1_weight._with_gemm_swizzled_scales = False + if getattr(grouped_fc2_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc2_weight, GroupedTensor + ): + grouped_fc2_weight._with_gemm_swizzled_scales = False + + # Group-quantize input tensor (NVFP4) + fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc1_input_quantizer.optimize_for_gemm = True + if isinstance(input_, GroupedTensor) and isinstance( + getattr(input_, "quantizer", None), NVFP4Quantizer + ): + grouped_fc1_x = input_ + else: + fc1_x = maybe_dequantize(input_, dtype) + grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) + + # Pack data tensors for cuDNN kernel + # NVFP4: data is uint8 (packed FP4), reinterpret as float4_e2m1fn_x2 + # Scales are uint8, reinterpret as float8_e4m3fn + # Block size is 16 (NVFP4_BLOCK_SCALING_SIZE) + fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1] // 2) + fc1_x_data = fc1_x_data.view(dtype=torch.float4_e2m1fn_x2) + fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + + fc1_x_scales = grouped_fc1_x.scale_inv + fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e4m3fn) + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + in_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + + # Per-token global scale. + # The per-token NVFP4 kernel (tex.quantize_nvfp4_pertoken) produces + # data + block_scales + per_token_scales in one pass. Here we call it + # to get the per-token scales. The quantized data from group_quantize + # (above) is used for the GEMM since it handles grouped layout/swizzle. + # TODO: Unify into a single quantization call once the grouped per-token + # kernel supports the full TE scale factor layout. + global_scale_tensor = None + try: + _, _, fc1_per_token_scales = tex.quantize_nvfp4_pertoken( + fc1_x.reshape(in_shape[0], in_shape[1]) + if not isinstance(input_, GroupedTensor) + else input_.dequantize(dtype=dtype).reshape(in_shape[0], in_shape[1]) + ) + global_scale_tensor = fc1_per_token_scales.reshape(-1, 1, 1) + except (AttributeError, RuntimeError): + # Fallback: per-tensor amax broadcast to all tokens + nvfp4_amax = grouped_fc1_x.amax + if nvfp4_amax is not None and nvfp4_amax.numel() == 1: + fp4_max = 6.0 + fp8_max = 448.0 + global_scale_val = nvfp4_amax.float() / (fp4_max * fp8_max) + global_scale_tensor = global_scale_val.expand(in_shape[0]).reshape(-1, 1, 1) + + alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) + norm_const_tensor = get_cached_ones_tensor(1, dtype, device) + current_stream = torch.cuda.current_stream().cuda_stream + + fc1_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc1_op) + fc2_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc2_op) + + fc1_glu_kwargs = { + "a_tensor": fc1_x_data, + "sfa_tensor": fc1_x_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "bias_tensor": fc1_bias_packed, + "norm_const_tensor": norm_const_tensor, + "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "global_scale_tensor": global_scale_tensor, + "acc_dtype": torch.float32, + "c_dtype": torch.bfloat16, + "d_dtype": torch.bfloat16, # NVFP4 output stays BF16 (no FP8 re-quant for FC2 input) + "cd_major": "n", + "sf_vec_size": NVFP4_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": False, + "act_func": "swiglu", + "use_dynamic_sched": True, + } + + if fc1_op.single_grouped_weight: + fc1_weight_for_gemm = grouped_fc1_weight.copy() + tex.grouped_swizzle_for_gemm(fc1_weight_for_gemm, rowwise=True, columnwise=False) + + fc1_w_data = fc1_weight_for_gemm.rowwise_data + fc1_w_data = fc1_w_data.view(dtype=torch.float4_e2m1fn_x2) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1] // 2) + fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn) + fc1_w_scales = fc1_w_scales.view( + num_groups, + fc1_weight_shape[0] // 128, + fc1_weight_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc1_glu_kwargs["b_tensor"] = fc1_w_data + fc1_glu_kwargs["sfb_tensor"] = fc1_w_scales + else: + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc1_weight], + [w._rowwise_scale_inv for w in grouped_fc1_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc1_weight[0]._fp8_dtype, + ) + fc1_glu_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_glu_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_glu_kwargs["n"] = fc1_weight_shape[0] + fc1_glu_kwargs["b_dtype"] = torch.float4_e2m1fn_x2 + fc1_glu_kwargs["b_major"] = "k" + + fc1_kernel_out = self.grouped_gemm_glu_kernel()(**fc1_glu_kwargs) + + # Unpack FC1 kernel outputs + # NVFP4 FC1 output is BF16 (no SFD generation needed for FC2) + swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0]) + + fc2_in_data = fc1_kernel_out["d_tensor"] + fc2_in_data = fc2_in_data.view(in_shape[0], fc2_weight_shape[1]) + + # FC2 GEMM: input is BF16 from FC1 output, needs re-quantization to NVFP4 + # For now, quantize the BF16 FC2 input to NVFP4 for the quant kernel + fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc2_input_quantizer.optimize_for_gemm = True + grouped_fc2_x = tex.group_quantize( + fc2_in_data, fc2_input_quantizer, num_groups, split_sizes + ) + + fc2_x_data = grouped_fc2_x.rowwise_data.view(in_shape[0], fc2_weight_shape[1] // 2) + fc2_x_data = fc2_x_data.view(dtype=torch.float4_e2m1fn_x2) + fc2_x_data = fc2_x_data.unsqueeze(0).permute(1, 2, 0) + + fc2_x_scales = grouped_fc2_x.scale_inv + fc2_x_scales = fc2_x_scales.view(dtype=torch.float8_e4m3fn) + fc2_x_scales = fc2_x_scales.view( + 1, + in_shape[0] // 128, + fc2_weight_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_x_scales = fc2_x_scales.permute(3, 4, 1, 5, 2, 0) + + # FC2 per-token global scale + fc2_nvfp4_amax = grouped_fc2_x.amax + if fc2_nvfp4_amax is not None and fc2_nvfp4_amax.numel() == 1: + fp4_max = 6.0 + fp8_max = 448.0 + fc2_gs_val = fc2_nvfp4_amax.float() / (fp4_max * fp8_max) + fc2_global_scale = fc2_gs_val.expand(in_shape[0]).reshape(-1, 1, 1) + else: + fc2_global_scale = None + + fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] + fc2_quant_kwargs = { + "a_tensor": fc2_x_data, + "sfa_tensor": fc2_x_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor.float(), + "norm_const_tensor": None, + "prob_tensor": torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device), + "global_scale_tensor": fc2_global_scale, + "acc_dtype": torch.float32, + "c_dtype": dtype, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": NVFP4_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "use_dynamic_sched": True, + } + + if fc2_op.single_grouped_weight: + fc2_weight_for_gemm = grouped_fc2_weight.copy() + tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) + + fc2_w_data = fc2_weight_for_gemm.rowwise_data + fc2_w_data = fc2_w_data.view(dtype=torch.float4_e2m1fn_x2) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1] // 2) + fc2_w_data = fc2_w_data.permute(1, 2, 0) + + fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn) + fc2_w_scales = fc2_w_scales.view( + num_groups, + fc2_weight_shape[0] // 128, + fc2_weight_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_quant_kwargs["b_tensor"] = fc2_w_data + fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + else: + fc2_b_ptrs, fc2_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc2_weight], + [w._rowwise_scale_inv for w in grouped_fc2_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc2_weight[0]._fp8_dtype, + ) + fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_quant_kwargs["n"] = fc2_weight_shape[0] + fc2_quant_kwargs["b_dtype"] = torch.float4_e2m1fn_x2 + fc2_quant_kwargs["b_major"] = "k" + + fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) + fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() + + # Save state for backward pass + if requires_grad: + mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) + fc1_input_tensors = ( + grouped_fc1_x.columnwise_data, + grouped_fc1_x.columnwise_scale_inv, + fc1_x_tensor_offsets, + ) + fc1_weight_tensors = ( + [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight + ) + fc1_ctx.save_for_backward( + split_sizes, split_points, *fc1_weight_tensors, *fc1_input_tensors + ) + fc1_ctx.with_quantized_compute = True + fc1_ctx.input_quantizer = fc1_input_quantizer + fc1_ctx.weight_quantizer = fc1_weight_quantizer + fc1_ctx.grad_output_quantizer = fc1_grad_output_quantizer + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + fc1_ctx.base_split_offsets = base_offsets + + swiglu_ctx.save_for_backward(swiglu_in, scales) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True + swiglu_ctx.dtype = dtype + + if grouped_fc2_x is not None: + fc2_input_tensors = ( + grouped_fc2_x.columnwise_data, + grouped_fc2_x.columnwise_scale_inv, + fc2_x_tensor_offsets, + ) + else: + fc2_input_tensors = (None, None, None) + + if fc2_op.single_grouped_weight: + fc2_ctx.save_for_backward(split_sizes, grouped_fc2_weight, *fc2_input_tensors) + else: + fc2_ctx.save_for_backward(split_sizes, *grouped_fc2_weight, *fc2_input_tensors) + + fc2_ctx.with_quantized_compute = True + fc2_ctx.input_quantizer = fc2_input_quantizer + fc2_ctx.weight_quantizer = fc2_weight_quantizer + fc2_ctx.grad_output_quantizer = fc2_grad_output_quantizer + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad + + return fc2_out, [(), (), ()] + + def fuse_forward_ops( ops: list[FusibleOperation], *, @@ -572,6 +1043,22 @@ def fuse_forward_ops( ) -# Register fusion if available +def fuse_forward_ops_nvfp4( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply NVFP4 operation fusion for forward pass.""" + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, + ) + + +# Register fusions if available if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): register_forward_fusion(fuse_forward_ops, prepend=True) +if ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.is_supported(): + register_forward_fusion(fuse_forward_ops_nvfp4, prepend=True) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..aefe8af39c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -24,6 +24,7 @@ Float8CurrentScaling, Float8BlockScaling, NVFP4BlockScaling, + NVFP4PerTokenBlockScaling, CustomRecipe, ) from .constants import dist_group_type @@ -1065,6 +1066,8 @@ def create( cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState + elif recipe.nvfp4_pertoken(): + cls = NVFP4PerTokenBlockScalingRecipeState elif recipe.nvfp4(): cls = NVFP4BlockScalingRecipeState elif recipe.custom(): @@ -1396,6 +1399,21 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: raise RuntimeError(f"Unexpected recipe mode ({self.mode})") +class NVFP4PerTokenBlockScalingRecipeState(NVFP4BlockScalingRecipeState): + """State for NVFP4PerTokenBlockScaling recipe. + + Inherits all quantizer creation logic from NVFP4BlockScalingRecipeState. + The per-token global scale is handled at the fused op level (in the cuDNN + kernel via global_scale_tensor), not in the quantizer itself. The quantizer + still produces per-tensor amax which is broadcast to per-token in the fused op. + + Once the per-token quantization kernel (quantize_pertoken_nvfp4.cuh) is + implemented, the quantizer will produce per-row amax directly. + """ + + pass + + class CustomRecipeState(RecipeState): """State for CustomRecipe: produce quantizers per tensor."""