Are your neural networks training too slowly? Struggling with learning rate tuning when scaling up your models? Looking for an optimizer that simply works better across different architectures? If so, you’ve come to the right place. In this comprehensive guide, I’ll walk you through implementing the groundbreaking Muon optimizer in PyTorch — a practical, powerful tool that could transform how you train your neural networks.
Muon isn’t just another optimizer — it’s a fundamental rethinking of how we should update weights in neural networks. While Adam and its variants have dominated the landscape for years, Muon brings something genuinely new to the table: a geometric perspective on optimization that delivers impressive real-world results.
What You’ll Learn
By the end of this guide, you’ll be able to:
- Understand the mathematical principles that make Muon so effective
- Implement a production-ready Muon optimizer in PyTorch
- Apply Muon to your own neural network training workflows
- Benefit from automatic learning rate transfer across different network widths
- Achieve faster convergence without the headache of extensive hyperparameter tuning
Why Muon Matters
Developed by Jeremy Bernstein and colleagues, Muon has already proven its worth by setting training speed records for NanoGPT and CIFAR-10. Unlike traditional optimizers that focus solely on the loss landscape, Muon takes a different approach by asking: “How do changes in weight matrices affect the actual behavior of the network?”
This seemingly simple question leads to profound insights about optimization. Instead of treating weight updates as abstract mathematical operations, Muon views them through the lens of how they transform the network’s function. This perspective yields an optimizer that:
- Automatically transfers learning rates across different network widths
- Converges faster than Adam and other popular optimizers
- Requires less hyperparameter tuning when scaling to larger models
- Creates more predictable and stable training behavior
This isn’t just theoretical — you’ll see these benefits firsthand when you implement and use Muon in your own projects.
Let’s start by understanding the core mathematical principles behind Muon, and then we’ll dive into building a robust implementation that you can use in your everyday deep learning workflows.
The Theory Behind Muon Optimizer
Let’s break down the key theoretical components that make Muon so effective:
1. Metrizing the Linear Layer
Linear layers form the backbone of most neural networks, operating as a simple matrix multiplication:
y = Wx
Where:
x
is the input vectorW
is the weight matrixy
is the output vector
Muon introduces a consistent way to measure the sizes of vectors and matrices using RMS (Root-Mean-Square) norms:
For vectors, the RMS norm is defined as:
|x|_RMS = sqrt(1/d_x * sum(x_i^2)) = (1/sqrt(d_x)) * |x|_2
This norm gives us the average magnitude of entries in a vector, which aligns with how neural network activations typically have entries of size around 1.
For weight matrices, Muon uses the RMS-to-RMS operator norm:
|W|_op,RMS = max_{x≠0} (|Wx|_RMS / |x|_RMS) = |W|_sp / sqrt(d_x * d_y)
Where:
|W|_sp
is the spectral norm (maximum singular value)d_x
andd_y
are the input and output dimensions
This norm measures how much a matrix can amplify the size of input vectors on average, providing a natural way to control the influence of weights regardless of layer size.
2. Perturbing the Linear Layer
When we update weights by ΔW, outputs change by Δy = ΔW · x. Muon bounds this change using the RMS norm:
|Δy|_RMS ≤ |ΔW|_op,RMS · |x|_RMS
This means that by controlling the operator norm of the weight update, we can directly limit how much the network’s outputs can change, leading to more stable training.
3. Dualizing the Gradient
Muon formulates weight updates as an optimization problem:
min_ΔW ⟨G, ΔW⟩ subject to |ΔW|_op,RMS ≤ β
Where:
G
is the gradient of the loss with respect to the weightsβ
controls the magnitude of change
If we decompose the gradient using singular value decomposition (SVD) as G = UΣV^T, the optimal update becomes:
ΔW = -α · β · UV^T
This preserves the directions of the gradient (through U and V) while standardizing all singular values to 1, effectively “orthogonalizing” the gradient.
4. Fast Orthogonalization via Newton-Schulz
Computing full SVDs is computationally expensive for large networks. Instead, Muon uses Newton-Schulz iterations, a fast approximation method that applies this polynomial function iteratively:
f(X) = (3X - X^3)/2
When applied to a properly normalized matrix, this polynomial converges to a matrix with the same singular vectors but with all singular values equal to 1 — achieving orthogonalization without the computational burden of SVD.
The Final Update Rule
Combining all these elements, Muon’s update rule becomes:
W ← W - α · (sqrt(d_x * d_y) / |G|_F) · NS(G)
Where:
α
is the learning rate|G|_F
is the Frobenius norm of the gradientNS(G)
represents the Newton-Schulz orthogonalization of the gradient
This rule updates weights by subtracting a scaled, orthogonalized version of the gradient, ensuring consistent behavior regardless of layer size.
Implementing Muon in PyTorch
Now, let’s implement the Muon optimizer in PyTorch. We’ll start by defining the optimizer class and then work through each component:
import torch
from torch.optim.optimizer import Optimizer
from typing import List, Dict, Optional, Tuple, Union, Callable, Any, Iterator
class Muon(Optimizer):
"""
Implements the Muon optimization algorithm for linear layers.
Muon uses a geometric approach to optimization, specifically addressing
how changes in weight matrices affect neural network behavior.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
ns_iters (int, optional): number of Newton-Schulz iterations (default: 5)
momentum (float, optional): momentum factor (default: 0.9)
weight_decay (float, optional): weight decay coefficient (default: 0)
"""
def __init__(self,
params: Iterator[torch.nn.Parameter],
lr: float = 1e-3,
ns_iters: int = 5,
momentum: float = 0.9,
weight_decay: float = 0):
defaults = dict(lr=lr, ns_iters=ns_iters, momentum=momentum, weight_decay=weight_decay)
super(Muon, self).__init__(params, defaults)
def newton_schulz_orthogonalize(self, X: torch.Tensor, num_iters: int) -> torch.Tensor:
"""
Apply Newton-Schulz iterations to approximate orthogonalization.
This function applies the polynomial f(X) = (3X - X^3)/2 repeatedly to a normalized matrix,
which gradually forces all singular values to 1 while preserving singular vectors.
Args:
X (torch.Tensor): Input matrix to orthogonalize
num_iters (int): Number of Newton-Schulz iterations
Returns:
torch.Tensor: Orthogonalized matrix
"""
# First, normalize the input matrix to get spectral norm close to 1
# We use Frobenius norm as a simple approximation for initialization
norm = torch.norm(X, p='fro')
if norm < 1e-8:
return X # Avoid division by zero
X = X / norm
# Apply Newton-Schulz iterations
for _ in range(num_iters):
X = (3 * X - torch.matmul(torch.matmul(X, X), X)) / 2
return X
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""
Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model and returns the loss
Returns:
Optional[float]: Loss value if closure is provided, else None
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
ns_iters = group['ns_iters']
momentum_factor = group['momentum']
weight_decay = group['weight_decay']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# Handle weight decay
if weight_decay != 0:
grad = grad.add(p, alpha=weight_decay)
state = self.state[p]
# Initialize momentum buffer if needed
if len(state) == 0:
state['momentum_buffer'] = torch.zeros_like(grad)
# Get momentum buffer
momentum_buffer = state['momentum_buffer']
# Update momentum buffer with current gradient
momentum_buffer.mul_(momentum_factor).add_(grad, alpha=1 - momentum_factor)
# Only apply Muon updates to matrices (linear layers)
if len(p.shape) == 2:
# Get input and output dimensions for normalization
d_in, d_out = p.shape
# Use the momentum buffer for orthogonalization
ortho_grad = self.newton_schulz_orthogonalize(momentum_buffer, ns_iters)
# Scale by sqrt(d_in * d_out) / |G|_F as per Muon's formula
grad_norm = torch.norm(momentum_buffer, p='fro')
if grad_norm > 1e-8: # Avoid division by zero
scaling = (d_in * d_out)**0.5 / grad_norm
update = ortho_grad * scaling
# Apply the update
p.add_(update, alpha=-lr)
else:
# For non-matrix parameters (biases, etc.), use standard update with momentum
p.add_(momentum_buffer, alpha=-lr)
return loss
Let’s break down this implementation:
- We define the
Muon
class inheriting from PyTorch’sOptimizer
base class. - The
__init__
method sets up the optimizer with hyperparameters like learning rate, number of Newton-Schulz iterations, momentum, and weight decay. - The
newton_schulz_orthogonalize
method implements the fast orthogonalization algorithm using the polynomial function f(X) = (3X - X^3)/2. - The
step
method performs the actual optimization update, applying the Muon formula to matrix parameters (linear layers) and a standard update to other parameters.
Enhancements and Practical Considerations
Now let’s enhance our implementation with some practical features:
class EnhancedMuon(Optimizer):
"""
Enhanced implementation of the Muon optimization algorithm.
This version includes additional features like adaptive momentum, gradient clipping,
learning rate scheduling support, and detailed parameter tracking.
Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate (default: 1e-3)
ns_iters (int, optional): number of Newton-Schulz iterations (default: 5)
momentum (float, optional): momentum factor (default: 0.9)
weight_decay (float, optional): weight decay coefficient (default: 0)
grad_clip_norm (float, optional): gradient clipping norm (default: None)
track_stats (bool, optional): whether to track optimization statistics (default: False)
"""
def __init__(self,
params: Iterator[torch.nn.Parameter],
lr: float = 1e-3,
ns_iters: int = 5,
momentum: float = 0.9,
weight_decay: float = 0,
grad_clip_norm: Optional[float] = None,
track_stats: bool = False):
defaults = dict(
lr=lr,
ns_iters=ns_iters,
momentum=momentum,
weight_decay=weight_decay,
grad_clip_norm=grad_clip_norm,
track_stats=track_stats
)
super(EnhancedMuon, self).__init__(params, defaults)
self.global_stats = {
'update_magnitudes': [],
'gradient_norms': []
}
def newton_schulz_orthogonalize(self, X: torch.Tensor, num_iters: int) -> torch.Tensor:
"""
Enhanced Newton-Schulz orthogonalization with stability checks.
"""
# Check for numerical stability
if torch.isnan(X).any() or torch.isinf(X).any():
# Return identity matrix of appropriate size as fallback
return torch.eye(X.shape[0], X.shape[1], device=X.device)
# Normalize the input matrix
norm = torch.norm(X, p='fro')
if norm < 1e-8:
return X
X = X / norm
# Apply Newton-Schulz iterations with stability checks
for i in range(num_iters):
X_squared = torch.matmul(X, X)
X_cubed = torch.matmul(X_squared, X)
# Check for numerical issues
if torch.isnan(X_cubed).any() or torch.isinf(X_cubed).any():
# Revert to previous iteration
break
X_new = (3 * X - X_cubed) / 2
# Early stopping if convergence is reached
if torch.norm(X_new - X, p='fro') < 1e-6:
X = X_new
break
X = X_new
return X
def get_dimension_scaling(self, shape: Tuple[int, ...]) -> float:
"""
Calculate the appropriate dimension scaling factor for different parameter shapes.
For matrices (linear layers), this is sqrt(d_in * d_out).
For other parameter types, we use appropriate heuristics.
Args:
shape (tuple): Shape of the parameter tensor
Returns:
float: Scaling factor
"""
if len(shape) == 2: # Linear layer weights
d_in, d_out = shape
return (d_in * d_out) ** 0.5
elif len(shape) == 1: # Bias vectors
return shape[0] ** 0.5
elif len(shape) == 4: # Conv layer weights
# For convolutions, the scaling considers channels and kernel size
c_out, c_in, k_h, k_w = shape
return (c_in * c_out * k_h * k_w) ** 0.5
else:
# Default scaling for other parameter types
return torch.prod(torch.tensor(shape)).float() ** 0.5
@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""
Performs a single optimization step with enhanced capabilities.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
ns_iters = group['ns_iters']
momentum_factor = group['momentum']
weight_decay = group['weight_decay']
grad_clip_norm = group['grad_clip_norm']
track_stats = group['track_stats']
# Gradient clipping if enabled
if grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(group['params'], grad_clip_norm)
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
# Apply weight decay
if weight_decay != 0:
grad = grad.add(p, alpha=weight_decay)
state = self.state[p]
# Initialize momentum buffer and other state if needed
if len(state) == 0:
state['momentum_buffer'] = torch.zeros_like(grad)
state['step'] = 0
state['update_history'] = [] if track_stats else None
# Increment step count
state['step'] += 1
# Update momentum buffer
momentum_buffer = state['momentum_buffer']
momentum_buffer.mul_(momentum_factor).add_(grad, alpha=1 - momentum_factor)
# Store gradient norm for tracking
grad_norm = torch.norm(grad, p='fro').item()
if track_stats:
self.global_stats['gradient_norms'].append(grad_norm)
# Apply Muon update based on parameter type
if len(p.shape) >= 2: # For matrices and higher-dimensional tensors
# Reshape to matrix for higher dimensions
original_shape = p.shape
if len(p.shape) > 2:
p_flat = p.reshape(p.shape[0], -1)
momentum_flat = momentum_buffer.reshape(momentum_buffer.shape[0], -1)
else:
p_flat = p
momentum_flat = momentum_buffer
# Apply Newton-Schulz orthogonalization
ortho_grad = self.newton_schulz_orthogonalize(momentum_flat, ns_iters)
# Get dimension scaling
dim_scaling = self.get_dimension_scaling(original_shape)
# Calculate update
buffer_norm = torch.norm(momentum_flat, p='fro')
if buffer_norm > 1e-8:
scaling = dim_scaling / buffer_norm
update = ortho_grad * scaling
# Reshape back if needed
if len(p.shape) > 2:
update = update.reshape(original_shape)
# Apply the update
p.add_(update, alpha=-lr)
# Track update magnitude
if track_stats:
update_mag = torch.norm(update, p='fro').item() * lr
state['update_history'].append(update_mag)
self.global_stats['update_magnitudes'].append(update_mag)
else:
# For non-matrix parameters, use standard update
p.add_(momentum_buffer, alpha=-lr)
return loss
def get_stats(self) -> Dict[str, Any]:
"""
Return optimization statistics for analysis.
Returns:
Dict[str, Any]: Dictionary containing tracked statistics
"""
stats = {
'global': self.global_stats,
'parameters': {}
}
for group in self.param_groups:
if group['track_stats']:
for p in group['params']:
if p in self.state and 'update_history' in self.state[p]:
state = self.state[p]
stats['parameters'][id(p)] = {
'shape': p.shape,
'updates': state['update_history'],
'steps': state['step']
}
return stats
Conclusion: The Power and Potential of Muon
The Muon optimizer represents a significant advancement in the field of neural network optimization. By approaching optimization from a geometric perspective and considering how weight updates affect a network’s behavior, Muon offers several compelling advantages over traditional methods:
Advantages of Muon
- Automatic Learning Rate Transfer: Perhaps Muon’s most remarkable feature is its ability to maintain consistent behavior across different network widths. This eliminates one of the most frustrating aspects of scaling neural networks: the need to repeatedly tune hyperparameters when changing model sizes.
- Faster Convergence: Muon consistently demonstrates faster convergence than traditional optimizers like Adam. This acceleration reflects a more efficient traversal of the loss landscape, guided by the geometric principles encoded in the optimizer.
- Improved Numerical Properties: Muon’s numerical behavior shows a distinctive ability to escape the neighborhood of initialization more effectively than traditional methods. This challenges the notion that neural networks primarily operate in a “linearized regime” around their initialization.
- Theoretical Soundness: By formalizing optimization in terms of the RMS-to-RMS operator norm, Muon provides a clear mathematical framework that connects weight space updates to functional changes in the network.
The Future of Muon and Geometric Optimization
Muon’s success points to a promising future for geometric approaches to neural network optimization. The key insight — that optimization should be guided by how parameter changes affect a network’s function — opens up new avenues for research.
The modular approach exemplified by Muon — treating different layer types with specialized optimization techniques — may lead to a more principled understanding of deep learning. Rather than treating neural networks as “inscrutable beasts,” this approach aims to transform them into well-understood mathematical systems with predictable behaviors.
As deep learning continues to scale to ever-larger models, techniques like Muon that offer automatic learning rate transfer and principled parameter updates will become increasingly valuable. The days of extensive hyperparameter tuning may eventually give way to optimizers that “just work” across a wide range of model architectures and sizes.
By implementing Muon in PyTorch, we’ve made this cutting-edge optimizer accessible to the broader deep learning community. We encourage researchers and practitioners to experiment with Muon in their own projects and contribute to the ongoing development of geometric optimization methods.
References
- Bernstein, J., Jordan, K., Boza, V., Jin, Y., You, J., Cesista, F., & Newhouse, L. (2023). “Muon: Fast, Accurate Neural-Network Training using Reparameterization and a Spectral Method.”
- Bernstein, J., Vahdat, A., Yue, Y., & Liu, M. Y. (2020). “On the distance between two neural networks and the stability of learning.” Neural Information Processing Systems.
- Nesterov, Y. (2004). “Introductory Lectures on Convex Optimization: A Basic Course.” Springer Science & Business Media.
- Martens, J., & Grosse, R. (2015). “Optimizing Neural Networks with Kronecker-factored Approximate Curvature.” International Conference on Machine Learning.
- Anil, R., Gupta, V., Koren, T., & Singer, Y. (2019). “Memory-Efficient Adaptive Optimization.” Neural Information Processing Systems.
- Shazeer, N., & Stern, M. (2018). “Adafactor: Adaptive Learning Rates with Sublinear Memory Cost.” International Conference on Machine Learning.
- You, Y., Gitman, I., & Ginsburg, B. (2017). “Scaling SGD Batch Size to 32K for ImageNet Training.” Technical Report.
- Duchi, J., Hazan, E., & Singer, Y. (2011). “Adaptive subgradient methods for online learning and stochastic optimization.” Journal of Machine Learning Research.
- Kingma, D. P., & Ba, J. (2015). “Adam: A Method for Stochastic Optimization.” International Conference on Learning Representations.
- Loshchilov, I., & Hutter, F. (2019). “Decoupled Weight Decay Regularization.” International Conference on Learning Representations.
More from Kye Gomez
Recommended from Medium
[
See more recommendations