Add NadamW based on mlcommons algorithm, added multi-tensor step

This commit is contained in:
Ross Wightman 2023-06-13 20:43:44 -07:00
parent fb4f220c2e
commit dab0360e00
2 changed files with 9 additions and 1 deletions

View File

@ -32,7 +32,12 @@ class Nadam(Optimizer):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
schedule_decay=schedule_decay,
)
super(Nadam, self).__init__(params, defaults)
@torch.no_grad()

View File

@ -22,6 +22,7 @@ from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF
@ -301,6 +302,8 @@ def create_optimizer_v2(
optimizer = optim.Nadam(parameters, **opt_args)
except AttributeError:
optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'nadamw':
optimizer = NAdamW(parameters, **opt_args)
elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamax':