Replace radam & nadam impl with torch.optim ver, rename legacy adamw, nadam, radam impl in timm. Update optim factory & tests.

cautious_optim
Ross Wightman 2024-11-26 10:54:17 -08:00 committed by Ross Wightman
parent 7b54eab807
commit a024ab3170
6 changed files with 106 additions and 30 deletions

View File

@ -376,10 +376,17 @@ def test_adam(optimizer):
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
def test_adopt(optimizer):
# FIXME rosenbrock is not passing for ADOPT
# _test_rosenbrock(
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
# )
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=3e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
@pytest.mark.parametrize('optimizer', ['adan', 'adanw'])
def test_adan(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
@ -432,6 +439,14 @@ def test_lamb(optimizer):
_test_model(optimizer, dict(lr=1e-3))
@pytest.mark.parametrize('optimizer', ['laprop'])
def test_laprop(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
)
_test_model(optimizer, dict(lr=1e-2))
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer):
_test_rosenbrock(
@ -448,6 +463,14 @@ def test_madgrad(optimizer):
_test_model(optimizer, dict(lr=1e-2))
@pytest.mark.parametrize('optimizer', ['mars'])
def test_mars(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
@pytest.mark.parametrize('optimizer', ['novograd'])
def test_novograd(optimizer):
_test_rosenbrock(

View File

@ -3,22 +3,27 @@ from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP
from .adamw import AdamW
from .adamw import AdamWLegacy
from .adan import Adan
from .adopt import Adopt
from .lamb import Lamb
from .laprop import LaProp
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .mars import Mars
from .nadam import NAdamLegacy
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .radam import RAdamLegacy
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW
# bring torch optim into timm.optim namespace for consistency
from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
create_optimizer_v2, create_optimizer, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers

View File

@ -19,17 +19,20 @@ from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP
from .adamw import AdamWLegacy
from .adan import Adan
from .adopt import Adopt
from .lamb import Lamb
from .laprop import LaProp
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .mars import Mars
from .nadam import NAdamLegacy
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .radam import RAdamLegacy
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW
@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
OptimInfo(
name='adam',
opt_class=optim.Adam,
description='torch.optim Adam (Adaptive Moment Estimation)',
description='torch.optim.Adam, Adaptive Moment Estimation',
has_betas=True
),
OptimInfo(
name='adamw',
opt_class=optim.AdamW,
description='torch.optim Adam with decoupled weight decay regularization',
description='torch.optim.AdamW, Adam with decoupled weight decay',
has_betas=True
),
OptimInfo(
name='adamwlegacy',
opt_class=AdamWLegacy,
description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
has_betas=True
),
OptimInfo(
@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='nadam',
opt_class=Nadam,
description='Adam with Nesterov momentum',
opt_class=torch.optim.NAdam,
description='torch.optim.NAdam, Adam with Nesterov momentum',
has_betas=True
),
OptimInfo(
name='nadamlegacy',
opt_class=NAdamLegacy,
description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
has_betas=True
),
OptimInfo(
name='nadamw',
opt_class=NAdamW,
description='Adam with Nesterov momentum and decoupled weight decay',
description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
has_betas=True
),
OptimInfo(
name='radam',
opt_class=RAdam,
description='Rectified Adam with variance adaptation',
opt_class=torch.optim.RAdam,
description='torch.optim.RAdam, Rectified Adam with variance adaptation',
has_betas=True
),
OptimInfo(
name='radamlegacy',
opt_class=RAdamLegacy,
description='legacy impl of RAdam that predates inclusion in torch.optim',
has_betas=True
),
OptimInfo(
name='radamw',
opt_class=torch.optim.RAdam,
description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
has_betas=True,
defaults={'decoupled_weight_decay': True}
),
OptimInfo(
name='adamax',
opt_class=optim.Adamax,
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
has_betas=True
),
OptimInfo(
@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo(
name='adadelta',
opt_class=optim.Adadelta,
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
),
OptimInfo(
name='adagrad',
opt_class=optim.Adagrad,
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8}
),
OptimInfo(
@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
second_order=True,
),
OptimInfo(
name='laprop',
opt_class=LaProp,
description='Separating Momentum and Adaptivity in Adam',
has_betas=True,
),
OptimInfo(
name='lion',
opt_class=Lion,
@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_momentum=True,
defaults={'decoupled_decay': True}
),
OptimInfo(
name='mars',
opt_class=Mars,
description='Unleashing the Power of Variance Reduction for Training Large Models',
has_betas=True,
),
OptimInfo(
name='novograd',
opt_class=NvNovoGrad,
@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo(
name='rmsprop',
opt_class=optim.RMSprop,
description='torch.optim RMSprop, Root Mean Square Propagation',
description='torch.optim.RMSprop, Root Mean Square Propagation',
has_momentum=True,
defaults={'alpha': 0.9}
),

View File

@ -1,17 +1,18 @@
""" AdamW Optimizer
Impl copied from PyTorch master
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
someday
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
"""
import math
import torch
from torch.optim.optimizer import Optimizer
class AdamW(Optimizer):
class AdamWLegacy(Optimizer):
r"""Implements AdamW algorithm.
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
@ -61,10 +62,10 @@ class AdamW(Optimizer):
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super(AdamW, self).__init__(params, defaults)
super(AdamWLegacy, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
super(AdamWLegacy, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)

View File

@ -4,9 +4,11 @@ import torch
from torch.optim.optimizer import Optimizer
class Nadam(Optimizer):
class NAdamLegacy(Optimizer):
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
Arguments:
@ -45,7 +47,7 @@ class Nadam(Optimizer):
weight_decay=weight_decay,
schedule_decay=schedule_decay,
)
super(Nadam, self).__init__(params, defaults)
super(NAdamLegacy, self).__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):

View File

@ -1,14 +1,19 @@
"""RAdam Optimizer.
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference
"""
import math
import torch
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
class RAdamLegacy(Optimizer):
""" PyTorch RAdam optimizer
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
"""
def __init__(
self,
params,
@ -24,10 +29,10 @@ class RAdam(Optimizer):
weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)]
)
super(RAdam, self).__init__(params, defaults)
super(RAdamLegacy, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
super(RAdamLegacy, self).__setstate__(state)
@torch.no_grad()
def step(self, closure=None):