From a2df6f83dac390ffa852b1508aae6ed0b458369b Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 23 Apr 2026 18:50:01 +0000 Subject: [PATCH 1/3] Add distributed Muon optimizer Signed-off-by: Vladimir Cherepanov --- .../pytorch/distributed/run_muon_optimizer.py | 218 ++++++++++++++++ .../distributed/test_muon_optimizer.py | 52 ++++ .../pytorch/optimizers/__init__.py | 1 + transformer_engine/pytorch/optimizers/muon.py | 234 ++++++++++++++++++ 4 files changed, 505 insertions(+) create mode 100644 tests/pytorch/distributed/run_muon_optimizer.py create mode 100644 tests/pytorch/distributed/test_muon_optimizer.py create mode 100644 transformer_engine/pytorch/optimizers/muon.py diff --git a/tests/pytorch/distributed/run_muon_optimizer.py b/tests/pytorch/distributed/run_muon_optimizer.py new file mode 100644 index 0000000000..005bdd6ec6 --- /dev/null +++ b/tests/pytorch/distributed/run_muon_optimizer.py @@ -0,0 +1,218 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Muon optimizer test worker. + +Launched via torchrun from test_muon_optimizer.py. +""" + +import argparse +import sys + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.newton_schulz import get_coefficients +from transformer_engine.pytorch.optimizers.muon import get_muon_scale_factor + + +def _reference_orthogonalize( + grad: torch.Tensor, + *, + partition_dim: int, + world_size: int, + coefficients: list[tuple[float, float, float]], + scale_mode: str, + extra_scale_factor: float, + eps: float, +) -> torch.Tensor: + global_shape = [grad.size(0), grad.size(1)] + global_shape[partition_dim] *= world_size + + x = grad.clone() + if partition_dim == 0: + x = x.mT.contiguous() + + x = x / torch.sqrt((x.float() * x.float()).sum()).clamp_min(eps).to(dtype=x.dtype) + + for a, b, c in coefficients: + xxt = x @ x.mT + x = a * x + b * (xxt @ x) + c * ((xxt @ xxt) @ x) + + if partition_dim == 0: + x = x.mT.contiguous() + + scale = get_muon_scale_factor(global_shape[0], global_shape[1], mode=scale_mode) + return x * (scale * extra_scale_factor) + + +def _reference_step( + param: torch.Tensor, + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + *, + lr: float, + momentum: float, + nesterov: bool, + weight_decay: float, + use_decoupled_weight_decay: bool, + partition_dim: int, + world_size: int, + coefficients: list[tuple[float, float, float]], + scale_mode: str, + extra_scale_factor: float, + eps: float, +) -> tuple[torch.Tensor, torch.Tensor]: + param = param.clone() + grad = grad.clone() + momentum_buffer = momentum_buffer.clone() + + if use_decoupled_weight_decay: + param = param * (1.0 - lr * weight_decay) + elif weight_decay != 0: + grad = grad + weight_decay * param + + momentum_buffer = momentum * momentum_buffer + (1.0 - momentum) * grad + if nesterov: + update = (1.0 - momentum) * grad + momentum * momentum_buffer + else: + update = momentum_buffer + + orth_update = _reference_orthogonalize( + update, + partition_dim=partition_dim, + world_size=world_size, + coefficients=coefficients, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + eps=eps, + ) + param = param - lr * orth_update + return param, momentum_buffer + + +@record +def main(): + parser = argparse.ArgumentParser(description="Distributed Muon optimizer test") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) + parser.add_argument("--partition-dim", type=int, default=1, choices=[0, 1]) + parser.add_argument("--weight-decay-mode", type=str, default="decoupled", choices=["decoupled", "l2"]) + parser.add_argument("--num-steps", type=int, default=2) + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 + if args.partition_dim == 0: + full_shape = (world_size * 64, 96) + else: + full_shape = (96, world_size * 64) + + lr = 3e-4 + momentum = 0.95 + nesterov = True + weight_decay = 0.01 + use_decoupled_weight_decay = args.weight_decay_mode == "decoupled" + coefficient_type = "quintic" + num_ns_steps = 5 + scale_mode = "spectral" + extra_scale_factor = 1.0 + eps = 1e-7 + coefficients = get_coefficients(num_ns_steps, coefficient_type) + + if rank == 0: + torch.manual_seed(1234) + full_param = torch.randn(full_shape, device="cuda", dtype=dtype) + full_grads = [ + torch.randn(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps) + ] + else: + full_param = torch.empty(full_shape, device="cuda", dtype=dtype) + full_grads = [ + torch.empty(full_shape, device="cuda", dtype=dtype) for _ in range(args.num_steps) + ] + + dist.broadcast(full_param, src=0) + for grad in full_grads: + dist.broadcast(grad, src=0) + + shard_size = full_shape[args.partition_dim] // world_size + shard_slice = slice(rank * shard_size, (rank + 1) * shard_size) + if args.partition_dim == 0: + local_param_init = full_param[shard_slice, :].contiguous() + else: + local_param_init = full_param[:, shard_slice].contiguous() + + param = torch.nn.Parameter(local_param_init.clone()) + optimizer = te.optimizers.MuonOptimizer( + [param], + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + use_decoupled_weight_decay=use_decoupled_weight_decay, + coefficient_type=coefficient_type, + num_ns_steps=num_ns_steps, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + process_group=dist.group.WORLD, + partition_dim=args.partition_dim, + eps=eps, + ) + + ref_param = full_param.float() + ref_momentum = torch.zeros_like(ref_param) + for full_grad in full_grads: + if args.partition_dim == 0: + param.grad = full_grad[shard_slice, :].contiguous() + else: + param.grad = full_grad[:, shard_slice].contiguous() + optimizer.step() + + ref_param, ref_momentum = _reference_step( + ref_param, + full_grad.float(), + ref_momentum, + lr=lr, + momentum=momentum, + nesterov=nesterov, + weight_decay=weight_decay, + use_decoupled_weight_decay=use_decoupled_weight_decay, + partition_dim=args.partition_dim, + world_size=world_size, + coefficients=coefficients, + scale_mode=scale_mode, + extra_scale_factor=extra_scale_factor, + eps=eps, + ) + + gathered = [torch.empty_like(param) for _ in range(world_size)] + dist.all_gather(gathered, param) + if args.partition_dim == 0: + test_param = torch.cat(gathered, dim=0) + else: + test_param = torch.cat(gathered, dim=1) + + if rank == 0: + expected = ref_param.to(dtype) + atol, rtol = (5e-2, 5e-2) if dtype == torch.bfloat16 else (2e-3, 2e-3) + if torch.allclose(test_param, expected, atol=atol, rtol=rtol): + print("MUON OPTIMIZER CHECK PASSED", flush=True) + else: + max_diff = (test_param - expected).abs().max().item() + print(f"Max |optimizer - reference|: {max_diff:.6e}", flush=True) + print("MUON OPTIMIZER CHECK FAILED", flush=True, file=sys.stderr) + sys.exit(1) + + optimizer.destroy() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/test_muon_optimizer.py b/tests/pytorch/distributed/test_muon_optimizer.py new file mode 100644 index 0000000000..6fdac25fc6 --- /dev/null +++ b/tests/pytorch/distributed/test_muon_optimizer.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed Muon optimizer.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +if torch.cuda.device_count() < 2: + pytest.skip("Muon optimizer tests require at least 2 GPUs.", allow_module_level=True) + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(dtype: str, partition_dim: int, weight_decay_mode: str) -> None: + test_path = TEST_ROOT / "run_muon_optimizer.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--dtype={dtype}", + f"--partition-dim={partition_dim}", + f"--weight-decay-mode={weight_decay_mode}", + ] + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300) + if ( + result.returncode != 0 + or "MUON OPTIMIZER CHECK FAILED" in result.stderr.decode() + or "MUON OPTIMIZER CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError( + "Muon optimizer test failed.\n" + f"stdout: {result.stdout.decode()}\n" + f"stderr: {result.stderr.decode()}" + ) + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +@pytest.mark.parametrize("partition_dim", [0, 1]) +def test_muon_optimizer_matches_reference(dtype: str, partition_dim: int) -> None: + """Compare distributed Muon updates with a full-matrix reference.""" + _run_test(dtype, partition_dim, "decoupled") + + +def test_muon_optimizer_l2_weight_decay() -> None: + """Exercise the L2 weight decay branch against the same reference.""" + _run_test("float32", 1, "l2") diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index 7220f1924a..c643d32287 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -16,4 +16,5 @@ ) from .fused_adam import FusedAdam from .fused_sgd import FusedSGD +from .muon import MuonOptimizer, get_muon_scale_factor from .multi_tensor_apply import MultiTensorApply, multi_tensor_applier diff --git a/transformer_engine/pytorch/optimizers/muon.py b/transformer_engine/pytorch/optimizers/muon.py new file mode 100644 index 0000000000..e4125c3391 --- /dev/null +++ b/transformer_engine/pytorch/optimizers/muon.py @@ -0,0 +1,234 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Muon optimizer backed by distributed Newton-Schulz orthogonalization.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Literal, Optional + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from transformer_engine.pytorch.newton_schulz import ( + CusolverMpCtx, + NSCoeffT, + get_coefficients, + newton_schulz, +) + + +MuonScaleT = Literal["shape_scaling", "spectral", "unit_rms_norm"] + + +def get_muon_scale_factor(size_out: int, size_in: int, mode: MuonScaleT = "spectral") -> float: + """Return the Muon update scale factor for the logical matrix shape.""" + if mode == "shape_scaling": + return max(1, size_out / size_in) ** 0.5 + if mode == "spectral": + return max(size_out, size_in) ** 0.5 + if mode == "unit_rms_norm": + return (size_out / size_in) ** 0.5 + raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") + + +class MuonOptimizer(Optimizer): + """Distributed Muon optimizer for 2D CUDA parameters. + + This optimizer applies SGD-momentum followed by Newton-Schulz orthogonalization + on tensor-parallel parameter shards. The local parameter shard must represent a + partition of a logical 2D matrix across the provided NCCL process group. + + Args: + params: Iterable of parameters or parameter group dicts. + lr: Learning rate. + momentum: Momentum coefficient. + nesterov: Whether to use Nesterov momentum. + weight_decay: Weight decay coefficient. + use_decoupled_weight_decay: Whether to apply decoupled weight decay. + coefficient_type: Newton-Schulz coefficient schedule. + num_ns_steps: Number of Newton-Schulz iterations. + scale_mode: Muon update scale mode. + extra_scale_factor: Extra multiplicative scale applied after orthogonalization. + process_group: NCCL process group for distributed Newton-Schulz. Defaults to world. + partition_dim: Dimension along which each logical 2D parameter is partitioned. + Must be 0 or 1. + eps: Lower bound for the distributed normalization denominator. + """ + + def __init__( + self, + params: Iterable[torch.nn.Parameter | dict], + lr: float = 3e-4, + momentum: float = 0.95, + nesterov: bool = True, + weight_decay: float = 0.01, + *, + use_decoupled_weight_decay: bool = True, + coefficient_type: NSCoeffT = "quintic", + num_ns_steps: int = 5, + scale_mode: MuonScaleT = "spectral", + extra_scale_factor: float = 1.0, + process_group: Optional[dist.ProcessGroup] = None, + partition_dim: int = 1, + eps: float = 1e-7, + ) -> None: + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0 or momentum >= 1.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if num_ns_steps < 1: + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") + if partition_dim not in (0, 1): + raise ValueError(f"partition_dim must be 0 or 1, got {partition_dim}") + get_coefficients(num_ns_steps, coefficient_type) + + if process_group is None: + if not dist.is_initialized(): + raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") + process_group = dist.group.WORLD + if dist.get_backend(process_group) != "nccl": + raise RuntimeError("MuonOptimizer requires an NCCL process group.") + + defaults = { + "lr": lr, + "momentum": momentum, + "nesterov": nesterov, + "weight_decay": weight_decay, + "use_decoupled_weight_decay": use_decoupled_weight_decay, + "coefficient_type": coefficient_type, + "num_ns_steps": num_ns_steps, + "scale_mode": scale_mode, + "extra_scale_factor": extra_scale_factor, + "partition_dim": partition_dim, + "eps": eps, + } + super().__init__(params, defaults) + self.process_group = process_group + self._ns_ctx: CusolverMpCtx | None = None + + def __del__(self) -> None: + self.destroy() + + def destroy(self) -> None: + """Release the underlying cuSolverMp context.""" + if self._ns_ctx is not None: + self._ns_ctx.destroy() + self._ns_ctx = None + + def _get_ctx(self) -> CusolverMpCtx: + if self._ns_ctx is None: + self._ns_ctx = CusolverMpCtx(self.process_group) + return self._ns_ctx + + @staticmethod + def _validate_param(param: torch.Tensor, partition_dim: int) -> None: + if param.ndim != 2: + raise ValueError("MuonOptimizer only supports 2D parameters.") + if not param.is_cuda: + raise ValueError("MuonOptimizer only supports CUDA parameters.") + if param.dtype not in (torch.float32, torch.bfloat16): + raise ValueError( + f"MuonOptimizer requires float32 or bfloat16 parameters, got {param.dtype}." + ) + if param.size(partition_dim) == 0: + raise ValueError("MuonOptimizer does not support empty tensor-parallel shards.") + + def _distributed_normalize_p2_( + self, + x: torch.Tensor, + eps: float, + ) -> None: + norm_sq = (x.float() * x.float()).sum() + dist.all_reduce(norm_sq, op=dist.ReduceOp.SUM, group=self.process_group) + x.div_(torch.sqrt(norm_sq).clamp_min(eps).to(dtype=x.dtype)) + + def _orthogonalize( + self, + grad: torch.Tensor, + *, + partition_dim: int, + coefficient_type: NSCoeffT, + num_ns_steps: int, + scale_mode: MuonScaleT, + extra_scale_factor: float, + eps: float, + ) -> torch.Tensor: + self._validate_param(grad, partition_dim) + world_size = dist.get_world_size(self.process_group) + global_shape = [grad.size(0), grad.size(1)] + global_shape[partition_dim] *= world_size + + orth_grad = grad.clone() + transposed = partition_dim == 0 + if transposed: + orth_grad = orth_grad.mT.contiguous() + else: + orth_grad = orth_grad.contiguous() + + self._distributed_normalize_p2_(orth_grad, eps) + coefficients = get_coefficients(num_ns_steps, coefficient_type) + newton_schulz(orth_grad, self._get_ctx(), num_ns_steps, coefficients=coefficients) + + if transposed: + orth_grad = orth_grad.mT.contiguous() + + scale_factor = get_muon_scale_factor(global_shape[0], global_shape[1], mode=scale_mode) + orth_grad.mul_(scale_factor * extra_scale_factor) + return orth_grad + + @torch.no_grad() + def step(self, closure=None): + """Perform a single optimization step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + self._validate_param(p, group["partition_dim"]) + grad = p.grad + if grad.dtype != p.dtype: + raise ValueError( + f"Gradient dtype {grad.dtype} must match parameter dtype {p.dtype}." + ) + if grad.shape != p.shape: + raise ValueError("Gradient shape must match parameter shape.") + + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(p) + + if group["use_decoupled_weight_decay"]: + p.mul_(1.0 - group["lr"] * group["weight_decay"]) + elif group["weight_decay"] != 0: + grad = grad.add(p, alpha=group["weight_decay"]) + + momentum_buffer = state["momentum_buffer"] + momentum_buffer.lerp_(grad, 1.0 - group["momentum"]) + + if group["nesterov"]: + update = grad.lerp(momentum_buffer, group["momentum"]) + else: + update = momentum_buffer + + orth_update = self._orthogonalize( + update, + partition_dim=group["partition_dim"], + coefficient_type=group["coefficient_type"], + num_ns_steps=group["num_ns_steps"], + scale_mode=group["scale_mode"], + extra_scale_factor=group["extra_scale_factor"], + eps=group["eps"], + ) + p.add_(orth_update, alpha=-group["lr"]) + + return loss From e332a8eb7e4fa26ccaa527ffef0329cc71f3a198 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:16:24 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/distributed/run_muon_optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/distributed/run_muon_optimizer.py b/tests/pytorch/distributed/run_muon_optimizer.py index 005bdd6ec6..f9f1678568 100644 --- a/tests/pytorch/distributed/run_muon_optimizer.py +++ b/tests/pytorch/distributed/run_muon_optimizer.py @@ -99,7 +99,9 @@ def main(): parser = argparse.ArgumentParser(description="Distributed Muon optimizer test") parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) parser.add_argument("--partition-dim", type=int, default=1, choices=[0, 1]) - parser.add_argument("--weight-decay-mode", type=str, default="decoupled", choices=["decoupled", "l2"]) + parser.add_argument( + "--weight-decay-mode", type=str, default="decoupled", choices=["decoupled", "l2"] + ) parser.add_argument("--num-steps", type=int, default=2) args = parser.parse_args() From 1304712a083bac7e425bf57e70d94ff3ebf2c52c Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Thu, 23 Apr 2026 19:40:42 +0000 Subject: [PATCH 3/3] Fix Muon closure and reference test Signed-off-by: Vladimir Cherepanov --- tests/pytorch/distributed/run_muon_optimizer.py | 5 ----- transformer_engine/pytorch/optimizers/muon.py | 3 ++- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/distributed/run_muon_optimizer.py b/tests/pytorch/distributed/run_muon_optimizer.py index f9f1678568..8725dbf5ff 100644 --- a/tests/pytorch/distributed/run_muon_optimizer.py +++ b/tests/pytorch/distributed/run_muon_optimizer.py @@ -23,14 +23,12 @@ def _reference_orthogonalize( grad: torch.Tensor, *, partition_dim: int, - world_size: int, coefficients: list[tuple[float, float, float]], scale_mode: str, extra_scale_factor: float, eps: float, ) -> torch.Tensor: global_shape = [grad.size(0), grad.size(1)] - global_shape[partition_dim] *= world_size x = grad.clone() if partition_dim == 0: @@ -60,7 +58,6 @@ def _reference_step( weight_decay: float, use_decoupled_weight_decay: bool, partition_dim: int, - world_size: int, coefficients: list[tuple[float, float, float]], scale_mode: str, extra_scale_factor: float, @@ -84,7 +81,6 @@ def _reference_step( orth_update = _reference_orthogonalize( update, partition_dim=partition_dim, - world_size=world_size, coefficients=coefficients, scale_mode=scale_mode, extra_scale_factor=extra_scale_factor, @@ -187,7 +183,6 @@ def main(): weight_decay=weight_decay, use_decoupled_weight_decay=use_decoupled_weight_decay, partition_dim=args.partition_dim, - world_size=world_size, coefficients=coefficients, scale_mode=scale_mode, extra_scale_factor=extra_scale_factor, diff --git a/transformer_engine/pytorch/optimizers/muon.py b/transformer_engine/pytorch/optimizers/muon.py index e4125c3391..5c174eca5f 100644 --- a/transformer_engine/pytorch/optimizers/muon.py +++ b/transformer_engine/pytorch/optimizers/muon.py @@ -187,7 +187,8 @@ def step(self, closure=None): """Perform a single optimization step.""" loss = None if closure is not None: - loss = closure() + with torch.enable_grad(): + loss = closure() for group in self.param_groups: for p in group["params"]: