[PyTorch] Fix stale columnwise data usage#2925
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR fixes stale
Confidence Score: 3/5Mostly safe bug fix, but one module-level regression in linear.py drops a defensive assignment that guards against an AttributeError in quantize_weight. The core fix (elif→if and FSDP2 tensor changes) is correct and well-tested for the primary code paths. A P1 regression in linear.py removes a guard for the weight_quantizer=None + QuantizedTensor scenario, turning a silent fallback into a potential crash. While this path may be rarely hit in practice, the asymmetry with how layernorm_mlp.py and layernorm_linear.py handle the same pattern makes linear.py's divergence worth addressing. transformer_engine/pytorch/module/linear.py — dropped defensive quantizer assignment before quantize_weight call Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Forward pass enters module] --> B{weight is QuantizedTensorStorage AND not debug?}
B -- Yes --> C[weight_quantizer = weight._quantizer]
B -- No --> D{weight_quantizer is not None?}
C --> D
D -- Yes --> E[set_usage rowwise=True columnwise=is_grad_enabled ...]
D -- No --> F[skip set_usage]
E --> G[quantize_weight]
F --> G
G --> H{FSDP2 reshard_after_forward?}
H -- Yes --> I[columnwise = is_backward_pass]
H -- No --> J[columnwise = is_backward_pass OR grad_enabled]
J --> K{columnwise=True but _columnwise_data is None?}
K -- Yes --> L[RuntimeError raised - mxfp8_tensor only]
K -- No --> M[all-gather sharded tensors]
I --> M
M --> N[GEMM forward]
|
| if weight_quantizer is not None: | ||
| if isinstance(weight, QuantizedTensor) and not debug: | ||
| weight_quantizer = weight._quantizer |
There was a problem hiding this comment.
Dropped defensive
weight_quantizer assignment loses the quantize_weight call
The original elif isinstance(weight, QuantizedTensor): weight_quantizer = weight._quantizer handled the case where weight_quantizer arrives as None while weight is already a QuantizedTensor. In that path, quantize_weight immediately dereferences quantizer.rowwise_usage (line 710 of base.py) and will raise AttributeError: 'NoneType' object has no attribute 'rowwise_usage'.
The new code only re-assigns weight_quantizer when it is already non-None, so the previously guarded scenario now crashes instead of falling back to the weight's own quantizer. The missing assignment should be:
if weight_quantizer is not None:
if isinstance(weight, QuantizedTensor) and not debug:
weight_quantizer = weight._quantizer
columnwise_usage = ...
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
elif isinstance(weight, QuantizedTensor):
# weight_quantizer is None but weight is pre-quantized — pick up its quantizer
weight_quantizer = weight._quantizer| columnwise_usage = is_backward_pass or torch.is_grad_enabled() | ||
| sharded_tensors = (self._rowwise_data, rowwise_scale_inv) | ||
| columnwise_usage = self._quantizer.columnwise_usage | ||
| if columnwise_usage: | ||
| if self._columnwise_data is None or columnwise_scale_inv is None: | ||
| raise RuntimeError( | ||
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | ||
| "for the upcoming backward pass, but the local shard has none. " | ||
| "Ensure the weight is quantized with columnwise_usage=True before " | ||
| "this all-gather." | ||
| ) |
There was a problem hiding this comment.
torch.is_grad_enabled() can be True during eval without torch.no_grad()
model.eval() alone does not disable the gradient tape — torch.is_grad_enabled() stays True unless the caller wraps the eval loop with torch.no_grad(). In that situation columnwise_usage becomes True, but the local shard may not have _columnwise_data (it was never quantized with columnwise support during eval), so the new RuntimeError fires.
Users who ran eval with grads enabled previously got silently incorrect (stale) data; they now get a hard crash. While the crash is more correct, the error message could guide them:
| columnwise_usage = is_backward_pass or torch.is_grad_enabled() | |
| sharded_tensors = (self._rowwise_data, rowwise_scale_inv) | |
| columnwise_usage = self._quantizer.columnwise_usage | |
| if columnwise_usage: | |
| if self._columnwise_data is None or columnwise_scale_inv is None: | |
| raise RuntimeError( | |
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | |
| "for the upcoming backward pass, but the local shard has none. " | |
| "Ensure the weight is quantized with columnwise_usage=True before " | |
| "this all-gather." | |
| ) | |
| if self._columnwise_data is None or columnwise_scale_inv is None: | |
| raise RuntimeError( | |
| "FSDP2 (reshard_after_forward=False) needs columnwise MXFP8 data " | |
| "for the upcoming backward pass, but the local shard has none. " | |
| "Ensure the weight is quantized with columnwise_usage=True before " | |
| "this all-gather. If you are running evaluation without requiring " | |
| "gradients, wrap the eval loop with torch.no_grad()." | |
| ) |
Description
This PR sets columnwise usage correctly for all quantizers instead of retaining the value in the quantizer, which may be incorrect after resuming training post validation steps as the columnwise usage is set to
Falsefor eval mode.Type of change
Changes
Checklist: