mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix]: fix ci (#441)
This commit is contained in:
parent
e08b9031fc
commit
d6ad01a4cf
@ -21,13 +21,12 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
||||
RuntimeInfoHook)
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
||||
MultiStepLR, OptimWrapper, OptimWrapperDict,
|
||||
StepLR)
|
||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||
OptimWrapper, OptimWrapperDict, StepLR)
|
||||
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
||||
RUNNERS, Registry)
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||
PARAM_SCHEDULERS, RUNNERS, Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||
@ -216,6 +215,11 @@ class ToyMetric2(BaseMetric):
|
||||
return dict(acc=1)
|
||||
|
||||
|
||||
@OPTIM_WRAPPERS.register_module()
|
||||
class ToyOptimWrapper(OptimWrapper):
|
||||
...
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ToyHook(Hook):
|
||||
priority = 'Lowest'
|
||||
@ -935,12 +939,13 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
||||
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||
|
||||
# Specify the type of optimizer wrapper
|
||||
model = nn.Linear(1, 1)
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper_cfg = dict(
|
||||
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
|
||||
optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2)
|
||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
|
||||
self.assertIsInstance(optim_wrapper, ToyOptimWrapper)
|
||||
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||
self.assertEqual(optim_wrapper._accumulative_counts, 2)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user