Fix comment, add 'stochastic weight decay' idea because why not

This commit is contained in:
Ross Wightman 2025-01-30 15:42:27 -08:00
parent 5940cc167f
commit 5f85f8eefa

View File

@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled weight decay
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
flatten_end_dim: End of flatten range, defaults to -1.
stochastic_weight_decay: Enable random modulation of weight decay
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
"""
@ -118,6 +120,7 @@ class Kron(torch.optim.Optimizer):
flatten: bool = False,
flatten_start_dim: int = 2,
flatten_end_dim: int = -1,
stochastic_weight_decay: bool = False,
deterministic: bool = False,
):
if not has_opt_einsum:
@ -147,6 +150,7 @@ class Kron(torch.optim.Optimizer):
flatten=flatten,
flatten_start_dim=flatten_start_dim,
flatten_end_dim=flatten_end_dim,
stochastic_weight_decay=stochastic_weight_decay,
)
super(Kron, self).__init__(params, defaults)
@ -353,11 +357,15 @@ class Kron(torch.optim.Optimizer):
pre_grad = pre_grad.view(p.shape)
# Apply weight decay
if group["weight_decay"] != 0:
weight_decay = group["weight_decay"]
if weight_decay != 0:
if group["stochastic_weight_decay"]:
weight_decay = 2 * self.rng.random() * weight_decay
if group["decoupled_decay"]:
p.mul_(1. - group["lr"] * group["weight_decay"])
p.mul_(1. - group["lr"] * weight_decay)
else:
pre_grad.add_(p, alpha=group["weight_decay"])
pre_grad.add_(p, alpha=weight_decay)
# Update parameters
p.add_(pre_grad, alpha=-group["lr"])