StableAdamW: AdamW with Update Clipping¶
StableAdamW is a AdamW-Adafactor hybrid, porting Adafactor’s update clipping into AdamW as a per parameter learning rate modification. StableAdamW’s update clipping outperforms gradient clipping on downstream tasks while avoiding model training instability.
StableAdamW was introduced by Wortsman et al in Stable and low-precision training for large-scale vision-language models.
Hyperparameters¶
StableAdamW is a drop-in replacement for AdamW and uses the same hyperparameters, with one exception: StableAdamW removes the need for gradient clipping.
If training on large batch sizes or still observing training loss spikes, consider reducing \(\beta_2\) between \([0.95, 0.99)\).
optimi’s implementation of StableAdamW also supports fully decoupled weight decay decouple_lr=True
. The default weight decay of 0.01 will likely need to be reduced when using fully decoupled weight decay as the learning rate will not modify the effective weight decay.
StableAdamW ¶
StableAdamW optimizer. An AdamW-Adafactor hybrid with learning rate update clipping.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Iterable[Tensor] | Iterable[dict]
|
Iterable of parameters to optimize or dicts defining parameter groups |
required |
lr |
float
|
Learning rate |
required |
betas |
tuple[float, float]
|
Coefficients for gradient and squared gradient moving averages (default: (0.9, 0.99)) |
(0.9, 0.99)
|
weight_decay |
float
|
Weight decay coefficient. If |
0.01
|
eps |
float
|
Added to denominator to improve numerical stability (default: 1e-6) |
1e-06
|
decouple_lr |
bool
|
Apply fully decoupled weight decay instead of decoupled weight decay (default: False) |
False
|
max_lr |
float | None
|
Maximum scheduled learning rate. Set if |
None
|
kahan_sum |
bool | None
|
Enables Kahan summation for more accurate parameter updates when training in low precision (float16 or bfloat16). If unspecified, automatically applies for low precision parameters (default: None) |
None
|
foreach |
bool | None
|
Enables the foreach implementation. If unspecified, tries to use foreach over for-loop implementation since it is significantly faster (default: None) |
None
|
gradient_release |
bool
|
Fuses optimizer step and zero_grad as part of the parameter's backward
pass. Requires model hooks created with |
False
|
Algorithm¶
StableAdam with decoupled weight decay (StableAdamW).
Following Stable and low-precision training for large-scale vision-language models, the \(\text{RMS}_t\) steps occur independantly for each tensor. Likewise, the \(\text{max}(\bm{v}_t, \epsilon^2)\) term, instead of \(\sqrt{\mathbb{E[\bm{g}^2_t/\bm{v}_t]}}\), is added to prevent division by zero issues.
optimi’s StableAdamW also supports fully decoupled weight decay, which is not shown.