Skip to content

[PyTorch] Fix stale columnwise data usage#2925

Open
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:fix_columnwise_usage_after_eval
Open

[PyTorch] Fix stale columnwise data usage#2925
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:fix_columnwise_usage_after_eval

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

This PR sets columnwise usage correctly for all quantizers instead of retaining the value in the quantizer, which may be incorrect after resuming training post validation steps as the columnwise usage is set to False for eval mode.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Set columnwise usage explicitly instead of retaining the value in the quantizer.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from timmoon10 April 25, 2026 01:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 25, 2026

Greptile Summary

This PR fixes stale columnwise_usage state on weight quantizers by changing elif conditions to if across linear.py, layernorm_linear.py, layernorm_mlp.py, and grouped_linear.py, so set_usage is always called with the correct value rather than relying on whatever was last written to the quantizer (which could be False after an eval pass). The FSDP2 tensor-level fix replaces the stale self._quantizer.columnwise_usage read in all three tensor types with is_backward_pass or torch.is_grad_enabled().

  • P1 — linear.py: The old elif isinstance(weight, QuantizedTensor): weight_quantizer = weight._quantizer defensive path is gone. If weight_quantizer arrives as None while the weight is already a QuantizedTensor, quantize_weight immediately dereferences quantizer.rowwise_usage and raises AttributeError. Restoring the elif assignment (without calling set_usage on it, since it is None) preserves the original safety net.

Confidence Score: 3/5

Mostly safe bug fix, but one module-level regression in linear.py drops a defensive assignment that guards against an AttributeError in quantize_weight.

The core fix (elif→if and FSDP2 tensor changes) is correct and well-tested for the primary code paths. A P1 regression in linear.py removes a guard for the weight_quantizer=None + QuantizedTensor scenario, turning a silent fallback into a potential crash. While this path may be rarely hit in practice, the asymmetry with how layernorm_mlp.py and layernorm_linear.py handle the same pattern makes linear.py's divergence worth addressing.

transformer_engine/pytorch/module/linear.py — dropped defensive quantizer assignment before quantize_weight call

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Correctly fixes stale columnwise_usage by always calling set_usage when weight_quantizer is non-None, but the old elif isinstance(weight, QuantizedTensor) fallback that prevents a crash in quantize_weight when weight_quantizer is None was dropped.
transformer_engine/pytorch/module/layernorm_linear.py Clean fix: elifif ensures set_usage is called unconditionally when weight_quantizer is not None, including after the weight-quantizer re-assignment for pre-quantized weights; FSDP2 guard (is_fsdp2 flag) preserved.
transformer_engine/pytorch/module/layernorm_mlp.py Same correct elifif pattern applied to both fc1 and fc2 weight quantizers, ensuring set_usage propagates after the weight-quantizer re-assignment step.
transformer_engine/pytorch/module/grouped_linear.py Correctly restructures condition so set_usage is always called on the true per-weight quantizers (extracted from the pre-quantized weights) rather than skipping when weights are already quantized.
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds set_usage propagation to pre-quantized weight tensors inside GroupedTensor.quantized_tensors and in the discrete-weight path; FC2 non-grouped path has a pre-existing redundant set_usage call now made more visible.
transformer_engine/pytorch/tensor/float8_tensor.py Replaces stale self._quantizer.columnwise_usage with is_backward_pass or torch.is_grad_enabled() for the FSDP2 reshard_after_forward=False path; also extracts training_state/is_backward_pass before the if reshard_after_forward branch.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Same FSDP2 reshard_after_forward=False fix as float8_tensor.py; clean and symmetric change.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Replaces stale quantizer state with is_backward_pass or torch.is_grad_enabled(); adds a RuntimeError when columnwise data is absent but needed — the error message could hint at wrapping eval loops with torch.no_grad() for clarity.
tests/pytorch/test_sanity.py Adds a focused regression test (test_quantizer_columnwise_usage_after_eval) covering the train→eval→train columnwise state scenario for all four module types across all three quantization recipes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Forward pass enters module] --> B{weight is QuantizedTensorStorage AND not debug?}
    B -- Yes --> C[weight_quantizer = weight._quantizer]
    B -- No --> D{weight_quantizer is not None?}
    C --> D
    D -- Yes --> E[set_usage rowwise=True columnwise=is_grad_enabled ...]
    D -- No --> F[skip set_usage]
    E --> G[quantize_weight]
    F --> G
    G --> H{FSDP2 reshard_after_forward?}
    H -- Yes --> I[columnwise = is_backward_pass]
    H -- No --> J[columnwise = is_backward_pass OR grad_enabled]
    J --> K{columnwise=True but _columnwise_data is None?}
    K -- Yes --> L[RuntimeError raised - mxfp8_tensor only]
    K -- No --> M[all-gather sharded tensors]
    I --> M
    M --> N[GEMM forward]
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py, line 274-278 (link)

    P2 Redundant set_usage call in the FC2 non-grouped path

    quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) appears unconditionally at the start of the loop iteration (line 275) and then again inside the if not is_quantized_tensor(weight) branch (line 277), so for non-quantized weights the call is made twice with the same arguments. The outer call is unnecessary and can be removed.

