Replace radam & nadam impl with torch.optim ver, rename legacy adamw, nadam, radam impl in timm. Update optim factory & tests.
parent
7b54eab807
commit
a024ab3170
|
@ -376,10 +376,17 @@ def test_adam(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
|
||||
def test_adopt(optimizer):
|
||||
# FIXME rosenbrock is not passing for ADOPT
|
||||
# _test_rosenbrock(
|
||||
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
# )
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=3e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['adan', 'adanw'])
|
||||
def test_adan(optimizer):
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
||||
|
||||
|
||||
|
@ -432,6 +439,14 @@ def test_lamb(optimizer):
|
|||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['laprop'])
|
||||
def test_laprop(optimizer):
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
||||
def test_lars(optimizer):
|
||||
_test_rosenbrock(
|
||||
|
@ -448,6 +463,14 @@ def test_madgrad(optimizer):
|
|||
_test_model(optimizer, dict(lr=1e-2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['mars'])
|
||||
def test_mars(optimizer):
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['novograd'])
|
||||
def test_novograd(optimizer):
|
||||
_test_rosenbrock(
|
||||
|
|
|
@ -3,22 +3,27 @@ from .adafactor import Adafactor
|
|||
from .adafactor_bv import AdafactorBigVision
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .adamw import AdamW
|
||||
from .adamw import AdamWLegacy
|
||||
from .adan import Adan
|
||||
from .adopt import Adopt
|
||||
from .lamb import Lamb
|
||||
from .laprop import LaProp
|
||||
from .lars import Lars
|
||||
from .lion import Lion
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .mars import Mars
|
||||
from .nadam import NAdamLegacy
|
||||
from .nadamw import NAdamW
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .radam import RAdamLegacy
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .sgdw import SGDW
|
||||
|
||||
# bring torch optim into timm.optim namespace for consistency
|
||||
from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD
|
||||
|
||||
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
|
||||
create_optimizer_v2, create_optimizer, optimizer_kwargs
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
|
||||
|
|
|
@ -19,17 +19,20 @@ from .adafactor import Adafactor
|
|||
from .adafactor_bv import AdafactorBigVision
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .adamw import AdamWLegacy
|
||||
from .adan import Adan
|
||||
from .adopt import Adopt
|
||||
from .lamb import Lamb
|
||||
from .laprop import LaProp
|
||||
from .lars import Lars
|
||||
from .lion import Lion
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .mars import Mars
|
||||
from .nadam import NAdamLegacy
|
||||
from .nadamw import NAdamW
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .radam import RAdamLegacy
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .sgdw import SGDW
|
||||
|
@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||
OptimInfo(
|
||||
name='adam',
|
||||
opt_class=optim.Adam,
|
||||
description='torch.optim Adam (Adaptive Moment Estimation)',
|
||||
description='torch.optim.Adam, Adaptive Moment Estimation',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamw',
|
||||
opt_class=optim.AdamW,
|
||||
description='torch.optim Adam with decoupled weight decay regularization',
|
||||
description='torch.optim.AdamW, Adam with decoupled weight decay',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamwlegacy',
|
||||
opt_class=AdamWLegacy,
|
||||
description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
|
@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||
),
|
||||
OptimInfo(
|
||||
name='nadam',
|
||||
opt_class=Nadam,
|
||||
description='Adam with Nesterov momentum',
|
||||
opt_class=torch.optim.NAdam,
|
||||
description='torch.optim.NAdam, Adam with Nesterov momentum',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='nadamlegacy',
|
||||
opt_class=NAdamLegacy,
|
||||
description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='nadamw',
|
||||
opt_class=NAdamW,
|
||||
description='Adam with Nesterov momentum and decoupled weight decay',
|
||||
description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='radam',
|
||||
opt_class=RAdam,
|
||||
description='Rectified Adam with variance adaptation',
|
||||
opt_class=torch.optim.RAdam,
|
||||
description='torch.optim.RAdam, Rectified Adam with variance adaptation',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='radamlegacy',
|
||||
opt_class=RAdamLegacy,
|
||||
description='legacy impl of RAdam that predates inclusion in torch.optim',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='radamw',
|
||||
opt_class=torch.optim.RAdam,
|
||||
description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
|
||||
has_betas=True,
|
||||
defaults={'decoupled_weight_decay': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamax',
|
||||
opt_class=optim.Adamax,
|
||||
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
|
||||
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
|
@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
OptimInfo(
|
||||
name='adadelta',
|
||||
opt_class=optim.Adadelta,
|
||||
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
|
||||
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
|
||||
),
|
||||
OptimInfo(
|
||||
name='adagrad',
|
||||
opt_class=optim.Adagrad,
|
||||
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||
defaults={'eps': 1e-8}
|
||||
),
|
||||
OptimInfo(
|
||||
|
@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
has_betas=True,
|
||||
second_order=True,
|
||||
),
|
||||
OptimInfo(
|
||||
name='laprop',
|
||||
opt_class=LaProp,
|
||||
description='Separating Momentum and Adaptivity in Adam',
|
||||
has_betas=True,
|
||||
),
|
||||
OptimInfo(
|
||||
name='lion',
|
||||
opt_class=Lion,
|
||||
|
@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
has_momentum=True,
|
||||
defaults={'decoupled_decay': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='mars',
|
||||
opt_class=Mars,
|
||||
description='Unleashing the Power of Variance Reduction for Training Large Models',
|
||||
has_betas=True,
|
||||
),
|
||||
OptimInfo(
|
||||
name='novograd',
|
||||
opt_class=NvNovoGrad,
|
||||
|
@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
OptimInfo(
|
||||
name='rmsprop',
|
||||
opt_class=optim.RMSprop,
|
||||
description='torch.optim RMSprop, Root Mean Square Propagation',
|
||||
description='torch.optim.RMSprop, Root Mean Square Propagation',
|
||||
has_momentum=True,
|
||||
defaults={'alpha': 0.9}
|
||||
),
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
""" AdamW Optimizer
|
||||
Impl copied from PyTorch master
|
||||
|
||||
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
|
||||
someday
|
||||
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
class AdamWLegacy(Optimizer):
|
||||
r"""Implements AdamW algorithm.
|
||||
|
||||
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
||||
|
||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||
|
||||
|
@ -61,10 +62,10 @@ class AdamW(Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
super(AdamWLegacy, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(AdamW, self).__setstate__(state)
|
||||
super(AdamWLegacy, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
|
|
|
@ -4,9 +4,11 @@ import torch
|
|||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class Nadam(Optimizer):
|
||||
class NAdamLegacy(Optimizer):
|
||||
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
|
||||
|
||||
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
||||
|
||||
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
|
||||
|
||||
Arguments:
|
||||
|
@ -45,7 +47,7 @@ class Nadam(Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
schedule_decay=schedule_decay,
|
||||
)
|
||||
super(Nadam, self).__init__(params, defaults)
|
||||
super(NAdamLegacy, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
"""RAdam Optimizer.
|
||||
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
|
||||
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
|
||||
|
||||
NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
class RAdamLegacy(Optimizer):
|
||||
""" PyTorch RAdam optimizer
|
||||
|
||||
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
|
@ -24,10 +29,10 @@ class RAdam(Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)]
|
||||
)
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
super(RAdamLegacy, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(RAdam, self).__setstate__(state)
|
||||
super(RAdamLegacy, self).__setstate__(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
|
|
Loading…
Reference in New Issue