Cautious optimizer impl plus some typing cleanup.

This commit is contained in:
Ross Wightman 2024-11-28 12:34:51 -08:00
parent aeb1ed7a15
commit 3086dd03fd
13 changed files with 521 additions and 234 deletions

View File

@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
assert isinstance(opt_info, OptimInfo) assert isinstance(opt_info, OptimInfo)
lr = (1e-2,) * 4 lr = (1e-2,) * 4
if optimizer in ('mars',): if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
lr = (1e-3,) * 4 lr = (1e-3,) * 4
try: try:

View File

@ -5,15 +5,16 @@ Hacked together by / Copyright 2021 Ross Wightman
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from fnmatch import fnmatch from fnmatch import fnmatch
import importlib import importlib
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from ._types import ParamsT, OptimType, OptimizerCallable
from .adabelief import AdaBelief from .adabelief import AdaBelief
from .adafactor import Adafactor from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision from .adafactor_bv import AdafactorBigVision
@ -39,11 +40,6 @@ from .sgdw import SGDW
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# Type variables
T = TypeVar('T')
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
def _import_class(class_string: str) -> Type: def _import_class(class_string: str) -> Type:
"""Dynamically import a class from a string.""" """Dynamically import a class from a string."""
@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
raise ImportError(f"Could not import {class_string}: {e}") raise ImportError(f"Could not import {class_string}: {e}")
class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
@dataclass(frozen=True) @dataclass(frozen=True)
class OptimInfo: class OptimInfo:
@ -76,7 +67,7 @@ class OptimInfo:
defaults: Optional default parameters for the optimizer defaults: Optional default parameters for the optimizer
""" """
name: str name: str
opt_class: Union[str, Type[optim.Optimizer]] opt_class: Union[str, OptimType]
description: str = '' description: str = ''
has_eps: bool = True has_eps: bool = True
has_momentum: bool = False has_momentum: bool = False
@ -185,7 +176,7 @@ class OptimizerRegistry:
self, self,
name_or_info: Union[str, OptimInfo], name_or_info: Union[str, OptimInfo],
bind_defaults: bool = True, bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]: ) -> Union[OptimType, OptimizerCallable]:
"""Get the optimizer class with any default arguments applied. """Get the optimizer class with any default arguments applied.
This allows direct instantiation of optimizers with their default configs This allows direct instantiation of optimizers with their default configs
@ -234,7 +225,7 @@ class OptimizerRegistry:
def create_optimizer( def create_optimizer(
self, self,
model_or_params: Union[nn.Module, Params], model_or_params: Union[nn.Module, ParamsT],
opt: str, opt: str,
lr: Optional[float] = None, lr: Optional[float] = None,
weight_decay: float = 0., weight_decay: float = 0.,
@ -242,9 +233,9 @@ class OptimizerRegistry:
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True, weight_decay_exclude_1d: bool = True,
layer_decay: Optional[float] = None, layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None, param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any, **kwargs: Any,
) -> optim.Optimizer: ) -> torch.optim.Optimizer:
"""Create an optimizer instance. """Create an optimizer instance.
Args: Args:
@ -347,7 +338,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
sgd_optimizers = [ sgd_optimizers = [
OptimInfo( OptimInfo(
name='sgd', name='sgd',
opt_class=optim.SGD, opt_class=torch.optim.SGD,
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum', description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
has_eps=False, has_eps=False,
has_momentum=True, has_momentum=True,
@ -355,7 +346,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
), ),
OptimInfo( OptimInfo(
name='momentum', name='momentum',
opt_class=optim.SGD, opt_class=torch.optim.SGD,
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum', description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
has_eps=False, has_eps=False,
has_momentum=True, has_momentum=True,
@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
adam_optimizers = [ adam_optimizers = [
OptimInfo( OptimInfo(
name='adam', name='adam',
opt_class=optim.Adam, opt_class=torch.optim.Adam,
description='torch.optim.Adam, Adaptive Moment Estimation', description='torch.optim.Adam, Adaptive Moment Estimation',
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
name='adamw', name='adamw',
opt_class=optim.AdamW, opt_class=torch.optim.AdamW,
description='torch.optim.AdamW, Adam with decoupled weight decay', description='torch.optim.AdamW, Adam with decoupled weight decay',
has_betas=True has_betas=True
), ),
@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
), ),
OptimInfo( OptimInfo(
name='adamax', name='adamax',
opt_class=optim.Adamax, opt_class=torch.optim.Adamax,
description='torch.optim.Adamax, Adam with infinity norm for more stable updates', description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
has_betas=True has_betas=True
), ),
@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
registry.register(opt) registry.register(opt)
def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
cautious_optimizers = [
OptimInfo(
name='cadafactor',
opt_class=Adafactor,
description='Cautious Adafactor',
defaults={'caution': True}
),
OptimInfo(
name='cadafactorbv',
opt_class=AdafactorBigVision,
description='Cautious Big Vision Adafactor',
defaults={'caution': True}
),
OptimInfo(
name='cadamw',
opt_class=AdamWLegacy,
description='Cautious AdamW',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='cadopt',
opt_class=Adopt,
description='Cautious Adopt',
defaults={'caution': True}
),
OptimInfo(
name='cadoptw',
opt_class=Adopt,
description='Cautious AdoptW (decoupled decay)',
defaults={'decoupled': True, 'caution': True}
),
OptimInfo(
name='clamb',
opt_class=Lamb,
description='Cautious LAMB',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='claprop',
opt_class=LaProp,
description='Cautious LaProp',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='clion',
opt_class=Lion,
description='Cautious Lion',
has_eps=False,
has_betas=True,
defaults = {'caution': True}
),
OptimInfo(
name='cnadamw',
opt_class=NAdamW,
description='Cautious NAdamW',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='crmsproptf',
opt_class=RMSpropTF,
description='Cautious TensorFlow-style RMSprop',
has_momentum=True,
defaults={'alpha': 0.9, 'caution': True}
),
OptimInfo(
name='csgdw',
opt_class=SGDW,
description='Cautious SGD with decoupled weight decay and Nesterov momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True, 'caution': True}
),
]
for opt in cautious_optimizers:
registry.register(opt)
def _register_other_optimizers(registry: OptimizerRegistry) -> None: def _register_other_optimizers(registry: OptimizerRegistry) -> None:
"""Register miscellaneous optimizers""" """Register miscellaneous optimizers"""
other_optimizers = [ other_optimizers = [
@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
), ),
OptimInfo( OptimInfo(
name='adadelta', name='adadelta',
opt_class=optim.Adadelta, opt_class=torch.optim.Adadelta,
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients' description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
), ),
OptimInfo( OptimInfo(
name='adagrad', name='adagrad',
opt_class=optim.Adagrad, opt_class=torch.optim.Adagrad,
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients', description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8} defaults={'eps': 1e-8}
), ),
@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
), ),
OptimInfo( OptimInfo(
name='rmsprop', name='rmsprop',
opt_class=optim.RMSprop, opt_class=torch.optim.RMSprop,
description='torch.optim.RMSprop, Root Mean Square Propagation', description='torch.optim.RMSprop, Root Mean Square Propagation',
has_momentum=True, has_momentum=True,
defaults={'alpha': 0.9} defaults={'alpha': 0.9}
@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
_register_other_optimizers(default_registry) _register_other_optimizers(default_registry)
_register_apex_optimizers(default_registry) _register_apex_optimizers(default_registry)
_register_bnb_optimizers(default_registry) _register_bnb_optimizers(default_registry)
_register_cautious_optimizers(default_registry)
# Register aliases # Register aliases
default_registry.register_alias('nesterov', 'sgd') default_registry.register_alias('nesterov', 'sgd')
@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
def get_optimizer_class( def get_optimizer_class(
name: str, name: str,
bind_defaults: bool = True, bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]: ) -> Union[OptimType, OptimizerCallable]:
"""Get optimizer class by name with option to bind default arguments. """Get optimizer class by name with option to bind default arguments.
Retrieves the optimizer class or a partial function with default arguments bound. Retrieves the optimizer class or a partial function with default arguments bound.
@ -874,7 +947,7 @@ def get_optimizer_class(
def create_optimizer_v2( def create_optimizer_v2(
model_or_params: Union[nn.Module, Params], model_or_params: Union[nn.Module, ParamsT],
opt: str = 'sgd', opt: str = 'sgd',
lr: Optional[float] = None, lr: Optional[float] = None,
weight_decay: float = 0., weight_decay: float = 0.,
@ -882,9 +955,9 @@ def create_optimizer_v2(
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True, filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None, layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None, param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any, **kwargs: Any,
) -> optim.Optimizer: ) -> torch.optim.Optimizer:
"""Create an optimizer instance via timm registry. """Create an optimizer instance via timm registry.
Creates and configures an optimizer with appropriate parameter groups and settings. Creates and configures an optimizer with appropriate parameter groups and settings.
@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
return kwargs return kwargs
def create_optimizer(args, model, filter_bias_and_bn=True): def create_optimizer(
args,
model: Union[nn.Module, ParamsT],
filter_bias_and_bn: bool = True,
) -> torch.optim.Optimizer:
""" Legacy optimizer factory for backwards compatibility. """ Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code. NOTE: Use create_optimizer_v2 for new code.
""" """

25
timm/optim/_types.py Normal file
View File

@ -0,0 +1,25 @@
from typing import Any, Dict, Iterable, Union, Protocol, Type
try:
from typing import TypeAlias, TypeVar
except ImportError:
from typing_extensions import TypeAlias, TypeVar
import torch
import torch.optim
try:
from torch.optim.optimizer import ParamsT
except (ImportError, TypeError):
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
OptimType = Type[torch.optim.Optimizer]
class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""
def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...
__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']

View File

@ -10,8 +10,12 @@ Original header/copyright below.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch
import math import math
from typing import Optional, Tuple
import torch
from ._types import ParamsT
class Adafactor(torch.optim.Optimizer): class Adafactor(torch.optim.Optimizer):
@ -26,33 +30,33 @@ class Adafactor(torch.optim.Optimizer):
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
`relative_step=False`. `relative_step=False`.
Arguments: Ags:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups params: iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): external learning rate (default: None) lr: external learning rate
eps (tuple[float, float]): regularization constants for square gradient eps: regularization constants for square gradient and parameter scale respectively
and parameter scale respectively (default: (1e-30, 1e-3)) eps_scale: regularization constants for parameter scale respectively
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) clip_threshold: threshold of root-mean-square of final gradient update
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) decay_rate: coefficient used to compute running averages of square gradient
beta1 (float): coefficient used for computing running averages of gradient (default: None) beta1: coefficient used for computing running averages of gradient
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) weight_decay: weight decay
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) scale_parameter: if True, learning rate is scaled by root-mean-square of parameter
warmup_init (bool): time-dependent learning rate computation depends on warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used
whether warm-up initialization is being used (default: False)
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=None, lr: Optional[float] = None,
eps=1e-30, eps: float = 1e-30,
eps_scale=1e-3, eps_scale: float = 1e-3,
clip_threshold=1.0, clip_threshold: float = 1.0,
decay_rate=-0.8, decay_rate: float = -0.8,
betas=None, betas: Optional[Tuple[float, float]] = None,
weight_decay=0.0, weight_decay: float = 0.0,
scale_parameter=True, scale_parameter: bool = True,
warmup_init=False, warmup_init: bool = False,
min_dim_size_to_factor=32, min_dim_size_to_factor: int = 16,
caution: bool = False,
): ):
relative_step = not lr relative_step = not lr
if warmup_init and not relative_step: if warmup_init and not relative_step:
@ -71,9 +75,16 @@ class Adafactor(torch.optim.Optimizer):
relative_step=relative_step, relative_step=relative_step,
warmup_init=warmup_init, warmup_init=warmup_init,
min_dim_size_to_factor=min_dim_size_to_factor, min_dim_size_to_factor=min_dim_size_to_factor,
caution=caution,
) )
super(Adafactor, self).__init__(params, defaults) super(Adafactor, self).__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('min_dim_size_to_factor', 32)
@staticmethod @staticmethod
def _get_lr(param_group, param_state): def _get_lr(param_group, param_state):
if param_group['relative_step']: if param_group['relative_step']:
@ -86,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
return param_group['lr'] return param_group['lr']
@staticmethod @staticmethod
def _get_options(param_group, param_shape, min_size_to_factor=32): def _get_options(param_group, param_shape, min_size_to_factor=16):
use_first_moment = param_group['beta1'] is not None use_first_moment = param_group['beta1'] is not None
factored = None factored = None
ndim = len(param_shape) ndim = len(param_shape)
@ -98,7 +109,7 @@ class Adafactor(torch.optim.Optimizer):
# nD convs in torch are ND + 2 dim weights with leading in/out chs # nD convs in torch are ND + 2 dim weights with leading in/out chs
factored = 0, 1 factored = 0, 1
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor: elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
# if the criteria above didn't match, test trailing dims for eligibility # if the criteria above didn't match, test trailing dims for eligibility as per original impl
factored = ndim - 2, ndim - 1 factored = ndim - 2, ndim - 1
return factored, use_first_moment return factored, use_first_moment
@ -113,7 +124,6 @@ class Adafactor(torch.optim.Optimizer):
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt() c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
return torch.mul(r_factor, c_factor) return torch.mul(r_factor, c_factor)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -201,6 +211,12 @@ class Adafactor(torch.optim.Optimizer):
if use_first_moment: if use_first_moment:
exp_avg = state['exp_avg'] exp_avg = state['exp_avg']
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update = exp_avg * mask
else:
update = exp_avg update = exp_avg
if group['weight_decay'] != 0: if group['weight_decay'] != 0:

