Make TE Sequential Grouped linear Op CUDA graphable#2923
Make TE Sequential Grouped linear Op CUDA graphable#2923vthumbe1503 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR refactors Confidence Score: 3/5Safe to merge after addressing the missing contiguity error-handling and confirming the weight_requires_grad intent; no data-corruption risk identified. All findings are P2: a missing try/except around main_grad.view(-1), a misleading comment about component-buffer saving, an undocumented behavioral change removing ctx.requires_grad from the weight_requires_grad guard, and trailing whitespace. No silent gradient corruption or data-loss path was found. transformer_engine/pytorch/ops/basic/grouped_linear.py — specifically _fuser_backward_grouped_tensor (main_grad view) and the weight_requires_grad change in fuser_forward. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
FW["fuser_forward()"] --> DISPATCH{"use_grouped_tensor_path?"}
DISPATCH -->|"BF16/FP16 or MXFP8"| GT["_fuser_forward_grouped_tensor()"]
DISPATCH -->|"FP32 / FP8 delayed/current/block / NVFP4"| SQ["_fuser_forward_split_quantize()"]
GT --> GGEMM["general_grouped_gemm_for_grouped_tensor (TN)"]
SQ --> LEGACY["general_grouped_gemm (TN, m_splits on CPU)"]
GGEMM --> OUT["return output"]
LEGACY --> OUT
BW["fuser_backward()"] --> DISP_BW{"ctx.use_grouped_tensor_path?"}
DISP_BW -->|True| GTB["_fuser_backward_grouped_tensor()"]
DISP_BW -->|False| SQB["_fuser_backward_split_quantize()"]
GTB --> DGRAD["dgrad GEMM (NN, grouped)"]
GTB --> WGRAD["wgrad GEMM (NT, grouped)"]
SQB --> DGRAD2["dgrad GEMM (NN, m_splits)"]
SQB --> WGRAD2["wgrad GEMM (NT, m_splits)"]
|
| bias_scale: Optional[torch.Tensor] = None | ||
| if has_bias: | ||
| # Bias always needs to be passed as a GroupedTensor for the grouped GEMM. | ||
| grouped_bias = self._get_grouped_bias_for_gemm(dtype, device) | ||
| if self._scale_bias: | ||
| bias_scale = scales.reshape(-1) | ||
| if bias_scale.dtype != torch.float32: | ||
| bias_scale = bias_scale.to(dtype=torch.float32) | ||
|
|
||
| # Forward grouped GEMM (TN layout: out[i] = x[i] @ w[i]^T) | ||
| general_grouped_gemm_for_grouped_tensor( | ||
| grouped_weights, | ||
| grouped_x, | ||
| grouped_out, |
There was a problem hiding this comment.
Missing contiguity error handling for
main_grad.view(-1)
main_grad.view(-1) will raise a generic RuntimeError if main_grad is non-contiguous (e.g. when returned by get_main_grad() via __fsdp_param__). The equivalent code in backward_grouped_mlp.py wraps the reshape in try/except and re-raises with an actionable message that includes the shape and stride. Without that guard, users hitting this case will see an opaque PyTorch error instead of a clear diagnostic.
| if ctx.requires_grad: | ||
| saved: list[Optional[torch.Tensor]] = [split_sizes, base_offsets] | ||
| if self._scale_bias: | ||
| saved.append(scales) | ||
| # For the wgrad input we save (data, scale_inv). | ||
| # * Quantized path saves columnwise data + scale. | ||
| # * Unquantized path saves the raw rowwise data and a None scale. | ||
| if grouped_x is not None: | ||
| if with_quantized_compute: | ||
| saved.extend( | ||
| [ | ||
| grouped_x.columnwise_data, | ||
| grouped_x.columnwise_scale_inv, | ||
| ] | ||
| ) | ||
| else: | ||
| saved.extend([grouped_x.rowwise_data, None]) | ||
| else: | ||
| saved.extend([None, None]) | ||
| if self.single_grouped_weight: | ||
| saved.append(grouped_weights) | ||
| else: | ||
| saved.extend(grouped_weights) | ||
| ctx.save_for_backward(*saved) | ||
| ctx.use_grouped_tensor_path = True | ||
| ctx.with_quantized_compute = with_quantized_compute | ||
| ctx.input_quantizers = input_quantizers | ||
| ctx.weight_quantizers = weight_quantizers |
There was a problem hiding this comment.
Comment contradicts implementation for weight saving
The block comment says "we save the GroupedTensor's component buffers (rather than the wrapper) and rebuild it in backward" — but the code that follows saves the entire GroupedTensor wrapper for grouped_weights (when single_grouped_weight=True, saved.append(grouped_weights)). Component-buffer saving only applies to grouped_x (which saves columnwise_data/rowwise_data). The misleading comment could cause confusion when debugging or extending this path.
| input_requires_grad = ctx.requires_grad | ||
| weight_requires_grad = ctx.requires_grad and weight_param.requires_grad | ||
| weight_requires_grad = weight_param.requires_grad | ||
|
|
||
| # Quantizers |
There was a problem hiding this comment.
Behavior change:
weight_requires_grad no longer gated on ctx.requires_grad
The old formula was weight_requires_grad = ctx.requires_grad and weight_param.requires_grad. Removing the ctx.requires_grad gate means that if input tensors don't require grad (ctx.requires_grad = False) but the weight does (weight_param.requires_grad = True), the new code will now save additional tensors (columnwise data, saved weights) for a backward that will never be called. If this change is intentional (e.g. to fix wgrad when only weights need grad), please add a comment explaining why ctx.requires_grad is deliberately excluded here.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
| # (Megatron-FSDP) or when wgrad is delayed. | ||
| if fc_op.single_grouped_weight: | ||
| packed_wgrad = None | ||
| if not delay_wgrad: |
There was a problem hiding this comment.
should this line be like this?:
if not delay_wgrad and not fc_op._accumulate_into_main_grad:
packed_wgrad = grouped_wgrad.rowwise_data.view(num_groups, *weight_shape)
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: