mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add NadamW based on mlcommons algorithm, added multi-tensor step
This commit is contained in:
parent
fb4f220c2e
commit
dab0360e00
@ -32,7 +32,12 @@ class Nadam(Optimizer):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
schedule_decay=schedule_decay,
|
||||
)
|
||||
super(Nadam, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -22,6 +22,7 @@ from .lion import Lion
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .nadamw import NAdamW
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
@ -301,6 +302,8 @@ def create_optimizer_v2(
|
||||
optimizer = optim.Nadam(parameters, **opt_args)
|
||||
except AttributeError:
|
||||
optimizer = Nadam(parameters, **opt_args)
|
||||
elif opt_lower == 'nadamw':
|
||||
optimizer = NAdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'radam':
|
||||
optimizer = RAdam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamax':
|
||||
|
Loading…
x
Reference in New Issue
Block a user