[Fix]: fix ci (#441)

This commit is contained in:
Mashiro 2022-08-18 14:04:19 +08:00 committed by GitHub
parent e08b9031fc
commit d6ad01a4cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -21,13 +21,12 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
RuntimeInfoHook)
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
MultiStepLR, OptimWrapper, OptimWrapperDict,
StepLR)
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR)
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
RUNNERS, Registry)
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, RUNNERS, Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.loops import _InfiniteDataloaderIterator
@ -216,6 +215,11 @@ class ToyMetric2(BaseMetric):
return dict(acc=1)
@OPTIM_WRAPPERS.register_module()
class ToyOptimWrapper(OptimWrapper):
...
@HOOKS.register_module()
class ToyHook(Hook):
priority = 'Lowest'
@ -935,12 +939,13 @@ class TestRunner(TestCase):
self.assertIsInstance(optim_wrapper, OptimWrapper)
self.assertIs(optim_wrapper.optimizer, optimizer)
# Specify the type of optimizer wrapper
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper_cfg = dict(
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2)
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
self.assertIsInstance(optim_wrapper, ToyOptimWrapper)
self.assertIs(optim_wrapper.optimizer, optimizer)
self.assertEqual(optim_wrapper._accumulative_counts, 2)