diff --git a/timm/optim/kron.py b/timm/optim/kron.py index 25c1b047..533354ec 100644 --- a/timm/optim/kron.py +++ b/timm/optim/kron.py @@ -95,7 +95,9 @@ class Kron(torch.optim.Optimizer): precond_dtype: Dtype of the preconditioner. decoupled_decay: AdamW style decoupled weight decay flatten: Flatten dimensions instead of fully relying on expressions for higher rank params - flatten_start_end: Range of dimensions to flatten, defaults to (2, -1). + flatten_start_dim: Start of flatten range, defaults to 2. Seems good tradeoff for ConvNets. + flatten_end_dim: End of flatten range, defaults to -1. + stochastic_weight_decay: Enable random modulation of weight decay deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work """ @@ -118,6 +120,7 @@ class Kron(torch.optim.Optimizer): flatten: bool = False, flatten_start_dim: int = 2, flatten_end_dim: int = -1, + stochastic_weight_decay: bool = False, deterministic: bool = False, ): if not has_opt_einsum: @@ -147,6 +150,7 @@ class Kron(torch.optim.Optimizer): flatten=flatten, flatten_start_dim=flatten_start_dim, flatten_end_dim=flatten_end_dim, + stochastic_weight_decay=stochastic_weight_decay, ) super(Kron, self).__init__(params, defaults) @@ -353,11 +357,15 @@ class Kron(torch.optim.Optimizer): pre_grad = pre_grad.view(p.shape) # Apply weight decay - if group["weight_decay"] != 0: + weight_decay = group["weight_decay"] + if weight_decay != 0: + if group["stochastic_weight_decay"]: + weight_decay = 2 * self.rng.random() * weight_decay + if group["decoupled_decay"]: - p.mul_(1. - group["lr"] * group["weight_decay"]) + p.mul_(1. - group["lr"] * weight_decay) else: - pre_grad.add_(p, alpha=group["weight_decay"]) + pre_grad.add_(p, alpha=weight_decay) # Update parameters p.add_(pre_grad, alpha=-group["lr"])