View File

@ -6,13 +6,14 @@ Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
Adaptation and PyTorch modifications by Ross Wightman Adaptation and PyTorch modifications by Ross Wightman
""" """
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim import Optimizer from torch.optim import Optimizer
from ._types import ParamsT
def _get_scalar_dtype(): def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state""" """Get the scalar dtype that the optimizer uses for state"""
@ -54,9 +55,9 @@ class AdafactorBigVision(Optimizer):
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr: float = 1.0, lr: float = 1.0,
min_dim_size_to_factor: int = 32, min_dim_size_to_factor: int = 16,
decay_rate: float = 0.8, decay_rate: float = 0.8,
decay_offset: int = 0, decay_offset: int = 0,
beta2_cap: float = 0.999, beta2_cap: float = 0.999,
@ -66,6 +67,7 @@ class AdafactorBigVision(Optimizer):
weight_decay: float = 0.0, weight_decay: float = 0.0,
clipping_threshold: Optional[float] = None, clipping_threshold: Optional[float] = None,
unscaled_wd: bool = False, unscaled_wd: bool = False,
caution: bool = False,
*, *,
foreach: Optional[bool] = False, foreach: Optional[bool] = False,
): ):
@ -91,6 +93,7 @@ class AdafactorBigVision(Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
clipping_threshold=clipping_threshold, clipping_threshold=clipping_threshold,
unscaled_wd=unscaled_wd, unscaled_wd=unscaled_wd,
caution=caution,
foreach=foreach, foreach=foreach,
) )
super().__init__(params, defaults) super().__init__(params, defaults)
@ -98,6 +101,7 @@ class AdafactorBigVision(Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('foreach', None) group.setdefault('foreach', None)
for p in group['params']: for p in group['params']:
p_state = self.state.get(p, {}) p_state = self.state.get(p, {})
@ -192,6 +196,7 @@ class AdafactorBigVision(Optimizer):
momentum_dtype=group['momentum_dtype'], momentum_dtype=group['momentum_dtype'],
clipping_threshold=group['clipping_threshold'], clipping_threshold=group['clipping_threshold'],
unscaled_wd=group['unscaled_wd'], unscaled_wd=group['unscaled_wd'],
caution=group['caution'],
) )
return loss return loss
@ -216,6 +221,7 @@ def _single_tensor_adafactor(
momentum_dtype: Union[str, torch.dtype], momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float], clipping_threshold: Optional[float],
unscaled_wd: bool, unscaled_wd: bool,
caution: bool,
): ):
for i, param in enumerate(params): for i, param in enumerate(params):
grad = grads[i] grad = grads[i]
@ -267,6 +273,12 @@ def _single_tensor_adafactor(
exp_avg.lerp_(update, 1 - momentum) # ema exp_avg.lerp_(update, 1 - momentum) # ema
update = exp_avg.clone() update = exp_avg.clone()
if caution:
# apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085
mask = (update * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update.mul_(mask)
# Scale by learning rate # Scale by learning rate
update.mul_(lr) update.mul_(lr)
@ -302,6 +314,7 @@ def _multi_tensor_adafactor(
momentum_dtype: Union[str, torch.dtype], momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float], clipping_threshold: Optional[float],
unscaled_wd: bool, unscaled_wd: bool,
caution: bool,
): ):
# FIXME TODO # FIXME TODO
assert False, 'multi-tensor fn (foreach=True) not implemented yet' assert False, 'multi-tensor fn (foreach=True) not implemented yet'

View File

@ -4,49 +4,45 @@ Impl copied from PyTorch master
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
""" """
import math import math
from typing import Tuple
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from ._types import ParamsT
class AdamWLegacy(Optimizer): class AdamWLegacy(Optimizer):
r"""Implements AdamW algorithm. r"""Implements AdamW algorithm.
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. References:
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
Arguments: Args:
params (iterable): iterable of parameters to optimize or dicts defining params: iterable of parameters to optimize or dicts defining parameter groups
parameter groups lr: learning rate
lr (float, optional): learning rate (default: 1e-3) betas: coefficients used for computing running averages of gradient and its square
betas (Tuple[float, float], optional): coefficients used for computing eps: term added to the denominator to improve numerical stability
running averages of gradient and its square (default: (0.9, 0.999)) weight_decay: weight decay coefficient
eps (float, optional): term added to the denominator to improve amsgrad: whether to use the AMSGrad variant of this algorithm
numerical stability (default: 1e-8) from the paper `On the Convergence of Adam and Beyond`
weight_decay (float, optional): weight decay coefficient (default: 1e-2) caution: apply caution when using AdamW
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=1e-3, lr: float = 1e-3,
betas=(0.9, 0.999), betas: Tuple[float, float] = (0.9, 0.999),
eps=1e-8, eps: float = 1e-8,
weight_decay=1e-2, weight_decay: float = 1e-2,
amsgrad=False, amsgrad: bool = False,
caution: bool = False,
): ):
# NOTE: deprecated in favour of builtin torch.optim.AdamW
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -61,6 +57,7 @@ class AdamWLegacy(Optimizer):
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
amsgrad=amsgrad, amsgrad=amsgrad,
caution=caution,
) )
super(AdamWLegacy, self).__init__(params, defaults) super(AdamWLegacy, self).__init__(params, defaults)
@ -68,6 +65,7 @@ class AdamWLegacy(Optimizer):
super(AdamWLegacy, self).__setstate__(state) super(AdamWLegacy, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsgrad', False) group.setdefault('amsgrad', False)
group.setdefault('caution', False)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
@ -131,6 +129,12 @@ class AdamWLegacy(Optimizer):
step_size = group['lr'] / bias_correction1 step_size = group['lr'] / bias_correction1
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
p.addcdiv_(exp_avg, denom, value=-step_size) p.addcdiv_(exp_avg, denom, value=-step_size)
return loss return loss

View File

@ -10,16 +10,15 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate}, title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
year = {2024} year = {2024}
} }
""" """
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from ._types import ParamsT
__all__ = ["Adopt", "adopt"] __all__ = ["Adopt", "adopt"]
def _view_as_real(params, *state_and_grads): def _view_as_real(params, *state_and_grads):
@ -60,7 +59,7 @@ class Adopt(Optimizer):
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr: Union[float, Tensor] = 1e-3, lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999), betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6, eps: float = 1e-6,
@ -68,7 +67,8 @@ class Adopt(Optimizer):
weight_decay: float = 0.0, weight_decay: float = 0.0,
decoupled: bool = False, decoupled: bool = False,
*, *,
foreach: Optional[bool] = None, caution: bool = False,
foreach: Optional[bool] = False,
maximize: bool = False, maximize: bool = False,
capturable: bool = False, capturable: bool = False,
differentiable: bool = False, differentiable: bool = False,
@ -98,6 +98,7 @@ class Adopt(Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
clip_exp=clip_exp, clip_exp=clip_exp,
decoupled=decoupled, decoupled=decoupled,
caution=caution,
maximize=maximize, maximize=maximize,
foreach=foreach, foreach=foreach,
capturable=capturable, capturable=capturable,
@ -105,7 +106,6 @@ class Adopt(Optimizer):
) )
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
@ -114,6 +114,7 @@ class Adopt(Optimizer):
group.setdefault("capturable", False) group.setdefault("capturable", False)
group.setdefault("differentiable", False) group.setdefault("differentiable", False)
group.setdefault("clip_exp", None) group.setdefault("clip_exp", None)
group.setdefault("caution", False)
for p in group["params"]: for p in group["params"]:
p_state = self.state.get(p, []) p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
@ -223,6 +224,7 @@ class Adopt(Optimizer):
clip_exp=group["clip_exp"], clip_exp=group["clip_exp"],
decoupled=group["decoupled"], decoupled=group["decoupled"],
eps=group["eps"], eps=group["eps"],
caution=group["caution"],
maximize=group["maximize"], maximize=group["maximize"],
foreach=group["foreach"], foreach=group["foreach"],
capturable=group["capturable"], capturable=group["capturable"],
@ -251,6 +253,7 @@ def _single_tensor_adopt(
clip_exp: Optional[float], clip_exp: Optional[float],
decoupled: bool, decoupled: bool,
eps: float, eps: float,
caution: bool,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
differentiable: bool, differentiable: bool,
@ -306,6 +309,13 @@ def _single_tensor_adopt(
normed_grad.clamp_(-clip_val, clip_val) normed_grad.clamp_(-clip_val, clip_val)
exp_avg.lerp_(normed_grad, 1 - beta1) exp_avg.lerp_(normed_grad, 1 - beta1)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
param.add_(exp_avg, alpha=-lr) param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@ -328,6 +338,7 @@ def _multi_tensor_adopt(
clip_exp: Optional[float], clip_exp: Optional[float],
decoupled: bool, decoupled: bool,
eps: float, eps: float,
caution: bool,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
differentiable: bool, differentiable: bool,
@ -403,6 +414,7 @@ def _multi_tensor_adopt(
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps) torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt) normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if clip_exp is not None: if clip_exp is not None:
@ -411,6 +423,16 @@ def _multi_tensor_adopt(
torch._foreach_minimum_(normed_grad, clip_val) torch._foreach_minimum_(normed_grad, clip_val)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1) torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(device_exp_avgs, device_grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2) torch._foreach_mul_(device_exp_avg_sqs, beta2)
@ -440,6 +462,7 @@ def adopt(
clip_exp: Optional[float], clip_exp: Optional[float],
decoupled: bool, decoupled: bool,
eps: float, eps: float,
caution: bool,
maximize: bool, maximize: bool,
): ):
r"""Functional API that performs ADOPT algorithm computation. r"""Functional API that performs ADOPT algorithm computation.
@ -477,6 +500,7 @@ def adopt(
clip_exp=clip_exp, clip_exp=clip_exp,
decoupled=decoupled, decoupled=decoupled,
eps=eps, eps=eps,
caution=caution,
maximize=maximize, maximize=maximize,
capturable=capturable, capturable=capturable,
differentiable=differentiable, differentiable=differentiable,

View File

@ -52,50 +52,48 @@ Modifications Copyright 2021 Ross Wightman
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import math import math
from typing import Optional, Tuple
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from ._types import ParamsT
class Lamb(Optimizer): class Lamb(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. LAMB was proposed in:
- Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
Arguments: Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups. params: Iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr: Learning rate
betas (Tuple[float, float], optional): coefficients used for computing betas: Coefficients used for computing running averages of gradient and its norm.
running averages of gradient and its norm. (default: (0.9, 0.999)) eps: Term added to the denominator to improve numerical stability.
eps (float, optional): term added to the denominator to improve weight_decay: Weight decay
numerical stability. (default: 1e-8) grad_averaging: Whether apply (1-beta2) to grad when calculating running averages of gradient.
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) max_grad_norm: Value used to clip global grad norm.
grad_averaging (bool, optional): whether apply (1-beta2) to grad when trust_clip: Enable LAMBC trust ratio clipping.
calculating running averages of gradient. (default: True) always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter.
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) caution: Apply caution.
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=1e-3, lr: float = 1e-3,
bias_correction=True, bias_correction: bool = True,
betas=(0.9, 0.999), betas: Tuple[float, float] = (0.9, 0.999),
eps=1e-6, eps: float = 1e-6,
weight_decay=0.01, weight_decay: float = 0.01,
grad_averaging=True, grad_averaging: bool = True,
max_grad_norm=1.0, max_grad_norm: Optional[float] = 1.0,
trust_clip=False, trust_clip: bool = False,
always_adapt=False, always_adapt: bool = False,
caution: bool = False,
): ):
defaults = dict( defaults = dict(
lr=lr, lr=lr,
@ -107,9 +105,15 @@ class Lamb(Optimizer):
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
trust_clip=trust_clip, trust_clip=trust_clip,
always_adapt=always_adapt, always_adapt=always_adapt,
caution=caution,
) )
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
def _get_clip_grad_norm(self): def _get_clip_grad_norm(self):
max_grad_norm = self.defaults['max_grad_norm'] max_grad_norm = self.defaults['max_grad_norm']
if max_grad_norm is None: if max_grad_norm is None:
@ -187,6 +191,12 @@ class Lamb(Optimizer):
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
update = (exp_avg / bias_correction1).div_(denom) update = (exp_avg / bias_correction1).div_(denom)
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (update * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update.mul_(mask)
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
if weight_decay != 0: if weight_decay != 0:
update.add_(p, alpha=weight_decay) update.add_(p, alpha=weight_decay)

View File

@ -12,9 +12,13 @@ Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs
} }
""" """
from typing import Tuple
from torch.optim import Optimizer from torch.optim import Optimizer
import torch import torch
from ._types import ParamsT
class LaProp(Optimizer): class LaProp(Optimizer):
""" LaProp Optimizer """ LaProp Optimizer
@ -23,11 +27,12 @@ class LaProp(Optimizer):
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=4e-4, lr: float = 4e-4,
betas=(0.9, 0.999), betas: Tuple[float, float] = (0.9, 0.999),
eps=1e-15, eps: float = 1e-15,
weight_decay=0, weight_decay: float = 0.,
caution: bool = False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
@ -42,6 +47,7 @@ class LaProp(Optimizer):
betas=betas, betas=betas,
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
caution=caution,
) )
super(LaProp, self).__init__(params, defaults) super(LaProp, self).__init__(params, defaults)
@ -101,7 +107,14 @@ class LaProp(Optimizer):
step_of_this_grad = grad / denom step_of_this_grad = grad / denom
exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1) exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
p.add_(exp_avg, alpha=-step_size) p.add_(exp_avg, alpha=-step_size)
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
p.add_(p, alpha=-group['weight_decay']) p.add_(p, alpha=-group['weight_decay'])

View File

@ -16,33 +16,35 @@ Original Impl: https://github.com/google/automl/tree/master/lion
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from typing import List from typing import List, Optional, Tuple
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from ._types import ParamsT
class Lion(Optimizer): class Lion(Optimizer):
r"""Implements Lion algorithm.""" r"""Implements Lion algorithm."""
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=1e-4, lr: float = 1e-4,
betas=(0.9, 0.99), betas: Tuple[float, float] = (0.9, 0.99),
weight_decay=0.0, weight_decay: float = 0.0,
maximize=False, caution: bool = False,
foreach=None, maximize: bool = False,
foreach: Optional[bool] = None,
): ):
"""Initialize the hyperparameters. """Initialize the hyperparameters.
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining params: iterable of parameters to optimize or dicts defining parameter groups
parameter groups lr: learning rate
lr (float, optional): learning rate (default: 1e-4) betas: coefficients used for computing running averages of gradient and its square
betas (Tuple[float, float], optional): coefficients used for computing weight_decay: weight decay coefficient
running averages of gradient and its square (default: (0.9, 0.99)) caution: apply caution
weight_decay (float, optional): weight decay coefficient (default: 0)
""" """
if not 0.0 <= lr: if not 0.0 <= lr:
@ -55,6 +57,7 @@ class Lion(Optimizer):
lr=lr, lr=lr,
betas=betas, betas=betas,
weight_decay=weight_decay, weight_decay=weight_decay,
caution=caution,
foreach=foreach, foreach=foreach,
maximize=maximize, maximize=maximize,
) )
@ -63,6 +66,7 @@ class Lion(Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('maximize', False) group.setdefault('maximize', False)
group.setdefault('foreach', None) group.setdefault('foreach', None)
@ -71,8 +75,7 @@ class Lion(Optimizer):
"""Performs a single optimization step. """Performs a single optimization step.
Args: Args:
closure (callable, optional): A closure that reevaluates the model closure: A closure that reevaluates the model and returns the loss.
and returns the loss.
Returns: Returns:
the loss. the loss.
@ -112,6 +115,7 @@ class Lion(Optimizer):
beta2=beta2, beta2=beta2,
lr=group['lr'], lr=group['lr'],
weight_decay=group['weight_decay'], weight_decay=group['weight_decay'],
caution=group['caution'],
maximize=group['maximize'], maximize=group['maximize'],
foreach=group['foreach'], foreach=group['foreach'],
) )
@ -132,6 +136,7 @@ def lion(
beta2: float, beta2: float,
lr: float, lr: float,
weight_decay: float, weight_decay: float,
caution: bool,
): ):
r"""Functional API that performs Lion algorithm computation. r"""Functional API that performs Lion algorithm computation.
""" """
@ -155,6 +160,7 @@ def lion(
beta2=beta2, beta2=beta2,
lr=lr, lr=lr,
weight_decay=weight_decay, weight_decay=weight_decay,
caution=caution,
maximize=maximize, maximize=maximize,
) )
@ -168,6 +174,7 @@ def _single_tensor_lion(
beta2: float, beta2: float,
lr: float, lr: float,
weight_decay: float, weight_decay: float,
caution: bool,
maximize: bool, maximize: bool,
): ):
for i, param in enumerate(params): for i, param in enumerate(params):
@ -183,8 +190,15 @@ def _single_tensor_lion(
param.mul_(1 - lr * weight_decay) param.mul_(1 - lr * weight_decay)
# Weight update # Weight update
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_()
param.add_(torch.sign(update), alpha=-lr)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (update * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update.mul_(mask)
param.add_(update, alpha=-lr)
# Decay the momentum running average coefficient # Decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2) exp_avg.lerp_(grad, 1 - beta2)
@ -199,6 +213,7 @@ def _multi_tensor_lion(
beta2: float, beta2: float,
lr: float, lr: float,
weight_decay: float, weight_decay: float,
caution: bool,
maximize: bool, maximize: bool,
): ):
if len(params) == 0: if len(params) == 0:
@ -217,8 +232,17 @@ def _multi_tensor_lion(
# Weight update # Weight update
updates = torch._foreach_mul(exp_avgs, beta1) updates = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(updates, grads, alpha=1 - beta1) torch._foreach_add_(updates, grads, alpha=1 - beta1)
updates = [u.sign_() for u in updates]
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(updates, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
torch._foreach_mul_(updates, masks)
updates = [u.sign() for u in updates]
torch._foreach_add_(params, updates, alpha=-lr) torch._foreach_add_(params, updates, alpha=-lr)
# Decay the momentum running average coefficient # Decay the momentum running average coefficient

View File

@ -5,44 +5,43 @@ Based on simplified algorithm in https://github.com/mlcommons/algorithmic-effici
Added multi-tensor (foreach) path. Added multi-tensor (foreach) path.
""" """
import math import math
from typing import List, Optional from typing import List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from ._types import ParamsT
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
class NAdamW(torch.optim.Optimizer): class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm. """ Implements NAdamW algorithm.
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
the NAdam algorithm (there is also a comment in the code which highlights the NAdam algorithm (there is also a comment in the code which highlights
the only difference of NAdamW and AdamW). the only difference of NAdamW and AdamW).
For further details regarding the algorithm we refer to For further details regarding the algorithm we refer to
`Decoupled Weight Decay Regularization`_. - Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
Args: Args:
params (iterable): iterable of parameters to optimize or dicts defining params: iterable of parameters to optimize or dicts defining parameter groups
parameter groups lr: learning rate
lr (float, optional): learning rate (default: 1e-3) betas: coefficients used for computing running averages of gradient and its square
betas (Tuple[float, float], optional): coefficients used for computing eps: term added to the denominator to improve numerical stability
running averages of gradient and its square (default: (0.9, 0.999)) weight_decay: weight decay coefficient
eps (float, optional): term added to the denominator to improve caution: enable caution
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=1e-3, lr: float = 1e-3,
betas=(0.9, 0.999), betas: Tuple[float, float] = (0.9, 0.999),
eps=1e-8, eps: float = 1e-8,
weight_decay=1e-2, weight_decay: float = 1e-2,
caution: bool = False,
maximize: bool = False, maximize: bool = False,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
capturable: bool = False, capturable: bool = False,
@ -62,6 +61,7 @@ class NAdamW(torch.optim.Optimizer):
betas=betas, betas=betas,
eps=eps, eps=eps,
weight_decay=weight_decay, weight_decay=weight_decay,
caution=caution,
foreach=foreach, foreach=foreach,
maximize=maximize, maximize=maximize,
capturable=capturable, capturable=capturable,
@ -71,11 +71,12 @@ class NAdamW(torch.optim.Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
state_values = list(self.state.values()) state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor( step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
state_values[0]['step'])
if not step_is_tensor: if not step_is_tensor:
for s in state_values: for s in state_values:
s['step'] = torch.tensor(float(s['step'])) s['step'] = torch.tensor(float(s['step']))
for group in self.param_groups:
group.setdefault('caution', False)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
@ -133,6 +134,7 @@ class NAdamW(torch.optim.Optimizer):
lr=group['lr'], lr=group['lr'],
weight_decay=group['weight_decay'], weight_decay=group['weight_decay'],
eps=group['eps'], eps=group['eps'],
caution=group['caution'],
maximize=group['maximize'], maximize=group['maximize'],
capturable=group['capturable'], capturable=group['capturable'],
) )
@ -154,6 +156,7 @@ def nadamw(
lr: float, lr: float,
weight_decay: float, weight_decay: float,
eps: float, eps: float,
caution: bool,
maximize: bool, maximize: bool,
) -> None: ) -> None:
r"""Functional API that performs NAdamW algorithm computation. r"""Functional API that performs NAdamW algorithm computation.
@ -183,6 +186,7 @@ def nadamw(
lr=lr, lr=lr,
weight_decay=weight_decay, weight_decay=weight_decay,
eps=eps, eps=eps,
caution=caution,
maximize=maximize, maximize=maximize,
capturable=capturable, capturable=capturable,
) )
@ -200,6 +204,7 @@ def _single_tensor_nadamw(
lr: float, lr: float,
weight_decay: float, weight_decay: float,
eps: float, eps: float,
caution: bool,
maximize: bool, maximize: bool,
capturable: bool capturable: bool
): ):
@ -238,6 +243,14 @@ def _single_tensor_nadamw(
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
# FIXME not 100% sure if this remains capturable?
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg.mul_(mask)
param.addcdiv_(exp_avg, denom) param.addcdiv_(exp_avg, denom)
else: else:
step = step_t.item() step = step_t.item()
@ -246,11 +259,17 @@ def _single_tensor_nadamw(
step_size = lr / bias_correction1 step_size = lr / bias_correction1
bias_correction2_sqrt = math.sqrt(bias_correction2) bias_correction2_sqrt = math.sqrt(bias_correction2)
# Only difference between NAdamW and AdamW in this implementation. # Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
# The official PyTorch implementation of NAdam uses a different algorithm. # The official PyTorch implementation of NAdam uses a different algorithm.
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg.mul_(mask)
param.addcdiv_(exp_avg, denom, value=-step_size) param.addcdiv_(exp_avg, denom, value=-step_size)
@ -266,6 +285,7 @@ def _multi_tensor_nadamw(
lr: float, lr: float,
weight_decay: float, weight_decay: float,
eps: float, eps: float,
caution: bool,
maximize: bool, maximize: bool,
capturable: bool, capturable: bool,
): ):
@ -322,12 +342,22 @@ def _multi_tensor_nadamw(
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_( torch._foreach_div_(
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size) exp_avg_sq_sqrt,
torch._foreach_mul(bias_correction2_sqrt, step_size)
) )
eps_over_step_size = torch._foreach_div(step_size, eps) eps_over_step_size = torch._foreach_div(step_size, eps)
torch._foreach_reciprocal_(eps_over_step_size) torch._foreach_reciprocal_(eps_over_step_size)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size) denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(exp_avgs, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] # capturable?
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
torch._foreach_mul_(exp_avgs, masks)
torch._foreach_addcdiv_(params, exp_avgs, denom) torch._foreach_addcdiv_(params, exp_avgs, denom)
else: else:
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps] bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
@ -337,7 +367,7 @@ def _multi_tensor_nadamw(
bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2] bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2]
# Only difference between NAdamW and AdamW in this implementation. # Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
# The official PyTorch implementation of NAdam uses a different algorithm. # The official PyTorch implementation of NAdam uses a different algorithm.
exp_avgs = torch._foreach_mul(exp_avgs, beta1) exp_avgs = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
@ -346,4 +376,13 @@ def _multi_tensor_nadamw(
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps) denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(exp_avgs, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
torch._foreach_mul_(exp_avgs, masks)
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size) torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)

