Skip to content

StableAdamW: AdamW with Update Clipping

StableAdamW is a AdamW-Adafactor hybrid, porting Adafactor’s update clipping into AdamW as a per parameter learning rate modification. StableAdamW’s update clipping outperforms gradient clipping on downstream tasks while avoiding model training instability.

StableAdamW was introduced by Wortsman et al in Stable and low-precision training for large-scale vision-language models.

Hyperparameters

StableAdamW is a drop-in replacement for AdamW and uses the same hyperparameters, with one exception: StableAdamW removes the need for gradient clipping.

If training on large batch sizes or still observing training loss spikes, consider reducing \(\beta_2\) between \([0.95, 0.99)\).

optimi’s implementation of StableAdamW also supports fully decoupled weight decay decouple_lr=True. The default weight decay of 0.01 will likely need to be reduced when using fully decoupled weight decay as the learning rate will not modify the effective weight decay.

StableAdamW

StableAdamW optimizer. An AdamW-Adafactor hybrid with learning rate update clipping.

Parameters:

Name Type Description Default
params Iterable[Tensor] | Iterable[dict]

Iterable of parameters to optimize or dicts defining parameter groups

required
lr float

Learning rate

required
betas tuple[float, float]

Coefficients for gradient and squared gradient moving averages (default: (0.9, 0.99))

(0.9, 0.99)
weight_decay float

Weight decay coefficient. If decouple_lr is False, applies decoupled weight decay (default: 1e-2)

0.01
eps float

Added to denominator to improve numerical stability (default: 1e-6)

1e-06
decouple_lr bool

Apply fully decoupled weight decay instead of decoupled weight decay (default: False)

False
max_lr float | None

Maximum scheduled learning rate. Set if lr is not the maximum scheduled learning rate and decouple_lr is True (default: None)

None
kahan_sum bool | None

Enables Kahan summation for more accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters (default: None)

None
foreach bool | None

Enables the foreach implementation. If unspecified, tries to use foreach over for-loop implementation since it is significantly faster (default: None)

None
gradient_release bool

Fuses optimizer step and zero_grad as part of the parameter's backward pass. Requires model hooks created with register_gradient_release. Incompatible with closure (default: False)

False

Algorithm

StableAdam with decoupled weight decay (StableAdamW).

\[ \begin{aligned} &\rule{100mm}{0.4pt}\\ &\hspace{2mm} \textbf{\textcolor{#9a3fe4}{Stable}AdamW} \\ &\hspace{5mm} \text{inputs} : \bm{\theta}_0 \: \text{(params)}; \: f(\bm{\theta}) \text{(objective)}; \: \gamma_t \:\text{(learning rate at } t \text{)}; \\ &\hspace{17.25mm} \beta_1, \beta_2 \: \text{(betas)}; \: \lambda \: \text{(weight decay)}; \: \epsilon \: \text{(epsilon)}\\ &\hspace{5mm} \text{initialize} : \bm{m}_{0} \leftarrow \bm{0}; \: \bm{v}_{0} \leftarrow \bm{0}\\[-0.5em] &\rule{100mm}{0.4pt}\\ &\hspace{5mm} \textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}\text{:}\\ &\hspace{10mm} \bm{g}_t \leftarrow \nabla_{\theta} f_t(\bm{\theta}_{t-1})\\[0.5em] &\hspace{10mm} \bm{m}_t \leftarrow \beta_1 \bm{m}_{t-1} + (1 - \beta_1) \bm{g}_t\\ &\hspace{10mm} \bm{v}_t \leftarrow \beta_2 \bm{v}_{t-1} + (1 - \beta_2) \bm{g}^2_t\\[0.5em] &\hspace{10mm} \hat{\bm{m}}_t \leftarrow \bm{m}_t/(1 - \beta_1^t)\\ &\hspace{10mm} \hat{\bm{v}}_t \leftarrow \bm{v}_t/(1 - \beta_2^t)\\[0.5em] &\hspace{10mm} \textcolor{#9a3fe4}{\textbf{RMS}_t \leftarrow \sqrt{\mathbb{E[\bm{g}^2_t/\text{max}(\bm{v}_t, \epsilon^2)]}}}\\ &\hspace{10mm} \textcolor{#9a3fe4}{\bm{\eta}_t \leftarrow \gamma_t/\text{max}(1,\textbf{RMS}_t)}\\[0.5em] &\hspace{10mm} \bm{\theta}_t \leftarrow \bm{\theta}_{t-1} - \textcolor{#9a3fe4}{\bm{\eta}_t} \bigl( \hat{\bm{m}}_t / (\sqrt{\hat{\bm{v}}_t} + \epsilon) + \lambda\bm{\theta}_{t-1} \bigr)\\[-0.5em] &\rule{100mm}{0.4pt}\\ \end{aligned} \]

Following Stable and low-precision training for large-scale vision-language models, the \(\text{RMS}_t\) steps occur independantly for each tensor. Likewise, the \(\text{max}(\bm{v}_t, \epsilon^2)\) term, instead of \(\sqrt{\mathbb{E[\bm{g}^2_t/\bm{v}_t]}}\), is added to prevent division by zero issues.

optimi’s StableAdamW also supports fully decoupled weight decay, which is not shown.