From dab0360e00b3d0b62bae0cfa0519d1a13b9d4889 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Jun 2023 20:43:44 -0700 Subject: [PATCH] Add NadamW based on mlcommons algorithm, added multi-tensor step --- timm/optim/nadam.py | 7 ++++++- timm/optim/optim_factory.py | 3 +++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py index 6268d5d4..4e911420 100644 --- a/timm/optim/nadam.py +++ b/timm/optim/nadam.py @@ -32,7 +32,12 @@ class Nadam(Optimizer): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay) + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + schedule_decay=schedule_decay, + ) super(Nadam, self).__init__(params, defaults) @torch.no_grad() diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index d8702a12..af316dee 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -22,6 +22,7 @@ from .lion import Lion from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam +from .nadamw import NAdamW from .nvnovograd import NvNovoGrad from .radam import RAdam from .rmsprop_tf import RMSpropTF @@ -301,6 +302,8 @@ def create_optimizer_v2( optimizer = optim.Nadam(parameters, **opt_args) except AttributeError: optimizer = Nadam(parameters, **opt_args) + elif opt_lower == 'nadamw': + optimizer = NAdamW(parameters, **opt_args) elif opt_lower == 'radam': optimizer = RAdam(parameters, **opt_args) elif opt_lower == 'adamax':