Nesterov lookahead and Muon 1
Momentum-based optimizers are ubiquitous in deep learning, but their implementation details can be subtle. In this post, I’ll walk through classical and Nesterov momentum, comparing PyTorch versions and mathematical formulations and highlighting memory-efficient tricks. We’ll use this foundation to understand a recent change in the Muon optimizer and why it was accompanied by a dramatic shift in the recommended learning rate.
Classical and Nesterov momentum
In stochastic gradient descent (SGD), the classical momentum update is typically given by:
where is the momentum buffer, is the momentum coefficient, is the gradient at , and is the learning rate. Using a tiny framework (see the appendix), we can write a simple PyTorch version:
class ClassicalMomentum(Optimizer): def __init__(self, params: ParamsT, lr: float, momentum: float): defaults = dict(lr=lr, momentum=momentum) super().__init__(params, defaults) self.state_init("momentum_buffer")
def update(self, group, param, state): state["momentum_buffer"].mul_(group["momentum"]).add_(param.grad) param.add_(state["momentum_buffer"], alpha=-group["lr"])
For each parameter, we update the momentum buffer in place on line 8 from to , as in eq. 1. Then, on line 9, we update the parameter itself in place, as in eq. 2. The real PyTorch implementation is the same (except that we’ve ignored dampening).
Sutskever et al. (2013) showed that the Nesterov Accelerated Gradient (NAG) method can be interpreted as a momentum method where the gradient is evaluated at a ‘lookahead’ position:
(In fact, they state both methods slightly differently, but I’ll stick to the PyTorch version.)
We can see how to implement this formulation of Nesterov momentum in PyTorch by rewriting the update in terms of . By definition, we can express as , and we can substitute eq. 4 and into this definition to give:
In other words, we can define an ‘update’ that we scale by the learning rate and subtract from the current position, instead of the momentum buffer as in eq. 4.
However, if we naïvely implement eq. 6, we need two momentum-buffer tensors for each parameter: one to hold and another to hold , requiring extra memory. We can avoid this by writing the update in terms of only the momentum and gradient . Rearranging eq. 3:
With this in mind, we can write a simple PyTorch version with a single momentum buffer:
class NesterovMomentum(Optimizer): def __init__(self, params: ParamsT, lr: float, momentum: float): defaults = dict(lr=lr, momentum=momentum) super().__init__(params, defaults) self.state_init("momentum_buffer")
def update(self, group, param, state): state["momentum_buffer"].mul_(group["momentum"]).add_(param.grad) param.grad.add_(state["momentum_buffer"], alpha=group["momentum"]) param.add_(param.grad, alpha=-group["lr"])
Importantly, we can re-use param.grad
on line 9 to hold the update instead of allocating memory to a new tensor (see PyTorch’s implementation).
This is a useful trick that we’ll see again in Muon.
Muon and ‘average gradient’ momentum
Jordan et al. (2024) originally defined the Muon update (implementation) by:
Let’s focus on the Nesterov case. Like we saw with SGD above, the momentum buffer is an accumulation of past gradients (eq. 9). In our simple optimizer framework, we could write:
class Muon(Optimizer): def __init__(self, params: ParamsT, lr: float, beta: float, mu: float): defaults = dict(lr=lr, beta=beta, mu=mu) super().__init__(params, defaults) self.state_init("momentum_buffer")
def param_step(self, group, param, state): state["momentum_buffer"].mul_(group["momentum"]).add_(param.grad) update = param.grad.add_(state["momentum_buffer"], alpha=group["momentum"]) param.add_(newtonschulz(update), alpha=-group["lr"])
Around the start of November 2024, the Muon update changed such that:
Now, eq. 12 is an exponential moving average (EMA) of the gradient with decay parameter . At the same time, the learning rate used in the nanoGPT speedrun changed dramatically, from 0.0003 to 0.02. What’s going on? Well, it turns out there’s another way to formulate momentum:
This formulation is equivalent to eqs. 1 and 2, but only when and we rescale the learning rate:
Given Muon’s default parameter , this rescaling suggests that we should increase the learning rate by a factor of 20 when moving to the ‘average gradient’ form. (In fact, the previous increase from 0.0003 to 0.02 would result from .) For a proof of the two formulations’ equivalence, see Sec. 3.1 of these lecture notes from Laurence Aitchison (my PhD advisor).
A nice property of the ‘average gradient’ form of momentum is that we can efficiently implement EMAs like eqs. 12 and 15 with torch.lerp
(e.g., in the new Muon update).
Now, we can write:
class Muon(Optimizer): def __init__(self, params: ParamsT, lr: float, beta: float, mu: float): defaults = dict(lr=lr, beta=beta, mu=mu) super().__init__(params, defaults) self.state_init("momentum_buffer")
def param_step(self, group, param, state): state["momentum_buffer"].lerp_(param.grad, 1 - group["momentum"]) update = param.grad.lerp_(state["momentum_buffer"], group["momentum"]) param.add_(newtonschulz(update), alpha=-group["lr"])
Again, we update param.grad
in place to hold , instead of allocating memory to a new tensor.
Conclusion
We’ve seen that implementation details matter: equivalent mathematical formulations can have different computational requirements, and tricks like in-place updates become important at scale.
The two primary ways of expressing momentum—the traditional accumulating sum and the exponential moving average—are equivalent, but only after rescaling the learning rate. This explains the dramatic change in learning rate that accompanied the new Muon implementation.
In the next post, I’ll extend this analysis to lookahead methods and show how to modify lookahead to work with Muon’s orthogonalized weight updates (spoiler alert: it doesn’t seem to help).
Appendix
The code below is a tiny framework I wrote to minimize the number of lines needed to define different momentum-based optimizers in the examples above and provide type hints while doing so. I wouldn’t recommend using it for anything else. You can find all the examples in this post on GitHub Gist.
from collections.abc import Callable, Sequencefrom typing import Any, Literal, Protocol, overload
import torchimport torch.optimfrom torch.optim.optimizer import ParamsT
class ParamGroup(Protocol): @overload def __getitem__(self, key: Literal["params"]) -> Sequence[torch.Tensor]: ...
@overload def __getitem__(self, key: str) -> Any: ...
class Optimizer(torch.optim.Optimizer): param_groups: Sequence[ParamGroup] state: dict[torch.Tensor, dict[str, torch.Tensor]]
def __init__(self, params: ParamsT, defaults: dict): super().__init__(params, defaults)
@torch.no_grad() def step(self, closure: Callable[[], float] | None = None) -> float | None: loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for param in group["params"]: self.param_step(group, param, self.state[param]) return loss
def param_step( self, group: ParamGroup, param: torch.Tensor, state: dict[str, torch.Tensor] ) -> None: ...
def state_init(self, *keys: str): for group in self.param_groups: for param in group["params"]: state = self.state[param] for key in keys: if key not in state: state[key] = torch.zeros_like(param)
# placeholder for zero-power via Newton-Schulz iterationnewtonschulz = lambda x: x