Skip to content

Utilities

param_groups_weight_decay

param_groups_weight_decay(
    model, weight_decay=0.01, additional_layers=None
)

Creates parameter groups excluding bias and normalization layers from weight decay.

Parameters:

Name Type Description Default
model Module

PyTorch model to create parameter groups for

required
weight_decay float

Weight decay coefficient applied to eligible parameters (default: 1e-2)

0.01
additional_layers Iterable[str] | None

Iterable of layer name substrings to exclude from weight decay. Any parameter whose name contains one of these substrings will be excluded from weight decay.

None

Returns:

Type Description
list[dict[str, Any]]

List of two parameter group dictionaries, one with and one without weight decay.

param_groups_weight_decay is adapted from timm's optimizer factory methods.

Examples

param_groups_weight_decay takes a model and returns two optimizer parameter group dictionaries. One with bias and normalization terms without weight decay and another dictionary with the rest of the model parameters with weight decay. The weight_decay passed to param_groups_weight_decay will override the optimizer's default weight decay.

params = param_groups_weight_decay(model, weight_decay=1e-5)
optimizer = StableAdamW(params, decouple_lr=True)

additional_layers parameter allows you to specify additional layer names or name substrings that should be excluded from weight decay. This is useful for excluding specific layers like token embeddings which also benefit from not having weight decay applied.

The parameter accepts an iterable of strings, where each string is matched as a substring against the full parameter name (as returned by model.named_parameters()).

class MiniLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_embeddings = nn.Embedding(1000, 20)
        self.pos_embeddings = nn.Embedding(100, 20)
        self.norm = nn.LayerNorm(20)
        self.layer1 = nn.Linear(20, 30)
        self.layer2 = nn.Linear(30, 1000)

model = MiniLM()

# Exclude token embeddings from weight decay in addition to bias and normalization layers
params = param_groups_weight_decay(
    model,
    weight_decay=1e-5,
    additional_layers=["tok_embeddings"]
)

prepare_for_gradient_release

prepare_for_gradient_release(
    model, optimizer, ignore_existing_hooks=False
)

Register post_accumulate_grad_hooks on parameters for the gradient release optimization step.

Parameters:

Name Type Description Default
model Module

Model to register post_accumulate_grad_hooks. Only registers on parameters with requires_grad=True.

required
optimizer OptimiOptimizer

Optimizer providing the fused optimizer step during the backward pass. Requires optimizer to be initialized with gradient_release=True

required
ignore_existing_hooks bool

If True, ignores existing post_accumulate_grad_hooks on parameters and registers gradient release hooks (default: False)

False

For details on using prepare_for_gradient_release, please see the gradient release docs.

remove_gradient_release

remove_gradient_release(model)

Removes post_accumulate_grad_hooks created by prepare_for_gradient_release.

Parameters:

Name Type Description Default
model Module

Model to remove gradient release post_accumulate_grad_hooks from.

required

For details on using remove_gradient_release, please see the gradient release docs.