mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Replace radam & nadam impl with torch.optim ver, rename legacy adamw, nadam, radam impl in timm. Update optim factory & tests.
This commit is contained in:
parent
8293d3f786
commit
b59058bd88
@ -376,10 +376,17 @@ def test_adam(optimizer):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
|
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
|
||||||
def test_adopt(optimizer):
|
def test_adopt(optimizer):
|
||||||
# FIXME rosenbrock is not passing for ADOPT
|
_test_rosenbrock(
|
||||||
# _test_rosenbrock(
|
lambda params: create_optimizer_v2(params, optimizer, lr=3e-3)
|
||||||
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
)
|
||||||
# )
|
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('optimizer', ['adan', 'adanw'])
|
||||||
|
def test_adan(optimizer):
|
||||||
|
_test_rosenbrock(
|
||||||
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
|
)
|
||||||
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
||||||
|
|
||||||
|
|
||||||
@ -432,6 +439,14 @@ def test_lamb(optimizer):
|
|||||||
_test_model(optimizer, dict(lr=1e-3))
|
_test_model(optimizer, dict(lr=1e-3))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('optimizer', ['laprop'])
|
||||||
|
def test_laprop(optimizer):
|
||||||
|
_test_rosenbrock(
|
||||||
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||||
|
)
|
||||||
|
_test_model(optimizer, dict(lr=1e-2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
||||||
def test_lars(optimizer):
|
def test_lars(optimizer):
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
@ -448,6 +463,14 @@ def test_madgrad(optimizer):
|
|||||||
_test_model(optimizer, dict(lr=1e-2))
|
_test_model(optimizer, dict(lr=1e-2))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('optimizer', ['mars'])
|
||||||
|
def test_mars(optimizer):
|
||||||
|
_test_rosenbrock(
|
||||||
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
|
)
|
||||||
|
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['novograd'])
|
@pytest.mark.parametrize('optimizer', ['novograd'])
|
||||||
def test_novograd(optimizer):
|
def test_novograd(optimizer):
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
|
@ -3,22 +3,27 @@ from .adafactor import Adafactor
|
|||||||
from .adafactor_bv import AdafactorBigVision
|
from .adafactor_bv import AdafactorBigVision
|
||||||
from .adahessian import Adahessian
|
from .adahessian import Adahessian
|
||||||
from .adamp import AdamP
|
from .adamp import AdamP
|
||||||
from .adamw import AdamW
|
from .adamw import AdamWLegacy
|
||||||
from .adan import Adan
|
from .adan import Adan
|
||||||
from .adopt import Adopt
|
from .adopt import Adopt
|
||||||
from .lamb import Lamb
|
from .lamb import Lamb
|
||||||
|
from .laprop import LaProp
|
||||||
from .lars import Lars
|
from .lars import Lars
|
||||||
from .lion import Lion
|
from .lion import Lion
|
||||||
from .lookahead import Lookahead
|
from .lookahead import Lookahead
|
||||||
from .madgrad import MADGRAD
|
from .madgrad import MADGRAD
|
||||||
from .nadam import Nadam
|
from .mars import Mars
|
||||||
|
from .nadam import NAdamLegacy
|
||||||
from .nadamw import NAdamW
|
from .nadamw import NAdamW
|
||||||
from .nvnovograd import NvNovoGrad
|
from .nvnovograd import NvNovoGrad
|
||||||
from .radam import RAdam
|
from .radam import RAdamLegacy
|
||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
from .sgdw import SGDW
|
from .sgdw import SGDW
|
||||||
|
|
||||||
|
# bring torch optim into timm.optim namespace for consistency
|
||||||
|
from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD
|
||||||
|
|
||||||
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
|
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
|
||||||
create_optimizer_v2, create_optimizer, optimizer_kwargs
|
create_optimizer_v2, create_optimizer, optimizer_kwargs
|
||||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
|
||||||
|
@ -19,17 +19,20 @@ from .adafactor import Adafactor
|
|||||||
from .adafactor_bv import AdafactorBigVision
|
from .adafactor_bv import AdafactorBigVision
|
||||||
from .adahessian import Adahessian
|
from .adahessian import Adahessian
|
||||||
from .adamp import AdamP
|
from .adamp import AdamP
|
||||||
|
from .adamw import AdamWLegacy
|
||||||
from .adan import Adan
|
from .adan import Adan
|
||||||
from .adopt import Adopt
|
from .adopt import Adopt
|
||||||
from .lamb import Lamb
|
from .lamb import Lamb
|
||||||
|
from .laprop import LaProp
|
||||||
from .lars import Lars
|
from .lars import Lars
|
||||||
from .lion import Lion
|
from .lion import Lion
|
||||||
from .lookahead import Lookahead
|
from .lookahead import Lookahead
|
||||||
from .madgrad import MADGRAD
|
from .madgrad import MADGRAD
|
||||||
from .nadam import Nadam
|
from .mars import Mars
|
||||||
|
from .nadam import NAdamLegacy
|
||||||
from .nadamw import NAdamW
|
from .nadamw import NAdamW
|
||||||
from .nvnovograd import NvNovoGrad
|
from .nvnovograd import NvNovoGrad
|
||||||
from .radam import RAdam
|
from .radam import RAdamLegacy
|
||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
from .sgdw import SGDW
|
from .sgdw import SGDW
|
||||||
@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adam',
|
name='adam',
|
||||||
opt_class=optim.Adam,
|
opt_class=optim.Adam,
|
||||||
description='torch.optim Adam (Adaptive Moment Estimation)',
|
description='torch.optim.Adam, Adaptive Moment Estimation',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adamw',
|
name='adamw',
|
||||||
opt_class=optim.AdamW,
|
opt_class=optim.AdamW,
|
||||||
description='torch.optim Adam with decoupled weight decay regularization',
|
description='torch.optim.AdamW, Adam with decoupled weight decay',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adamwlegacy',
|
||||||
|
opt_class=AdamWLegacy,
|
||||||
|
description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
|||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='nadam',
|
name='nadam',
|
||||||
opt_class=Nadam,
|
opt_class=torch.optim.NAdam,
|
||||||
description='Adam with Nesterov momentum',
|
description='torch.optim.NAdam, Adam with Nesterov momentum',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='nadamlegacy',
|
||||||
|
opt_class=NAdamLegacy,
|
||||||
|
description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='nadamw',
|
name='nadamw',
|
||||||
opt_class=NAdamW,
|
opt_class=NAdamW,
|
||||||
description='Adam with Nesterov momentum and decoupled weight decay',
|
description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='radam',
|
name='radam',
|
||||||
opt_class=RAdam,
|
opt_class=torch.optim.RAdam,
|
||||||
description='Rectified Adam with variance adaptation',
|
description='torch.optim.RAdam, Rectified Adam with variance adaptation',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='radamlegacy',
|
||||||
|
opt_class=RAdamLegacy,
|
||||||
|
description='legacy impl of RAdam that predates inclusion in torch.optim',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='radamw',
|
||||||
|
opt_class=torch.optim.RAdam,
|
||||||
|
description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'decoupled_weight_decay': True}
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adamax',
|
name='adamax',
|
||||||
opt_class=optim.Adamax,
|
opt_class=optim.Adamax,
|
||||||
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
|
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adadelta',
|
name='adadelta',
|
||||||
opt_class=optim.Adadelta,
|
opt_class=optim.Adadelta,
|
||||||
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
|
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='adagrad',
|
name='adagrad',
|
||||||
opt_class=optim.Adagrad,
|
opt_class=optim.Adagrad,
|
||||||
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
|
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||||
defaults={'eps': 1e-8}
|
defaults={'eps': 1e-8}
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
has_betas=True,
|
has_betas=True,
|
||||||
second_order=True,
|
second_order=True,
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='laprop',
|
||||||
|
opt_class=LaProp,
|
||||||
|
description='Separating Momentum and Adaptivity in Adam',
|
||||||
|
has_betas=True,
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='lion',
|
name='lion',
|
||||||
opt_class=Lion,
|
opt_class=Lion,
|
||||||
@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
defaults={'decoupled_decay': True}
|
defaults={'decoupled_decay': True}
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='mars',
|
||||||
|
opt_class=Mars,
|
||||||
|
description='Unleashing the Power of Variance Reduction for Training Large Models',
|
||||||
|
has_betas=True,
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='novograd',
|
name='novograd',
|
||||||
opt_class=NvNovoGrad,
|
opt_class=NvNovoGrad,
|
||||||
@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='rmsprop',
|
name='rmsprop',
|
||||||
opt_class=optim.RMSprop,
|
opt_class=optim.RMSprop,
|
||||||
description='torch.optim RMSprop, Root Mean Square Propagation',
|
description='torch.optim.RMSprop, Root Mean Square Propagation',
|
||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
defaults={'alpha': 0.9}
|
defaults={'alpha': 0.9}
|
||||||
),
|
),
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
""" AdamW Optimizer
|
""" AdamW Optimizer
|
||||||
Impl copied from PyTorch master
|
Impl copied from PyTorch master
|
||||||
|
|
||||||
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
|
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
||||||
someday
|
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class AdamW(Optimizer):
|
class AdamWLegacy(Optimizer):
|
||||||
r"""Implements AdamW algorithm.
|
r"""Implements AdamW algorithm.
|
||||||
|
|
||||||
|
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
||||||
|
|
||||||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
||||||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
||||||
|
|
||||||
@ -61,10 +62,10 @@ class AdamW(Optimizer):
|
|||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
amsgrad=amsgrad,
|
amsgrad=amsgrad,
|
||||||
)
|
)
|
||||||
super(AdamW, self).__init__(params, defaults)
|
super(AdamWLegacy, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super(AdamW, self).__setstate__(state)
|
super(AdamWLegacy, self).__setstate__(state)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('amsgrad', False)
|
group.setdefault('amsgrad', False)
|
||||||
|
|
||||||
|
@ -4,9 +4,11 @@ import torch
|
|||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class Nadam(Optimizer):
|
class NAdamLegacy(Optimizer):
|
||||||
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
|
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
|
||||||
|
|
||||||
|
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
|
||||||
|
|
||||||
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
|
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -45,7 +47,7 @@ class Nadam(Optimizer):
|
|||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
schedule_decay=schedule_decay,
|
schedule_decay=schedule_decay,
|
||||||
)
|
)
|
||||||
super(Nadam, self).__init__(params, defaults)
|
super(NAdamLegacy, self).__init__(params, defaults)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
|
@ -1,14 +1,19 @@
|
|||||||
"""RAdam Optimizer.
|
"""RAdam Optimizer.
|
||||||
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
|
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
|
||||||
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
|
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
|
||||||
|
|
||||||
|
NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
|
|
||||||
class RAdam(Optimizer):
|
class RAdamLegacy(Optimizer):
|
||||||
|
""" PyTorch RAdam optimizer
|
||||||
|
|
||||||
|
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params,
|
||||||
@ -24,10 +29,10 @@ class RAdam(Optimizer):
|
|||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
buffer=[[None, None, None] for _ in range(10)]
|
buffer=[[None, None, None] for _ in range(10)]
|
||||||
)
|
)
|
||||||
super(RAdam, self).__init__(params, defaults)
|
super(RAdamLegacy, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super(RAdam, self).__setstate__(state)
|
super(RAdamLegacy, self).__setstate__(state)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user