mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix comment, add 'stochastic weight decay' idea because why not
This commit is contained in:
parent
5940cc167f
commit
5f85f8eefa
@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer):
|
|||||||
precond_dtype: Dtype of the preconditioner.
|
precond_dtype: Dtype of the preconditioner.
|
||||||
decoupled_decay: AdamW style decoupled weight decay
|
decoupled_decay: AdamW style decoupled weight decay
|
||||||
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
|
flatten: Flatten dimensions instead of fully relying on expressions for higher rank params
|
||||||
flatten_start_end: Range of dimensions to flatten, defaults to (2, -1).
|
flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets.
|
||||||
|
flatten_end_dim: End of flatten range, defaults to -1.
|
||||||
|
stochastic_weight_decay: Enable random modulation of weight decay
|
||||||
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
|
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -118,6 +120,7 @@ class Kron(torch.optim.Optimizer):
|
|||||||
flatten: bool = False,
|
flatten: bool = False,
|
||||||
flatten_start_dim: int = 2,
|
flatten_start_dim: int = 2,
|
||||||
flatten_end_dim: int = -1,
|
flatten_end_dim: int = -1,
|
||||||
|
stochastic_weight_decay: bool = False,
|
||||||
deterministic: bool = False,
|
deterministic: bool = False,
|
||||||
):
|
):
|
||||||
if not has_opt_einsum:
|
if not has_opt_einsum:
|
||||||
@ -147,6 +150,7 @@ class Kron(torch.optim.Optimizer):
|
|||||||
flatten=flatten,
|
flatten=flatten,
|
||||||
flatten_start_dim=flatten_start_dim,
|
flatten_start_dim=flatten_start_dim,
|
||||||
flatten_end_dim=flatten_end_dim,
|
flatten_end_dim=flatten_end_dim,
|
||||||
|
stochastic_weight_decay=stochastic_weight_decay,
|
||||||
)
|
)
|
||||||
super(Kron, self).__init__(params, defaults)
|
super(Kron, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -353,11 +357,15 @@ class Kron(torch.optim.Optimizer):
|
|||||||
pre_grad = pre_grad.view(p.shape)
|
pre_grad = pre_grad.view(p.shape)
|
||||||
|
|
||||||
# Apply weight decay
|
# Apply weight decay
|
||||||
if group["weight_decay"] != 0:
|
weight_decay = group["weight_decay"]
|
||||||
|
if weight_decay != 0:
|
||||||
|
if group["stochastic_weight_decay"]:
|
||||||
|
weight_decay = 2 * self.rng.random() * weight_decay
|
||||||
|
|
||||||
if group["decoupled_decay"]:
|
if group["decoupled_decay"]:
|
||||||
p.mul_(1. - group["lr"] * group["weight_decay"])
|
p.mul_(1. - group["lr"] * weight_decay)
|
||||||
else:
|
else:
|
||||||
pre_grad.add_(p, alpha=group["weight_decay"])
|
pre_grad.add_(p, alpha=weight_decay)
|
||||||
|
|
||||||
# Update parameters
|
# Update parameters
|
||||||
p.add_(pre_grad, alpha=-group["lr"])
|
p.add_(pre_grad, alpha=-group["lr"])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user