diff --git a/tests/test_optim.py b/tests/test_optim.py index e10ed532..58cd40e7 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -376,10 +376,17 @@ def test_adam(optimizer): @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) def test_adopt(optimizer): - # FIXME rosenbrock is not passing for ADOPT - # _test_rosenbrock( - # lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) - # ) + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=3e-3) + ) + _test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT + + +@pytest.mark.parametrize('optimizer', ['adan', 'adanw']) +def test_adan(optimizer): + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + ) _test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT @@ -432,6 +439,14 @@ def test_lamb(optimizer): _test_model(optimizer, dict(lr=1e-3)) +@pytest.mark.parametrize('optimizer', ['laprop']) +def test_laprop(optimizer): + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) + ) + _test_model(optimizer, dict(lr=1e-2)) + + @pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc']) def test_lars(optimizer): _test_rosenbrock( @@ -448,6 +463,14 @@ def test_madgrad(optimizer): _test_model(optimizer, dict(lr=1e-2)) +@pytest.mark.parametrize('optimizer', ['mars']) +def test_mars(optimizer): + _test_rosenbrock( + lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) + ) + _test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT + + @pytest.mark.parametrize('optimizer', ['novograd']) def test_novograd(optimizer): _test_rosenbrock( diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 35cf7bc0..eee5f45b 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -3,22 +3,27 @@ from .adafactor import Adafactor from .adafactor_bv import AdafactorBigVision from .adahessian import Adahessian from .adamp import AdamP -from .adamw import AdamW +from .adamw import AdamWLegacy from .adan import Adan from .adopt import Adopt from .lamb import Lamb +from .laprop import LaProp from .lars import Lars from .lion import Lion from .lookahead import Lookahead from .madgrad import MADGRAD -from .nadam import Nadam +from .mars import Mars +from .nadam import NAdamLegacy from .nadamw import NAdamW from .nvnovograd import NvNovoGrad -from .radam import RAdam +from .radam import RAdamLegacy from .rmsprop_tf import RMSpropTF from .sgdp import SGDP from .sgdw import SGDW +# bring torch optim into timm.optim namespace for consistency +from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD + from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ create_optimizer_v2, create_optimizer, optimizer_kwargs from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 37b5fdc0..f4784d7c 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -19,17 +19,20 @@ from .adafactor import Adafactor from .adafactor_bv import AdafactorBigVision from .adahessian import Adahessian from .adamp import AdamP +from .adamw import AdamWLegacy from .adan import Adan from .adopt import Adopt from .lamb import Lamb +from .laprop import LaProp from .lars import Lars from .lion import Lion from .lookahead import Lookahead from .madgrad import MADGRAD -from .nadam import Nadam +from .mars import Mars +from .nadam import NAdamLegacy from .nadamw import NAdamW from .nvnovograd import NvNovoGrad -from .radam import RAdam +from .radam import RAdamLegacy from .rmsprop_tf import RMSpropTF from .sgdp import SGDP from .sgdw import SGDW @@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None: OptimInfo( name='adam', opt_class=optim.Adam, - description='torch.optim Adam (Adaptive Moment Estimation)', + description='torch.optim.Adam, Adaptive Moment Estimation', has_betas=True ), OptimInfo( name='adamw', opt_class=optim.AdamW, - description='torch.optim Adam with decoupled weight decay regularization', + description='torch.optim.AdamW, Adam with decoupled weight decay', + has_betas=True + ), + OptimInfo( + name='adamwlegacy', + opt_class=AdamWLegacy, + description='legacy impl of AdamW that pre-dates inclusion to torch.optim', has_betas=True ), OptimInfo( @@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None: ), OptimInfo( name='nadam', - opt_class=Nadam, - description='Adam with Nesterov momentum', + opt_class=torch.optim.NAdam, + description='torch.optim.NAdam, Adam with Nesterov momentum', + has_betas=True + ), + OptimInfo( + name='nadamlegacy', + opt_class=NAdamLegacy, + description='legacy impl of NAdam that pre-dates inclusion in torch.optim', has_betas=True ), OptimInfo( name='nadamw', opt_class=NAdamW, - description='Adam with Nesterov momentum and decoupled weight decay', + description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl', has_betas=True ), OptimInfo( name='radam', - opt_class=RAdam, - description='Rectified Adam with variance adaptation', + opt_class=torch.optim.RAdam, + description='torch.optim.RAdam, Rectified Adam with variance adaptation', has_betas=True ), + OptimInfo( + name='radamlegacy', + opt_class=RAdamLegacy, + description='legacy impl of RAdam that predates inclusion in torch.optim', + has_betas=True + ), + OptimInfo( + name='radamw', + opt_class=torch.optim.RAdam, + description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay', + has_betas=True, + defaults={'decoupled_weight_decay': True} + ), OptimInfo( name='adamax', opt_class=optim.Adamax, - description='torch.optim Adamax, Adam with infinity norm for more stable updates', + description='torch.optim.Adamax, Adam with infinity norm for more stable updates', has_betas=True ), OptimInfo( @@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: OptimInfo( name='adadelta', opt_class=optim.Adadelta, - description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients' + description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients' ), OptimInfo( name='adagrad', opt_class=optim.Adagrad, - description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients', + description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients', defaults={'eps': 1e-8} ), OptimInfo( @@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: has_betas=True, second_order=True, ), + OptimInfo( + name='laprop', + opt_class=LaProp, + description='Separating Momentum and Adaptivity in Adam', + has_betas=True, + ), OptimInfo( name='lion', opt_class=Lion, @@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: has_momentum=True, defaults={'decoupled_decay': True} ), + OptimInfo( + name='mars', + opt_class=Mars, + description='Unleashing the Power of Variance Reduction for Training Large Models', + has_betas=True, + ), OptimInfo( name='novograd', opt_class=NvNovoGrad, @@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: OptimInfo( name='rmsprop', opt_class=optim.RMSprop, - description='torch.optim RMSprop, Root Mean Square Propagation', + description='torch.optim.RMSprop, Root Mean Square Propagation', has_momentum=True, defaults={'alpha': 0.9} ), diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py index b755a57c..fe34609c 100644 --- a/timm/optim/adamw.py +++ b/timm/optim/adamw.py @@ -1,17 +1,18 @@ """ AdamW Optimizer Impl copied from PyTorch master -NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed -someday +NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference """ import math import torch from torch.optim.optimizer import Optimizer -class AdamW(Optimizer): +class AdamWLegacy(Optimizer): r"""Implements AdamW algorithm. + NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. @@ -61,10 +62,10 @@ class AdamW(Optimizer): weight_decay=weight_decay, amsgrad=amsgrad, ) - super(AdamW, self).__init__(params, defaults) + super(AdamWLegacy, self).__init__(params, defaults) def __setstate__(self, state): - super(AdamW, self).__setstate__(state) + super(AdamWLegacy, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py index 892262cf..46f6150b 100644 --- a/timm/optim/nadam.py +++ b/timm/optim/nadam.py @@ -4,9 +4,11 @@ import torch from torch.optim.optimizer import Optimizer -class Nadam(Optimizer): +class NAdamLegacy(Optimizer): """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. Arguments: @@ -45,7 +47,7 @@ class Nadam(Optimizer): weight_decay=weight_decay, schedule_decay=schedule_decay, ) - super(Nadam, self).__init__(params, defaults) + super(NAdamLegacy, self).__init__(params, defaults) @torch.no_grad() def step(self, closure=None): diff --git a/timm/optim/radam.py b/timm/optim/radam.py index d6f8d30a..9b12b98a 100644 --- a/timm/optim/radam.py +++ b/timm/optim/radam.py @@ -1,14 +1,19 @@ """RAdam Optimizer. Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 + +NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference """ import math import torch from torch.optim.optimizer import Optimizer -class RAdam(Optimizer): +class RAdamLegacy(Optimizer): + """ PyTorch RAdam optimizer + NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference + """ def __init__( self, params, @@ -24,10 +29,10 @@ class RAdam(Optimizer): weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] ) - super(RAdam, self).__init__(params, defaults) + super(RAdamLegacy, self).__init__(params, defaults) def __setstate__(self, state): - super(RAdam, self).__setstate__(state) + super(RAdamLegacy, self).__setstate__(state) @torch.no_grad() def step(self, closure=None):