Tim Lawson

Nesterov lookahead and Muon 1

Momentum-based optimizers are ubiquitous in deep learning, but their implementation details can be subtle. In this post, I’ll walk through classical and Nesterov momentum, comparing PyTorch versions and mathematical formulations and highlighting memory-efficient tricks. We’ll use this foundation to understand a recent change in the Muon optimizer and why it was accompanied by a dramatic shift in the recommended learning rate.

Classical and Nesterov momentum

In stochastic gradient descent (SGD), the classical momentum update is typically given by:

vt+1=μvt+f(θt)θt+1=θtεvt+1\begin{align} v_{t+1} &= \mu v_t + \nabla f(\theta_t) \\ \theta_{t+1} &= \theta_t - \varepsilon v_{t+1} \end{align}

where vv is the momentum buffer, μ\mu is the momentum coefficient, f(θt)\nabla f(\theta_t) is the gradient at θt\theta_t, and ε\varepsilon is the learning rate. Using a tiny framework (see the appendix), we can write a simple PyTorch version:

class ClassicalMomentum(Optimizer):
def __init__(self, params: ParamsT, lr: float, momentum: float):
defaults = dict(lr=lr, momentum=momentum)
super().__init__(params, defaults)
self.state_init("momentum_buffer")
def update(self, group, param, state):
state["momentum_buffer"].mul_(group["momentum"]).add_(param.grad)
param.add_(state["momentum_buffer"], alpha=-group["lr"])

For each parameter, we update the momentum buffer in place on line 8 from vtv_t to vt+1v_{t+1}, as in eq. 1. Then, on line 9, we update the parameter itself in place, as in eq. 2. The real PyTorch implementation is the same (except that we’ve ignored dampening).

Sutskever et al. (2013) showed that the Nesterov Accelerated Gradient (NAG) method can be interpreted as a momentum method where the gradient is evaluated at a ‘lookahead’ position:

vt+1=μvt+f(θtεμvt)θt+1=θtεvt+1.\begin{align} v_{t+1} &= \mu v_t + \nabla f(\theta_t - \varepsilon \mu v_t) \\ \theta_{t+1} &= \theta_t - \varepsilon v_{t+1}. \end{align}

(In fact, they state both methods slightly differently, but I’ll stick to the PyTorch version.)

We can see how to implement this formulation of Nesterov momentum in PyTorch by rewriting the update in terms of θt=θtεμvt\theta_t' = \theta_t - \varepsilon \mu v_t. By definition, we can express θt+1\theta_{t+1}' as θt+1εμvt+1\theta_{t+1} - \varepsilon \mu v_{t+1}, and we can substitute eq. 4 and θt=θt+εμvt\theta_t = \theta_t' + \varepsilon \mu v_t into this definition to give:

θt+1=θtεvt+1εμvt+1=θtε((1+μ)vt+1μvt).\begin{align} \theta_{t+1}' &= \theta_t - \varepsilon v_{t+1} - \varepsilon \mu v_{t+1} \\ &= \theta_t' - \varepsilon \left( (1 + \mu) v_{t+1} - \mu v_t \right). \end{align}

In other words, we can define an ‘update’ zt+1=(1+μ)vt+1μvtz_{t+1} = (1 + \mu) v_{t+1} - \mu v_t that we scale by the learning rate and subtract from the current position, instead of the momentum buffer vt+1v_{t+1} as in eq. 4.

However, if we naïvely implement eq. 6, we need two momentum-buffer tensors for each parameter: one to hold vtv_t and another to hold vt+1v_{t+1}, requiring extra memory. We can avoid this by writing the update zt+1z_{t+1} in terms of only the momentum vt+1v_{t+1} and gradient f(θt)\nabla f(\theta_t). Rearranging eq. 3:

