diff --git a/tests/test_optim.py b/tests/test_optim.py index 10142eeb..26c09a50 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -12,10 +12,9 @@ import torch from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter -from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay -from timm.scheduler import PlateauLRScheduler - from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class +from timm.optim import param_groups_layer_decay, param_groups_weight_decay +from timm.scheduler import PlateauLRScheduler import importlib import os diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index c3c533cc..53921b76 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -18,4 +18,5 @@ from .rmsprop_tf import RMSpropTF from .sgdp import SGDP from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \ - create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry \ No newline at end of file + create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay \ No newline at end of file