mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add caution to Adan. Add decouple decay option to LAMB.
This commit is contained in:
parent
553ded5c6b
commit
afdf11d9ae
@ -485,6 +485,20 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
|
|||||||
has_betas=True,
|
has_betas=True,
|
||||||
defaults={'trust_clip': True}
|
defaults={'trust_clip': True}
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='lambw',
|
||||||
|
opt_class=Lamb,
|
||||||
|
description='LAMB with decoupled weight decay',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'decoupled_decay': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='lambcw',
|
||||||
|
opt_class=Lamb,
|
||||||
|
description='LAMB with trust ratio clipping for stability and decoupled decay',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'trust_clip': True, 'decoupled_decay': True}
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='lars',
|
name='lars',
|
||||||
opt_class=Lars,
|
opt_class=Lars,
|
||||||
@ -544,6 +558,22 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
description='Cautious Adopt',
|
description='Cautious Adopt',
|
||||||
defaults={'caution': True}
|
defaults={'caution': True}
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cadan',
|
||||||
|
opt_class=Adan,
|
||||||
|
description='Cautious Adaptive Nesterov Momentum Algorithm',
|
||||||
|
defaults={'caution': True, 'no_prox': False},
|
||||||
|
has_betas=True,
|
||||||
|
num_betas=3
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cadanw',
|
||||||
|
opt_class=Adan,
|
||||||
|
description='Cautious Adaptive Nesterov Momentum with decoupled weight decay',
|
||||||
|
defaults={'caution': True, 'no_prox': True},
|
||||||
|
has_betas=True,
|
||||||
|
num_betas=3
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='cadoptw',
|
name='cadoptw',
|
||||||
opt_class=Adopt,
|
opt_class=Adopt,
|
||||||
@ -557,6 +587,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
has_betas=True,
|
has_betas=True,
|
||||||
defaults={'caution': True}
|
defaults={'caution': True}
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='clambw',
|
||||||
|
opt_class=Lamb,
|
||||||
|
description='Cautious LAMB with decoupled weight decay',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'caution': True, 'decoupled_decay': True}
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='claprop',
|
name='claprop',
|
||||||
opt_class=LaProp,
|
opt_class=LaProp,
|
||||||
|
@ -20,7 +20,7 @@ Implementation adapted from https://github.com/sail-sg/Adan
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -56,6 +56,7 @@ class Adan(Optimizer):
|
|||||||
eps: Term added to the denominator to improve numerical stability.
|
eps: Term added to the denominator to improve numerical stability.
|
||||||
weight_decay: Decoupled weight decay (L2 penalty)
|
weight_decay: Decoupled weight decay (L2 penalty)
|
||||||
no_prox: How to perform the weight decay
|
no_prox: How to perform the weight decay
|
||||||
|
caution: Enable caution from 'Cautious Optimizers'
|
||||||
foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
|
foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -66,7 +67,8 @@ class Adan(Optimizer):
|
|||||||
eps: float = 1e-8,
|
eps: float = 1e-8,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
no_prox: bool = False,
|
no_prox: bool = False,
|
||||||
foreach: bool = True,
|
caution: bool = False,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError('Invalid learning rate: {}'.format(lr))
|
raise ValueError('Invalid learning rate: {}'.format(lr))
|
||||||
@ -85,6 +87,7 @@ class Adan(Optimizer):
|
|||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
no_prox=no_prox,
|
no_prox=no_prox,
|
||||||
|
caution=caution,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
)
|
)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
@ -93,6 +96,7 @@ class Adan(Optimizer):
|
|||||||
super(Adan, self).__setstate__(state)
|
super(Adan, self).__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('no_prox', False)
|
group.setdefault('no_prox', False)
|
||||||
|
group.setdefault('caution', False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def restart_opt(self):
|
def restart_opt(self):
|
||||||
@ -118,6 +122,11 @@ class Adan(Optimizer):
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
loss = closure()
|
loss = closure()
|
||||||
|
|
||||||
|
try:
|
||||||
|
has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
|
||||||
|
except:
|
||||||
|
has_scalar_maximum = False
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
params_with_grad = []
|
params_with_grad = []
|
||||||
grads = []
|
grads = []
|
||||||
@ -161,9 +170,19 @@ class Adan(Optimizer):
|
|||||||
if not params_with_grad:
|
if not params_with_grad:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
kwargs = dict(
|
if group['foreach'] is None:
|
||||||
params=params_with_grad,
|
use_foreach = not group['caution'] or has_scalar_maximum
|
||||||
grads=grads,
|
else:
|
||||||
|
use_foreach = group['foreach']
|
||||||
|
|
||||||
|
if use_foreach:
|
||||||
|
func = _multi_tensor_adan
|
||||||
|
else:
|
||||||
|
func = _single_tensor_adan
|
||||||
|
|
||||||
|
func(
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
exp_avgs=exp_avgs,
|
exp_avgs=exp_avgs,
|
||||||
exp_avg_sqs=exp_avg_sqs,
|
exp_avg_sqs=exp_avg_sqs,
|
||||||
exp_avg_diffs=exp_avg_diffs,
|
exp_avg_diffs=exp_avg_diffs,
|
||||||
@ -178,13 +197,9 @@ class Adan(Optimizer):
|
|||||||
weight_decay=group['weight_decay'],
|
weight_decay=group['weight_decay'],
|
||||||
eps=group['eps'],
|
eps=group['eps'],
|
||||||
no_prox=group['no_prox'],
|
no_prox=group['no_prox'],
|
||||||
|
caution=group['caution'],
|
||||||
)
|
)
|
||||||
|
|
||||||
if group['foreach']:
|
|
||||||
_multi_tensor_adan(**kwargs)
|
|
||||||
else:
|
|
||||||
_single_tensor_adan(**kwargs)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
@ -206,6 +221,7 @@ def _single_tensor_adan(
|
|||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
no_prox: bool,
|
no_prox: bool,
|
||||||
|
caution: bool,
|
||||||
):
|
):
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
grad = grads[i]
|
grad = grads[i]
|
||||||
@ -227,6 +243,12 @@ def _single_tensor_adan(
|
|||||||
step_size_diff = lr * beta2 / bias_correction2
|
step_size_diff = lr * beta2 / bias_correction2
|
||||||
step_size = lr / bias_correction1
|
step_size = lr / bias_correction1
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
exp_avg = exp_avg * mask
|
||||||
|
|
||||||
if no_prox:
|
if no_prox:
|
||||||
param.mul_(1 - lr * weight_decay)
|
param.mul_(1 - lr * weight_decay)
|
||||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
@ -257,6 +279,7 @@ def _multi_tensor_adan(
|
|||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
no_prox: bool,
|
no_prox: bool,
|
||||||
|
caution: bool,
|
||||||
):
|
):
|
||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
return
|
return
|
||||||
@ -282,6 +305,15 @@ def _multi_tensor_adan(
|
|||||||
step_size_diff = lr * beta2 / bias_correction2
|
step_size_diff = lr * beta2 / bias_correction2
|
||||||
step_size = lr / bias_correction1
|
step_size = lr / bias_correction1
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
masks = torch._foreach_mul(exp_avgs, grads)
|
||||||
|
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
|
||||||
|
mask_scale = [m.mean() for m in masks]
|
||||||
|
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||||
|
torch._foreach_div_(masks, mask_scale)
|
||||||
|
exp_avgs = torch._foreach_mul(exp_avgs, masks)
|
||||||
|
|
||||||
if no_prox:
|
if no_prox:
|
||||||
torch._foreach_mul_(params, 1 - lr * weight_decay)
|
torch._foreach_mul_(params, 1 - lr * weight_decay)
|
||||||
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
|
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
|
||||||
|
@ -94,6 +94,7 @@ class Lamb(Optimizer):
|
|||||||
trust_clip: bool = False,
|
trust_clip: bool = False,
|
||||||
always_adapt: bool = False,
|
always_adapt: bool = False,
|
||||||
caution: bool = False,
|
caution: bool = False,
|
||||||
|
decoupled_decay: bool = False,
|
||||||
):
|
):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
@ -106,6 +107,7 @@ class Lamb(Optimizer):
|
|||||||
trust_clip=trust_clip,
|
trust_clip=trust_clip,
|
||||||
always_adapt=always_adapt,
|
always_adapt=always_adapt,
|
||||||
caution=caution,
|
caution=caution,
|
||||||
|
decoupled_decay=decoupled_decay,
|
||||||
)
|
)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
@ -113,6 +115,7 @@ class Lamb(Optimizer):
|
|||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('caution', False)
|
group.setdefault('caution', False)
|
||||||
|
group.setdefault('decoupled_decay', False)
|
||||||
|
|
||||||
def _get_clip_grad_norm(self):
|
def _get_clip_grad_norm(self):
|
||||||
max_grad_norm = self.defaults['max_grad_norm']
|
max_grad_norm = self.defaults['max_grad_norm']
|
||||||
@ -199,7 +202,10 @@ class Lamb(Optimizer):
|
|||||||
|
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group['weight_decay']
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
update.add_(p, alpha=weight_decay)
|
if group.get('decoupled_decay', False):
|
||||||
|
p.add_(p, alpha=-group['lr'] * weight_decay)
|
||||||
|
else:
|
||||||
|
update.add_(p, alpha=weight_decay)
|
||||||
|
|
||||||
if weight_decay != 0 or group['always_adapt']:
|
if weight_decay != 0 or group['always_adapt']:
|
||||||
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
|
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
|
||||||
|
@ -54,10 +54,13 @@ def _mars_single_tensor_step(
|
|||||||
if c_t_norm > 1.:
|
if c_t_norm > 1.:
|
||||||
c_t = c_t / c_t_norm
|
c_t = c_t / c_t_norm
|
||||||
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
|
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
|
||||||
|
|
||||||
if caution:
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
exp_avg = exp_avg * mask
|
exp_avg = exp_avg * mask
|
||||||
|
|
||||||
if mars_type == "adamw":
|
if mars_type == "adamw":
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
|
||||||
bias_correction1 = 1.0 - beta1 ** step
|
bias_correction1 = 1.0 - beta1 ** step
|
||||||
|
Loading…
x
Reference in New Issue
Block a user