vt=1μ(vt+1f(θt))zt+1=μvt+1+f(θt).\begin{align} v_t &= \frac{1}{\mu} \left( v_{t+1} - \nabla f(\theta_t') \right) \\ z_{t+1} &= \mu v_{t+1} + \nabla f(\theta_t'). \end{align}

With this in mind, we can write a simple PyTorch version with a single momentum buffer:

class NesterovMomentum(Optimizer):
def __init__(self, params: ParamsT, lr: float, momentum: float):
defaults = dict(lr=lr, momentum=momentum)
super().__init__(params, defaults)
self.state_init("momentum_buffer")
def update(self, group, param, state):
state["momentum_buffer"].mul_(group["momentum"]).add_(param.grad)
param.grad.add_(state["momentum_buffer"], alpha=group["momentum"])
param.add_(param.grad, alpha=-group["lr"])

Importantly, we can re-use param.grad on line 9 to hold the update zt+1z_{t+1} instead of allocating memory to a new tensor (see PyTorch’s implementation). This is a useful trick that we’ll see again in Muon.

Muon and ‘average gradient’ momentum

Jordan et al. (2024) originally defined the Muon update (implementation) by:

vt+1=μvt+f(θt)zt+1={μvt+1+f(θt)if Nesterovvt+1otherwiseθt+1=θtεNewtonSchulz5(zt+1).\begin{align} v_{t+1} &= \mu v_t + \nabla f(\theta_t) \\ z_{t+1} &= \begin{cases} \mu v_{t+1} + \nabla f(\theta_t) & \text{if Nesterov} \\ v_{t+1} & \text{otherwise} \end{cases} \\ \theta_{t+1} &= \theta_t - \varepsilon \operatorname{NewtonSchulz5}(z_{t+1}). \end{align}

Let’s focus on the Nesterov case. Like we saw with SGD above, the momentum buffer vt+1v_{t+1} is an accumulation of past gradients (eq. 9). In our simple optimizer framework, we could write:

class Muon(Optimizer):
def __init__(self, params: ParamsT, lr: float, beta: float, mu: float):
defaults = dict(lr=lr, beta=beta, mu=mu)
super().__init__(params, defaults)
self.state_init("momentum_buffer")
def param_step(self, group, param, state):
state["momentum_buffer"].mul_(group["momentum"]).add_(param.grad)
update = param.grad.add_(state["momentum_buffer"], alpha=group["momentum"])
param.add_(newtonschulz(update), alpha=-group["lr"])

Around the start of November 2024, the Muon update changed such that:

vt+1=βvt+(1β)f(θt)zt+1={βvt+1+(1β)f(θt)if Nesterovvt+1otherwiseθt+1=θtεNewtonSchulz5(zt+1).\begin{align} v_{t+1} &= \beta v_t + (1 - \beta) \nabla f(\theta_t) \\ z_{t+1} &= \begin{cases} \beta v_{t+1} + (1 - \beta) \nabla f(\theta_t) & \text{if Nesterov} \\ v_{t+1} & \text{otherwise} \end{cases} \\ \theta_{t+1} &= \theta_t - \varepsilon \operatorname{NewtonSchulz5}(z_{t+1}). \end{align}

Now, eq. 12 is an exponential moving average (EMA) of the gradient f(θt)\nabla f(\theta_t) with decay parameter β\beta. At the same time, the learning rate used in the nanoGPT speedrun changed dramatically, from 0.0003 to 0.02. What’s going on? Well, it turns out there’s another way to formulate momentum:

vt+1=βvt+(1β)f(θt)θt+1=θtεavgvt+1.\begin{align} v_{t+1} &= \beta v_t + (1 - \beta) \nabla f(\theta_t) \\ \theta_{t+1} &= \theta_t - \varepsilon_\text{avg} v_{t+1}. \end{align}

This formulation is equivalent to eqs. 1 and 2, but only when μ=β\mu = \beta and we rescale the learning rate:

εavg=11βεtrad.\begin{equation} \varepsilon_\text{avg} = \frac{1}{1 - \beta} \varepsilon_\text{trad}. \end{equation}

Given Muon’s default parameter β=0.95\beta = 0.95, this rescaling suggests that we should increase the learning rate by a factor of 20 when moving to the ‘average gradient’ form. (In fact, the previous increase from 0.0003 to 0.02 would result from β=0.985\beta = 0.985.) For a proof of the two formulations’ equivalence, see Sec. 3.1 of these lecture notes from Laurence Aitchison (my PhD advisor).

A nice property of the ‘average gradient’ form of momentum is that we can efficiently implement EMAs like eqs. 12 and 15 with torch.lerp (e.g., in the new Muon update). Now, we can write:

class Muon(Optimizer):
def __init__(self, params: ParamsT, lr: float, beta: float, mu: float):
defaults = dict(lr=lr, beta=beta, mu=mu)
super().__init__(params, defaults)
self.state_init("momentum_buffer")
def param_step(self, group, param, state):
state["momentum_buffer"].lerp_(param.grad, 1 - group["momentum"])
update = param.grad.lerp_(state["momentum_buffer"], group["momentum"])
param.add_(newtonschulz(update), alpha=-group["lr"])

Again, we update param.grad in place to hold zt+1z_{t+1}, instead of allocating memory to a new tensor.

Conclusion

We’ve seen that implementation details matter: equivalent mathematical formulations can have different computational requirements, and tricks like in-place updates become important at scale.

The two primary ways of expressing momentum—the traditional accumulating sum and the exponential moving average—are equivalent, but only after rescaling the learning rate. This explains the dramatic change in learning rate that accompanied the new Muon implementation.

In the next post, I’ll extend this analysis to lookahead methods and show how to modify lookahead to work with Muon’s orthogonalized weight updates (spoiler alert: it doesn’t seem to help).

Appendix

The code below is a tiny framework I wrote to minimize the number of lines needed to define different momentum-based optimizers in the examples above and provide type hints while doing so. I wouldn’t recommend using it for anything else. You can find all the examples in this post on GitHub Gist.

from collections.abc import Callable, Sequence
from typing import Any, Literal, Protocol, overload
import torch
import torch.optim
from torch.optim.optimizer import ParamsT
class ParamGroup(Protocol):
@overload
def __getitem__(self, key: Literal["params"]) -> Sequence[torch.Tensor]: ...
@overload
def __getitem__(self, key: str) -> Any: ...
class Optimizer(torch.optim.Optimizer):
param_groups: Sequence[ParamGroup]
state: dict[torch.Tensor, dict[str, torch.Tensor]]
def __init__(self, params: ParamsT, defaults: dict):
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Callable[[], float] | None = None) -> float | None:
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for param in group["params"]:
self.param_step(group, param, self.state[param])
return loss
def param_step(
self, group: ParamGroup, param: torch.Tensor, state: dict[str, torch.Tensor]
) -> None: ...
def state_init(self, *keys: str):
for group in self.param_groups:
for param in group["params"]:
state = self.state[param]
for key in keys:
if key not in state:
state[key] = torch.zeros_like(param)
# placeholder for zero-power via Newton-Schulz iteration
newtonschulz = lambda x: x