From d6ad01a4cfd9e9274cfeaff274e9a2a0ce2b76df Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Thu, 18 Aug 2022 14:04:19 +0800 Subject: [PATCH] [Fix]: fix ci (#441) --- tests/test_runner/test_runner.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 40494d92..cda9bac0 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -21,13 +21,12 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, RuntimeInfoHook) from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor, - MultiStepLR, OptimWrapper, OptimWrapperDict, - StepLR) +from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, + OptimWrapper, OptimWrapperDict, StepLR) from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, - OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS, - RUNNERS, Registry) + OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, + PARAM_SCHEDULERS, RUNNERS, Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.loops import _InfiniteDataloaderIterator @@ -216,6 +215,11 @@ class ToyMetric2(BaseMetric): return dict(acc=1) +@OPTIM_WRAPPERS.register_module() +class ToyOptimWrapper(OptimWrapper): + ... + + @HOOKS.register_module() class ToyHook(Hook): priority = 'Lowest' @@ -935,12 +939,13 @@ class TestRunner(TestCase): self.assertIsInstance(optim_wrapper, OptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) + # Specify the type of optimizer wrapper model = nn.Linear(1, 1) optimizer = SGD(model.parameters(), lr=0.1) optim_wrapper_cfg = dict( - optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2) + optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) - self.assertIsInstance(optim_wrapper, AmpOptimWrapper) + self.assertIsInstance(optim_wrapper, ToyOptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) self.assertEqual(optim_wrapper._accumulative_counts, 2)