diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 40494d92..cda9bac0 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -21,13 +21,12 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, RuntimeInfoHook) from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor, - MultiStepLR, OptimWrapper, OptimWrapperDict, - StepLR) +from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, + OptimWrapper, OptimWrapperDict, StepLR) from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, - OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS, - RUNNERS, Registry) + OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, + PARAM_SCHEDULERS, RUNNERS, Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.loops import _InfiniteDataloaderIterator @@ -216,6 +215,11 @@ class ToyMetric2(BaseMetric): return dict(acc=1) +@OPTIM_WRAPPERS.register_module() +class ToyOptimWrapper(OptimWrapper): + ... + + @HOOKS.register_module() class ToyHook(Hook): priority = 'Lowest' @@ -935,12 +939,13 @@ class TestRunner(TestCase): self.assertIsInstance(optim_wrapper, OptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) + # Specify the type of optimizer wrapper model = nn.Linear(1, 1) optimizer = SGD(model.parameters(), lr=0.1) optim_wrapper_cfg = dict( - optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2) + optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - self.assertIsInstance(optim_wrapper, AmpOptimWrapper) + self.assertIsInstance(optim_wrapper, ToyOptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) self.assertEqual(optim_wrapper._accumulative_counts, 2)