Reviews (1): Last reviewed commit: "Fix columnwise_usage after eval mode" | Re-trigger Greptile

Comment on lines +264 to +266
if weight_quantizer is not None:
if isinstance(weight, QuantizedTensor) and not debug:
weight_quantizer = weight._quantizer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Dropped defensive weight_quantizer assignment loses the quantize_weight call

The original elif isinstance(weight, QuantizedTensor): weight_quantizer = weight._quantizer handled the case where weight_quantizer arrives as None while weight is already a QuantizedTensor. In that path, quantize_weight immediately dereferences quantizer.rowwise_usage (line 710 of base.py) and will raise AttributeError: 'NoneType' object has no attribute 'rowwise_usage'.

The new code only re-assigns weight_quantizer when it is already non-None, so the previously guarded scenario now crashes instead of falling back to the weight's own quantizer. The missing assignment should be:

if weight_quantizer is not None:
    if isinstance(weight, QuantizedTensor) and not debug:
        weight_quantizer = weight._quantizer
    columnwise_usage = ...
    weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
    # weight_quantizer is None but weight is pre-quantized — pick up its quantizer
    weight_quantizer = weight._quantizer

Comment on lines +684 to +693
columnwise_usage = is_backward_pass or torch.is_grad_enabled()
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
if self._columnwise_data is None or columnwise_scale_inv is None:
raise RuntimeError(
"FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data "
"for the upcoming backward pass, but the local shard has none. "
"Ensure the weight is quantized with columnwise_usage=True before "
"this all-gather."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 torch.is_grad_enabled() can be True during eval without torch.no_grad()

model.eval() alone does not disable the gradient tape — torch.is_grad_enabled() stays True unless the caller wraps the eval loop with torch.no_grad(). In that situation columnwise_usage becomes True, but the local shard may not have _columnwise_data (it was never quantized with columnwise support during eval), so the new RuntimeError fires.

Users who ran eval with grads enabled previously got silently incorrect (stale) data; they now get a hard crash. While the crash is more correct, the error message could guide them:

Suggested change
columnwise_usage = is_backward_pass or torch.is_grad_enabled()
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
if self._columnwise_data is None or columnwise_scale_inv is None:
raise RuntimeError(
"FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data "
"for the upcoming backward pass, but the local shard has none. "
"Ensure the weight is quantized with columnwise_usage=True before "
"this all-gather."
)
if self._columnwise_data is None or columnwise_scale_inv is None:
raise RuntimeError(
"FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data "
"for the upcoming backward pass, but the local shard has none. "
"Ensure the weight is quantized with columnwise_usage=True before "
"this all-gather. If you are running evaluation without requiring "
"gradients, wrap the eval loop with torch.no_grad()."
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant