diff --git a/tests/test_optim.py b/tests/test_optim.py index d70ec98d..e10ed532 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -294,7 +294,7 @@ def _build_params_dict_single(weight, bias, **kwargs): @pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) def test_optim_factory(optimizer): - assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer) + assert issubclass(get_optimizer_class(optimizer, bind_defaults=False), torch.optim.Optimizer) opt_info = get_optimizer_info(optimizer) assert isinstance(opt_info, OptimInfo) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 552585c9..35cf7bc0 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -12,10 +12,12 @@ from .lion import Lion from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam +from .nadamw import NAdamW from .nvnovograd import NvNovoGrad from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP +from .sgdw import SGDW from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ create_optimizer_v2, create_optimizer, optimizer_kwargs