View File

@ -10,6 +10,8 @@ Modifications Copyright 2021 Ross Wightman
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from ._types import ParamsT
class RMSpropTF(Optimizer): class RMSpropTF(Optimizer):
"""Implements RMSprop algorithm (TensorFlow style epsilon) """Implements RMSprop algorithm (TensorFlow style epsilon)
@ -28,34 +30,31 @@ class RMSpropTF(Optimizer):
The centered version first appears in `Generating Sequences The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_. With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
Arguments: Args:
params (iterable): iterable of parameters to optimize or dicts defining params: iterable of parameters to optimize or dicts defining parameter groups
parameter groups lr: learning rate
lr (float, optional): learning rate (default: 1e-2) momentum: momentum factor
momentum (float, optional): momentum factor (default: 0) alpha: smoothing (decay) constant
alpha (float, optional): smoothing (decay) constant (default: 0.9) eps: term added to the denominator to improve numerical stability
eps (float, optional): term added to the denominator to improve centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
numerical stability (default: 1e-10) weight_decay: weight decay (L2 penalty) (default: 0)
centered (bool, optional) : if ``True``, compute the centered RMSProp, decoupled_decay: decoupled weight decay as per https://arxiv.org/abs/1711.05101
the gradient is normalized by an estimation of its variance lr_in_momentum: learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) caution: apply caution
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
update as per defaults in Tensorflow
""" """
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=1e-2, lr: float = 1e-2,
alpha=0.9, alpha: float = 0.9,
eps=1e-10, eps: float = 1e-10,
weight_decay=0, weight_decay: float = 0,
momentum=0., momentum: float = 0.,
centered=False, centered: bool = False,
decoupled_decay=False, decoupled_decay: bool = False,
lr_in_momentum=True, lr_in_momentum: bool = True,
caution: bool = False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
@ -77,6 +76,7 @@ class RMSpropTF(Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
decoupled_decay=decoupled_decay, decoupled_decay=decoupled_decay,
lr_in_momentum=lr_in_momentum, lr_in_momentum=lr_in_momentum,
caution=caution,
) )
super(RMSpropTF, self).__init__(params, defaults) super(RMSpropTF, self).__init__(params, defaults)
@ -85,6 +85,7 @@ class RMSpropTF(Optimizer):
for group in self.param_groups: for group in self.param_groups:
group.setdefault('momentum', 0) group.setdefault('momentum', 0)
group.setdefault('centered', False) group.setdefault('centered', False)
group.setdefault('caution', False)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
@ -142,13 +143,25 @@ class RMSpropTF(Optimizer):
if group['momentum'] > 0: if group['momentum'] > 0:
buf = state['momentum_buffer'] buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer buf.mul_(group['momentum'])
def _apply_caution(_m, _g):
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (_m * _g > 0).to(_g.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
return _m * mask
if group['lr_in_momentum']: if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) # Tensorflow accumulates the LR scaling in the momentum buffer
buf.addcdiv_(grad, avg, value=group['lr'])
if group['caution']:
buf = _apply_caution(buf, grad)
p.add_(-buf) p.add_(-buf)
else: else:
# PyTorch scales the param update by LR # PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg) buf.addcdiv_(grad, avg)
if group['caution']:
buf = _apply_caution(buf, grad)
p.add_(buf, alpha=-group['lr']) p.add_(buf, alpha=-group['lr'])
else: else:
p.addcdiv_(grad, avg, value=-group['lr']) p.addcdiv_(grad, avg, value=-group['lr'])

View File

@ -1,4 +1,5 @@
from functools import update_wrapper, wraps from typing import List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
@ -8,7 +9,7 @@ try:
except ImportError: except ImportError:
has_recent_pt = False has_recent_pt = False
from typing import List, Optional from ._types import ParamsT
__all__ = ['SGDW', 'sgdw'] __all__ = ['SGDW', 'sgdw']
@ -16,13 +17,14 @@ __all__ = ['SGDW', 'sgdw']
class SGDW(Optimizer): class SGDW(Optimizer):
def __init__( def __init__(
self, self,
params, params: ParamsT,
lr=1e-3, lr: float = 1e-3,
momentum=0, momentum: float = 0.,
dampening=0, dampening: float = 0.,
weight_decay=0, weight_decay: float = 0.,
nesterov=False, nesterov: bool = False,
*, *,
caution: bool = False,
maximize: bool = False, maximize: bool = False,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
differentiable: bool = False, differentiable: bool = False,
@ -40,6 +42,7 @@ class SGDW(Optimizer):
dampening=dampening, dampening=dampening,
weight_decay=weight_decay, weight_decay=weight_decay,
nesterov=nesterov, nesterov=nesterov,
caution=caution,
maximize=maximize, maximize=maximize,
foreach=foreach, foreach=foreach,
differentiable=differentiable, differentiable=differentiable,
@ -51,18 +54,19 @@ class SGDW(Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super().__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('nesterov', False) group.setdefault('nesterov', False)
group.setdefault('maximize', False) group.setdefault('maximize', False)
group.setdefault('foreach', None) group.setdefault('foreach', None)
group.setdefault('differentiable', False) group.setdefault('differentiable', False)
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): def _init_group(self, group, params_with_grad, grads, momentum_buffer_list):
has_sparse_grad = False has_sparse_grad = False
for p in group['params']: for p in group['params']:
if p.grad is not None: if p.grad is not None:
params_with_grad.append(p) params_with_grad.append(p)
d_p_list.append(p.grad) grads.append(p.grad)
if p.grad.is_sparse: if p.grad.is_sparse:
has_sparse_grad = True has_sparse_grad = True
@ -91,20 +95,21 @@ class SGDW(Optimizer):
for group in self.param_groups: for group in self.param_groups:
params_with_grad = [] params_with_grad = []
d_p_list = [] grads = []
momentum_buffer_list = [] momentum_buffer_list = []
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list) has_sparse_grad = self._init_group(group, params_with_grad, grads, momentum_buffer_list)
sgdw( sgdw(
params_with_grad, params_with_grad,
d_p_list, grads,
momentum_buffer_list, momentum_buffer_list,
weight_decay=group['weight_decay'], weight_decay=group['weight_decay'],
momentum=group['momentum'], momentum=group['momentum'],
lr=group['lr'], lr=group['lr'],
dampening=group['dampening'], dampening=group['dampening'],
nesterov=group['nesterov'], nesterov=group['nesterov'],
caution=group['caution'],
maximize=group['maximize'], maximize=group['maximize'],
has_sparse_grad=has_sparse_grad, has_sparse_grad=has_sparse_grad,
foreach=group['foreach'], foreach=group['foreach'],
@ -120,7 +125,7 @@ class SGDW(Optimizer):
def sgdw( def sgdw(
params: List[Tensor], params: List[Tensor],
d_p_list: List[Tensor], grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]], momentum_buffer_list: List[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
@ -132,6 +137,7 @@ def sgdw(
lr: float, lr: float,
dampening: float, dampening: float,
nesterov: bool, nesterov: bool,
caution: bool,
maximize: bool maximize: bool
): ):
r"""Functional API that performs SGD algorithm computation. r"""Functional API that performs SGD algorithm computation.
@ -159,13 +165,14 @@ def sgdw(
func( func(
params, params,
d_p_list, grads,
momentum_buffer_list, momentum_buffer_list,
weight_decay=weight_decay, weight_decay=weight_decay,
momentum=momentum, momentum=momentum,
lr=lr, lr=lr,
dampening=dampening, dampening=dampening,
nesterov=nesterov, nesterov=nesterov,
caution=caution,
has_sparse_grad=has_sparse_grad, has_sparse_grad=has_sparse_grad,
maximize=maximize, maximize=maximize,
) )
@ -173,7 +180,7 @@ def sgdw(
def _single_tensor_sgdw( def _single_tensor_sgdw(
params: List[Tensor], params: List[Tensor],
d_p_list: List[Tensor], grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]], momentum_buffer_list: List[Optional[Tensor]],
*, *,
weight_decay: float, weight_decay: float,
@ -181,11 +188,12 @@ def _single_tensor_sgdw(
lr: float, lr: float,
dampening: float, dampening: float,
nesterov: bool, nesterov: bool,
caution: bool,
maximize: bool, maximize: bool,
has_sparse_grad: bool has_sparse_grad: bool
): ):
for i, param in enumerate(params): for i, param in enumerate(params):
d_p = d_p_list[i] if not maximize else -d_p_list[i] grad = grads[i] if not maximize else -grads[i]
param.mul_(1. - lr * weight_decay) param.mul_(1. - lr * weight_decay)
@ -193,17 +201,25 @@ def _single_tensor_sgdw(
buf = momentum_buffer_list[i] buf = momentum_buffer_list[i]
if buf is None: if buf is None:
buf = torch.clone(d_p).detach() buf = torch.clone(grad).detach()
momentum_buffer_list[i] = buf momentum_buffer_list[i] = buf
else: else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if caution:
if nesterov: if nesterov:
d_p = d_p.add(buf, alpha=momentum) buf = grad.add(buf, alpha=momentum)
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (buf * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
grad = buf * mask
else: else:
d_p = buf if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
param.add_(d_p, alpha=-lr) param.add_(grad, alpha=-lr)
def _multi_tensor_sgdw( def _multi_tensor_sgdw(
@ -216,6 +232,7 @@ def _multi_tensor_sgdw(
lr: float, lr: float,
dampening: float, dampening: float,
nesterov: bool, nesterov: bool,
caution: bool,
maximize: bool, maximize: bool,
has_sparse_grad: bool has_sparse_grad: bool
): ):
@ -258,6 +275,18 @@ def _multi_tensor_sgdw(
bufs.append(buf) bufs.append(buf)
if caution:
if nesterov:
# Can't do nesterov in-place if we want to compare against orig grad for caution
bufs = torch._foreach_add(device_grads, bufs, alpha=momentum)
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(bufs, device_grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
device_grads = torch._foreach_mul(bufs, masks)
else:
if nesterov: if nesterov:
torch._foreach_add_(device_grads, bufs, alpha=momentum) torch._foreach_add_(device_grads, bufs, alpha=momentum)
else: else: