diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ebfb98b2d6..e6bedee0c0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -776,7 +776,7 @@ def __init__(self, name: Optional[str] = None) -> None: self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None self.fp8_meta_tensors_initialized = False - self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} + self.quantizers = {"scaling_fwd": [], "scaling_bwd": []} self.tp_group = None self.tp_size = 1 self.sequence_parallel = False