Skip to content

Make TE Sequential Grouped linear Op CUDA graphable#2923

Draft
vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_linear_integration_v2
Draft

Make TE Sequential Grouped linear Op CUDA graphable#2923
vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_linear_integration_v2

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as draft April 24, 2026 20:05
@vthumbe1503 vthumbe1503 changed the title Grouped linear integration v2 Make Grouped linear TE Sequential Op CUDA graphable Apr 24, 2026
@vthumbe1503 vthumbe1503 changed the title Make Grouped linear TE Sequential Op CUDA graphable Make TE Sequential Grouped linear Op CUDA graphable Apr 24, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR refactors GroupedLinear to support a new CUDA-graph-safe forward/backward path based on GroupedTensor and general_grouped_gemm_for_grouped_tensor, while keeping the legacy tex.split_quantize + general_grouped_gemm path for quantization recipes that require CPU-side split sizes (FP8 delayed/current/block scaling, NVFP4). The dispatch selects the new path for BF16/FP16 and MXFP8; it also unifies single_grouped_weight and single_grouped_bias handling across both paths and fixes the Megatron-LM main_grad fusion logic in backward_grouped_mlp.py.

Confidence Score: 3/5

Safe to merge after addressing the missing contiguity error-handling and confirming the weight_requires_grad intent; no data-corruption risk identified.

All findings are P2: a missing try/except around main_grad.view(-1), a misleading comment about component-buffer saving, an undocumented behavioral change removing ctx.requires_grad from the weight_requires_grad guard, and trailing whitespace. No silent gradient corruption or data-loss path was found.

transformer_engine/pytorch/ops/basic/grouped_linear.py — specifically _fuser_backward_grouped_tensor (main_grad view) and the weight_requires_grad change in fuser_forward.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py Major refactor splitting fuser_forward/fuser_backward into two dispatch paths: a new graph-safe GroupedTensor path and the legacy split_quantize path. Missing contiguity error handling for main_grad.view(-1) and a misleading comment about component-buffer saving.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Introduces request_main_grad_fusion to decouple user-facing opt-in from the GEMM accumulate flag; gates post-GEMM grad_added_to_main_grad bookkeeping on the user request. Changes are targeted and well-commented.
tests/pytorch/test_fusible_ops.py Extends test_grouped_linear with single_grouped_weight/single_grouped_bias parameters and adds test_grouped_linear_cuda_graph_safe. Skip conditions and grad assertions are correctly handled.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    FW["fuser_forward()"] --> DISPATCH{"use_grouped_tensor_path?"}
    DISPATCH -->|"BF16/FP16 or MXFP8"| GT["_fuser_forward_grouped_tensor()"]
    DISPATCH -->|"FP32 / FP8 delayed/current/block / NVFP4"| SQ["_fuser_forward_split_quantize()"]
    GT --> GGEMM["general_grouped_gemm_for_grouped_tensor (TN)"]
    SQ --> LEGACY["general_grouped_gemm (TN, m_splits on CPU)"]
    GGEMM --> OUT["return output"]
    LEGACY --> OUT
    BW["fuser_backward()"] --> DISP_BW{"ctx.use_grouped_tensor_path?"}
    DISP_BW -->|True| GTB["_fuser_backward_grouped_tensor()"]
    DISP_BW -->|False| SQB["_fuser_backward_split_quantize()"]
    GTB --> DGRAD["dgrad GEMM (NN, grouped)"]
    GTB --> WGRAD["wgrad GEMM (NT, grouped)"]
    SQB --> DGRAD2["dgrad GEMM (NN, m_splits)"]
    SQB --> WGRAD2["wgrad GEMM (NT, m_splits)"]
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/basic/grouped_linear.py, line 328-331 (link)

    P2 Trailing whitespace after closing paren

    There is trailing whitespace on line 331 after the closing ) of _is_grouped_quantize_supported. Most linters and CI checks flag this.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +1157 to +1170
bias_scale: Optional[torch.Tensor] = None
if has_bias:
# Bias always needs to be passed as a GroupedTensor for the grouped GEMM.
grouped_bias = self._get_grouped_bias_for_gemm(dtype, device)
if self._scale_bias:
bias_scale = scales.reshape(-1)
if bias_scale.dtype != torch.float32:
bias_scale = bias_scale.to(dtype=torch.float32)

# Forward grouped GEMM (TN layout: out[i] = x[i] @ w[i]^T)
general_grouped_gemm_for_grouped_tensor(
grouped_weights,
grouped_x,
grouped_out,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing contiguity error handling for main_grad.view(-1)

main_grad.view(-1) will raise a generic RuntimeError if main_grad is non-contiguous (e.g. when returned by get_main_grad() via __fsdp_param__). The equivalent code in backward_grouped_mlp.py wraps the reshape in try/except and re-raises with an actionable message that includes the shape and stride. Without that guard, users hitting this case will see an opaque PyTorch error instead of a clear diagnostic.

Comment on lines +1191 to 1218
if ctx.requires_grad:
saved: list[Optional[torch.Tensor]] = [split_sizes, base_offsets]
if self._scale_bias:
saved.append(scales)
# For the wgrad input we save (data, scale_inv).
# * Quantized path saves columnwise data + scale.
# * Unquantized path saves the raw rowwise data and a None scale.
if grouped_x is not None:
if with_quantized_compute:
saved.extend(
[
grouped_x.columnwise_data,
grouped_x.columnwise_scale_inv,
]
)
else:
saved.extend([grouped_x.rowwise_data, None])
else:
saved.extend([None, None])
if self.single_grouped_weight:
saved.append(grouped_weights)
else:
saved.extend(grouped_weights)
ctx.save_for_backward(*saved)
ctx.use_grouped_tensor_path = True
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizers = input_quantizers
ctx.weight_quantizers = weight_quantizers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Comment contradicts implementation for weight saving

The block comment says "we save the GroupedTensor's component buffers (rather than the wrapper) and rebuild it in backward" — but the code that follows saves the entire GroupedTensor wrapper for grouped_weights (when single_grouped_weight=True, saved.append(grouped_weights)). Component-buffer saving only applies to grouped_x (which saves columnwise_data/rowwise_data). The misleading comment could cause confusion when debugging or extending this path.

Comment on lines 875 to 878
input_requires_grad = ctx.requires_grad
weight_requires_grad = ctx.requires_grad and weight_param.requires_grad
weight_requires_grad = weight_param.requires_grad

# Quantizers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Behavior change: weight_requires_grad no longer gated on ctx.requires_grad

The old formula was weight_requires_grad = ctx.requires_grad and weight_param.requires_grad. Removing the ctx.requires_grad gate means that if input tensors don't require grad (ctx.requires_grad = False) but the weight does (weight_param.requires_grad = True), the new code will now save additional tensors (columnwise data, saved weights) for a backward that will never be called. If this change is intentional (e.g. to fix wgrad when only weights need grad), please add a comment explaining why ctx.requires_grad is deliberately excluded here.

vthumbe1503 and others added 4 commits April 25, 2026 00:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
# (Megatron-FSDP) or when wgrad is delayed.
if fc_op.single_grouped_weight:
packed_wgrad = None
if not delay_wgrad:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this line be like this?:

if not delay_wgrad and not fc_op._accumulate_into_main_grad:
    packed_wgrad = grouped_wgrad.rowwise_data.view(num_groups, *weight_shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants