Replace radam & nadam impl with torch.optim ver, rename legacy adamw, nadam, radam impl in timm. Update optim factory & tests.

This commit is contained in:
Ross Wightman 2024-11-26 10:54:17 -08:00
parent 8293d3f786
commit b59058bd88
6 changed files with 106 additions and 30 deletions

View File

@ -376,10 +376,17 @@ def test_adam(optimizer):
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
def test_adopt(optimizer): def test_adopt(optimizer):
# FIXME rosenbrock is not passing for ADOPT _test_rosenbrock(
# _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=3e-3)
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) )
# ) _test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
@pytest.mark.parametrize('optimizer', ['adan', 'adanw'])
def test_adan(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT _test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
@ -432,6 +439,14 @@ def test_lamb(optimizer):
_test_model(optimizer, dict(lr=1e-3)) _test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['laprop'])
def test_laprop(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
)
_test_model(optimizer, dict(lr=1e-2))
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc']) @pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer): def test_lars(optimizer):
_test_rosenbrock( _test_rosenbrock(
@ -448,6 +463,14 @@ def test_madgrad(optimizer):
_test_model(optimizer, dict(lr=1e-2)) _test_model(optimizer, dict(lr=1e-2))
@pytest.mark.parametrize('optimizer', ['mars'])
def test_mars(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
@pytest.mark.parametrize('optimizer', ['novograd']) @pytest.mark.parametrize('optimizer', ['novograd'])
def test_novograd(optimizer): def test_novograd(optimizer):
_test_rosenbrock( _test_rosenbrock(

View File

@ -3,22 +3,27 @@ from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian from .adahessian import Adahessian
from .adamp import AdamP from .adamp import AdamP
from .adamw import AdamW from .adamw import AdamWLegacy
from .adan import Adan from .adan import Adan
from .adopt import Adopt from .adopt import Adopt
from .lamb import Lamb from .lamb import Lamb
from .laprop import LaProp
from .lars import Lars from .lars import Lars
from .lion import Lion from .lion import Lion
from .lookahead import Lookahead from .lookahead import Lookahead
from .madgrad import MADGRAD from .madgrad import MADGRAD
from .nadam import Nadam from .mars import Mars
from .nadam import NAdamLegacy
from .nadamw import NAdamW from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad from .nvnovograd import NvNovoGrad
from .radam import RAdam from .radam import RAdamLegacy
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from .sgdw import SGDW from .sgdw import SGDW
# bring torch optim into timm.optim namespace for consistency
from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
create_optimizer_v2, create_optimizer, optimizer_kwargs create_optimizer_v2, create_optimizer, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers

View File

@ -19,17 +19,20 @@ from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian from .adahessian import Adahessian
from .adamp import AdamP from .adamp import AdamP
from .adamw import AdamWLegacy
from .adan import Adan from .adan import Adan
from .adopt import Adopt from .adopt import Adopt
from .lamb import Lamb from .lamb import Lamb
from .laprop import LaProp
from .lars import Lars from .lars import Lars
from .lion import Lion from .lion import Lion
from .lookahead import Lookahead from .lookahead import Lookahead
from .madgrad import MADGRAD from .madgrad import MADGRAD
from .nadam import Nadam from .mars import Mars
from .nadam import NAdamLegacy
from .nadamw import NAdamW from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad from .nvnovograd import NvNovoGrad
from .radam import RAdam from .radam import RAdamLegacy
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from .sgdw import SGDW from .sgdw import SGDW
@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
OptimInfo( OptimInfo(
name='adam', name='adam',
opt_class=optim.Adam, opt_class=optim.Adam,
description='torch.optim Adam (Adaptive Moment Estimation)', description='torch.optim.Adam, Adaptive Moment Estimation',
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
name='adamw', name='adamw',
opt_class=optim.AdamW, opt_class=optim.AdamW,
description='torch.optim Adam with decoupled weight decay regularization', description='torch.optim.AdamW, Adam with decoupled weight decay',
has_betas=True
),
OptimInfo(
name='adamwlegacy',
opt_class=AdamWLegacy,
description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
), ),
OptimInfo( OptimInfo(
name='nadam', name='nadam',
opt_class=Nadam, opt_class=torch.optim.NAdam,
description='Adam with Nesterov momentum', description='torch.optim.NAdam, Adam with Nesterov momentum',
has_betas=True
),
OptimInfo(
name='nadamlegacy',
opt_class=NAdamLegacy,
description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
name='nadamw', name='nadamw',
opt_class=NAdamW, opt_class=NAdamW,
description='Adam with Nesterov momentum and decoupled weight decay', description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
name='radam', name='radam',
opt_class=RAdam, opt_class=torch.optim.RAdam,
description='Rectified Adam with variance adaptation', description='torch.optim.RAdam, Rectified Adam with variance adaptation',
has_betas=True has_betas=True
), ),
OptimInfo(
name='radamlegacy',
opt_class=RAdamLegacy,
description='legacy impl of RAdam that predates inclusion in torch.optim',
has_betas=True
),
OptimInfo(
name='radamw',
opt_class=torch.optim.RAdam,
description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
has_betas=True,
defaults={'decoupled_weight_decay': True}
),
OptimInfo( OptimInfo(
name='adamax', name='adamax',
opt_class=optim.Adamax, opt_class=optim.Adamax,
description='torch.optim Adamax, Adam with infinity norm for more stable updates', description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo( OptimInfo(
name='adadelta', name='adadelta',
opt_class=optim.Adadelta, opt_class=optim.Adadelta,
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients' description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
), ),
OptimInfo( OptimInfo(
name='adagrad', name='adagrad',
opt_class=optim.Adagrad, opt_class=optim.Adagrad,
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients', description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8} defaults={'eps': 1e-8}
), ),
OptimInfo( OptimInfo(
@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True, has_betas=True,
second_order=True, second_order=True,
), ),
OptimInfo(
name='laprop',
opt_class=LaProp,
description='Separating Momentum and Adaptivity in Adam',
has_betas=True,
),
OptimInfo( OptimInfo(
name='lion', name='lion',
opt_class=Lion, opt_class=Lion,
@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_momentum=True, has_momentum=True,
defaults={'decoupled_decay': True} defaults={'decoupled_decay': True}
), ),
OptimInfo(
name='mars',
opt_class=Mars,
description='Unleashing the Power of Variance Reduction for Training Large Models',
has_betas=True,
),
OptimInfo( OptimInfo(
name='novograd', name='novograd',
opt_class=NvNovoGrad, opt_class=NvNovoGrad,
@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo( OptimInfo(
name='rmsprop', name='rmsprop',
opt_class=optim.RMSprop, opt_class=optim.RMSprop,
description='torch.optim RMSprop, Root Mean Square Propagation', description='torch.optim.RMSprop, Root Mean Square Propagation',
has_momentum=True, has_momentum=True,
defaults={'alpha': 0.9} defaults={'alpha': 0.9}
), ),

View File

@ -1,17 +1,18 @@
""" AdamW Optimizer """ AdamW Optimizer
Impl copied from PyTorch master Impl copied from PyTorch master
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
someday
""" """
import math import math
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
class AdamW(Optimizer): class AdamWLegacy(Optimizer):
r"""Implements AdamW algorithm. r"""Implements AdamW algorithm.
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
@ -61,10 +62,10 @@ class AdamW(Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
amsgrad=amsgrad, amsgrad=amsgrad,
) )
super(AdamW, self).__init__(params, defaults) super(AdamWLegacy, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
super(AdamW, self).__setstate__(state) super(AdamWLegacy, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('amsgrad', False) group.setdefault('amsgrad', False)

View File

@ -4,9 +4,11 @@ import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
class Nadam(Optimizer): class NAdamLegacy(Optimizer):
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
It has been proposed in `Incorporating Nesterov Momentum into Adam`__. It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
Arguments: Arguments:
@ -45,7 +47,7 @@ class Nadam(Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
schedule_decay=schedule_decay, schedule_decay=schedule_decay,
) )
super(Nadam, self).__init__(params, defaults) super(NAdamLegacy, self).__init__(params, defaults)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):

View File

@ -1,14 +1,19 @@
"""RAdam Optimizer. """RAdam Optimizer.
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference
""" """
import math import math
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
class RAdam(Optimizer): class RAdamLegacy(Optimizer):
""" PyTorch RAdam optimizer
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
"""
def __init__( def __init__(
self, self,
params, params,
@ -24,10 +29,10 @@ class RAdam(Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)] buffer=[[None, None, None] for _ in range(10)]
) )
super(RAdam, self).__init__(params, defaults) super(RAdamLegacy, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
super(RAdam, self).__setstate__(state) super(RAdamLegacy, self).__setstate__(state)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):