Add caution to Adan. Add decouple decay option to LAMB.

This commit is contained in:
Ross Wightman 2024-12-05 13:50:30 -08:00
parent 553ded5c6b
commit afdf11d9ae
4 changed files with 89 additions and 11 deletions

View File

@ -485,6 +485,20 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
has_betas=True, has_betas=True,
defaults={'trust_clip': True} defaults={'trust_clip': True}
), ),
OptimInfo(
name='lambw',
opt_class=Lamb,
description='LAMB with decoupled weight decay',
has_betas=True,
defaults={'decoupled_decay': True}
),
OptimInfo(
name='lambcw',
opt_class=Lamb,
description='LAMB with trust ratio clipping for stability and decoupled decay',
has_betas=True,
defaults={'trust_clip': True, 'decoupled_decay': True}
),
OptimInfo( OptimInfo(
name='lars', name='lars',
opt_class=Lars, opt_class=Lars,
@ -544,6 +558,22 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
description='Cautious Adopt', description='Cautious Adopt',
defaults={'caution': True} defaults={'caution': True}
), ),
OptimInfo(
name='cadan',
opt_class=Adan,
description='Cautious Adaptive Nesterov Momentum Algorithm',
defaults={'caution': True, 'no_prox': False},
has_betas=True,
num_betas=3
),
OptimInfo(
name='cadanw',
opt_class=Adan,
description='Cautious Adaptive Nesterov Momentum with decoupled weight decay',
defaults={'caution': True, 'no_prox': True},
has_betas=True,
num_betas=3
),
OptimInfo( OptimInfo(
name='cadoptw', name='cadoptw',
opt_class=Adopt, opt_class=Adopt,
@ -557,6 +587,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True, has_betas=True,
defaults={'caution': True} defaults={'caution': True}
), ),
OptimInfo(
name='clambw',
opt_class=Lamb,
description='Cautious LAMB with decoupled weight decay',
has_betas=True,
defaults={'caution': True, 'decoupled_decay': True}
),
OptimInfo( OptimInfo(
name='claprop', name='claprop',
opt_class=LaProp, opt_class=LaProp,

View File

@ -20,7 +20,7 @@ Implementation adapted from https://github.com/sail-sg/Adan
# limitations under the License. # limitations under the License.
import math import math
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
@ -56,6 +56,7 @@ class Adan(Optimizer):
eps: Term added to the denominator to improve numerical stability. eps: Term added to the denominator to improve numerical stability.
weight_decay: Decoupled weight decay (L2 penalty) weight_decay: Decoupled weight decay (L2 penalty)
no_prox: How to perform the weight decay no_prox: How to perform the weight decay
caution: Enable caution from 'Cautious Optimizers'
foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory. foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
""" """
@ -66,7 +67,8 @@ class Adan(Optimizer):
eps: float = 1e-8, eps: float = 1e-8,
weight_decay: float = 0.0, weight_decay: float = 0.0,
no_prox: bool = False, no_prox: bool = False,
foreach: bool = True, caution: bool = False,
foreach: Optional[bool] = None,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr)) raise ValueError('Invalid learning rate: {}'.format(lr))
@ -85,6 +87,7 @@ class Adan(Optimizer):
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
no_prox=no_prox, no_prox=no_prox,
caution=caution,
foreach=foreach, foreach=foreach,
) )
super().__init__(params, defaults) super().__init__(params, defaults)
@ -93,6 +96,7 @@ class Adan(Optimizer):
super(Adan, self).__setstate__(state) super(Adan, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('no_prox', False) group.setdefault('no_prox', False)
group.setdefault('caution', False)
@torch.no_grad() @torch.no_grad()
def restart_opt(self): def restart_opt(self):
@ -118,6 +122,11 @@ class Adan(Optimizer):
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
try:
has_scalar_maximum = 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except:
has_scalar_maximum = False
for group in self.param_groups: for group in self.param_groups:
params_with_grad = [] params_with_grad = []
grads = [] grads = []
@ -161,9 +170,19 @@ class Adan(Optimizer):
if not params_with_grad: if not params_with_grad:
continue continue
kwargs = dict( if group['foreach'] is None:
params=params_with_grad, use_foreach = not group['caution'] or has_scalar_maximum
grads=grads, else:
use_foreach = group['foreach']
if use_foreach:
func = _multi_tensor_adan
else:
func = _single_tensor_adan
func(
params_with_grad,
grads,
exp_avgs=exp_avgs, exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs, exp_avg_sqs=exp_avg_sqs,
exp_avg_diffs=exp_avg_diffs, exp_avg_diffs=exp_avg_diffs,
@ -178,13 +197,9 @@ class Adan(Optimizer):
weight_decay=group['weight_decay'], weight_decay=group['weight_decay'],
eps=group['eps'], eps=group['eps'],
no_prox=group['no_prox'], no_prox=group['no_prox'],
caution=group['caution'],
) )
if group['foreach']:
_multi_tensor_adan(**kwargs)
else:
_single_tensor_adan(**kwargs)
return loss return loss
@ -206,6 +221,7 @@ def _single_tensor_adan(
weight_decay: float, weight_decay: float,
eps: float, eps: float,
no_prox: bool, no_prox: bool,
caution: bool,
): ):
for i, param in enumerate(params): for i, param in enumerate(params):
grad = grads[i] grad = grads[i]
@ -227,6 +243,12 @@ def _single_tensor_adan(
step_size_diff = lr * beta2 / bias_correction2 step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1 step_size = lr / bias_correction1
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
if no_prox: if no_prox:
param.mul_(1 - lr * weight_decay) param.mul_(1 - lr * weight_decay)
param.addcdiv_(exp_avg, denom, value=-step_size) param.addcdiv_(exp_avg, denom, value=-step_size)
@ -257,6 +279,7 @@ def _multi_tensor_adan(
weight_decay: float, weight_decay: float,
eps: float, eps: float,
no_prox: bool, no_prox: bool,
caution: bool,
): ):
if len(params) == 0: if len(params) == 0:
return return
@ -282,6 +305,15 @@ def _multi_tensor_adan(
step_size_diff = lr * beta2 / bias_correction2 step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1 step_size = lr / bias_correction1
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(exp_avgs, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
exp_avgs = torch._foreach_mul(exp_avgs, masks)
if no_prox: if no_prox:
torch._foreach_mul_(params, 1 - lr * weight_decay) torch._foreach_mul_(params, 1 - lr * weight_decay)
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)

View File

@ -94,6 +94,7 @@ class Lamb(Optimizer):
trust_clip: bool = False, trust_clip: bool = False,
always_adapt: bool = False, always_adapt: bool = False,
caution: bool = False, caution: bool = False,
decoupled_decay: bool = False,
): ):
defaults = dict( defaults = dict(
lr=lr, lr=lr,
@ -106,6 +107,7 @@ class Lamb(Optimizer):
trust_clip=trust_clip, trust_clip=trust_clip,
always_adapt=always_adapt, always_adapt=always_adapt,
caution=caution, caution=caution,
decoupled_decay=decoupled_decay,
) )
super().__init__(params, defaults) super().__init__(params, defaults)
@ -113,6 +115,7 @@ class Lamb(Optimizer):
super().__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('caution', False) group.setdefault('caution', False)
group.setdefault('decoupled_decay', False)
def _get_clip_grad_norm(self): def _get_clip_grad_norm(self):
max_grad_norm = self.defaults['max_grad_norm'] max_grad_norm = self.defaults['max_grad_norm']
@ -199,7 +202,10 @@ class Lamb(Optimizer):
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
if weight_decay != 0: if weight_decay != 0:
update.add_(p, alpha=weight_decay) if group.get('decoupled_decay', False):
p.add_(p, alpha=-group['lr'] * weight_decay)
else:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group['always_adapt']: if weight_decay != 0 or group['always_adapt']:
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are # Layer-wise LR adaptation. By default, skip adaptation on parameters that are

View File

@ -54,10 +54,13 @@ def _mars_single_tensor_step(
if c_t_norm > 1.: if c_t_norm > 1.:
c_t = c_t / c_t_norm c_t = c_t / c_t_norm
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1) exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
if caution: if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype) mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3)) mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask exp_avg = exp_avg * mask
if mars_type == "adamw": if mars_type == "adamw":
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2) exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
bias_correction1 = 1.0 - beta1 ** step bias_correction1 = 1.0 - beta1 ** step