mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cautious optimizer impl plus some typing cleanup.
This commit is contained in:
parent
aeb1ed7a15
commit
3086dd03fd
@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
|
|||||||
assert isinstance(opt_info, OptimInfo)
|
assert isinstance(opt_info, OptimInfo)
|
||||||
|
|
||||||
lr = (1e-2,) * 4
|
lr = (1e-2,) * 4
|
||||||
if optimizer in ('mars',):
|
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
|
||||||
lr = (1e-3,) * 4
|
lr = (1e-3,) * 4
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -5,15 +5,16 @@ Hacked together by / Copyright 2021 Ross Wightman
|
|||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
from fnmatch import fnmatch
|
from fnmatch import fnmatch
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim
|
||||||
|
|
||||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
||||||
|
from ._types import ParamsT, OptimType, OptimizerCallable
|
||||||
from .adabelief import AdaBelief
|
from .adabelief import AdaBelief
|
||||||
from .adafactor import Adafactor
|
from .adafactor import Adafactor
|
||||||
from .adafactor_bv import AdafactorBigVision
|
from .adafactor_bv import AdafactorBigVision
|
||||||
@ -39,11 +40,6 @@ from .sgdw import SGDW
|
|||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Type variables
|
|
||||||
T = TypeVar('T')
|
|
||||||
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
|
|
||||||
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
|
|
||||||
|
|
||||||
|
|
||||||
def _import_class(class_string: str) -> Type:
|
def _import_class(class_string: str) -> Type:
|
||||||
"""Dynamically import a class from a string."""
|
"""Dynamically import a class from a string."""
|
||||||
@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
|
|||||||
raise ImportError(f"Could not import {class_string}: {e}")
|
raise ImportError(f"Could not import {class_string}: {e}")
|
||||||
|
|
||||||
|
|
||||||
class OptimizerCallable(Protocol):
|
|
||||||
"""Protocol for optimizer constructor signatures."""
|
|
||||||
|
|
||||||
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class OptimInfo:
|
class OptimInfo:
|
||||||
@ -76,7 +67,7 @@ class OptimInfo:
|
|||||||
defaults: Optional default parameters for the optimizer
|
defaults: Optional default parameters for the optimizer
|
||||||
"""
|
"""
|
||||||
name: str
|
name: str
|
||||||
opt_class: Union[str, Type[optim.Optimizer]]
|
opt_class: Union[str, OptimType]
|
||||||
description: str = ''
|
description: str = ''
|
||||||
has_eps: bool = True
|
has_eps: bool = True
|
||||||
has_momentum: bool = False
|
has_momentum: bool = False
|
||||||
@ -185,7 +176,7 @@ class OptimizerRegistry:
|
|||||||
self,
|
self,
|
||||||
name_or_info: Union[str, OptimInfo],
|
name_or_info: Union[str, OptimInfo],
|
||||||
bind_defaults: bool = True,
|
bind_defaults: bool = True,
|
||||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
) -> Union[OptimType, OptimizerCallable]:
|
||||||
"""Get the optimizer class with any default arguments applied.
|
"""Get the optimizer class with any default arguments applied.
|
||||||
|
|
||||||
This allows direct instantiation of optimizers with their default configs
|
This allows direct instantiation of optimizers with their default configs
|
||||||
@ -234,7 +225,7 @@ class OptimizerRegistry:
|
|||||||
|
|
||||||
def create_optimizer(
|
def create_optimizer(
|
||||||
self,
|
self,
|
||||||
model_or_params: Union[nn.Module, Params],
|
model_or_params: Union[nn.Module, ParamsT],
|
||||||
opt: str,
|
opt: str,
|
||||||
lr: Optional[float] = None,
|
lr: Optional[float] = None,
|
||||||
weight_decay: float = 0.,
|
weight_decay: float = 0.,
|
||||||
@ -242,9 +233,9 @@ class OptimizerRegistry:
|
|||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
weight_decay_exclude_1d: bool = True,
|
weight_decay_exclude_1d: bool = True,
|
||||||
layer_decay: Optional[float] = None,
|
layer_decay: Optional[float] = None,
|
||||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> optim.Optimizer:
|
) -> torch.optim.Optimizer:
|
||||||
"""Create an optimizer instance.
|
"""Create an optimizer instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -347,7 +338,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
|||||||
sgd_optimizers = [
|
sgd_optimizers = [
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='sgd',
|
name='sgd',
|
||||||
opt_class=optim.SGD,
|
opt_class=torch.optim.SGD,
|
||||||
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
|
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
|
||||||
has_eps=False,
|
has_eps=False,
|
||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
@ -355,7 +346,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
|||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='momentum',
|
name='momentum',
|
||||||
opt_class=optim.SGD,
|
opt_class=torch.optim.SGD,
|
||||||
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
|
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
|
||||||
has_eps=False,
|
has_eps=False,
|
||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||||||
adam_optimizers = [
|
adam_optimizers = [
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adam',
|
name='adam',
|
||||||
opt_class=optim.Adam,
|
opt_class=torch.optim.Adam,
|
||||||
description='torch.optim.Adam, Adaptive Moment Estimation',
|
description='torch.optim.Adam, Adaptive Moment Estimation',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adamw',
|
name='adamw',
|
||||||
opt_class=optim.AdamW,
|
opt_class=torch.optim.AdamW,
|
||||||
description='torch.optim.AdamW, Adam with decoupled weight decay',
|
description='torch.optim.AdamW, Adam with decoupled weight decay',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adamax',
|
name='adamax',
|
||||||
opt_class=optim.Adamax,
|
opt_class=torch.optim.Adamax,
|
||||||
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
|
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
|
|||||||
registry.register(opt)
|
registry.register(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
|
||||||
|
cautious_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='cadafactor',
|
||||||
|
opt_class=Adafactor,
|
||||||
|
description='Cautious Adafactor',
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cadafactorbv',
|
||||||
|
opt_class=AdafactorBigVision,
|
||||||
|
description='Cautious Big Vision Adafactor',
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cadamw',
|
||||||
|
opt_class=AdamWLegacy,
|
||||||
|
description='Cautious AdamW',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cadopt',
|
||||||
|
opt_class=Adopt,
|
||||||
|
description='Cautious Adopt',
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cadoptw',
|
||||||
|
opt_class=Adopt,
|
||||||
|
description='Cautious AdoptW (decoupled decay)',
|
||||||
|
defaults={'decoupled': True, 'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='clamb',
|
||||||
|
opt_class=Lamb,
|
||||||
|
description='Cautious LAMB',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='claprop',
|
||||||
|
opt_class=LaProp,
|
||||||
|
description='Cautious LaProp',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='clion',
|
||||||
|
opt_class=Lion,
|
||||||
|
description='Cautious Lion',
|
||||||
|
has_eps=False,
|
||||||
|
has_betas=True,
|
||||||
|
defaults = {'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='cnadamw',
|
||||||
|
opt_class=NAdamW,
|
||||||
|
description='Cautious NAdamW',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='crmsproptf',
|
||||||
|
opt_class=RMSpropTF,
|
||||||
|
description='Cautious TensorFlow-style RMSprop',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'alpha': 0.9, 'caution': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='csgdw',
|
||||||
|
opt_class=SGDW,
|
||||||
|
description='Cautious SGD with decoupled weight decay and Nesterov momentum',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True, 'caution': True}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in cautious_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
|
||||||
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
||||||
"""Register miscellaneous optimizers"""
|
"""Register miscellaneous optimizers"""
|
||||||
other_optimizers = [
|
other_optimizers = [
|
||||||
@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adadelta',
|
name='adadelta',
|
||||||
opt_class=optim.Adadelta,
|
opt_class=torch.optim.Adadelta,
|
||||||
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
|
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adagrad',
|
name='adagrad',
|
||||||
opt_class=optim.Adagrad,
|
opt_class=torch.optim.Adagrad,
|
||||||
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
|
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||||
defaults={'eps': 1e-8}
|
defaults={'eps': 1e-8}
|
||||||
),
|
),
|
||||||
@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='rmsprop',
|
name='rmsprop',
|
||||||
opt_class=optim.RMSprop,
|
opt_class=torch.optim.RMSprop,
|
||||||
description='torch.optim.RMSprop, Root Mean Square Propagation',
|
description='torch.optim.RMSprop, Root Mean Square Propagation',
|
||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
defaults={'alpha': 0.9}
|
defaults={'alpha': 0.9}
|
||||||
@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
|
|||||||
_register_other_optimizers(default_registry)
|
_register_other_optimizers(default_registry)
|
||||||
_register_apex_optimizers(default_registry)
|
_register_apex_optimizers(default_registry)
|
||||||
_register_bnb_optimizers(default_registry)
|
_register_bnb_optimizers(default_registry)
|
||||||
|
_register_cautious_optimizers(default_registry)
|
||||||
|
|
||||||
# Register aliases
|
# Register aliases
|
||||||
default_registry.register_alias('nesterov', 'sgd')
|
default_registry.register_alias('nesterov', 'sgd')
|
||||||
@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
|
|||||||
def get_optimizer_class(
|
def get_optimizer_class(
|
||||||
name: str,
|
name: str,
|
||||||
bind_defaults: bool = True,
|
bind_defaults: bool = True,
|
||||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
) -> Union[OptimType, OptimizerCallable]:
|
||||||
"""Get optimizer class by name with option to bind default arguments.
|
"""Get optimizer class by name with option to bind default arguments.
|
||||||
|
|
||||||
Retrieves the optimizer class or a partial function with default arguments bound.
|
Retrieves the optimizer class or a partial function with default arguments bound.
|
||||||
@ -874,7 +947,7 @@ def get_optimizer_class(
|
|||||||
|
|
||||||
|
|
||||||
def create_optimizer_v2(
|
def create_optimizer_v2(
|
||||||
model_or_params: Union[nn.Module, Params],
|
model_or_params: Union[nn.Module, ParamsT],
|
||||||
opt: str = 'sgd',
|
opt: str = 'sgd',
|
||||||
lr: Optional[float] = None,
|
lr: Optional[float] = None,
|
||||||
weight_decay: float = 0.,
|
weight_decay: float = 0.,
|
||||||
@ -882,9 +955,9 @@ def create_optimizer_v2(
|
|||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
filter_bias_and_bn: bool = True,
|
filter_bias_and_bn: bool = True,
|
||||||
layer_decay: Optional[float] = None,
|
layer_decay: Optional[float] = None,
|
||||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> optim.Optimizer:
|
) -> torch.optim.Optimizer:
|
||||||
"""Create an optimizer instance via timm registry.
|
"""Create an optimizer instance via timm registry.
|
||||||
|
|
||||||
Creates and configures an optimizer with appropriate parameter groups and settings.
|
Creates and configures an optimizer with appropriate parameter groups and settings.
|
||||||
@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
def create_optimizer(
|
||||||
|
args,
|
||||||
|
model: Union[nn.Module, ParamsT],
|
||||||
|
filter_bias_and_bn: bool = True,
|
||||||
|
) -> torch.optim.Optimizer:
|
||||||
""" Legacy optimizer factory for backwards compatibility.
|
""" Legacy optimizer factory for backwards compatibility.
|
||||||
NOTE: Use create_optimizer_v2 for new code.
|
NOTE: Use create_optimizer_v2 for new code.
|
||||||
"""
|
"""
|
||||||
|
25
timm/optim/_types.py
Normal file
25
timm/optim/_types.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Any, Dict, Iterable, Union, Protocol, Type
|
||||||
|
try:
|
||||||
|
from typing import TypeAlias, TypeVar
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import TypeAlias, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.optim.optimizer import ParamsT
|
||||||
|
except (ImportError, TypeError):
|
||||||
|
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
OptimType = Type[torch.optim.Optimizer]
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerCallable(Protocol):
|
||||||
|
"""Protocol for optimizer constructor signatures."""
|
||||||
|
|
||||||
|
def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']
|
@ -10,8 +10,12 @@ Original header/copyright below.
|
|||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
class Adafactor(torch.optim.Optimizer):
|
class Adafactor(torch.optim.Optimizer):
|
||||||
@ -26,33 +30,33 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||||
`relative_step=False`.
|
`relative_step=False`.
|
||||||
|
|
||||||
Arguments:
|
Ags:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
|
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||||
lr (float, optional): external learning rate (default: None)
|
lr: external learning rate
|
||||||
eps (tuple[float, float]): regularization constants for square gradient
|
eps: regularization constants for square gradient and parameter scale respectively
|
||||||
and parameter scale respectively (default: (1e-30, 1e-3))
|
eps_scale: regularization constants for parameter scale respectively
|
||||||
clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
|
clip_threshold: threshold of root-mean-square of final gradient update
|
||||||
decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
|
decay_rate: coefficient used to compute running averages of square gradient
|
||||||
beta1 (float): coefficient used for computing running averages of gradient (default: None)
|
beta1: coefficient used for computing running averages of gradient
|
||||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
weight_decay: weight decay
|
||||||
scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
|
scale_parameter: if True, learning rate is scaled by root-mean-square of parameter
|
||||||
warmup_init (bool): time-dependent learning rate computation depends on
|
warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used
|
||||||
whether warm-up initialization is being used (default: False)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=None,
|
lr: Optional[float] = None,
|
||||||
eps=1e-30,
|
eps: float = 1e-30,
|
||||||
eps_scale=1e-3,
|
eps_scale: float = 1e-3,
|
||||||
clip_threshold=1.0,
|
clip_threshold: float = 1.0,
|
||||||
decay_rate=-0.8,
|
decay_rate: float = -0.8,
|
||||||
betas=None,
|
betas: Optional[Tuple[float, float]] = None,
|
||||||
weight_decay=0.0,
|
weight_decay: float = 0.0,
|
||||||
scale_parameter=True,
|
scale_parameter: bool = True,
|
||||||
warmup_init=False,
|
warmup_init: bool = False,
|
||||||
min_dim_size_to_factor=32,
|
min_dim_size_to_factor: int = 16,
|
||||||
|
caution: bool = False,
|
||||||
):
|
):
|
||||||
relative_step = not lr
|
relative_step = not lr
|
||||||
if warmup_init and not relative_step:
|
if warmup_init and not relative_step:
|
||||||
@ -71,9 +75,16 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
relative_step=relative_step,
|
relative_step=relative_step,
|
||||||
warmup_init=warmup_init,
|
warmup_init=warmup_init,
|
||||||
min_dim_size_to_factor=min_dim_size_to_factor,
|
min_dim_size_to_factor=min_dim_size_to_factor,
|
||||||
|
caution=caution,
|
||||||
)
|
)
|
||||||
super(Adafactor, self).__init__(params, defaults)
|
super(Adafactor, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('caution', False)
|
||||||
|
group.setdefault('min_dim_size_to_factor', 32)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_lr(param_group, param_state):
|
def _get_lr(param_group, param_state):
|
||||||
if param_group['relative_step']:
|
if param_group['relative_step']:
|
||||||
@ -86,7 +97,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
return param_group['lr']
|
return param_group['lr']
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_options(param_group, param_shape, min_size_to_factor=32):
|
def _get_options(param_group, param_shape, min_size_to_factor=16):
|
||||||
use_first_moment = param_group['beta1'] is not None
|
use_first_moment = param_group['beta1'] is not None
|
||||||
factored = None
|
factored = None
|
||||||
ndim = len(param_shape)
|
ndim = len(param_shape)
|
||||||
@ -98,7 +109,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
||||||
factored = 0, 1
|
factored = 0, 1
|
||||||
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
||||||
# if the criteria above didn't match, test trailing dims for eligibility
|
# if the criteria above didn't match, test trailing dims for eligibility as per original impl
|
||||||
factored = ndim - 2, ndim - 1
|
factored = ndim - 2, ndim - 1
|
||||||
|
|
||||||
return factored, use_first_moment
|
return factored, use_first_moment
|
||||||
@ -113,7 +124,6 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
|
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
|
||||||
return torch.mul(r_factor, c_factor)
|
return torch.mul(r_factor, c_factor)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
@ -201,6 +211,12 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
if use_first_moment:
|
if use_first_moment:
|
||||||
exp_avg = state['exp_avg']
|
exp_avg = state['exp_avg']
|
||||||
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
|
exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
|
||||||
|
if group['caution']:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
update = exp_avg * mask
|
||||||
|
else:
|
||||||
update = exp_avg
|
update = exp_avg
|
||||||
|
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
|
@ -6,13 +6,14 @@ Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
|
|||||||
|
|
||||||
Adaptation and PyTorch modifications by Ross Wightman
|
Adaptation and PyTorch modifications by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
def _get_scalar_dtype():
|
def _get_scalar_dtype():
|
||||||
"""Get the scalar dtype that the optimizer uses for state"""
|
"""Get the scalar dtype that the optimizer uses for state"""
|
||||||
@ -54,9 +55,9 @@ class AdafactorBigVision(Optimizer):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr: float = 1.0,
|
lr: float = 1.0,
|
||||||
min_dim_size_to_factor: int = 32,
|
min_dim_size_to_factor: int = 16,
|
||||||
decay_rate: float = 0.8,
|
decay_rate: float = 0.8,
|
||||||
decay_offset: int = 0,
|
decay_offset: int = 0,
|
||||||
beta2_cap: float = 0.999,
|
beta2_cap: float = 0.999,
|
||||||
@ -66,6 +67,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
clipping_threshold: Optional[float] = None,
|
clipping_threshold: Optional[float] = None,
|
||||||
unscaled_wd: bool = False,
|
unscaled_wd: bool = False,
|
||||||
|
caution: bool = False,
|
||||||
*,
|
*,
|
||||||
foreach: Optional[bool] = False,
|
foreach: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
@ -91,6 +93,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
clipping_threshold=clipping_threshold,
|
clipping_threshold=clipping_threshold,
|
||||||
unscaled_wd=unscaled_wd,
|
unscaled_wd=unscaled_wd,
|
||||||
|
caution=caution,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
)
|
)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
@ -98,6 +101,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
|
group.setdefault('caution', False)
|
||||||
group.setdefault('foreach', None)
|
group.setdefault('foreach', None)
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
p_state = self.state.get(p, {})
|
p_state = self.state.get(p, {})
|
||||||
@ -192,6 +196,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
momentum_dtype=group['momentum_dtype'],
|
momentum_dtype=group['momentum_dtype'],
|
||||||
clipping_threshold=group['clipping_threshold'],
|
clipping_threshold=group['clipping_threshold'],
|
||||||
unscaled_wd=group['unscaled_wd'],
|
unscaled_wd=group['unscaled_wd'],
|
||||||
|
caution=group['caution'],
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
@ -216,6 +221,7 @@ def _single_tensor_adafactor(
|
|||||||
momentum_dtype: Union[str, torch.dtype],
|
momentum_dtype: Union[str, torch.dtype],
|
||||||
clipping_threshold: Optional[float],
|
clipping_threshold: Optional[float],
|
||||||
unscaled_wd: bool,
|
unscaled_wd: bool,
|
||||||
|
caution: bool,
|
||||||
):
|
):
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
grad = grads[i]
|
grad = grads[i]
|
||||||
@ -267,6 +273,12 @@ def _single_tensor_adafactor(
|
|||||||
exp_avg.lerp_(update, 1 - momentum) # ema
|
exp_avg.lerp_(update, 1 - momentum) # ema
|
||||||
update = exp_avg.clone()
|
update = exp_avg.clone()
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (update * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
update.mul_(mask)
|
||||||
|
|
||||||
# Scale by learning rate
|
# Scale by learning rate
|
||||||
update.mul_(lr)
|
update.mul_(lr)
|
||||||
|
|
||||||
@ -302,6 +314,7 @@ def _multi_tensor_adafactor(
|
|||||||
momentum_dtype: Union[str, torch.dtype],
|
momentum_dtype: Union[str, torch.dtype],
|
||||||
clipping_threshold: Optional[float],
|
clipping_threshold: Optional[float],
|
||||||
unscaled_wd: bool,
|
unscaled_wd: bool,
|
||||||
|
caution: bool,
|
||||||
):
|
):
|
||||||
# FIXME TODO
|
# FIXME TODO
|
||||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
||||||
|
@ -4,49 +4,45 @@ Impl copied from PyTorch master
|
|||||||
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
class AdamWLegacy(Optimizer):
|
class AdamWLegacy(Optimizer):
|
||||||
r"""Implements AdamW algorithm.
|
r"""Implements AdamW algorithm.
|
||||||
|
|
||||||
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
||||||
|
|
||||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
References:
|
||||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
- Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980
|
||||||
|
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
|
||||||
|
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining
|
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||||
parameter groups
|
lr: learning rate
|
||||||
lr (float, optional): learning rate (default: 1e-3)
|
betas: coefficients used for computing running averages of gradient and its square
|
||||||
betas (Tuple[float, float], optional): coefficients used for computing
|
eps: term added to the denominator to improve numerical stability
|
||||||
running averages of gradient and its square (default: (0.9, 0.999))
|
weight_decay: weight decay coefficient
|
||||||
eps (float, optional): term added to the denominator to improve
|
amsgrad: whether to use the AMSGrad variant of this algorithm
|
||||||
numerical stability (default: 1e-8)
|
from the paper `On the Convergence of Adam and Beyond`
|
||||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
caution: apply caution when using AdamW
|
||||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
|
||||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
|
||||||
(default: False)
|
|
||||||
|
|
||||||
.. _Adam\: A Method for Stochastic Optimization:
|
|
||||||
https://arxiv.org/abs/1412.6980
|
|
||||||
.. _Decoupled Weight Decay Regularization:
|
|
||||||
https://arxiv.org/abs/1711.05101
|
|
||||||
.. _On the Convergence of Adam and Beyond:
|
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=1e-3,
|
lr: float = 1e-3,
|
||||||
betas=(0.9, 0.999),
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
eps=1e-8,
|
eps: float = 1e-8,
|
||||||
weight_decay=1e-2,
|
weight_decay: float = 1e-2,
|
||||||
amsgrad=False,
|
amsgrad: bool = False,
|
||||||
|
caution: bool = False,
|
||||||
):
|
):
|
||||||
# NOTE: deprecated in favour of builtin torch.optim.AdamW
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
@ -61,6 +57,7 @@ class AdamWLegacy(Optimizer):
|
|||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
amsgrad=amsgrad,
|
amsgrad=amsgrad,
|
||||||
|
caution=caution,
|
||||||
)
|
)
|
||||||
super(AdamWLegacy, self).__init__(params, defaults)
|
super(AdamWLegacy, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -68,6 +65,7 @@ class AdamWLegacy(Optimizer):
|
|||||||
super(AdamWLegacy, self).__setstate__(state)
|
super(AdamWLegacy, self).__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('amsgrad', False)
|
group.setdefault('amsgrad', False)
|
||||||
|
group.setdefault('caution', False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
@ -131,6 +129,12 @@ class AdamWLegacy(Optimizer):
|
|||||||
|
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group['lr'] / bias_correction1
|
||||||
|
|
||||||
|
if group['caution']:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
exp_avg = exp_avg * mask
|
||||||
|
|
||||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -10,16 +10,15 @@ Modified for reduced dependencies on PyTorch internals from original at: https:/
|
|||||||
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
|
title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate},
|
||||||
year = {2024}
|
year = {2024}
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from typing import cast, List, Optional, Tuple, Union
|
||||||
from typing import cast, Callable, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
__all__ = ["Adopt", "adopt"]
|
__all__ = ["Adopt", "adopt"]
|
||||||
|
|
||||||
def _view_as_real(params, *state_and_grads):
|
def _view_as_real(params, *state_and_grads):
|
||||||
@ -60,7 +59,7 @@ class Adopt(Optimizer):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr: Union[float, Tensor] = 1e-3,
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
@ -68,7 +67,8 @@ class Adopt(Optimizer):
|
|||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
decoupled: bool = False,
|
decoupled: bool = False,
|
||||||
*,
|
*,
|
||||||
foreach: Optional[bool] = None,
|
caution: bool = False,
|
||||||
|
foreach: Optional[bool] = False,
|
||||||
maximize: bool = False,
|
maximize: bool = False,
|
||||||
capturable: bool = False,
|
capturable: bool = False,
|
||||||
differentiable: bool = False,
|
differentiable: bool = False,
|
||||||
@ -98,6 +98,7 @@ class Adopt(Optimizer):
|
|||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
clip_exp=clip_exp,
|
clip_exp=clip_exp,
|
||||||
decoupled=decoupled,
|
decoupled=decoupled,
|
||||||
|
caution=caution,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
@ -105,7 +106,6 @@ class Adopt(Optimizer):
|
|||||||
)
|
)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
@ -114,6 +114,7 @@ class Adopt(Optimizer):
|
|||||||
group.setdefault("capturable", False)
|
group.setdefault("capturable", False)
|
||||||
group.setdefault("differentiable", False)
|
group.setdefault("differentiable", False)
|
||||||
group.setdefault("clip_exp", None)
|
group.setdefault("clip_exp", None)
|
||||||
|
group.setdefault("caution", False)
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
p_state = self.state.get(p, [])
|
p_state = self.state.get(p, [])
|
||||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||||
@ -223,6 +224,7 @@ class Adopt(Optimizer):
|
|||||||
clip_exp=group["clip_exp"],
|
clip_exp=group["clip_exp"],
|
||||||
decoupled=group["decoupled"],
|
decoupled=group["decoupled"],
|
||||||
eps=group["eps"],
|
eps=group["eps"],
|
||||||
|
caution=group["caution"],
|
||||||
maximize=group["maximize"],
|
maximize=group["maximize"],
|
||||||
foreach=group["foreach"],
|
foreach=group["foreach"],
|
||||||
capturable=group["capturable"],
|
capturable=group["capturable"],
|
||||||
@ -251,6 +253,7 @@ def _single_tensor_adopt(
|
|||||||
clip_exp: Optional[float],
|
clip_exp: Optional[float],
|
||||||
decoupled: bool,
|
decoupled: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
differentiable: bool,
|
differentiable: bool,
|
||||||
@ -306,6 +309,13 @@ def _single_tensor_adopt(
|
|||||||
normed_grad.clamp_(-clip_val, clip_val)
|
normed_grad.clamp_(-clip_val, clip_val)
|
||||||
|
|
||||||
exp_avg.lerp_(normed_grad, 1 - beta1)
|
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
exp_avg = exp_avg * mask
|
||||||
|
|
||||||
param.add_(exp_avg, alpha=-lr)
|
param.add_(exp_avg, alpha=-lr)
|
||||||
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
@ -328,6 +338,7 @@ def _multi_tensor_adopt(
|
|||||||
clip_exp: Optional[float],
|
clip_exp: Optional[float],
|
||||||
decoupled: bool,
|
decoupled: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
differentiable: bool,
|
differentiable: bool,
|
||||||
@ -403,6 +414,7 @@ def _multi_tensor_adopt(
|
|||||||
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||||
|
|
||||||
if clip_exp is not None:
|
if clip_exp is not None:
|
||||||
@ -411,6 +423,16 @@ def _multi_tensor_adopt(
|
|||||||
torch._foreach_minimum_(normed_grad, clip_val)
|
torch._foreach_minimum_(normed_grad, clip_val)
|
||||||
|
|
||||||
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
masks = torch._foreach_mul(device_exp_avgs, device_grads)
|
||||||
|
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
|
||||||
|
mask_scale = [m.mean() for m in masks]
|
||||||
|
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||||
|
torch._foreach_div_(masks, mask_scale)
|
||||||
|
device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks)
|
||||||
|
|
||||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
|
||||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
@ -440,6 +462,7 @@ def adopt(
|
|||||||
clip_exp: Optional[float],
|
clip_exp: Optional[float],
|
||||||
decoupled: bool,
|
decoupled: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
):
|
):
|
||||||
r"""Functional API that performs ADOPT algorithm computation.
|
r"""Functional API that performs ADOPT algorithm computation.
|
||||||
@ -477,6 +500,7 @@ def adopt(
|
|||||||
clip_exp=clip_exp,
|
clip_exp=clip_exp,
|
||||||
decoupled=decoupled,
|
decoupled=decoupled,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
|
caution=caution,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
differentiable=differentiable,
|
differentiable=differentiable,
|
||||||
|
@ -52,50 +52,48 @@ Modifications Copyright 2021 Ross Wightman
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
class Lamb(Optimizer):
|
class Lamb(Optimizer):
|
||||||
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
|
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
|
||||||
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||||
|
|
||||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
LAMB was proposed in:
|
||||||
|
- Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962
|
||||||
|
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
params: Iterable of parameters to optimize or dicts defining parameter groups.
|
||||||
lr (float, optional): learning rate. (default: 1e-3)
|
lr: Learning rate
|
||||||
betas (Tuple[float, float], optional): coefficients used for computing
|
betas: Coefficients used for computing running averages of gradient and its norm.
|
||||||
running averages of gradient and its norm. (default: (0.9, 0.999))
|
eps: Term added to the denominator to improve numerical stability.
|
||||||
eps (float, optional): term added to the denominator to improve
|
weight_decay: Weight decay
|
||||||
numerical stability. (default: 1e-8)
|
grad_averaging: Whether apply (1-beta2) to grad when calculating running averages of gradient.
|
||||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
max_grad_norm: Value used to clip global grad norm.
|
||||||
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
|
trust_clip: Enable LAMBC trust ratio clipping.
|
||||||
calculating running averages of gradient. (default: True)
|
always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter.
|
||||||
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
|
caution: Apply caution.
|
||||||
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
|
|
||||||
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
|
|
||||||
weight decay parameter (default: False)
|
|
||||||
|
|
||||||
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
|
|
||||||
https://arxiv.org/abs/1904.00962
|
|
||||||
.. _On the Convergence of Adam and Beyond:
|
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=1e-3,
|
lr: float = 1e-3,
|
||||||
bias_correction=True,
|
bias_correction: bool = True,
|
||||||
betas=(0.9, 0.999),
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
eps=1e-6,
|
eps: float = 1e-6,
|
||||||
weight_decay=0.01,
|
weight_decay: float = 0.01,
|
||||||
grad_averaging=True,
|
grad_averaging: bool = True,
|
||||||
max_grad_norm=1.0,
|
max_grad_norm: Optional[float] = 1.0,
|
||||||
trust_clip=False,
|
trust_clip: bool = False,
|
||||||
always_adapt=False,
|
always_adapt: bool = False,
|
||||||
|
caution: bool = False,
|
||||||
):
|
):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
@ -107,9 +105,15 @@ class Lamb(Optimizer):
|
|||||||
max_grad_norm=max_grad_norm,
|
max_grad_norm=max_grad_norm,
|
||||||
trust_clip=trust_clip,
|
trust_clip=trust_clip,
|
||||||
always_adapt=always_adapt,
|
always_adapt=always_adapt,
|
||||||
|
caution=caution,
|
||||||
)
|
)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('caution', False)
|
||||||
|
|
||||||
def _get_clip_grad_norm(self):
|
def _get_clip_grad_norm(self):
|
||||||
max_grad_norm = self.defaults['max_grad_norm']
|
max_grad_norm = self.defaults['max_grad_norm']
|
||||||
if max_grad_norm is None:
|
if max_grad_norm is None:
|
||||||
@ -187,6 +191,12 @@ class Lamb(Optimizer):
|
|||||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||||
update = (exp_avg / bias_correction1).div_(denom)
|
update = (exp_avg / bias_correction1).div_(denom)
|
||||||
|
|
||||||
|
if group['caution']:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (update * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
update.mul_(mask)
|
||||||
|
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group['weight_decay']
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
update.add_(p, alpha=weight_decay)
|
update.add_(p, alpha=weight_decay)
|
||||||
|
@ -12,9 +12,13 @@ Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs
|
|||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
class LaProp(Optimizer):
|
class LaProp(Optimizer):
|
||||||
""" LaProp Optimizer
|
""" LaProp Optimizer
|
||||||
@ -23,11 +27,12 @@ class LaProp(Optimizer):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=4e-4,
|
lr: float = 4e-4,
|
||||||
betas=(0.9, 0.999),
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
eps=1e-15,
|
eps: float = 1e-15,
|
||||||
weight_decay=0,
|
weight_decay: float = 0.,
|
||||||
|
caution: bool = False,
|
||||||
):
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
@ -42,6 +47,7 @@ class LaProp(Optimizer):
|
|||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
caution=caution,
|
||||||
)
|
)
|
||||||
super(LaProp, self).__init__(params, defaults)
|
super(LaProp, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -101,7 +107,14 @@ class LaProp(Optimizer):
|
|||||||
step_of_this_grad = grad / denom
|
step_of_this_grad = grad / denom
|
||||||
exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
|
exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
|
||||||
|
|
||||||
|
if group['caution']:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
exp_avg = exp_avg * mask
|
||||||
|
|
||||||
p.add_(exp_avg, alpha=-step_size)
|
p.add_(exp_avg, alpha=-step_size)
|
||||||
|
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
p.add_(p, alpha=-group['weight_decay'])
|
p.add_(p, alpha=-group['weight_decay'])
|
||||||
|
|
||||||
|
@ -16,33 +16,35 @@ Original Impl: https://github.com/google/automl/tree/master/lion
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from typing import List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
class Lion(Optimizer):
|
class Lion(Optimizer):
|
||||||
r"""Implements Lion algorithm."""
|
r"""Implements Lion algorithm."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=1e-4,
|
lr: float = 1e-4,
|
||||||
betas=(0.9, 0.99),
|
betas: Tuple[float, float] = (0.9, 0.99),
|
||||||
weight_decay=0.0,
|
weight_decay: float = 0.0,
|
||||||
maximize=False,
|
caution: bool = False,
|
||||||
foreach=None,
|
maximize: bool = False,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the hyperparameters.
|
"""Initialize the hyperparameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining
|
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||||
parameter groups
|
lr: learning rate
|
||||||
lr (float, optional): learning rate (default: 1e-4)
|
betas: coefficients used for computing running averages of gradient and its square
|
||||||
betas (Tuple[float, float], optional): coefficients used for computing
|
weight_decay: weight decay coefficient
|
||||||
running averages of gradient and its square (default: (0.9, 0.99))
|
caution: apply caution
|
||||||
weight_decay (float, optional): weight decay coefficient (default: 0)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
@ -55,6 +57,7 @@ class Lion(Optimizer):
|
|||||||
lr=lr,
|
lr=lr,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
caution=caution,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
)
|
)
|
||||||
@ -63,6 +66,7 @@ class Lion(Optimizer):
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
|
group.setdefault('caution', False)
|
||||||
group.setdefault('maximize', False)
|
group.setdefault('maximize', False)
|
||||||
group.setdefault('foreach', None)
|
group.setdefault('foreach', None)
|
||||||
|
|
||||||
@ -71,8 +75,7 @@ class Lion(Optimizer):
|
|||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
closure (callable, optional): A closure that reevaluates the model
|
closure: A closure that reevaluates the model and returns the loss.
|
||||||
and returns the loss.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the loss.
|
the loss.
|
||||||
@ -112,6 +115,7 @@ class Lion(Optimizer):
|
|||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=group['lr'],
|
lr=group['lr'],
|
||||||
weight_decay=group['weight_decay'],
|
weight_decay=group['weight_decay'],
|
||||||
|
caution=group['caution'],
|
||||||
maximize=group['maximize'],
|
maximize=group['maximize'],
|
||||||
foreach=group['foreach'],
|
foreach=group['foreach'],
|
||||||
)
|
)
|
||||||
@ -132,6 +136,7 @@ def lion(
|
|||||||
beta2: float,
|
beta2: float,
|
||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
|
caution: bool,
|
||||||
):
|
):
|
||||||
r"""Functional API that performs Lion algorithm computation.
|
r"""Functional API that performs Lion algorithm computation.
|
||||||
"""
|
"""
|
||||||
@ -155,6 +160,7 @@ def lion(
|
|||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
caution=caution,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -168,6 +174,7 @@ def _single_tensor_lion(
|
|||||||
beta2: float,
|
beta2: float,
|
||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
):
|
):
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
@ -183,8 +190,15 @@ def _single_tensor_lion(
|
|||||||
param.mul_(1 - lr * weight_decay)
|
param.mul_(1 - lr * weight_decay)
|
||||||
|
|
||||||
# Weight update
|
# Weight update
|
||||||
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_()
|
||||||
param.add_(torch.sign(update), alpha=-lr)
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (update * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
update.mul_(mask)
|
||||||
|
|
||||||
|
param.add_(update, alpha=-lr)
|
||||||
|
|
||||||
# Decay the momentum running average coefficient
|
# Decay the momentum running average coefficient
|
||||||
exp_avg.lerp_(grad, 1 - beta2)
|
exp_avg.lerp_(grad, 1 - beta2)
|
||||||
@ -199,6 +213,7 @@ def _multi_tensor_lion(
|
|||||||
beta2: float,
|
beta2: float,
|
||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
):
|
):
|
||||||
if len(params) == 0:
|
if len(params) == 0:
|
||||||
@ -217,8 +232,17 @@ def _multi_tensor_lion(
|
|||||||
# Weight update
|
# Weight update
|
||||||
updates = torch._foreach_mul(exp_avgs, beta1)
|
updates = torch._foreach_mul(exp_avgs, beta1)
|
||||||
torch._foreach_add_(updates, grads, alpha=1 - beta1)
|
torch._foreach_add_(updates, grads, alpha=1 - beta1)
|
||||||
|
updates = [u.sign_() for u in updates]
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
masks = torch._foreach_mul(updates, grads)
|
||||||
|
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
|
||||||
|
mask_scale = [m.mean() for m in masks]
|
||||||
|
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||||
|
torch._foreach_div_(masks, mask_scale)
|
||||||
|
torch._foreach_mul_(updates, masks)
|
||||||
|
|
||||||
updates = [u.sign() for u in updates]
|
|
||||||
torch._foreach_add_(params, updates, alpha=-lr)
|
torch._foreach_add_(params, updates, alpha=-lr)
|
||||||
|
|
||||||
# Decay the momentum running average coefficient
|
# Decay the momentum running average coefficient
|
||||||
|
@ -5,44 +5,43 @@ Based on simplified algorithm in https://github.com/mlcommons/algorithmic-effici
|
|||||||
Added multi-tensor (foreach) path.
|
Added multi-tensor (foreach) path.
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
|
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
|
||||||
class NAdamW(torch.optim.Optimizer):
|
class NAdamW(torch.optim.Optimizer):
|
||||||
r"""Implements NAdamW algorithm.
|
""" Implements NAdamW algorithm.
|
||||||
|
|
||||||
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
|
See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
|
||||||
the NAdam algorithm (there is also a comment in the code which highlights
|
the NAdam algorithm (there is also a comment in the code which highlights
|
||||||
the only difference of NAdamW and AdamW).
|
the only difference of NAdamW and AdamW).
|
||||||
|
|
||||||
For further details regarding the algorithm we refer to
|
For further details regarding the algorithm we refer to
|
||||||
`Decoupled Weight Decay Regularization`_.
|
- Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
|
||||||
|
- On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining
|
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||||
parameter groups
|
lr: learning rate
|
||||||
lr (float, optional): learning rate (default: 1e-3)
|
betas: coefficients used for computing running averages of gradient and its square
|
||||||
betas (Tuple[float, float], optional): coefficients used for computing
|
eps: term added to the denominator to improve numerical stability
|
||||||
running averages of gradient and its square (default: (0.9, 0.999))
|
weight_decay: weight decay coefficient
|
||||||
eps (float, optional): term added to the denominator to improve
|
caution: enable caution
|
||||||
numerical stability (default: 1e-8)
|
|
||||||
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
|
|
||||||
.. _Decoupled Weight Decay Regularization:
|
|
||||||
https://arxiv.org/abs/1711.05101
|
|
||||||
.. _On the Convergence of Adam and Beyond:
|
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=1e-3,
|
lr: float = 1e-3,
|
||||||
betas=(0.9, 0.999),
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
eps=1e-8,
|
eps: float = 1e-8,
|
||||||
weight_decay=1e-2,
|
weight_decay: float = 1e-2,
|
||||||
|
caution: bool = False,
|
||||||
maximize: bool = False,
|
maximize: bool = False,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
capturable: bool = False,
|
capturable: bool = False,
|
||||||
@ -62,6 +61,7 @@ class NAdamW(torch.optim.Optimizer):
|
|||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
|
caution=caution,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
@ -71,11 +71,12 @@ class NAdamW(torch.optim.Optimizer):
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
state_values = list(self.state.values())
|
state_values = list(self.state.values())
|
||||||
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
|
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
|
||||||
state_values[0]['step'])
|
|
||||||
if not step_is_tensor:
|
if not step_is_tensor:
|
||||||
for s in state_values:
|
for s in state_values:
|
||||||
s['step'] = torch.tensor(float(s['step']))
|
s['step'] = torch.tensor(float(s['step']))
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('caution', False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
@ -133,6 +134,7 @@ class NAdamW(torch.optim.Optimizer):
|
|||||||
lr=group['lr'],
|
lr=group['lr'],
|
||||||
weight_decay=group['weight_decay'],
|
weight_decay=group['weight_decay'],
|
||||||
eps=group['eps'],
|
eps=group['eps'],
|
||||||
|
caution=group['caution'],
|
||||||
maximize=group['maximize'],
|
maximize=group['maximize'],
|
||||||
capturable=group['capturable'],
|
capturable=group['capturable'],
|
||||||
)
|
)
|
||||||
@ -154,6 +156,7 @@ def nadamw(
|
|||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""Functional API that performs NAdamW algorithm computation.
|
r"""Functional API that performs NAdamW algorithm computation.
|
||||||
@ -183,6 +186,7 @@ def nadamw(
|
|||||||
lr=lr,
|
lr=lr,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
|
caution=caution,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
)
|
)
|
||||||
@ -200,6 +204,7 @@ def _single_tensor_nadamw(
|
|||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool
|
capturable: bool
|
||||||
):
|
):
|
||||||
@ -238,6 +243,14 @@ def _single_tensor_nadamw(
|
|||||||
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
||||||
|
|
||||||
denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
|
denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
# FIXME not 100% sure if this remains capturable?
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
exp_avg.mul_(mask)
|
||||||
|
|
||||||
param.addcdiv_(exp_avg, denom)
|
param.addcdiv_(exp_avg, denom)
|
||||||
else:
|
else:
|
||||||
step = step_t.item()
|
step = step_t.item()
|
||||||
@ -246,11 +259,17 @@ def _single_tensor_nadamw(
|
|||||||
step_size = lr / bias_correction1
|
step_size = lr / bias_correction1
|
||||||
bias_correction2_sqrt = math.sqrt(bias_correction2)
|
bias_correction2_sqrt = math.sqrt(bias_correction2)
|
||||||
|
|
||||||
# Only difference between NAdamW and AdamW in this implementation.
|
# Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
|
||||||
# The official PyTorch implementation of NAdam uses a different algorithm.
|
# The official PyTorch implementation of NAdam uses a different algorithm.
|
||||||
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
|
||||||
|
|
||||||
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
exp_avg.mul_(mask)
|
||||||
|
|
||||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
|
|
||||||
@ -266,6 +285,7 @@ def _multi_tensor_nadamw(
|
|||||||
lr: float,
|
lr: float,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
eps: float,
|
eps: float,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
):
|
):
|
||||||
@ -322,12 +342,22 @@ def _multi_tensor_nadamw(
|
|||||||
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
|
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
|
||||||
torch._foreach_div_(
|
torch._foreach_div_(
|
||||||
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
|
exp_avg_sq_sqrt,
|
||||||
|
torch._foreach_mul(bias_correction2_sqrt, step_size)
|
||||||
)
|
)
|
||||||
eps_over_step_size = torch._foreach_div(step_size, eps)
|
eps_over_step_size = torch._foreach_div(step_size, eps)
|
||||||
torch._foreach_reciprocal_(eps_over_step_size)
|
torch._foreach_reciprocal_(eps_over_step_size)
|
||||||
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
|
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
masks = torch._foreach_mul(exp_avgs, grads)
|
||||||
|
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] # capturable?
|
||||||
|
mask_scale = [m.mean() for m in masks]
|
||||||
|
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||||
|
torch._foreach_div_(masks, mask_scale)
|
||||||
|
torch._foreach_mul_(exp_avgs, masks)
|
||||||
|
|
||||||
torch._foreach_addcdiv_(params, exp_avgs, denom)
|
torch._foreach_addcdiv_(params, exp_avgs, denom)
|
||||||
else:
|
else:
|
||||||
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
|
bias_correction1 = [1 - beta1 ** step.item() for step in state_steps]
|
||||||
@ -337,7 +367,7 @@ def _multi_tensor_nadamw(
|
|||||||
|
|
||||||
bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2]
|
bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2]
|
||||||
|
|
||||||
# Only difference between NAdamW and AdamW in this implementation.
|
# Apply Nesterov. Only difference between NAdamW and AdamW in this implementation.
|
||||||
# The official PyTorch implementation of NAdam uses a different algorithm.
|
# The official PyTorch implementation of NAdam uses a different algorithm.
|
||||||
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
|
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
|
||||||
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
|
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
|
||||||
@ -346,4 +376,13 @@ def _multi_tensor_nadamw(
|
|||||||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
||||||
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
|
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
masks = torch._foreach_mul(exp_avgs, grads)
|
||||||
|
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)]
|
||||||
|
mask_scale = [m.mean() for m in masks]
|
||||||
|
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||||
|
torch._foreach_div_(masks, mask_scale)
|
||||||
|
torch._foreach_mul_(exp_avgs, masks)
|
||||||
|
|
||||||
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
|
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
|
||||||
|
@ -10,6 +10,8 @@ Modifications Copyright 2021 Ross Wightman
|
|||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from ._types import ParamsT
|
||||||
|
|
||||||
|
|
||||||
class RMSpropTF(Optimizer):
|
class RMSpropTF(Optimizer):
|
||||||
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
"""Implements RMSprop algorithm (TensorFlow style epsilon)
|
||||||
@ -28,34 +30,31 @@ class RMSpropTF(Optimizer):
|
|||||||
The centered version first appears in `Generating Sequences
|
The centered version first appears in `Generating Sequences
|
||||||
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining
|
params: iterable of parameters to optimize or dicts defining parameter groups
|
||||||
parameter groups
|
lr: learning rate
|
||||||
lr (float, optional): learning rate (default: 1e-2)
|
momentum: momentum factor
|
||||||
momentum (float, optional): momentum factor (default: 0)
|
alpha: smoothing (decay) constant
|
||||||
alpha (float, optional): smoothing (decay) constant (default: 0.9)
|
eps: term added to the denominator to improve numerical stability
|
||||||
eps (float, optional): term added to the denominator to improve
|
centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
|
||||||
numerical stability (default: 1e-10)
|
weight_decay: weight decay (L2 penalty) (default: 0)
|
||||||
centered (bool, optional) : if ``True``, compute the centered RMSProp,
|
decoupled_decay: decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
||||||
the gradient is normalized by an estimation of its variance
|
lr_in_momentum: learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow
|
||||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
caution: apply caution
|
||||||
decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
|
|
||||||
lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
|
|
||||||
update as per defaults in Tensorflow
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=1e-2,
|
lr: float = 1e-2,
|
||||||
alpha=0.9,
|
alpha: float = 0.9,
|
||||||
eps=1e-10,
|
eps: float = 1e-10,
|
||||||
weight_decay=0,
|
weight_decay: float = 0,
|
||||||
momentum=0.,
|
momentum: float = 0.,
|
||||||
centered=False,
|
centered: bool = False,
|
||||||
decoupled_decay=False,
|
decoupled_decay: bool = False,
|
||||||
lr_in_momentum=True,
|
lr_in_momentum: bool = True,
|
||||||
|
caution: bool = False,
|
||||||
):
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
@ -77,6 +76,7 @@ class RMSpropTF(Optimizer):
|
|||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
decoupled_decay=decoupled_decay,
|
decoupled_decay=decoupled_decay,
|
||||||
lr_in_momentum=lr_in_momentum,
|
lr_in_momentum=lr_in_momentum,
|
||||||
|
caution=caution,
|
||||||
)
|
)
|
||||||
super(RMSpropTF, self).__init__(params, defaults)
|
super(RMSpropTF, self).__init__(params, defaults)
|
||||||
|
|
||||||
@ -85,6 +85,7 @@ class RMSpropTF(Optimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('momentum', 0)
|
group.setdefault('momentum', 0)
|
||||||
group.setdefault('centered', False)
|
group.setdefault('centered', False)
|
||||||
|
group.setdefault('caution', False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
@ -142,13 +143,25 @@ class RMSpropTF(Optimizer):
|
|||||||
|
|
||||||
if group['momentum'] > 0:
|
if group['momentum'] > 0:
|
||||||
buf = state['momentum_buffer']
|
buf = state['momentum_buffer']
|
||||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
buf.mul_(group['momentum'])
|
||||||
|
|
||||||
|
def _apply_caution(_m, _g):
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (_m * _g > 0).to(_g.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
return _m * mask
|
||||||
|
|
||||||
if group['lr_in_momentum']:
|
if group['lr_in_momentum']:
|
||||||
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||||
|
buf.addcdiv_(grad, avg, value=group['lr'])
|
||||||
|
if group['caution']:
|
||||||
|
buf = _apply_caution(buf, grad)
|
||||||
p.add_(-buf)
|
p.add_(-buf)
|
||||||
else:
|
else:
|
||||||
# PyTorch scales the param update by LR
|
# PyTorch scales the param update by LR
|
||||||
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
buf.addcdiv_(grad, avg)
|
||||||
|
if group['caution']:
|
||||||
|
buf = _apply_caution(buf, grad)
|
||||||
p.add_(buf, alpha=-group['lr'])
|
p.add_(buf, alpha=-group['lr'])
|
||||||
else:
|
else:
|
||||||
p.addcdiv_(grad, avg, value=-group['lr'])
|
p.addcdiv_(grad, avg, value=-group['lr'])
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from functools import update_wrapper, wraps
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
@ -8,7 +9,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_recent_pt = False
|
has_recent_pt = False
|
||||||
|
|
||||||
from typing import List, Optional
|
from ._types import ParamsT
|
||||||
|
|
||||||
__all__ = ['SGDW', 'sgdw']
|
__all__ = ['SGDW', 'sgdw']
|
||||||
|
|
||||||
@ -16,13 +17,14 @@ __all__ = ['SGDW', 'sgdw']
|
|||||||
class SGDW(Optimizer):
|
class SGDW(Optimizer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params: ParamsT,
|
||||||
lr=1e-3,
|
lr: float = 1e-3,
|
||||||
momentum=0,
|
momentum: float = 0.,
|
||||||
dampening=0,
|
dampening: float = 0.,
|
||||||
weight_decay=0,
|
weight_decay: float = 0.,
|
||||||
nesterov=False,
|
nesterov: bool = False,
|
||||||
*,
|
*,
|
||||||
|
caution: bool = False,
|
||||||
maximize: bool = False,
|
maximize: bool = False,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
differentiable: bool = False,
|
differentiable: bool = False,
|
||||||
@ -40,6 +42,7 @@ class SGDW(Optimizer):
|
|||||||
dampening=dampening,
|
dampening=dampening,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
nesterov=nesterov,
|
nesterov=nesterov,
|
||||||
|
caution=caution,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
differentiable=differentiable,
|
differentiable=differentiable,
|
||||||
@ -51,18 +54,19 @@ class SGDW(Optimizer):
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super().__setstate__(state)
|
super().__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
|
group.setdefault('caution', False)
|
||||||
group.setdefault('nesterov', False)
|
group.setdefault('nesterov', False)
|
||||||
group.setdefault('maximize', False)
|
group.setdefault('maximize', False)
|
||||||
group.setdefault('foreach', None)
|
group.setdefault('foreach', None)
|
||||||
group.setdefault('differentiable', False)
|
group.setdefault('differentiable', False)
|
||||||
|
|
||||||
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
|
def _init_group(self, group, params_with_grad, grads, momentum_buffer_list):
|
||||||
has_sparse_grad = False
|
has_sparse_grad = False
|
||||||
|
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is not None:
|
if p.grad is not None:
|
||||||
params_with_grad.append(p)
|
params_with_grad.append(p)
|
||||||
d_p_list.append(p.grad)
|
grads.append(p.grad)
|
||||||
if p.grad.is_sparse:
|
if p.grad.is_sparse:
|
||||||
has_sparse_grad = True
|
has_sparse_grad = True
|
||||||
|
|
||||||
@ -91,20 +95,21 @@ class SGDW(Optimizer):
|
|||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
params_with_grad = []
|
params_with_grad = []
|
||||||
d_p_list = []
|
grads = []
|
||||||
momentum_buffer_list = []
|
momentum_buffer_list = []
|
||||||
|
|
||||||
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
|
has_sparse_grad = self._init_group(group, params_with_grad, grads, momentum_buffer_list)
|
||||||
|
|
||||||
sgdw(
|
sgdw(
|
||||||
params_with_grad,
|
params_with_grad,
|
||||||
d_p_list,
|
grads,
|
||||||
momentum_buffer_list,
|
momentum_buffer_list,
|
||||||
weight_decay=group['weight_decay'],
|
weight_decay=group['weight_decay'],
|
||||||
momentum=group['momentum'],
|
momentum=group['momentum'],
|
||||||
lr=group['lr'],
|
lr=group['lr'],
|
||||||
dampening=group['dampening'],
|
dampening=group['dampening'],
|
||||||
nesterov=group['nesterov'],
|
nesterov=group['nesterov'],
|
||||||
|
caution=group['caution'],
|
||||||
maximize=group['maximize'],
|
maximize=group['maximize'],
|
||||||
has_sparse_grad=has_sparse_grad,
|
has_sparse_grad=has_sparse_grad,
|
||||||
foreach=group['foreach'],
|
foreach=group['foreach'],
|
||||||
@ -120,7 +125,7 @@ class SGDW(Optimizer):
|
|||||||
|
|
||||||
def sgdw(
|
def sgdw(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
d_p_list: List[Tensor],
|
grads: List[Tensor],
|
||||||
momentum_buffer_list: List[Optional[Tensor]],
|
momentum_buffer_list: List[Optional[Tensor]],
|
||||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||||
@ -132,6 +137,7 @@ def sgdw(
|
|||||||
lr: float,
|
lr: float,
|
||||||
dampening: float,
|
dampening: float,
|
||||||
nesterov: bool,
|
nesterov: bool,
|
||||||
|
caution: bool,
|
||||||
maximize: bool
|
maximize: bool
|
||||||
):
|
):
|
||||||
r"""Functional API that performs SGD algorithm computation.
|
r"""Functional API that performs SGD algorithm computation.
|
||||||
@ -159,13 +165,14 @@ def sgdw(
|
|||||||
|
|
||||||
func(
|
func(
|
||||||
params,
|
params,
|
||||||
d_p_list,
|
grads,
|
||||||
momentum_buffer_list,
|
momentum_buffer_list,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
momentum=momentum,
|
momentum=momentum,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
dampening=dampening,
|
dampening=dampening,
|
||||||
nesterov=nesterov,
|
nesterov=nesterov,
|
||||||
|
caution=caution,
|
||||||
has_sparse_grad=has_sparse_grad,
|
has_sparse_grad=has_sparse_grad,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
)
|
)
|
||||||
@ -173,7 +180,7 @@ def sgdw(
|
|||||||
|
|
||||||
def _single_tensor_sgdw(
|
def _single_tensor_sgdw(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
d_p_list: List[Tensor],
|
grads: List[Tensor],
|
||||||
momentum_buffer_list: List[Optional[Tensor]],
|
momentum_buffer_list: List[Optional[Tensor]],
|
||||||
*,
|
*,
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
@ -181,11 +188,12 @@ def _single_tensor_sgdw(
|
|||||||
lr: float,
|
lr: float,
|
||||||
dampening: float,
|
dampening: float,
|
||||||
nesterov: bool,
|
nesterov: bool,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
has_sparse_grad: bool
|
has_sparse_grad: bool
|
||||||
):
|
):
|
||||||
for i, param in enumerate(params):
|
for i, param in enumerate(params):
|
||||||
d_p = d_p_list[i] if not maximize else -d_p_list[i]
|
grad = grads[i] if not maximize else -grads[i]
|
||||||
|
|
||||||
param.mul_(1. - lr * weight_decay)
|
param.mul_(1. - lr * weight_decay)
|
||||||
|
|
||||||
@ -193,17 +201,25 @@ def _single_tensor_sgdw(
|
|||||||
buf = momentum_buffer_list[i]
|
buf = momentum_buffer_list[i]
|
||||||
|
|
||||||
if buf is None:
|
if buf is None:
|
||||||
buf = torch.clone(d_p).detach()
|
buf = torch.clone(grad).detach()
|
||||||
momentum_buffer_list[i] = buf
|
momentum_buffer_list[i] = buf
|
||||||
else:
|
else:
|
||||||
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
buf.mul_(momentum).add_(grad, alpha=1 - dampening)
|
||||||
|
|
||||||
|
if caution:
|
||||||
if nesterov:
|
if nesterov:
|
||||||
d_p = d_p.add(buf, alpha=momentum)
|
buf = grad.add(buf, alpha=momentum)
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
mask = (buf * grad > 0).to(grad.dtype)
|
||||||
|
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||||
|
grad = buf * mask
|
||||||
else:
|
else:
|
||||||
d_p = buf
|
if nesterov:
|
||||||
|
grad = grad.add(buf, alpha=momentum)
|
||||||
|
else:
|
||||||
|
grad = buf
|
||||||
|
|
||||||
param.add_(d_p, alpha=-lr)
|
param.add_(grad, alpha=-lr)
|
||||||
|
|
||||||
|
|
||||||
def _multi_tensor_sgdw(
|
def _multi_tensor_sgdw(
|
||||||
@ -216,6 +232,7 @@ def _multi_tensor_sgdw(
|
|||||||
lr: float,
|
lr: float,
|
||||||
dampening: float,
|
dampening: float,
|
||||||
nesterov: bool,
|
nesterov: bool,
|
||||||
|
caution: bool,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
has_sparse_grad: bool
|
has_sparse_grad: bool
|
||||||
):
|
):
|
||||||
@ -258,6 +275,18 @@ def _multi_tensor_sgdw(
|
|||||||
|
|
||||||
bufs.append(buf)
|
bufs.append(buf)
|
||||||
|
|
||||||
|
if caution:
|
||||||
|
if nesterov:
|
||||||
|
# Can't do nesterov in-place if we want to compare against orig grad for caution
|
||||||
|
bufs = torch._foreach_add(device_grads, bufs, alpha=momentum)
|
||||||
|
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
|
||||||
|
masks = torch._foreach_mul(bufs, device_grads)
|
||||||
|
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
|
||||||
|
mask_scale = [m.mean() for m in masks]
|
||||||
|
torch._foreach_maximum_(mask_scale, 1e-3)
|
||||||
|
torch._foreach_div_(masks, mask_scale)
|
||||||
|
device_grads = torch._foreach_mul(bufs, masks)
|
||||||
|
else:
|
||||||
if nesterov:
|
if nesterov:
|
||||||
torch._foreach_add_(device_grads, bufs, alpha=momentum)
|
torch._foreach_add_(device_grads, bufs, alpha=momentum)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user