Cautious optimizer impl plus some typing cleanup.
parent
aeb1ed7a15
commit
7cf683628f
|
@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
|
|||
assert isinstance(opt_info, OptimInfo)
|
||||
|
||||
lr = (1e-2,) * 4
|
||||
if optimizer in ('mars',):
|
||||
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
|
||||
lr = (1e-3,) * 4
|
||||
|
||||
try:
|
||||
|
|
|
@ -5,15 +5,16 @@ Hacked together by / Copyright 2021 Ross Wightman
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from fnmatch import fnmatch
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.optim
|
||||
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
||||
from ._types import ParamsT, OptimType, OptimizerCallable
|
||||
from .adabelief import AdaBelief
|
||||
from .adafactor import Adafactor
|
||||
from .adafactor_bv import AdafactorBigVision
|
||||
|
@ -39,11 +40,6 @@ from .sgdw import SGDW
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Type variables
|
||||
T = TypeVar('T')
|
||||
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
|
||||
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
|
||||
|
||||
|
||||
def _import_class(class_string: str) -> Type:
|
||||
"""Dynamically import a class from a string."""
|
||||
|
@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
|
|||
raise ImportError(f"Could not import {class_string}: {e}")
|
||||
|
||||
|
||||
class OptimizerCallable(Protocol):
|
||||
"""Protocol for optimizer constructor signatures."""
|
||||
|
||||
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptimInfo:
|
||||
|
@ -76,7 +67,7 @@ class OptimInfo:
|
|||
defaults: Optional default parameters for the optimizer
|
||||
"""
|
||||
name: str
|
||||
opt_class: Union[str, Type[optim.Optimizer]]
|
||||
opt_class: Union[str, OptimType]
|
||||
description: str = ''
|
||||
has_eps: bool = True
|
||||
has_momentum: bool = False
|
||||
|
@ -185,7 +176,7 @@ class OptimizerRegistry:
|
|||
self,
|
||||
name_or_info: Union[str, OptimInfo],
|
||||
bind_defaults: bool = True,
|
||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||
) -> Union[OptimType, OptimizerCallable]:
|
||||
"""Get the optimizer class with any default arguments applied.
|
||||
|
||||
This allows direct instantiation of optimizers with their default configs
|
||||
|
@ -234,7 +225,7 @@ class OptimizerRegistry:
|
|||
|
||||
def create_optimizer(
|
||||
self,
|
||||
model_or_params: Union[nn.Module, Params],
|
||||
model_or_params: Union[nn.Module, ParamsT],
|
||||
opt: str,
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: float = 0.,
|
||||
|
@ -242,9 +233,9 @@ class OptimizerRegistry:
|
|||
foreach: Optional[bool] = None,
|
||||
weight_decay_exclude_1d: bool = True,
|
||||
layer_decay: Optional[float] = None,
|
||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
|
||||
**kwargs: Any,
|
||||
) -> optim.Optimizer:
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Create an optimizer instance.
|
||||
|
||||
Args:
|
||||
|
@ -347,7 +338,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
|||
sgd_optimizers = [
|
||||
OptimInfo(
|
||||
name='sgd',
|
||||
opt_class=optim.SGD,
|
||||
opt_class=torch.optim.SGD,
|
||||
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
|
@ -355,7 +346,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
|||
),
|
||||
OptimInfo(
|
||||
name='momentum',
|
||||
opt_class=optim.SGD,
|
||||
opt_class=torch.optim.SGD,
|
||||
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
|
@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||
adam_optimizers = [
|
||||
OptimInfo(
|
||||
name='adam',
|
||||
opt_class=optim.Adam,
|
||||
opt_class=torch.optim.Adam,
|
||||
description='torch.optim.Adam, Adaptive Moment Estimation',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamw',
|
||||
opt_class=optim.AdamW,
|
||||
opt_class=torch.optim.AdamW,
|
||||
description='torch.optim.AdamW, Adam with decoupled weight decay',
|
||||
has_betas=True
|
||||
),
|
||||
|
@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||
),
|
||||
OptimInfo(
|
||||
name='adamax',
|
||||
opt_class=optim.Adamax,
|
||||
opt_class=torch.optim.Adamax,
|
||||
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
|
||||
has_betas=True
|
||||
),
|
||||
|
@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
|
|||
registry.register(opt)
|
||||
|
||||
|
||||
def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
|
||||
cautious_optimizers = [
|
||||
OptimInfo(
|
||||
name='cadafactor',
|
||||
opt_class=Adafactor,
|
||||
description='Cautious Adafactor',
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cadafactorbv',
|
||||
opt_class=AdafactorBigVision,
|
||||
description='Cautious Big Vision Adafactor',
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cadamw',
|
||||
opt_class=AdamWLegacy,
|
||||
description='Cautious AdamW',
|
||||
has_betas=True,
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cadopt',
|
||||
opt_class=Adopt,
|
||||
description='Cautious Adopt',
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cadoptw',
|
||||
opt_class=Adopt,
|
||||
description='Cautious AdoptW (decoupled decay)',
|
||||
defaults={'decoupled': True, 'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='clamb',
|
||||
opt_class=Lamb,
|
||||
description='Cautious LAMB',
|
||||
has_betas=True,
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='claprop',
|
||||
opt_class=LaProp,
|
||||
description='Cautious LaProp',
|
||||
has_betas=True,
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='clion',
|
||||
opt_class=Lion,
|
||||
description='Cautious Lion',
|
||||
has_eps=False,
|
||||
has_betas=True,
|
||||
defaults = {'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cnadamw',
|
||||
opt_class=NAdamW,
|
||||
description='Cautious NAdamW',
|
||||
has_betas=True,
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='crmsproptf',
|
||||
opt_class=RMSpropTF,
|
||||
description='Cautious TensorFlow-style RMSprop',
|
||||
has_momentum=True,
|
||||
defaults={'alpha': 0.9, 'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='csgdw',
|
||||
opt_class=SGDW,
|
||||
description='Cautious SGD with decoupled weight decay and Nesterov momentum',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True, 'caution': True}
|
||||
),
|
||||
]
|
||||
for opt in cautious_optimizers:
|
||||
registry.register(opt)
|
||||
|
||||
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
||||
"""Register miscellaneous optimizers"""
|
||||
other_optimizers = [
|
||||
|
@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
),
|
||||
OptimInfo(
|
||||
name='adadelta',
|
||||
opt_class=optim.Adadelta,
|
||||
opt_class=torch.optim.Adadelta,
|
||||
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
|
||||
),
|
||||
OptimInfo(
|
||||
name='adagrad',
|
||||
opt_class=optim.Adagrad,
|
||||
opt_class=torch.optim.Adagrad,
|
||||
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||
defaults={'eps': 1e-8}
|
||||
),
|
||||
|
@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
),
|
||||
OptimInfo(
|
||||
name='rmsprop',
|
||||
opt_class=optim.RMSprop,
|
||||
opt_class=torch.optim.RMSprop,
|
||||
description='torch.optim.RMSprop, Root Mean Square Propagation',
|
||||
has_momentum=True,
|
||||
defaults={'alpha': 0.9}
|
||||
|
@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
|
|||
_register_other_optimizers(default_registry)
|
||||
_register_apex_optimizers(default_registry)
|
||||
_register_bnb_optimizers(default_registry)
|
||||
_register_cautious_optimizers(default_registry)
|
||||
|
||||
# Register aliases
|
||||
default_registry.register_alias('nesterov', 'sgd')
|
||||
|
@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
|
|||
def get_optimizer_class(
|
||||
name: str,
|
||||
bind_defaults: bool = True,
|
||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||
) -> Union[OptimType, OptimizerCallable]:
|
||||
"""Get optimizer class by name with option to bind default arguments.
|
||||
|
||||
Retrieves the optimizer class or a partial function with default arguments bound.
|
||||
|
@ -874,7 +947,7 @@ def get_optimizer_class(
|
|||
|
||||
|
||||
def create_optimizer_v2(
|
||||
model_or_params: Union[nn.Module, Params],
|
||||
model_or_params: Union[nn.Module, ParamsT],
|
||||
opt: str = 'sgd',
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: float = 0.,
|
||||
|
@ -882,9 +955,9 @@ def create_optimizer_v2(
|
|||
foreach: Optional[bool] = None,
|
||||
filter_bias_and_bn: bool = True,
|
||||
layer_decay: Optional[float] = None,
|
||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
|
||||
**kwargs: Any,
|
||||
) -> optim.Optimizer:
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Create an optimizer instance via timm registry.
|
||||
|
||||
Creates and configures an optimizer with appropriate parameter groups and settings.
|
||||
|
@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
|
|||
return kwargs
|
||||
|
||||
|
||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
def create_optimizer(
|
||||
args,
|
||||
model: Union[nn.Module, ParamsT],
|
||||
filter_bias_and_bn: bool = True,
|
||||
) -> torch.optim.Optimizer:
|
||||
""" Legacy optimizer factory for backwards compatibility.
|
||||
NOTE: Use create_optimizer_v2 for new code.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from typing import Any, Dict, Iterable, Union, Protocol, Type
|
||||
try:
|
||||
from typing import TypeAlias, TypeVar
|
||||
except ImportError:
|
||||
from typing_extensions import TypeAlias, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.optim
|
||||
|
||||
try:
|
||||
from torch.optim.optimizer import ParamsT
|
||||
except (ImportError, TypeError):
|
||||
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
|
||||
|
||||
|
||||
OptimType = Type[torch.optim.Optimizer]
|
||||
|
||||
|
||||
class OptimizerCallable(Protocol):
|
||||
"""Protocol for optimizer constructor signatures."""
|
||||
|
||||
def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...
|
||||
|
||||
|
||||
__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']
|
|
@ -10,8 +10,12 @@ Original header/copyright below.
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
class Adafactor(torch.optim.Optimizer):
|
||||
|
@ -26,33 +30,33 @@ class Adafactor(torch.optim.Optimizer):
|
|||
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||
`relative_step=False`.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr (float, optional): external learning rate (default: None)
|
||||
eps (tuple[float, float]): regularization constants for square gradient
|
||||
and parameter scale respectively (default: (1e-30, 1e-3))
|
||||
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
||||
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
||||
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
||||
warmup_init (bool): time-dependent learning rate computation depends on
|
||||
whether warm-up initialization is being used (default: False)
|
||||
Ags:
|
||||
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr: external learning rate
|
||||
eps: regularization constants for square gradient and parameter scale respectively
|
||||
eps_scale: regularization constants for parameter scale respectively
|
||||
clip_threshold: threshold of root-mean-square of final gradient update
|
||||
decay_rate: coefficient used to compute running averages of square gradient
|
||||
beta1: coefficient used for computing running averages of gradient
|
||||
weight_decay: weight decay
|
||||
scale_parameter: if True, learning rate is scaled by root-mean-square of parameter
|
||||
warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=None,
|
||||
eps=1e-30,
|
||||
eps_scale=1e-3,
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
betas=None,
|
||||
weight_decay=0.0,
|
||||
scale_parameter=True,
|
||||
warmup_init=False,
|
||||
min_dim_size_to_factor=32,
|
||||
params: ParamsT,
|
||||
lr: Optional[float] = None,
|
||||
eps: float = 1e-30,
|
||||
eps_scale: float = 1e-3,
|
||||
clip_threshold: float = 1.0,
|
||||
decay_rate: float = -0.8,
|
||||
betas: Optional[Tuple[float, float]] = None,
|
||||
weight_decay: float = 0.0,
|
||||
scale_parameter: bool = True,
|
||||
warmup_init: bool = False,
|
||||
min_dim_size_to_factor: int = 16,
|
||||
caution: bool = False,
|
||||
):
|
||||
relative_step = not lr
|
||||
if warmup_init and not relative_step:
|
||||
|
@ -71,9 +75,16 @@ class Adafactor(torch.optim.Optimizer):
|
|||
relative_step=relative_step,
|
||||
warmup_init=warmup_init,
|
||||
min_dim_size_to_factor=min_dim_size_to_factor,
|
||||
caution=caution,
|
||||
)
|
||||
super(Adafactor, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
group.setdefault('min_dim_size_to_factor', 32)
|
||||
|
||||
@staticmethod
|
||||
def _get_lr(param_group, param_state):
|
||||
if param_group['relative_step']:
|
||||
|
@ -86,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||
return param_group['lr']
|
||||
|
||||
@staticmethod
|
||||
def _get_options(param_group, param_shape, min_size_to_factor=32):
|
||||
def _get_options(param_group, param_shape, min_size_to_factor=16):
|
||||
use_first_moment = param_group['beta1'] is not None
|
||||
factored = None
|
||||
ndim = len(param_shape)
|
||||
|
@ -98,7 +109,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
||||
factored = 0, 1
|
||||
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
||||
# if the criteria above didn't match, test trailing dims for eligibility
|
||||
# if the criteria above didn't match, test trailing dims for eligibility as per original impl
|
||||
factored = ndim - 2, ndim - 1
|
||||
|
||||
return factored, use_first_moment
|
||||
|
@ -113,7 +124,6 @@ class Adafactor(torch.optim.Optimizer):
|
|||
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -201,7 +211,13 @@ class Adafactor(torch.optim.Optimizer):
|
|||
if use_first_moment:
|
||||
exp_avg = state['exp_avg']
|
||||
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
|
||||
update = exp_avg
|
||||
if group['caution']:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
update = exp_avg * mask
|
||||
else:
|
||||
update = exp_avg
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
|
||||
|
|
|
@ -6,13 +6,14 @@ Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
|
|||
|
||||
Adaptation and PyTorch modifications by Ross Wightman
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
def _get_scalar_dtype():
|
||||
"""Get the scalar dtype that the optimizer uses for state"""
|
||||
|
@ -54,9 +55,9 @@ class AdafactorBigVision(Optimizer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
params: ParamsT,
|
||||
lr: float = 1.0,
|
||||
min_dim_size_to_factor: int = 32,
|
||||
min_dim_size_to_factor: int = 16,
|
||||
decay_rate: float = 0.8,
|
||||
decay_offset: int = 0,
|
||||
beta2_cap: float = 0.999,
|
||||
|
@ -66,6 +67,7 @@ class AdafactorBigVision(Optimizer):
|
|||
weight_decay: float = 0.0,
|
||||
clipping_threshold: Optional[float] = None,
|
||||
unscaled_wd: bool = False,
|
||||
caution: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = False,
|
||||
):
|
||||
|
@ -91,6 +93,7 @@ class AdafactorBigVision(Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
clipping_threshold=clipping_threshold,
|
||||
unscaled_wd=unscaled_wd,
|
||||
caution=caution,
|
||||
foreach=foreach,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
@ -98,6 +101,7 @@ class AdafactorBigVision(Optimizer):
|
|||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
group.setdefault('foreach', None)
|
||||
for p in group['params']:
|
||||
p_state = self.state.get(p, {})
|
||||
|
@ -192,6 +196,7 @@ class AdafactorBigVision(Optimizer):
|
|||
momentum_dtype=group['momentum_dtype'],
|
||||
clipping_threshold=group['clipping_threshold'],
|
||||
unscaled_wd=group['unscaled_wd'],
|
||||
caution=group['caution'],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
@ -216,6 +221,7 @@ def _single_tensor_adafactor(
|
|||
momentum_dtype: Union[str, torch.dtype],
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
caution: bool,
|
||||
):
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
|
@ -267,6 +273,12 @@ def _single_tensor_adafactor(
|
|||
exp_avg.lerp_(update, 1 - momentum) # ema
|
||||
update = exp_avg.clone()
|
||||
|
||||
if caution:
|
||||
# apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085
|
||||
mask = (update * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
update.mul_(mask)
|
||||
|
||||
# Scale by learning rate
|
||||
update.mul_(lr)
|
||||
|
||||
|
@ -302,6 +314,7 @@ def _multi_tensor_adafactor(
|
|||
momentum_dtype: Union[str, torch.dtype],
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
caution: bool,
|
||||
):
|
||||
# FIXME TODO
|
||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
||||
|
|
|
@ -4,49 +4,45 @@ Impl copied from PyTorch master
|
|||
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
||||
"""
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
class AdamWLegacy(Optimizer):
|
||||
r"""Implements AdamW algorithm.
|
||||
|
||||
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
||||
|
||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||
References:
|
||||
- Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980
|
||||
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
|
||||
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||
(default: False)
|
||||
|
||||
.. _Adam\: A Method for Stochastic Optimization:
|
||||
https://arxiv.org/abs/1412.6980
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
Args:
|
||||
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr: learning rate
|
||||
betas: coefficients used for computing running averages of gradient and its square
|
||||
eps: term added to the denominator to improve numerical stability
|
||||
weight_decay: weight decay coefficient
|
||||
amsgrad: whether to use the AMSGrad variant of this algorithm
|
||||
from the paper `On the Convergence of Adam and Beyond`
|
||||
caution: apply caution when using AdamW
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
params: ParamsT,
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 1e-2,
|
||||
amsgrad: bool = False,
|
||||
caution: bool = False,
|
||||
):
|
||||
# NOTE: deprecated in favour of builtin torch.optim.AdamW
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -61,6 +57,7 @@ class AdamWLegacy(Optimizer):
|
|||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
caution=caution,
|
||||
)
|
||||
super(AdamWLegacy, self).__init__(params, defaults)
|
||||
|
||||
|
@ -68,6 +65,7 @@ class AdamWLegacy(Optimizer):
|
|||
super(AdamWLegacy, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
group.setdefault('caution', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
@ -131,6 +129,12 @@ class AdamWLegacy(Optimizer):
|
|||
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
if group['caution']:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg = exp_avg * mask
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -10,16 +10,15 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
|
|||
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
|
||||
year = {2024}
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
from typing import cast, Callable, List, Optional, Tuple, Union
|
||||
from typing import cast, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
__all__ = ["Adopt", "adopt"]
|
||||
|
||||
def _view_as_real(params, *state_and_grads):
|
||||
|
@ -60,7 +59,7 @@ class Adopt(Optimizer):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
|
@ -68,7 +67,8 @@ class Adopt(Optimizer):
|
|||
weight_decay: float = 0.0,
|
||||
decoupled: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
caution: bool = False,
|
||||
foreach: Optional[bool] = False,
|
||||
maximize: bool = False,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
|
@ -98,6 +98,7 @@ class Adopt(Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
clip_exp=clip_exp,
|
||||
decoupled=decoupled,
|
||||
caution=caution,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
|
@ -105,7 +106,6 @@ class Adopt(Optimizer):
|
|||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
|
@ -114,6 +114,7 @@ class Adopt(Optimizer):
|
|||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
group.setdefault("clip_exp", None)
|
||||
group.setdefault("caution", False)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
|
@ -223,6 +224,7 @@ class Adopt(Optimizer):
|
|||
clip_exp=group["clip_exp"],
|
||||
decoupled=group["decoupled"],
|
||||
eps=group["eps"],
|
||||
caution=group["caution"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
capturable=group["capturable"],
|
||||
|
@ -251,6 +253,7 @@ def _single_tensor_adopt(
|
|||
clip_exp: Optional[float],
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
|
@ -306,6 +309,13 @@ def _single_tensor_adopt(
|
|||
normed_grad.clamp_(-clip_val, clip_val)
|
||||
|
||||
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg = exp_avg * mask
|
||||
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
@ -328,6 +338,7 @@ def _multi_tensor_adopt(
|
|||
clip_exp: Optional[float],
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
|
@ -403,6 +414,7 @@ def _multi_tensor_adopt(
|
|||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||
|
||||
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||
|
||||
if clip_exp is not None:
|
||||
|
@ -411,6 +423,16 @@ def _multi_tensor_adopt(
|
|||
torch._foreach_minimum_(normed_grad, clip_val)
|
||||
|
||||
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
masks = torch._foreach_mul(device_exp_avgs, device_grads)
|
||||
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
|
||||
mask_scale = [m.mean() for m in masks]
|
||||
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||
torch._foreach_div_(masks, mask_scale)
|
||||
device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
|
@ -440,6 +462,7 @@ def adopt(
|
|||
clip_exp: Optional[float],
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs ADOPT algorithm computation.
|
||||
|
@ -477,6 +500,7 @@ def adopt(
|
|||
clip_exp=clip_exp,
|
||||
decoupled=decoupled,
|
||||
eps=eps,
|
||||
caution=caution,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
|
|
|
@ -52,50 +52,48 @@ Modifications Copyright 2021 Ross Wightman
|
|||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
class Lamb(Optimizer):
|
||||
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
|
||||
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||
|
||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
||||
LAMB was proposed in:
|
||||
- Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962
|
||||
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its norm. (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability. (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
|
||||
calculating running averages of gradient. (default: True)
|
||||
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
|
||||
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
||||
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
||||
weight decay parameter (default: False)
|
||||
|
||||
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
|
||||
https://arxiv.org/abs/1904.00962
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
Args:
|
||||
params: Iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr: Learning rate
|
||||
betas: Coefficients used for computing running averages of gradient and its norm.
|
||||
eps: Term added to the denominator to improve numerical stability.
|
||||
weight_decay: Weight decay
|
||||
grad_averaging: Whether apply (1-beta2) to grad when calculating running averages of gradient.
|
||||
max_grad_norm: Value used to clip global grad norm.
|
||||
trust_clip: Enable LAMBC trust ratio clipping.
|
||||
always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter.
|
||||
caution: Apply caution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
grad_averaging=True,
|
||||
max_grad_norm=1.0,
|
||||
trust_clip=False,
|
||||
always_adapt=False,
|
||||
params: ParamsT,
|
||||
lr: float = 1e-3,
|
||||
bias_correction: bool = True,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.01,
|
||||
grad_averaging: bool = True,
|
||||
max_grad_norm: Optional[float] = 1.0,
|
||||
trust_clip: bool = False,
|
||||
always_adapt: bool = False,
|
||||
caution: bool = False,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -107,9 +105,15 @@ class Lamb(Optimizer):
|
|||
max_grad_norm=max_grad_norm,
|
||||
trust_clip=trust_clip,
|
||||
always_adapt=always_adapt,
|
||||
caution=caution,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
|
||||
def _get_clip_grad_norm(self):
|
||||
max_grad_norm = self.defaults['max_grad_norm']
|
||||
if max_grad_norm is None:
|
||||
|
@ -187,6 +191,12 @@ class Lamb(Optimizer):
|
|||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
update = (exp_avg / bias_correction1).div_(denom)
|
||||
|
||||
if group['caution']:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (update * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
update.mul_(mask)
|
||||
|
||||
weight_decay = group['weight_decay']
|
||||
if weight_decay != 0:
|
||||
update.add_(p, alpha=weight_decay)
|
||||
|
|
|
@ -12,9 +12,13 @@ Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs
|
|||
}
|
||||
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
from torch.optim import Optimizer
|
||||
import torch
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
class LaProp(Optimizer):
|
||||
""" LaProp Optimizer
|
||||
|
@ -23,11 +27,12 @@ class LaProp(Optimizer):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=4e-4,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-15,
|
||||
weight_decay=0,
|
||||
params: ParamsT,
|
||||
lr: float = 4e-4,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-15,
|
||||
weight_decay: float = 0.,
|
||||
caution: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
|
@ -42,6 +47,7 @@ class LaProp(Optimizer):
|
|||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
caution=caution,
|
||||
)
|
||||
super(LaProp, self).__init__(params, defaults)
|
||||
|
||||
|
@ -101,7 +107,14 @@ class LaProp(Optimizer):
|
|||
step_of_this_grad = grad / denom
|
||||
exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
|
||||
|
||||
if group['caution']:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg = exp_avg * mask
|
||||
|
||||
p.add_(exp_avg, alpha=-step_size)
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p.add_(p, alpha=-group['weight_decay'])
|
||||
|
||||
|
|
|
@ -16,33 +16,35 @@ Original Impl: https://github.com/google/automl/tree/master/lion
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from typing import List
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
class Lion(Optimizer):
|
||||
r"""Implements Lion algorithm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0.0,
|
||||
maximize=False,
|
||||
foreach=None,
|
||||
params: ParamsT,
|
||||
lr: float = 1e-4,
|
||||
betas: Tuple[float, float] = (0.9, 0.99),
|
||||
weight_decay: float = 0.0,
|
||||
caution: bool = False,
|
||||
maximize: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
):
|
||||
"""Initialize the hyperparameters.
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-4)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.99))
|
||||
weight_decay (float, optional): weight decay coefficient (default: 0)
|
||||
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr: learning rate
|
||||
betas: coefficients used for computing running averages of gradient and its square
|
||||
weight_decay: weight decay coefficient
|
||||
caution: apply caution
|
||||
"""
|
||||
|
||||
if not 0.0 <= lr:
|
||||
|
@ -55,6 +57,7 @@ class Lion(Optimizer):
|
|||
lr=lr,
|
||||
betas=betas,
|
||||
weight_decay=weight_decay,
|
||||
caution=caution,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
)
|
||||
|
@ -63,6 +66,7 @@ class Lion(Optimizer):
|
|||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
group.setdefault('maximize', False)
|
||||
group.setdefault('foreach', None)
|
||||
|
||||
|
@ -71,8 +75,7 @@ class Lion(Optimizer):
|
|||
"""Performs a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
closure: A closure that reevaluates the model and returns the loss.
|
||||
|
||||
Returns:
|
||||
the loss.
|
||||
|
@ -112,6 +115,7 @@ class Lion(Optimizer):
|
|||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
caution=group['caution'],
|
||||
maximize=group['maximize'],
|
||||
foreach=group['foreach'],
|
||||
)
|
||||
|
@ -132,6 +136,7 @@ def lion(
|
|||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
caution: bool,
|
||||
):
|
||||
r"""Functional API that performs Lion algorithm computation.
|
||||
"""
|
||||
|
@ -155,6 +160,7 @@ def lion(
|
|||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
caution=caution,
|
||||
maximize=maximize,
|
||||
)
|
||||
|
||||
|
@ -168,6 +174,7 @@ def _single_tensor_lion(
|
|||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
):
|
||||
for i, param in enumerate(params):
|
||||
|
@ -183,8 +190,15 @@ def _single_tensor_lion(
|
|||
param.mul_(1 - lr * weight_decay)
|
||||
|
||||
# Weight update
|
||||
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
||||
param.add_(torch.sign(update), alpha=-lr)
|
||||
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_()
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (update * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
update.mul_(mask)
|
||||
|
||||
param.add_(update, alpha=-lr)
|
||||
|
||||
# Decay the momentum running average coefficient
|
||||
exp_avg.lerp_(grad, 1 - beta2)
|
||||
|
@ -199,6 +213,7 @@ def _multi_tensor_lion(
|
|||
beta2: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
):
|
||||
if len(params) == 0:
|
||||
|
@ -217,8 +232,17 @@ def _multi_tensor_lion(
|
|||
# Weight update
|
||||
updates = torch._foreach_mul(exp_avgs, beta1)
|
||||
torch._foreach_add_(updates, grads, alpha=1 - beta1)
|
||||
updates = [u.sign_() for u in updates]
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
masks = torch._foreach_mul(updates, grads)
|
||||
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
|
||||
mask_scale = [m.mean() for m in masks]
|
||||
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||
torch._foreach_div_(masks, mask_scale)
|
||||
torch._foreach_mul_(updates, masks)
|
||||
|
||||
updates = [u.sign() for u in updates]
|
||||
torch._foreach_add_(params, updates, alpha=-lr)
|
||||
|
||||
# Decay the momentum running average coefficient
|
||||
|
|
|
@ -5,44 +5,43 @@ Based on simplified algorithm in https://github.com/mlcommons/algorithmic-effici
|
|||
Added multi-tensor (foreach) path.
|
||||
"""
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
|
||||
class NAdamW(torch.optim.Optimizer):
|
||||
r"""Implements NAdamW algorithm.
|
||||
""" Implements NAdamW algorithm.
|
||||
|
||||
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
|
||||
the NAdam algorithm (there is also a comment in the code which highlights
|
||||
the only difference of NAdamW and AdamW).
|
||||
For further details regarding the algorithm we refer to
|
||||
`Decoupled Weight Decay Regularization`_.
|
||||
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
|
||||
the NAdam algorithm (there is also a comment in the code which highlights
|
||||
the only difference of NAdamW and AdamW).
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
||||
.. _Decoupled Weight Decay Regularization:
|
||||
https://arxiv.org/abs/1711.05101
|
||||
.. _On the Convergence of Adam and Beyond:
|
||||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
For further details regarding the algorithm we refer to
|
||||
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
|
||||
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
|
||||
|
||||
Args:
|
||||
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr: learning rate
|
||||
betas: coefficients used for computing running averages of gradient and its square
|
||||
eps: term added to the denominator to improve numerical stability
|
||||
weight_decay: weight decay coefficient
|
||||
caution: enable caution
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
params: ParamsT,
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 1e-2,
|
||||
caution: bool = False,
|
||||
maximize: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
|
@ -62,6 +61,7 @@ class NAdamW(torch.optim.Optimizer):
|
|||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
caution=caution,
|
||||
foreach=foreach,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
|
@ -71,11 +71,12 @@ class NAdamW(torch.optim.Optimizer):
|
|||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
state_values = list(self.state.values())
|
||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
||||
state_values[0]['step'])
|
||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
||||
if not step_is_tensor:
|
||||
for s in state_values:
|
||||
s['step'] = torch.tensor(float(s['step']))
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
@ -133,6 +134,7 @@ class NAdamW(torch.optim.Optimizer):
|
|||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
caution=group['caution'],
|
||||
maximize=group['maximize'],
|
||||
capturable=group['capturable'],
|
||||
)
|
||||
|
@ -154,6 +156,7 @@ def nadamw(
|
|||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
) -> None:
|
||||
r"""Functional API that performs NAdamW algorithm computation.
|
||||
|
@ -183,6 +186,7 @@ def nadamw(
|
|||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
eps=eps,
|
||||
caution=caution,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
)
|
||||
|
@ -200,6 +204,7 @@ def _single_tensor_nadamw(
|
|||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
capturable: bool
|
||||
):
|
||||
|
@ -238,6 +243,14 @@ def _single_tensor_nadamw(
|
|||
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
||||
|
||||
denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
# FIXME not 100% sure if this remains capturable?
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg.mul_(mask)
|
||||
|
||||
param.addcdiv_(exp_avg, denom)
|
||||
else:
|
||||
step = step_t.item()
|
||||
|
@ -246,11 +259,17 @@ def _single_tensor_nadamw(
|
|||
step_size = lr / bias_correction1
|
||||
bias_correction2_sqrt = math.sqrt(bias_correction2)
|
||||
|
||||
# Only difference between NAdamW and AdamW in this implementation.
|
||||
# Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
|
||||
# The official PyTorch implementation of NAdam uses a different algorithm.
|
||||
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
||||
|
||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg.mul_(mask)
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
|
@ -266,6 +285,7 @@ def _multi_tensor_nadamw(
|
|||
lr: float,
|
||||
weight_decay: float,
|
||||
eps: float,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
):
|
||||
|
@ -322,12 +342,22 @@ def _multi_tensor_nadamw(
|
|||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
|
||||
torch._foreach_div_(
|
||||
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
|
||||
exp_avg_sq_sqrt,
|
||||
torch._foreach_mul(bias_correction2_sqrt, step_size)
|
||||
)
|
||||
eps_over_step_size = torch._foreach_div(step_size, eps)
|
||||
torch._foreach_reciprocal_(eps_over_step_size)
|
||||
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
masks = torch._foreach_mul(exp_avgs, grads)
|
||||
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] # capturable?
|
||||
mask_scale = [m.mean() for m in masks]
|
||||
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||
torch._foreach_div_(masks, mask_scale)
|
||||
torch._foreach_mul_(exp_avgs, masks)
|
||||
|
||||
torch._foreach_addcdiv_(params, exp_avgs, denom)
|
||||
else:
|
||||
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
|
||||
|
@ -337,7 +367,7 @@ def _multi_tensor_nadamw(
|
|||
|
||||
bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2]
|
||||
|
||||
# Only difference between NAdamW and AdamW in this implementation.
|
||||
# Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
|
||||
# The official PyTorch implementation of NAdam uses a different algorithm.
|
||||
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
|
||||
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
|
||||
|
@ -346,4 +376,13 @@ def _multi_tensor_nadamw(
|
|||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
||||
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
|
||||
|
||||
if caution:
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
masks = torch._foreach_mul(exp_avgs, grads)
|
||||
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
|
||||
mask_scale = [m.mean() for m in masks]
|
||||
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||
torch._foreach_div_(masks, mask_scale)
|
||||
torch._foreach_mul_(exp_avgs, masks)
|
||||
|
||||
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
|
||||
|
|
|
@ -10,6 +10,8 @@ Modifications Copyright 2021 Ross Wightman
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
|
||||
class RMSpropTF(Optimizer):
|
||||
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
||||
|
@ -28,34 +30,31 @@ class RMSpropTF(Optimizer):
|
|||
The centered version first appears in `Generating Sequences
|
||||
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-2)
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
alpha (float, optional): smoothing (decay) constant (default: 0.9)
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-10)
|
||||
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
||||
the gradient is normalized by an estimation of its variance
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
||||
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
|
||||
update as per defaults in Tensorflow
|
||||
|
||||
Args:
|
||||
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||
lr: learning rate
|
||||
momentum: momentum factor
|
||||
alpha: smoothing (decay) constant
|
||||
eps: term added to the denominator to improve numerical stability
|
||||
centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
|
||||
weight_decay: weight decay (L2 penalty) (default: 0)
|
||||
decoupled_decay: decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
||||
lr_in_momentum: learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow
|
||||
caution: apply caution
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-2,
|
||||
alpha=0.9,
|
||||
eps=1e-10,
|
||||
weight_decay=0,
|
||||
momentum=0.,
|
||||
centered=False,
|
||||
decoupled_decay=False,
|
||||
lr_in_momentum=True,
|
||||
params: ParamsT,
|
||||
lr: float = 1e-2,
|
||||
alpha: float = 0.9,
|
||||
eps: float = 1e-10,
|
||||
weight_decay: float = 0,
|
||||
momentum: float = 0.,
|
||||
centered: bool = False,
|
||||
decoupled_decay: bool = False,
|
||||
lr_in_momentum: bool = True,
|
||||
caution: bool = False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
|
@ -77,6 +76,7 @@ class RMSpropTF(Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay,
|
||||
lr_in_momentum=lr_in_momentum,
|
||||
caution=caution,
|
||||
)
|
||||
super(RMSpropTF, self).__init__(params, defaults)
|
||||
|
||||
|
@ -85,6 +85,7 @@ class RMSpropTF(Optimizer):
|
|||
for group in self.param_groups:
|
||||
group.setdefault('momentum', 0)
|
||||
group.setdefault('centered', False)
|
||||
group.setdefault('caution', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
@ -142,13 +143,25 @@ class RMSpropTF(Optimizer):
|
|||
|
||||
if group['momentum'] > 0:
|
||||
buf = state['momentum_buffer']
|
||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||
buf.mul_(group['momentum'])
|
||||
|
||||
def _apply_caution(_m, _g):
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (_m * _g > 0).to(_g.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
return _m * mask
|
||||
|
||||
if group['lr_in_momentum']:
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||
buf.addcdiv_(grad, avg, value=group['lr'])
|
||||
if group['caution']:
|
||||
buf = _apply_caution(buf, grad)
|
||||
p.add_(-buf)
|
||||
else:
|
||||
# PyTorch scales the param update by LR
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
||||
buf.addcdiv_(grad, avg)
|
||||
if group['caution']:
|
||||
buf = _apply_caution(buf, grad)
|
||||
p.add_(buf, alpha=-group['lr'])
|
||||
else:
|
||||
p.addcdiv_(grad, avg, value=-group['lr'])
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from functools import update_wrapper, wraps
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
@ -8,7 +9,7 @@ try:
|
|||
except ImportError:
|
||||
has_recent_pt = False
|
||||
|
||||
from typing import List, Optional
|
||||
from ._types import ParamsT
|
||||
|
||||
__all__ = ['SGDW', 'sgdw']
|
||||
|
||||
|
@ -16,13 +17,14 @@ __all__ = ['SGDW', 'sgdw']
|
|||
class SGDW(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
momentum=0,
|
||||
dampening=0,
|
||||
weight_decay=0,
|
||||
nesterov=False,
|
||||
params: ParamsT,
|
||||
lr: float = 1e-3,
|
||||
momentum: float = 0.,
|
||||
dampening: float = 0.,
|
||||
weight_decay: float = 0.,
|
||||
nesterov: bool = False,
|
||||
*,
|
||||
caution: bool = False,
|
||||
maximize: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
differentiable: bool = False,
|
||||
|
@ -40,6 +42,7 @@ class SGDW(Optimizer):
|
|||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
caution=caution,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
differentiable=differentiable,
|
||||
|
@ -51,18 +54,19 @@ class SGDW(Optimizer):
|
|||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
group.setdefault('nesterov', False)
|
||||
group.setdefault('maximize', False)
|
||||
group.setdefault('foreach', None)
|
||||
group.setdefault('differentiable', False)
|
||||
|
||||
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
|
||||
def _init_group(self, group, params_with_grad, grads, momentum_buffer_list):
|
||||
has_sparse_grad = False
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
params_with_grad.append(p)
|
||||
d_p_list.append(p.grad)
|
||||
grads.append(p.grad)
|
||||
if p.grad.is_sparse:
|
||||
has_sparse_grad = True
|
||||
|
||||
|
@ -91,20 +95,21 @@ class SGDW(Optimizer):
|
|||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
d_p_list = []
|
||||
grads = []
|
||||
momentum_buffer_list = []
|
||||
|
||||
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
|
||||
has_sparse_grad = self._init_group(group, params_with_grad, grads, momentum_buffer_list)
|
||||
|
||||
sgdw(
|
||||
params_with_grad,
|
||||
d_p_list,
|
||||
grads,
|
||||
momentum_buffer_list,
|
||||
weight_decay=group['weight_decay'],
|
||||
momentum=group['momentum'],
|
||||
lr=group['lr'],
|
||||
dampening=group['dampening'],
|
||||
nesterov=group['nesterov'],
|
||||
caution=group['caution'],
|
||||
maximize=group['maximize'],
|
||||
has_sparse_grad=has_sparse_grad,
|
||||
foreach=group['foreach'],
|
||||
|
@ -120,7 +125,7 @@ class SGDW(Optimizer):
|
|||
|
||||
def sgdw(
|
||||
params: List[Tensor],
|
||||
d_p_list: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
|
@ -132,6 +137,7 @@ def sgdw(
|
|||
lr: float,
|
||||
dampening: float,
|
||||
nesterov: bool,
|
||||
caution: bool,
|
||||
maximize: bool
|
||||
):
|
||||
r"""Functional API that performs SGD algorithm computation.
|
||||
|
@ -159,13 +165,14 @@ def sgdw(
|
|||
|
||||
func(
|
||||
params,
|
||||
d_p_list,
|
||||
grads,
|
||||
momentum_buffer_list,
|
||||
weight_decay=weight_decay,
|
||||
momentum=momentum,
|
||||
lr=lr,
|
||||
dampening=dampening,
|
||||
nesterov=nesterov,
|
||||
caution=caution,
|
||||
has_sparse_grad=has_sparse_grad,
|
||||
maximize=maximize,
|
||||
)
|
||||
|
@ -173,7 +180,7 @@ def sgdw(
|
|||
|
||||
def _single_tensor_sgdw(
|
||||
params: List[Tensor],
|
||||
d_p_list: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
momentum_buffer_list: List[Optional[Tensor]],
|
||||
*,
|
||||
weight_decay: float,
|
||||
|
@ -181,11 +188,12 @@ def _single_tensor_sgdw(
|
|||
lr: float,
|
||||
dampening: float,
|
||||
nesterov: bool,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
has_sparse_grad: bool
|
||||
):
|
||||
for i, param in enumerate(params):
|
||||
d_p = d_p_list[i] if not maximize else -d_p_list[i]
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
|
||||
param.mul_(1. - lr * weight_decay)
|
||||
|
||||
|
@ -193,17 +201,25 @@ def _single_tensor_sgdw(
|
|||
buf = momentum_buffer_list[i]
|
||||
|
||||
if buf is None:
|
||||
buf = torch.clone(d_p).detach()
|
||||
buf = torch.clone(grad).detach()
|
||||
momentum_buffer_list[i] = buf
|
||||
else:
|
||||
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
||||
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
|
||||
|
||||
if nesterov:
|
||||
d_p = d_p.add(buf, alpha=momentum)
|
||||
if caution:
|
||||
if nesterov:
|
||||
buf = grad.add(buf, alpha=momentum)
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
mask = (buf * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
grad = buf * mask
|
||||
else:
|
||||
d_p = buf
|
||||
if nesterov:
|
||||
grad = grad.add(buf, alpha=momentum)
|
||||
else:
|
||||
grad = buf
|
||||
|
||||
param.add_(d_p, alpha=-lr)
|
||||
param.add_(grad, alpha=-lr)
|
||||
|
||||
|
||||
def _multi_tensor_sgdw(
|
||||
|
@ -216,6 +232,7 @@ def _multi_tensor_sgdw(
|
|||
lr: float,
|
||||
dampening: float,
|
||||
nesterov: bool,
|
||||
caution: bool,
|
||||
maximize: bool,
|
||||
has_sparse_grad: bool
|
||||
):
|
||||
|
@ -258,10 +275,22 @@ def _multi_tensor_sgdw(
|
|||
|
||||
bufs.append(buf)
|
||||
|
||||
if nesterov:
|
||||
torch._foreach_add_(device_grads, bufs, alpha=momentum)
|
||||
if caution:
|
||||
if nesterov:
|
||||
# Can't do nesterov in-place if we want to compare against orig grad for caution
|
||||
bufs = torch._foreach_add(device_grads, bufs, alpha=momentum)
|
||||
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||
masks = torch._foreach_mul(bufs, device_grads)
|
||||
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
|
||||
mask_scale = [m.mean() for m in masks]
|
||||
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||
torch._foreach_div_(masks, mask_scale)
|
||||
device_grads = torch._foreach_mul(bufs, masks)
|
||||
else:
|
||||
device_grads = bufs
|
||||
if nesterov:
|
||||
torch._foreach_add_(device_grads, bufs, alpha=momentum)
|
||||
else:
|
||||
device_grads = bufs
|
||||
|
||||
if not device_has_sparse_grad:
|
||||
torch._foreach_add_(device_params, device_grads, alpha=-lr)
|
||||
|
|
Loading…
Reference in New Issue