diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 1384224b..0ec0eeca 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -485,6 +485,20 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None: has_betas=True, defaults={'trust_clip': True} ), + OptimInfo( + name='lambw', + opt_class=Lamb, + description='LAMB with decoupled weight decay', + has_betas=True, + defaults={'decoupled_decay': True} + ), + OptimInfo( + name='lambcw', + opt_class=Lamb, + description='LAMB with trust ratio clipping for stability and decoupled decay', + has_betas=True, + defaults={'trust_clip': True, 'decoupled_decay': True} + ), OptimInfo( name='lars', opt_class=Lars, @@ -544,6 +558,22 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None: description='Cautious Adopt', defaults={'caution': True} ), + OptimInfo( + name='cadan', + opt_class=Adan, + description='Cautious Adaptive Nesterov Momentum Algorithm', + defaults={'caution': True, 'no_prox': False}, + has_betas=True, + num_betas=3 + ), + OptimInfo( + name='cadanw', + opt_class=Adan, + description='Cautious Adaptive Nesterov Momentum with decoupled weight decay', + defaults={'caution': True, 'no_prox': True}, + has_betas=True, + num_betas=3 + ), OptimInfo( name='cadoptw', opt_class=Adopt, @@ -557,6 +587,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None: has_betas=True, defaults={'caution': True} ), + OptimInfo( + name='clambw', + opt_class=Lamb, + description='Cautious LAMB with decoupled weight decay', + has_betas=True, + defaults={'caution': True, 'decoupled_decay': True} + ), OptimInfo( name='claprop', opt_class=LaProp, diff --git a/timm/optim/adan.py b/timm/optim/adan.py index 94fa9ef2..4db62e9c 100644 --- a/timm/optim/adan.py +++ b/timm/optim/adan.py @@ -20,7 +20,7 @@ Implementation adapted from https://github.com/sail-sg/Adan # limitations under the License. import math -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor @@ -56,6 +56,7 @@ class Adan(Optimizer): eps: Term added to the denominator to improve numerical stability. weight_decay: Decoupled weight decay (L2 penalty) no_prox: How to perform the weight decay + caution: Enable caution from 'Cautious Optimizers' foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory. """ @@ -66,7 +67,8 @@ class Adan(Optimizer): eps: float = 1e-8, weight_decay: float = 0.0, no_prox: bool = False, - foreach: bool = True, + caution: bool = False, + foreach: Optional[bool] = None, ): if not 0.0 <= lr: raise ValueError('Invalid learning rate: {}'.format(lr)) @@ -85,6 +87,7 @@ class Adan(Optimizer): eps=eps, weight_decay=weight_decay, no_prox=no_prox, + caution=caution, foreach=foreach, ) super().__init__(params, defaults) @@ -93,6 +96,7 @@ class Adan(Optimizer): super(Adan, self).__setstate__(state) for group in self.param_groups: group.setdefault('no_prox', False) + group.setdefault('caution', False) @torch.no_grad() def restart_opt(self): @@ -118,6 +122,11 @@ class Adan(Optimizer): with torch.enable_grad(): loss = closure() + try: + has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads() + except: + has_scalar_maximum = False + for group in self.param_groups: params_with_grad = [] grads = [] @@ -161,9 +170,19 @@ class Adan(Optimizer): if not params_with_grad: continue - kwargs = dict( - params=params_with_grad, - grads=grads, + if group['foreach'] is None: + use_foreach = not group['caution'] or has_scalar_maximum + else: + use_foreach = group['foreach'] + + if use_foreach: + func = _multi_tensor_adan + else: + func = _single_tensor_adan + + func( + params_with_grad, + grads, exp_avgs=exp_avgs, exp_avg_sqs=exp_avg_sqs, exp_avg_diffs=exp_avg_diffs, @@ -178,13 +197,9 @@ class Adan(Optimizer): weight_decay=group['weight_decay'], eps=group['eps'], no_prox=group['no_prox'], + caution=group['caution'], ) - if group['foreach']: - _multi_tensor_adan(**kwargs) - else: - _single_tensor_adan(**kwargs) - return loss @@ -206,6 +221,7 @@ def _single_tensor_adan( weight_decay: float, eps: float, no_prox: bool, + caution: bool, ): for i, param in enumerate(params): grad = grads[i] @@ -227,6 +243,12 @@ def _single_tensor_adan( step_size_diff = lr * beta2 / bias_correction2 step_size = lr / bias_correction1 + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg = exp_avg * mask + if no_prox: param.mul_(1 - lr * weight_decay) param.addcdiv_(exp_avg, denom, value=-step_size) @@ -257,6 +279,7 @@ def _multi_tensor_adan( weight_decay: float, eps: float, no_prox: bool, + caution: bool, ): if len(params) == 0: return @@ -282,6 +305,15 @@ def _multi_tensor_adan( step_size_diff = lr * beta2 / bias_correction2 step_size = lr / bias_correction1 + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + masks = torch._foreach_mul(exp_avgs, grads) + masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] + mask_scale = [m.mean() for m in masks] + torch._foreach_maximum_(mask_scale, 1e-3) + torch._foreach_div_(masks, mask_scale) + exp_avgs = torch._foreach_mul(exp_avgs, masks) + if no_prox: torch._foreach_mul_(params, 1 - lr * weight_decay) torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index ee89225e..fa867574 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -94,6 +94,7 @@ class Lamb(Optimizer): trust_clip: bool = False, always_adapt: bool = False, caution: bool = False, + decoupled_decay: bool = False, ): defaults = dict( lr=lr, @@ -106,6 +107,7 @@ class Lamb(Optimizer): trust_clip=trust_clip, always_adapt=always_adapt, caution=caution, + decoupled_decay=decoupled_decay, ) super().__init__(params, defaults) @@ -113,6 +115,7 @@ class Lamb(Optimizer): super().__setstate__(state) for group in self.param_groups: group.setdefault('caution', False) + group.setdefault('decoupled_decay', False) def _get_clip_grad_norm(self): max_grad_norm = self.defaults['max_grad_norm'] @@ -199,7 +202,10 @@ class Lamb(Optimizer): weight_decay = group['weight_decay'] if weight_decay != 0: - update.add_(p, alpha=weight_decay) + if group.get('decoupled_decay', False): + p.add_(p, alpha=-group['lr'] * weight_decay) + else: + update.add_(p, alpha=weight_decay) if weight_decay != 0 or group['always_adapt']: # Layer-wise LR adaptation. By default, skip adaptation on parameters that are diff --git a/timm/optim/mars.py b/timm/optim/mars.py index 11b1cf20..1068ee91 100644 --- a/timm/optim/mars.py +++ b/timm/optim/mars.py @@ -54,10 +54,13 @@ def _mars_single_tensor_step( if c_t_norm > 1.: c_t = c_t / c_t_norm exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1) + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 mask = (exp_avg * grad > 0).to(grad.dtype) mask.div_(mask.mean().clamp_(min=1e-3)) exp_avg = exp_avg * mask + if mars_type == "adamw": exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) bias_correction1 = 1.0 - beta1 ** step