mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix comment, add 'stochastic weight decay' idea because why not
This commit is contained in:
parent
5940cc167f
commit
5f85f8eefa
@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
|
||||
precond_dtype: Dtype of the preconditioner.
|
||||
decoupled_decay: AdamW style decoupled weight decay
|
||||
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
|
||||
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
|
||||
flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
|
||||
flatten_end_dim: End of flatten range, defaults to -1.
|
||||
stochastic_weight_decay: Enable random modulation of weight decay
|
||||
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
|
||||
"""
|
||||
|
||||
@ -118,6 +120,7 @@ class Kron(torch.optim.Optimizer):
|
||||
flatten: bool = False,
|
||||
flatten_start_dim: int = 2,
|
||||
flatten_end_dim: int = -1,
|
||||
stochastic_weight_decay: bool = False,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
if not has_opt_einsum:
|
||||
@ -147,6 +150,7 @@ class Kron(torch.optim.Optimizer):
|
||||
flatten=flatten,
|
||||
flatten_start_dim=flatten_start_dim,
|
||||
flatten_end_dim=flatten_end_dim,
|
||||
stochastic_weight_decay=stochastic_weight_decay,
|
||||
)
|
||||
super(Kron, self).__init__(params, defaults)
|
||||
|
||||
@ -353,11 +357,15 @@ class Kron(torch.optim.Optimizer):
|
||||
pre_grad = pre_grad.view(p.shape)
|
||||
|
||||
# Apply weight decay
|
||||
if group["weight_decay"] != 0:
|
||||
weight_decay = group["weight_decay"]
|
||||
if weight_decay != 0:
|
||||
if group["stochastic_weight_decay"]:
|
||||
weight_decay = 2 * self.rng.random() * weight_decay
|
||||
|
||||
if group["decoupled_decay"]:
|
||||
p.mul_(1. - group["lr"] * group["weight_decay"])
|
||||
p.mul_(1. - group["lr"] * weight_decay)
|
||||
else:
|
||||
pre_grad.add_(p, alpha=group["weight_decay"])
|
||||
pre_grad.add_(p, alpha=weight_decay)
|
||||
|
||||
# Update parameters
|
||||
p.add_(pre_grad, alpha=-group["lr"])
|
||||
|
Loading…
x
Reference in New Issue
Block a user