Add NadamW based on mlcommons algorithm, added multi-tensor step

This commit is contained in:
Ross Wightman 2023-06-13 20:43:44 -07:00
parent fb4f220c2e
commit dab0360e00
2 changed files with 9 additions and 1 deletions

View File

@ -32,7 +32,12 @@ class Nadam(Optimizer):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
defaults = dict( defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
schedule_decay=schedule_decay,
)
super(Nadam, self).__init__(params, defaults) super(Nadam, self).__init__(params, defaults)
@torch.no_grad() @torch.no_grad()

View File

@ -22,6 +22,7 @@ from .lion import Lion
from .lookahead import Lookahead from .lookahead import Lookahead
from .madgrad import MADGRAD from .madgrad import MADGRAD
from .nadam import Nadam from .nadam import Nadam
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad from .nvnovograd import NvNovoGrad
from .radam import RAdam from .radam import RAdam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
@ -301,6 +302,8 @@ def create_optimizer_v2(
optimizer = optim.Nadam(parameters, **opt_args) optimizer = optim.Nadam(parameters, **opt_args)
except AttributeError: except AttributeError:
optimizer = Nadam(parameters, **opt_args) optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'nadamw':
optimizer = NAdamW(parameters, **opt_args)
elif opt_lower == 'radam': elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args) optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamax': elif opt_lower == 'adamax':