mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add NadamW based on mlcommons algorithm, added multi-tensor step
This commit is contained in:
parent
fb4f220c2e
commit
dab0360e00
@ -32,7 +32,12 @@ class Nadam(Optimizer):
|
|||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
schedule_decay=schedule_decay,
|
||||||
|
)
|
||||||
super(Nadam, self).__init__(params, defaults)
|
super(Nadam, self).__init__(params, defaults)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -22,6 +22,7 @@ from .lion import Lion
|
|||||||
from .lookahead import Lookahead
|
from .lookahead import Lookahead
|
||||||
from .madgrad import MADGRAD
|
from .madgrad import MADGRAD
|
||||||
from .nadam import Nadam
|
from .nadam import Nadam
|
||||||
|
from .nadamw import NAdamW
|
||||||
from .nvnovograd import NvNovoGrad
|
from .nvnovograd import NvNovoGrad
|
||||||
from .radam import RAdam
|
from .radam import RAdam
|
||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
@ -301,6 +302,8 @@ def create_optimizer_v2(
|
|||||||
optimizer = optim.Nadam(parameters, **opt_args)
|
optimizer = optim.Nadam(parameters, **opt_args)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
optimizer = Nadam(parameters, **opt_args)
|
optimizer = Nadam(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'nadamw':
|
||||||
|
optimizer = NAdamW(parameters, **opt_args)
|
||||||
elif opt_lower == 'radam':
|
elif opt_lower == 'radam':
|
||||||
optimizer = RAdam(parameters, **opt_args)
|
optimizer = RAdam(parameters, **opt_args)
|
||||||
elif opt_lower == 'adamax':
|
elif opt_lower == 'adamax':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user