Cautious optimizer impl plus some typing cleanup.

mars_tweak
Ross Wightman 2024-11-28 12:34:51 -08:00 committed by Ross Wightman
parent aeb1ed7a15
commit 7cf683628f
13 changed files with 521 additions and 234 deletions

View File

@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
assert isinstance(opt_info, OptimInfo)
lr = (1e-2,) * 4
if optimizer in ('mars',):
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
lr = (1e-3,) * 4
try:

View File

@ -5,15 +5,16 @@ Hacked together by / Copyright 2021 Ross Wightman
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from fnmatch import fnmatch
import importlib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from ._types import ParamsT, OptimType, OptimizerCallable
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
@ -39,11 +40,6 @@ from .sgdw import SGDW
_logger = logging.getLogger(__name__)
# Type variables
T = TypeVar('T')
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
def _import_class(class_string: str) -> Type:
"""Dynamically import a class from a string."""
@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
raise ImportError(f"Could not import {class_string}: {e}")
class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
@dataclass(frozen=True)
class OptimInfo:
@ -76,7 +67,7 @@ class OptimInfo:
defaults: Optional default parameters for the optimizer
"""
name: str
opt_class: Union[str, Type[optim.Optimizer]]
opt_class: Union[str, OptimType]
description: str = ''
has_eps: bool = True
has_momentum: bool = False
@ -185,7 +176,7 @@ class OptimizerRegistry:
self,
name_or_info: Union[str, OptimInfo],
bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
) -> Union[OptimType, OptimizerCallable]:
"""Get the optimizer class with any default arguments applied.
This allows direct instantiation of optimizers with their default configs
@ -234,7 +225,7 @@ class OptimizerRegistry:
def create_optimizer(
self,
model_or_params: Union[nn.Module, Params],
model_or_params: Union[nn.Module, ParamsT],
opt: str,
lr: Optional[float] = None,
weight_decay: float = 0.,
@ -242,9 +233,9 @@ class OptimizerRegistry:
foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any,
) -> optim.Optimizer:
) -> torch.optim.Optimizer:
"""Create an optimizer instance.
Args:
@ -347,7 +338,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
sgd_optimizers = [
OptimInfo(
name='sgd',
opt_class=optim.SGD,
opt_class=torch.optim.SGD,
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
has_eps=False,
has_momentum=True,
@ -355,7 +346,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='momentum',
opt_class=optim.SGD,
opt_class=torch.optim.SGD,
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
has_eps=False,
has_momentum=True,
@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
adam_optimizers = [
OptimInfo(
name='adam',
opt_class=optim.Adam,
opt_class=torch.optim.Adam,
description='torch.optim.Adam, Adaptive Moment Estimation',
has_betas=True
),
OptimInfo(
name='adamw',
opt_class=optim.AdamW,
opt_class=torch.optim.AdamW,
description='torch.optim.AdamW, Adam with decoupled weight decay',
has_betas=True
),
@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='adamax',
opt_class=optim.Adamax,
opt_class=torch.optim.Adamax,
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
has_betas=True
),
@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
registry.register(opt)
def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
cautious_optimizers = [
OptimInfo(
name='cadafactor',
opt_class=Adafactor,
description='Cautious Adafactor',
defaults={'caution': True}
),
OptimInfo(
name='cadafactorbv',
opt_class=AdafactorBigVision,
description='Cautious Big Vision Adafactor',
defaults={'caution': True}
),
OptimInfo(
name='cadamw',
opt_class=AdamWLegacy,
description='Cautious AdamW',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='cadopt',
opt_class=Adopt,
description='Cautious Adopt',
defaults={'caution': True}
),
OptimInfo(
name='cadoptw',
opt_class=Adopt,
description='Cautious AdoptW (decoupled decay)',
defaults={'decoupled': True, 'caution': True}
),
OptimInfo(
name='clamb',
opt_class=Lamb,
description='Cautious LAMB',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='claprop',
opt_class=LaProp,
description='Cautious LaProp',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='clion',
opt_class=Lion,
description='Cautious Lion',
has_eps=False,
has_betas=True,
defaults = {'caution': True}
),
OptimInfo(
name='cnadamw',
opt_class=NAdamW,
description='Cautious NAdamW',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='crmsproptf',
opt_class=RMSpropTF,
description='Cautious TensorFlow-style RMSprop',
has_momentum=True,
defaults={'alpha': 0.9, 'caution': True}
),
OptimInfo(
name='csgdw',
opt_class=SGDW,
description='Cautious SGD with decoupled weight decay and Nesterov momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True, 'caution': True}
),
]
for opt in cautious_optimizers:
registry.register(opt)
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
"""Register miscellaneous optimizers"""
other_optimizers = [
@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='adadelta',
opt_class=optim.Adadelta,
opt_class=torch.optim.Adadelta,
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
),
OptimInfo(
name='adagrad',
opt_class=optim.Adagrad,
opt_class=torch.optim.Adagrad,
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8}
),
@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='rmsprop',
opt_class=optim.RMSprop,
opt_class=torch.optim.RMSprop,
description='torch.optim.RMSprop, Root Mean Square Propagation',
has_momentum=True,
defaults={'alpha': 0.9}
@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
_register_other_optimizers(default_registry)
_register_apex_optimizers(default_registry)
_register_bnb_optimizers(default_registry)
_register_cautious_optimizers(default_registry)
# Register aliases
default_registry.register_alias('nesterov', 'sgd')
@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
def get_optimizer_class(
name: str,
bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
) -> Union[OptimType, OptimizerCallable]:
"""Get optimizer class by name with option to bind default arguments.
Retrieves the optimizer class or a partial function with default arguments bound.
@ -874,7 +947,7 @@ def get_optimizer_class(
def create_optimizer_v2(
model_or_params: Union[nn.Module, Params],
model_or_params: Union[nn.Module, ParamsT],
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
@ -882,9 +955,9 @@ def create_optimizer_v2(
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any,
) -> optim.Optimizer:
) -> torch.optim.Optimizer:
"""Create an optimizer instance via timm registry.
Creates and configures an optimizer with appropriate parameter groups and settings.
@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
return kwargs
def create_optimizer(args, model, filter_bias_and_bn=True):
def create_optimizer(
args,
model: Union[nn.Module, ParamsT],
filter_bias_and_bn: bool = True,
) -> torch.optim.Optimizer:
""" Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code.
"""

View File

@ -0,0 +1,25 @@
from typing import Any, Dict, Iterable, Union, Protocol, Type
try:
from typing import TypeAlias, TypeVar
except ImportError:
from typing_extensions import TypeAlias, TypeVar
import torch
import torch.optim
try:
from torch.optim.optimizer import ParamsT
except (ImportError, TypeError):
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
OptimType = Type[torch.optim.Optimizer]
class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""
def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...
__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']

View File

@ -10,8 +10,12 @@ Original header/copyright below.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import math
from typing import Optional, Tuple
import torch
from ._types import ParamsT
class Adafactor(torch.optim.Optimizer):
@ -26,33 +30,33 @@ class Adafactor(torch.optim.Optimizer):
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
`relative_step=False`.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): external learning rate (default: None)
eps (tuple[float, float]): regularization constants for square gradient
and parameter scale respectively (default: (1e-30, 1e-3))
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
beta1 (float): coefficient used for computing running averages of gradient (default: None)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
warmup_init (bool): time-dependent learning rate computation depends on
whether warm-up initialization is being used (default: False)
Ags:
params: iterable of parameters to optimize or dicts defining parameter groups
lr: external learning rate
eps: regularization constants for square gradient and parameter scale respectively
eps_scale: regularization constants for parameter scale respectively
clip_threshold: threshold of root-mean-square of final gradient update
decay_rate: coefficient used to compute running averages of square gradient
beta1: coefficient used for computing running averages of gradient
weight_decay: weight decay
scale_parameter: if True, learning rate is scaled by root-mean-square of parameter
warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used
"""
def __init__(
self,
params,
lr=None,
eps=1e-30,
eps_scale=1e-3,
clip_threshold=1.0,
decay_rate=-0.8,
betas=None,
weight_decay=0.0,
scale_parameter=True,
warmup_init=False,
min_dim_size_to_factor=32,
params: ParamsT,
lr: Optional[float] = None,
eps: float = 1e-30,
eps_scale: float = 1e-3,
clip_threshold: float = 1.0,
decay_rate: float = -0.8,
betas: Optional[Tuple[float, float]] = None,
weight_decay: float = 0.0,
scale_parameter: bool = True,
warmup_init: bool = False,
min_dim_size_to_factor: int = 16,
caution: bool = False,
):
relative_step = not lr
if warmup_init and not relative_step:
@ -71,9 +75,16 @@ class Adafactor(torch.optim.Optimizer):
relative_step=relative_step,
warmup_init=warmup_init,
min_dim_size_to_factor=min_dim_size_to_factor,
caution=caution,
)
super(Adafactor, self).__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('min_dim_size_to_factor', 32)
@staticmethod
def _get_lr(param_group, param_state):
if param_group['relative_step']:
@ -86,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
return param_group['lr']
@staticmethod
def _get_options(param_group, param_shape, min_size_to_factor=32):
def _get_options(param_group, param_shape, min_size_to_factor=16):
use_first_moment = param_group['beta1'] is not None
factored = None
ndim = len(param_shape)
@ -98,7 +109,7 @@ class Adafactor(torch.optim.Optimizer):
# nD convs in torch are ND + 2 dim weights with leading in/out chs
factored = 0, 1
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
# if the criteria above didn't match, test trailing dims for eligibility
# if the criteria above didn't match, test trailing dims for eligibility as per original impl
factored = ndim - 2, ndim - 1
return factored, use_first_moment
@ -113,7 +124,6 @@ class Adafactor(torch.optim.Optimizer):
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -201,7 +211,13 @@ class Adafactor(torch.optim.Optimizer):
if use_first_moment:
exp_avg = state['exp_avg']
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
update = exp_avg
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update = exp_avg * mask
else:
update = exp_avg
if group['weight_decay'] != 0:
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)

View File

@ -6,13 +6,14 @@ Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
Adaptation and PyTorch modifications by Ross Wightman
"""
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.optim import Optimizer
from ._types import ParamsT
def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state"""
@ -54,9 +55,9 @@ class AdafactorBigVision(Optimizer):
def __init__(
self,
params,
params: ParamsT,
lr: float = 1.0,
min_dim_size_to_factor: int = 32,
min_dim_size_to_factor: int = 16,
decay_rate: float = 0.8,
decay_offset: int = 0,
beta2_cap: float = 0.999,
@ -66,6 +67,7 @@ class AdafactorBigVision(Optimizer):
weight_decay: float = 0.0,
clipping_threshold: Optional[float] = None,
unscaled_wd: bool = False,
caution: bool = False,
*,
foreach: Optional[bool] = False,
):
@ -91,6 +93,7 @@ class AdafactorBigVision(Optimizer):
weight_decay=weight_decay,
clipping_threshold=clipping_threshold,
unscaled_wd=unscaled_wd,
caution=caution,
foreach=foreach,
)
super().__init__(params, defaults)
@ -98,6 +101,7 @@ class AdafactorBigVision(Optimizer):
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('foreach', None)
for p in group['params']:
p_state = self.state.get(p, {})
@ -192,6 +196,7 @@ class AdafactorBigVision(Optimizer):
momentum_dtype=group['momentum_dtype'],
clipping_threshold=group['clipping_threshold'],
unscaled_wd=group['unscaled_wd'],
caution=group['caution'],
)
return loss
@ -216,6 +221,7 @@ def _single_tensor_adafactor(
momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float],
unscaled_wd: bool,
caution: bool,
):
for i, param in enumerate(params):
grad = grads[i]
@ -267,6 +273,12 @@ def _single_tensor_adafactor(
exp_avg.lerp_(update, 1 - momentum) # ema
update = exp_avg.clone()
if caution:
# apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085
mask = (update * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update.mul_(mask)
# Scale by learning rate
update.mul_(lr)
@ -302,6 +314,7 @@ def _multi_tensor_adafactor(
momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float],
unscaled_wd: bool,
caution: bool,
):
# FIXME TODO
assert False, 'multi-tensor fn (foreach=True) not implemented yet'

View File

@ -4,49 +4,45 @@ Impl copied from PyTorch master
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
"""
import math
from typing import Tuple
import torch
from torch.optim.optimizer import Optimizer
from ._types import ParamsT
class AdamWLegacy(Optimizer):
r"""Implements AdamW algorithm.
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
References:
- Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
Args:
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
betas: coefficients used for computing running averages of gradient and its square
eps: term added to the denominator to improve numerical stability
weight_decay: weight decay coefficient
amsgrad: whether to use the AMSGrad variant of this algorithm
from the paper `On the Convergence of Adam and Beyond`
caution: apply caution when using AdamW
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
params: ParamsT,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
amsgrad: bool = False,
caution: bool = False,
):
# NOTE: deprecated in favour of builtin torch.optim.AdamW
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@ -61,6 +57,7 @@ class AdamWLegacy(Optimizer):
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
caution=caution,
)
super(AdamWLegacy, self).__init__(params, defaults)
@ -68,6 +65,7 @@ class AdamWLegacy(Optimizer):
super(AdamWLegacy, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
group.setdefault('caution', False)
@torch.no_grad()
def step(self, closure=None):
@ -131,6 +129,12 @@ class AdamWLegacy(Optimizer):
step_size = group['lr'] / bias_correction1
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
p.addcdiv_(exp_avg, denom, value=-step_size)
return loss

View File

@ -10,16 +10,15 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
year = {2024}
}
"""
from typing import cast, Callable, List, Optional, Tuple, Union
from typing import cast, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from ._types import ParamsT
__all__ = ["Adopt", "adopt"]
def _view_as_real(params, *state_and_grads):
@ -60,7 +59,7 @@ class Adopt(Optimizer):
"""
def __init__(
self,
params,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.9999),
eps: float = 1e-6,
@ -68,7 +67,8 @@ class Adopt(Optimizer):
weight_decay: float = 0.0,
decoupled: bool = False,
*,
foreach: Optional[bool] = None,
caution: bool = False,
foreach: Optional[bool] = False,
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
@ -98,6 +98,7 @@ class Adopt(Optimizer):
weight_decay=weight_decay,
clip_exp=clip_exp,
decoupled=decoupled,
caution=caution,
maximize=maximize,
foreach=foreach,
capturable=capturable,
@ -105,7 +106,6 @@ class Adopt(Optimizer):
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
@ -114,6 +114,7 @@ class Adopt(Optimizer):
group.setdefault("capturable", False)
group.setdefault("differentiable", False)
group.setdefault("clip_exp", None)
group.setdefault("caution", False)
for p in group["params"]:
p_state = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
@ -223,6 +224,7 @@ class Adopt(Optimizer):
clip_exp=group["clip_exp"],
decoupled=group["decoupled"],
eps=group["eps"],
caution=group["caution"],
maximize=group["maximize"],
foreach=group["foreach"],
capturable=group["capturable"],
@ -251,6 +253,7 @@ def _single_tensor_adopt(
clip_exp: Optional[float],
decoupled: bool,
eps: float,
caution: bool,
maximize: bool,
capturable: bool,
differentiable: bool,
@ -306,6 +309,13 @@ def _single_tensor_adopt(
normed_grad.clamp_(-clip_val, clip_val)
exp_avg.lerp_(normed_grad, 1 - beta1)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
param.add_(exp_avg, alpha=-lr)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@ -328,6 +338,7 @@ def _multi_tensor_adopt(
clip_exp: Optional[float],
decoupled: bool,
eps: float,
caution: bool,
maximize: bool,
capturable: bool,
differentiable: bool,
@ -403,6 +414,7 @@ def _multi_tensor_adopt(
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
if clip_exp is not None:
@ -411,6 +423,16 @@ def _multi_tensor_adopt(
torch._foreach_minimum_(normed_grad, clip_val)
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(device_exp_avgs, device_grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks)
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
@ -440,6 +462,7 @@ def adopt(
clip_exp: Optional[float],
decoupled: bool,
eps: float,
caution: bool,
maximize: bool,
):
r"""Functional API that performs ADOPT algorithm computation.
@ -477,6 +500,7 @@ def adopt(
clip_exp=clip_exp,
decoupled=decoupled,
eps=eps,
caution=caution,
maximize=maximize,
capturable=capturable,
differentiable=differentiable,

View File

@ -52,50 +52,48 @@ Modifications Copyright 2021 Ross Wightman
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import Optional, Tuple
import torch
from torch.optim import Optimizer
from ._types import ParamsT
class Lamb(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
LAMB was proposed in:
- Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate
betas: Coefficients used for computing running averages of gradient and its norm.
eps: Term added to the denominator to improve numerical stability.
weight_decay: Weight decay
grad_averaging: Whether apply (1-beta2) to grad when calculating running averages of gradient.
max_grad_norm: Value used to clip global grad norm.
trust_clip: Enable LAMBC trust ratio clipping.
always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter.
caution: Apply caution.
"""
def __init__(
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
grad_averaging=True,
max_grad_norm=1.0,
trust_clip=False,
always_adapt=False,
params: ParamsT,
lr: float = 1e-3,
bias_correction: bool = True,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.01,
grad_averaging: bool = True,
max_grad_norm: Optional[float] = 1.0,
trust_clip: bool = False,
always_adapt: bool = False,
caution: bool = False,
):
defaults = dict(
lr=lr,
@ -107,9 +105,15 @@ class Lamb(Optimizer):
max_grad_norm=max_grad_norm,
trust_clip=trust_clip,
always_adapt=always_adapt,
caution=caution,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
def _get_clip_grad_norm(self):
max_grad_norm = self.defaults['max_grad_norm']
if max_grad_norm is None:
@ -187,6 +191,12 @@ class Lamb(Optimizer):
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
update = (exp_avg / bias_correction1).div_(denom)
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (update * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update.mul_(mask)
weight_decay = group['weight_decay']
if weight_decay != 0:
update.add_(p, alpha=weight_decay)

View File

@ -12,9 +12,13 @@ Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs
}
"""
from typing import Tuple
from torch.optim import Optimizer
import torch
from ._types import ParamsT
class LaProp(Optimizer):
""" LaProp Optimizer
@ -23,11 +27,12 @@ class LaProp(Optimizer):
"""
def __init__(
self,
params,
lr=4e-4,
betas=(0.9, 0.999),
eps=1e-15,
weight_decay=0,
params: ParamsT,
lr: float = 4e-4,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-15,
weight_decay: float = 0.,
caution: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -42,6 +47,7 @@ class LaProp(Optimizer):
betas=betas,
eps=eps,
weight_decay=weight_decay,
caution=caution,
)
super(LaProp, self).__init__(params, defaults)
@ -101,7 +107,14 @@ class LaProp(Optimizer):
step_of_this_grad = grad / denom
exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
if group['caution']:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
p.add_(exp_avg, alpha=-step_size)
if group['weight_decay'] != 0:
p.add_(p, alpha=-group['weight_decay'])

View File

@ -16,33 +16,35 @@ Original Impl: https://github.com/google/automl/tree/master/lion
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import List
from typing import List, Optional, Tuple
import torch
from torch.optim.optimizer import Optimizer
from ._types import ParamsT
class Lion(Optimizer):
r"""Implements Lion algorithm."""
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0.0,
maximize=False,
foreach=None,
params: ParamsT,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
caution: bool = False,
maximize: bool = False,
foreach: Optional[bool] = None,
):
"""Initialize the hyperparameters.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-4)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.99))
weight_decay (float, optional): weight decay coefficient (default: 0)
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
betas: coefficients used for computing running averages of gradient and its square
weight_decay: weight decay coefficient
caution: apply caution
"""
if not 0.0 <= lr:
@ -55,6 +57,7 @@ class Lion(Optimizer):
lr=lr,
betas=betas,
weight_decay=weight_decay,
caution=caution,
foreach=foreach,
maximize=maximize,
)
@ -63,6 +66,7 @@ class Lion(Optimizer):
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('maximize', False)
group.setdefault('foreach', None)
@ -71,8 +75,7 @@ class Lion(Optimizer):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
closure: A closure that reevaluates the model and returns the loss.
Returns:
the loss.
@ -112,6 +115,7 @@ class Lion(Optimizer):
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
caution=group['caution'],
maximize=group['maximize'],
foreach=group['foreach'],
)
@ -132,6 +136,7 @@ def lion(
beta2: float,
lr: float,
weight_decay: float,
caution: bool,
):
r"""Functional API that performs Lion algorithm computation.
"""
@ -155,6 +160,7 @@ def lion(
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
caution=caution,
maximize=maximize,
)
@ -168,6 +174,7 @@ def _single_tensor_lion(
beta2: float,
lr: float,
weight_decay: float,
caution: bool,
maximize: bool,
):
for i, param in enumerate(params):
@ -183,8 +190,15 @@ def _single_tensor_lion(
param.mul_(1 - lr * weight_decay)
# Weight update
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
param.add_(torch.sign(update), alpha=-lr)
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_()
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (update * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
update.mul_(mask)
param.add_(update, alpha=-lr)
# Decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2)
@ -199,6 +213,7 @@ def _multi_tensor_lion(
beta2: float,
lr: float,
weight_decay: float,
caution: bool,
maximize: bool,
):
if len(params) == 0:
@ -217,8 +232,17 @@ def _multi_tensor_lion(
# Weight update
updates = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(updates, grads, alpha=1 - beta1)
updates = [u.sign_() for u in updates]
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(updates, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
torch._foreach_mul_(updates, masks)
updates = [u.sign() for u in updates]
torch._foreach_add_(params, updates, alpha=-lr)
# Decay the momentum running average coefficient

View File

@ -5,44 +5,43 @@ Based on simplified algorithm in https://github.com/mlcommons/algorithmic-effici
Added multi-tensor (foreach) path.
"""
import math
from typing import List, Optional
from typing import List, Optional, Tuple
import torch
from torch import Tensor
from ._types import ParamsT
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
""" Implements NAdamW algorithm.
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
the NAdam algorithm (there is also a comment in the code which highlights
the only difference of NAdamW and AdamW).
For further details regarding the algorithm we refer to
`Decoupled Weight Decay Regularization`_.
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
the NAdam algorithm (there is also a comment in the code which highlights
the only difference of NAdamW and AdamW).
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
For further details regarding the algorithm we refer to
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
Args:
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
betas: coefficients used for computing running averages of gradient and its square
eps: term added to the denominator to improve numerical stability
weight_decay: weight decay coefficient
caution: enable caution
"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
params: ParamsT,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
caution: bool = False,
maximize: bool = False,
foreach: Optional[bool] = None,
capturable: bool = False,
@ -62,6 +61,7 @@ class NAdamW(torch.optim.Optimizer):
betas=betas,
eps=eps,
weight_decay=weight_decay,
caution=caution,
foreach=foreach,
maximize=maximize,
capturable=capturable,
@ -71,11 +71,12 @@ class NAdamW(torch.optim.Optimizer):
def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]['step'])
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
for group in self.param_groups:
group.setdefault('caution', False)
@torch.no_grad()
def step(self, closure=None):
@ -133,6 +134,7 @@ class NAdamW(torch.optim.Optimizer):
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
caution=group['caution'],
maximize=group['maximize'],
capturable=group['capturable'],
)
@ -154,6 +156,7 @@ def nadamw(
lr: float,
weight_decay: float,
eps: float,
caution: bool,
maximize: bool,
) -> None:
r"""Functional API that performs NAdamW algorithm computation.
@ -183,6 +186,7 @@ def nadamw(
lr=lr,
weight_decay=weight_decay,
eps=eps,
caution=caution,
maximize=maximize,
capturable=capturable,
)
@ -200,6 +204,7 @@ def _single_tensor_nadamw(
lr: float,
weight_decay: float,
eps: float,
caution: bool,
maximize: bool,
capturable: bool
):
@ -238,6 +243,14 @@ def _single_tensor_nadamw(
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
# FIXME not 100% sure if this remains capturable?
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg.mul_(mask)
param.addcdiv_(exp_avg, denom)
else:
step = step_t.item()
@ -246,11 +259,17 @@ def _single_tensor_nadamw(
step_size = lr / bias_correction1
bias_correction2_sqrt = math.sqrt(bias_correction2)
# Only difference between NAdamW and AdamW in this implementation.
# Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
# The official PyTorch implementation of NAdam uses a different algorithm.
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg.mul_(mask)
param.addcdiv_(exp_avg, denom, value=-step_size)
@ -266,6 +285,7 @@ def _multi_tensor_nadamw(
lr: float,
weight_decay: float,
eps: float,
caution: bool,
maximize: bool,
capturable: bool,
):
@ -322,12 +342,22 @@ def _multi_tensor_nadamw(
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
exp_avg_sq_sqrt,
torch._foreach_mul(bias_correction2_sqrt, step_size)
)
eps_over_step_size = torch._foreach_div(step_size, eps)
torch._foreach_reciprocal_(eps_over_step_size)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(exp_avgs, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] # capturable?
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
torch._foreach_mul_(exp_avgs, masks)
torch._foreach_addcdiv_(params, exp_avgs, denom)
else:
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
@ -337,7 +367,7 @@ def _multi_tensor_nadamw(
bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2]
# Only difference between NAdamW and AdamW in this implementation.
# Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
# The official PyTorch implementation of NAdam uses a different algorithm.
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
@ -346,4 +376,13 @@ def _multi_tensor_nadamw(
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(exp_avgs, grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
torch._foreach_mul_(exp_avgs, masks)
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)

View File

@ -10,6 +10,8 @@ Modifications Copyright 2021 Ross Wightman
import torch
from torch.optim import Optimizer
from ._types import ParamsT
class RMSpropTF(Optimizer):
"""Implements RMSprop algorithm (TensorFlow style epsilon)
@ -28,34 +30,31 @@ class RMSpropTF(Optimizer):
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-2)
momentum (float, optional): momentum factor (default: 0)
alpha (float, optional): smoothing (decay) constant (default: 0.9)
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-10)
centered (bool, optional) : if ``True``, compute the centered RMSProp,
the gradient is normalized by an estimation of its variance
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
update as per defaults in Tensorflow
Args:
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
momentum: momentum factor
alpha: smoothing (decay) constant
eps: term added to the denominator to improve numerical stability
centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
weight_decay: weight decay (L2 penalty) (default: 0)
decoupled_decay: decoupled weight decay as per https://arxiv.org/abs/1711.05101
lr_in_momentum: learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow
caution: apply caution
"""
def __init__(
self,
params,
lr=1e-2,
alpha=0.9,
eps=1e-10,
weight_decay=0,
momentum=0.,
centered=False,
decoupled_decay=False,
lr_in_momentum=True,
params: ParamsT,
lr: float = 1e-2,
alpha: float = 0.9,
eps: float = 1e-10,
weight_decay: float = 0,
momentum: float = 0.,
centered: bool = False,
decoupled_decay: bool = False,
lr_in_momentum: bool = True,
caution: bool = False,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -77,6 +76,7 @@ class RMSpropTF(Optimizer):
weight_decay=weight_decay,
decoupled_decay=decoupled_decay,
lr_in_momentum=lr_in_momentum,
caution=caution,
)
super(RMSpropTF, self).__init__(params, defaults)
@ -85,6 +85,7 @@ class RMSpropTF(Optimizer):
for group in self.param_groups:
group.setdefault('momentum', 0)
group.setdefault('centered', False)
group.setdefault('caution', False)
@torch.no_grad()
def step(self, closure=None):
@ -142,13 +143,25 @@ class RMSpropTF(Optimizer):
if group['momentum'] > 0:
buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer
buf.mul_(group['momentum'])
def _apply_caution(_m, _g):
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (_m * _g > 0).to(_g.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
return _m * mask
if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
# Tensorflow accumulates the LR scaling in the momentum buffer
buf.addcdiv_(grad, avg, value=group['lr'])
if group['caution']:
buf = _apply_caution(buf, grad)
p.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg)
buf.addcdiv_(grad, avg)
if group['caution']:
buf = _apply_caution(buf, grad)
p.add_(buf, alpha=-group['lr'])
else:
p.addcdiv_(grad, avg, value=-group['lr'])

View File

@ -1,4 +1,5 @@
from functools import update_wrapper, wraps
from typing import List, Optional
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
@ -8,7 +9,7 @@ try:
except ImportError:
has_recent_pt = False
from typing import List, Optional
from ._types import ParamsT
__all__ = ['SGDW', 'sgdw']
@ -16,13 +17,14 @@ __all__ = ['SGDW', 'sgdw']
class SGDW(Optimizer):
def __init__(
self,
params,
lr=1e-3,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
params: ParamsT,
lr: float = 1e-3,
momentum: float = 0.,
dampening: float = 0.,
weight_decay: float = 0.,
nesterov: bool = False,
*,
caution: bool = False,
maximize: bool = False,
foreach: Optional[bool] = None,
differentiable: bool = False,
@ -40,6 +42,7 @@ class SGDW(Optimizer):
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
caution=caution,
maximize=maximize,
foreach=foreach,
differentiable=differentiable,
@ -51,18 +54,19 @@ class SGDW(Optimizer):
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
group.setdefault('nesterov', False)
group.setdefault('maximize', False)
group.setdefault('foreach', None)
group.setdefault('differentiable', False)
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
def _init_group(self, group, params_with_grad, grads, momentum_buffer_list):
has_sparse_grad = False
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
d_p_list.append(p.grad)
grads.append(p.grad)
if p.grad.is_sparse:
has_sparse_grad = True
@ -91,20 +95,21 @@ class SGDW(Optimizer):
for group in self.param_groups:
params_with_grad = []
d_p_list = []
grads = []
momentum_buffer_list = []
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
has_sparse_grad = self._init_group(group, params_with_grad, grads, momentum_buffer_list)
sgdw(
params_with_grad,
d_p_list,
grads,
momentum_buffer_list,
weight_decay=group['weight_decay'],
momentum=group['momentum'],
lr=group['lr'],
dampening=group['dampening'],
nesterov=group['nesterov'],
caution=group['caution'],
maximize=group['maximize'],
has_sparse_grad=has_sparse_grad,
foreach=group['foreach'],
@ -120,7 +125,7 @@ class SGDW(Optimizer):
def sgdw(
params: List[Tensor],
d_p_list: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
@ -132,6 +137,7 @@ def sgdw(
lr: float,
dampening: float,
nesterov: bool,
caution: bool,
maximize: bool
):
r"""Functional API that performs SGD algorithm computation.
@ -159,13 +165,14 @@ def sgdw(
func(
params,
d_p_list,
grads,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
caution=caution,
has_sparse_grad=has_sparse_grad,
maximize=maximize,
)
@ -173,7 +180,7 @@ def sgdw(
def _single_tensor_sgdw(
params: List[Tensor],
d_p_list: List[Tensor],
grads: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
*,
weight_decay: float,
@ -181,11 +188,12 @@ def _single_tensor_sgdw(
lr: float,
dampening: float,
nesterov: bool,
caution: bool,
maximize: bool,
has_sparse_grad: bool
):
for i, param in enumerate(params):
d_p = d_p_list[i] if not maximize else -d_p_list[i]
grad = grads[i] if not maximize else -grads[i]
param.mul_(1. - lr * weight_decay)
@ -193,17 +201,25 @@ def _single_tensor_sgdw(
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(d_p).detach()
buf = torch.clone(grad).detach()
momentum_buffer_list[i] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
if caution:
if nesterov:
buf = grad.add(buf, alpha=momentum)
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (buf * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
grad = buf * mask
else:
d_p = buf
if nesterov:
grad = grad.add(buf, alpha=momentum)
else:
grad = buf
param.add_(d_p, alpha=-lr)
param.add_(grad, alpha=-lr)
def _multi_tensor_sgdw(
@ -216,6 +232,7 @@ def _multi_tensor_sgdw(
lr: float,
dampening: float,
nesterov: bool,
caution: bool,
maximize: bool,
has_sparse_grad: bool
):
@ -258,10 +275,22 @@ def _multi_tensor_sgdw(
bufs.append(buf)
if nesterov:
torch._foreach_add_(device_grads, bufs, alpha=momentum)
if caution:
if nesterov:
# Can't do nesterov in-place if we want to compare against orig grad for caution
bufs = torch._foreach_add(device_grads, bufs, alpha=momentum)
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(bufs, device_grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
mask_scale = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
device_grads = torch._foreach_mul(bufs, masks)
else:
device_grads = bufs
if nesterov:
torch._foreach_add_(device_grads, bufs, alpha=momentum)
else:
device_grads = bufs
if not device_has_sparse_grad:
torch._foreach_add_(device_params, device_grads, alpha=-lr)