Are your neural networks training too slowly? Struggling with learning rate tuning when scaling up your models? Looking for an optimizer that simply works better across different architectures? If so, you’ve come to the right place. In this comprehensive guide, I’ll walk you through implementing the groundbreaking Muon optimizer in PyTorch — a practical, powerful tool that could transform how you train your neural networks.

Muon isn’t just another optimizer — it’s a fundamental rethinking of how we should update weights in neural networks. While Adam and its variants have dominated the landscape for years, Muon brings something genuinely new to the table: a geometric perspective on optimization that delivers impressive real-world results.

What You’ll Learn

By the end of this guide, you’ll be able to:

  • Understand the mathematical principles that make Muon so effective
  • Implement a production-ready Muon optimizer in PyTorch
  • Apply Muon to your own neural network training workflows
  • Benefit from automatic learning rate transfer across different network widths
  • Achieve faster convergence without the headache of extensive hyperparameter tuning

Why Muon Matters

Developed by Jeremy Bernstein and colleagues, Muon has already proven its worth by setting training speed records for NanoGPT and CIFAR-10. Unlike traditional optimizers that focus solely on the loss landscape, Muon takes a different approach by asking: “How do changes in weight matrices affect the actual behavior of the network?”

This seemingly simple question leads to profound insights about optimization. Instead of treating weight updates as abstract mathematical operations, Muon views them through the lens of how they transform the network’s function. This perspective yields an optimizer that:

  1. Automatically transfers learning rates across different network widths
  2. Converges faster than Adam and other popular optimizers
  3. Requires less hyperparameter tuning when scaling to larger models
  4. Creates more predictable and stable training behavior

This isn’t just theoretical — you’ll see these benefits firsthand when you implement and use Muon in your own projects.

Let’s start by understanding the core mathematical principles behind Muon, and then we’ll dive into building a robust implementation that you can use in your everyday deep learning workflows.

The Theory Behind Muon Optimizer

Let’s break down the key theoretical components that make Muon so effective:

1. Metrizing the Linear Layer

Linear layers form the backbone of most neural networks, operating as a simple matrix multiplication:

y = Wx

Where:

  • x is the input vector
  • W is the weight matrix
  • y is the output vector

Muon introduces a consistent way to measure the sizes of vectors and matrices using RMS (Root-Mean-Square) norms:

For vectors, the RMS norm is defined as:

|x|_RMS = sqrt(1/d_x * sum(x_i^2)) = (1/sqrt(d_x)) * |x|_2

This norm gives us the average magnitude of entries in a vector, which aligns with how neural network activations typically have entries of size around 1.

For weight matrices, Muon uses the RMS-to-RMS operator norm:

|W|_op,RMS = max_{x≠0} (|Wx|_RMS / |x|_RMS) = |W|_sp / sqrt(d_x * d_y)

Where:

  • |W|_sp is the spectral norm (maximum singular value)
  • d_x and d_y are the input and output dimensions

This norm measures how much a matrix can amplify the size of input vectors on average, providing a natural way to control the influence of weights regardless of layer size.

2. Perturbing the Linear Layer

When we update weights by ΔW, outputs change by Δy = ΔW · x. Muon bounds this change using the RMS norm:

|Δy|_RMS ≤ |ΔW|_op,RMS · |x|_RMS

This means that by controlling the operator norm of the weight update, we can directly limit how much the network’s outputs can change, leading to more stable training.

3. Dualizing the Gradient

Muon formulates weight updates as an optimization problem:

min_ΔW ⟨G, ΔW⟩ subject to |ΔW|_op,RMS ≤ β

Where:

  • G is the gradient of the loss with respect to the weights
  • β controls the magnitude of change

If we decompose the gradient using singular value decomposition (SVD) as G = UΣV^T, the optimal update becomes:

ΔW = -α · β · UV^T

This preserves the directions of the gradient (through U and V) while standardizing all singular values to 1, effectively “orthogonalizing” the gradient.

4. Fast Orthogonalization via Newton-Schulz

Computing full SVDs is computationally expensive for large networks. Instead, Muon uses Newton-Schulz iterations, a fast approximation method that applies this polynomial function iteratively:

f(X) = (3X - X^3)/2

When applied to a properly normalized matrix, this polynomial converges to a matrix with the same singular vectors but with all singular values equal to 1 — achieving orthogonalization without the computational burden of SVD.

The Final Update Rule

Combining all these elements, Muon’s update rule becomes:

W ← W - α · (sqrt(d_x * d_y) / |G|_F) · NS(G)

Where:

  • α is the learning rate
  • |G|_F is the Frobenius norm of the gradient
  • NS(G) represents the Newton-Schulz orthogonalization of the gradient

This rule updates weights by subtracting a scaled, orthogonalized version of the gradient, ensuring consistent behavior regardless of layer size.

Implementing Muon in PyTorch

Now, let’s implement the Muon optimizer in PyTorch. We’ll start by defining the optimizer class and then work through each component:

import torch
from torch.optim.optimizer import Optimizer
from typing import List, Dict, Optional, Tuple, Union, Callable, Any, Iterator
 
class Muon(Optimizer):
    """
    Implements the Muon optimization algorithm for linear layers.
    
    Muon uses a geometric approach to optimization, specifically addressing
    how changes in weight matrices affect neural network behavior.
    
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        ns_iters (int, optional): number of Newton-Schulz iterations (default: 5)
        momentum (float, optional): momentum factor (default: 0.9)
        weight_decay (float, optional): weight decay coefficient (default: 0)
    """
    
    def __init__(self, 
                 params: Iterator[torch.nn.Parameter], 
                 lr: float = 1e-3, 
                 ns_iters: int = 5, 
                 momentum: float = 0.9, 
                 weight_decay: float = 0):
        
        defaults = dict(lr=lr, ns_iters=ns_iters, momentum=momentum, weight_decay=weight_decay)
        super(Muon, self).__init__(params, defaults)
    
    def newton_schulz_orthogonalize(self, X: torch.Tensor, num_iters: int) -> torch.Tensor:
        """
        Apply Newton-Schulz iterations to approximate orthogonalization.
        
        This function applies the polynomial f(X) = (3X - X^3)/2 repeatedly to a normalized matrix,
        which gradually forces all singular values to 1 while preserving singular vectors.
        
        Args:
            X (torch.Tensor): Input matrix to orthogonalize
            num_iters (int): Number of Newton-Schulz iterations
            
        Returns:
            torch.Tensor: Orthogonalized matrix
        """
        # First, normalize the input matrix to get spectral norm close to 1
        # We use Frobenius norm as a simple approximation for initialization
        norm = torch.norm(X, p='fro')
        if norm < 1e-8:
            return X  # Avoid division by zero
        
        X = X / norm
        
        # Apply Newton-Schulz iterations
        for _ in range(num_iters):
            X = (3 * X - torch.matmul(torch.matmul(X, X), X)) / 2
            
        return X
    
    @torch.no_grad()
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """
        Performs a single optimization step.
        
        Args:
            closure (callable, optional): A closure that reevaluates the model and returns the loss
            
        Returns:
            Optional[float]: Loss value if closure is provided, else None
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            lr = group['lr']
            ns_iters = group['ns_iters']
            momentum_factor = group['momentum']
            weight_decay = group['weight_decay']
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad
                
                # Handle weight decay
                if weight_decay != 0:
                    grad = grad.add(p, alpha=weight_decay)
                
                state = self.state[p]
                
                # Initialize momentum buffer if needed
                if len(state) == 0:
                    state['momentum_buffer'] = torch.zeros_like(grad)
                
                # Get momentum buffer
                momentum_buffer = state['momentum_buffer']
                
                # Update momentum buffer with current gradient
                momentum_buffer.mul_(momentum_factor).add_(grad, alpha=1 - momentum_factor)
                
                # Only apply Muon updates to matrices (linear layers)
                if len(p.shape) == 2:
                    # Get input and output dimensions for normalization
                    d_in, d_out = p.shape
                    
                    # Use the momentum buffer for orthogonalization
                    ortho_grad = self.newton_schulz_orthogonalize(momentum_buffer, ns_iters)
                    
                    # Scale by sqrt(d_in * d_out) / |G|_F as per Muon's formula
                    grad_norm = torch.norm(momentum_buffer, p='fro')
                    if grad_norm > 1e-8:  # Avoid division by zero
                        scaling = (d_in * d_out)**0.5 / grad_norm
                        update = ortho_grad * scaling
                        
                        # Apply the update
                        p.add_(update, alpha=-lr)
                    
                else:
                    # For non-matrix parameters (biases, etc.), use standard update with momentum
                    p.add_(momentum_buffer, alpha=-lr)
        
        return loss

Let’s break down this implementation:

  1. We define the Muon class inheriting from PyTorch’s Optimizer base class.
  2. The __init__ method sets up the optimizer with hyperparameters like learning rate, number of Newton-Schulz iterations, momentum, and weight decay.
  3. The newton_schulz_orthogonalize method implements the fast orthogonalization algorithm using the polynomial function f(X) = (3X - X^3)/2.
  4. The step method performs the actual optimization update, applying the Muon formula to matrix parameters (linear layers) and a standard update to other parameters.

Enhancements and Practical Considerations

Now let’s enhance our implementation with some practical features:

class EnhancedMuon(Optimizer):
    """
    Enhanced implementation of the Muon optimization algorithm.
    
    This version includes additional features like adaptive momentum, gradient clipping,
    learning rate scheduling support, and detailed parameter tracking.
    
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        ns_iters (int, optional): number of Newton-Schulz iterations (default: 5)
        momentum (float, optional): momentum factor (default: 0.9)
        weight_decay (float, optional): weight decay coefficient (default: 0)
        grad_clip_norm (float, optional): gradient clipping norm (default: None)
        track_stats (bool, optional): whether to track optimization statistics (default: False)
    """
    
    def __init__(self, 
                 params: Iterator[torch.nn.Parameter], 
                 lr: float = 1e-3, 
                 ns_iters: int = 5, 
                 momentum: float = 0.9, 
                 weight_decay: float = 0,
                 grad_clip_norm: Optional[float] = None,
                 track_stats: bool = False):
        
        defaults = dict(
            lr=lr, 
            ns_iters=ns_iters, 
            momentum=momentum, 
            weight_decay=weight_decay,
            grad_clip_norm=grad_clip_norm,
            track_stats=track_stats
        )
        super(EnhancedMuon, self).__init__(params, defaults)
        
        self.global_stats = {
            'update_magnitudes': [],
            'gradient_norms': []
        }
    
    def newton_schulz_orthogonalize(self, X: torch.Tensor, num_iters: int) -> torch.Tensor:
        """
        Enhanced Newton-Schulz orthogonalization with stability checks.
        """
        # Check for numerical stability
        if torch.isnan(X).any() or torch.isinf(X).any():
            # Return identity matrix of appropriate size as fallback
            return torch.eye(X.shape[0], X.shape[1], device=X.device)
        
        # Normalize the input matrix
        norm = torch.norm(X, p='fro')
        if norm < 1e-8:
            return X
        
        X = X / norm
        
        # Apply Newton-Schulz iterations with stability checks
        for i in range(num_iters):
            X_squared = torch.matmul(X, X)
            X_cubed = torch.matmul(X_squared, X)
            
            # Check for numerical issues
            if torch.isnan(X_cubed).any() or torch.isinf(X_cubed).any():
                # Revert to previous iteration
                break
            
            X_new = (3 * X - X_cubed) / 2
            
            # Early stopping if convergence is reached
            if torch.norm(X_new - X, p='fro') < 1e-6:
                X = X_new
                break
                
            X = X_new
            
        return X
    
    def get_dimension_scaling(self, shape: Tuple[int, ...]) -> float:
        """
        Calculate the appropriate dimension scaling factor for different parameter shapes.
        
        For matrices (linear layers), this is sqrt(d_in * d_out).
        For other parameter types, we use appropriate heuristics.
        
        Args:
            shape (tuple): Shape of the parameter tensor
            
        Returns:
            float: Scaling factor
        """
        if len(shape) == 2:  # Linear layer weights
            d_in, d_out = shape
            return (d_in * d_out) ** 0.5
        elif len(shape) == 1:  # Bias vectors
            return shape[0] ** 0.5
        elif len(shape) == 4:  # Conv layer weights
            # For convolutions, the scaling considers channels and kernel size
            c_out, c_in, k_h, k_w = shape
            return (c_in * c_out * k_h * k_w) ** 0.5
        else:
            # Default scaling for other parameter types
            return torch.prod(torch.tensor(shape)).float() ** 0.5
    
    @torch.no_grad()
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        """
        Performs a single optimization step with enhanced capabilities.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            lr = group['lr']
            ns_iters = group['ns_iters']
            momentum_factor = group['momentum']
            weight_decay = group['weight_decay']
            grad_clip_norm = group['grad_clip_norm']
            track_stats = group['track_stats']
            
            # Gradient clipping if enabled
            if grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(group['params'], grad_clip_norm)
            
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad
                
                # Apply weight decay
                if weight_decay != 0:
                    grad = grad.add(p, alpha=weight_decay)
                
                state = self.state[p]
                
                # Initialize momentum buffer and other state if needed
                if len(state) == 0:
                    state['momentum_buffer'] = torch.zeros_like(grad)
                    state['step'] = 0
                    state['update_history'] = [] if track_stats else None
                
                # Increment step count
                state['step'] += 1
                
                # Update momentum buffer
                momentum_buffer = state['momentum_buffer']
                momentum_buffer.mul_(momentum_factor).add_(grad, alpha=1 - momentum_factor)
                
                # Store gradient norm for tracking
                grad_norm = torch.norm(grad, p='fro').item()
                if track_stats:
                    self.global_stats['gradient_norms'].append(grad_norm)
                
                # Apply Muon update based on parameter type
                if len(p.shape) >= 2:  # For matrices and higher-dimensional tensors
                    # Reshape to matrix for higher dimensions
                    original_shape = p.shape
                    if len(p.shape) > 2:
                        p_flat = p.reshape(p.shape[0], -1)
                        momentum_flat = momentum_buffer.reshape(momentum_buffer.shape[0], -1)
                    else:
                        p_flat = p
                        momentum_flat = momentum_buffer
                    
                    # Apply Newton-Schulz orthogonalization
                    ortho_grad = self.newton_schulz_orthogonalize(momentum_flat, ns_iters)
                    
                    # Get dimension scaling
                    dim_scaling = self.get_dimension_scaling(original_shape)
                    
                    # Calculate update
                    buffer_norm = torch.norm(momentum_flat, p='fro')
                    if buffer_norm > 1e-8:
                        scaling = dim_scaling / buffer_norm
                        update = ortho_grad * scaling
                        
                        # Reshape back if needed
                        if len(p.shape) > 2:
                            update = update.reshape(original_shape)
                        
                        # Apply the update
                        p.add_(update, alpha=-lr)
                        
                        # Track update magnitude
                        if track_stats:
                            update_mag = torch.norm(update, p='fro').item() * lr
                            state['update_history'].append(update_mag)
                            self.global_stats['update_magnitudes'].append(update_mag)
                    
                else:
                    # For non-matrix parameters, use standard update
                    p.add_(momentum_buffer, alpha=-lr)
        
        return loss
    
    def get_stats(self) -> Dict[str, Any]:
        """
        Return optimization statistics for analysis.
        
        Returns:
            Dict[str, Any]: Dictionary containing tracked statistics
        """
        stats = {
            'global': self.global_stats,
            'parameters': {}
        }
        
        for group in self.param_groups:
            if group['track_stats']:
                for p in group['params']:
                    if p in self.state and 'update_history' in self.state[p]:
                        state = self.state[p]
                        stats['parameters'][id(p)] = {
                            'shape': p.shape,
                            'updates': state['update_history'],
                            'steps': state['step']
                        }
        
        return stats

Conclusion: The Power and Potential of Muon

The Muon optimizer represents a significant advancement in the field of neural network optimization. By approaching optimization from a geometric perspective and considering how weight updates affect a network’s behavior, Muon offers several compelling advantages over traditional methods:

Advantages of Muon

  1. Automatic Learning Rate Transfer: Perhaps Muon’s most remarkable feature is its ability to maintain consistent behavior across different network widths. This eliminates one of the most frustrating aspects of scaling neural networks: the need to repeatedly tune hyperparameters when changing model sizes.
  2. Faster Convergence: Muon consistently demonstrates faster convergence than traditional optimizers like Adam. This acceleration reflects a more efficient traversal of the loss landscape, guided by the geometric principles encoded in the optimizer.
  3. Improved Numerical Properties: Muon’s numerical behavior shows a distinctive ability to escape the neighborhood of initialization more effectively than traditional methods. This challenges the notion that neural networks primarily operate in a “linearized regime” around their initialization.
  4. Theoretical Soundness: By formalizing optimization in terms of the RMS-to-RMS operator norm, Muon provides a clear mathematical framework that connects weight space updates to functional changes in the network.

The Future of Muon and Geometric Optimization

Muon’s success points to a promising future for geometric approaches to neural network optimization. The key insight — that optimization should be guided by how parameter changes affect a network’s function — opens up new avenues for research.

The modular approach exemplified by Muon — treating different layer types with specialized optimization techniques — may lead to a more principled understanding of deep learning. Rather than treating neural networks as “inscrutable beasts,” this approach aims to transform them into well-understood mathematical systems with predictable behaviors.

As deep learning continues to scale to ever-larger models, techniques like Muon that offer automatic learning rate transfer and principled parameter updates will become increasingly valuable. The days of extensive hyperparameter tuning may eventually give way to optimizers that “just work” across a wide range of model architectures and sizes.

By implementing Muon in PyTorch, we’ve made this cutting-edge optimizer accessible to the broader deep learning community. We encourage researchers and practitioners to experiment with Muon in their own projects and contribute to the ongoing development of geometric optimization methods.

References

  1. Bernstein, J., Jordan, K., Boza, V., Jin, Y., You, J., Cesista, F., & Newhouse, L. (2023). “Muon: Fast, Accurate Neural-Network Training using Reparameterization and a Spectral Method.”
  2. Bernstein, J., Vahdat, A., Yue, Y., & Liu, M. Y. (2020). “On the distance between two neural networks and the stability of learning.” Neural Information Processing Systems.
  3. Nesterov, Y. (2004). “Introductory Lectures on Convex Optimization: A Basic Course.” Springer Science & Business Media.
  4. Martens, J., & Grosse, R. (2015). “Optimizing Neural Networks with Kronecker-factored Approximate Curvature.” International Conference on Machine Learning.
  5. Anil, R., Gupta, V., Koren, T., & Singer, Y. (2019). “Memory-Efficient Adaptive Optimization.” Neural Information Processing Systems.
  6. Shazeer, N., & Stern, M. (2018). “Adafactor: Adaptive Learning Rates with Sublinear Memory Cost.” International Conference on Machine Learning.
  7. You, Y., Gitman, I., & Ginsburg, B. (2017). “Scaling SGD Batch Size to 32K for ImageNet Training.” Technical Report.
  8. Duchi, J., Hazan, E., & Singer, Y. (2011). “Adaptive subgradient methods for online learning and stochastic optimization.” Journal of Machine Learning Research.
  9. Kingma, D. P., & Ba, J. (2015). “Adam: A Method for Stochastic Optimization.” International Conference on Learning Representations.
  10. Loshchilov, I., & Hutter, F. (2019). “Decoupled Weight Decay Regularization.” International Conference on Learning Representations.

More from Kye Gomez

[

See more recommendations

](https://medium.com/?source=post_page---read_next_recirc—17f4601be548---------------------------------------)