mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix]: fix ci (#441)
This commit is contained in:
parent
e08b9031fc
commit
d6ad01a4cf
@ -21,13 +21,12 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
|||||||
RuntimeInfoHook)
|
RuntimeInfoHook)
|
||||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||||
from mmengine.optim import (AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||||
MultiStepLR, OptimWrapper, OptimWrapperDict,
|
OptimWrapper, OptimWrapperDict, StepLR)
|
||||||
StepLR)
|
|
||||||
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||||
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
||||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
||||||
RUNNERS, Registry)
|
PARAM_SCHEDULERS, RUNNERS, Registry)
|
||||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||||
Runner, TestLoop, ValLoop)
|
Runner, TestLoop, ValLoop)
|
||||||
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||||
@ -216,6 +215,11 @@ class ToyMetric2(BaseMetric):
|
|||||||
return dict(acc=1)
|
return dict(acc=1)
|
||||||
|
|
||||||
|
|
||||||
|
@OPTIM_WRAPPERS.register_module()
|
||||||
|
class ToyOptimWrapper(OptimWrapper):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class ToyHook(Hook):
|
class ToyHook(Hook):
|
||||||
priority = 'Lowest'
|
priority = 'Lowest'
|
||||||
@ -935,12 +939,13 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
self.assertIsInstance(optim_wrapper, OptimWrapper)
|
||||||
self.assertIs(optim_wrapper.optimizer, optimizer)
|
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||||
|
|
||||||
|
# Specify the type of optimizer wrapper
|
||||||
model = nn.Linear(1, 1)
|
model = nn.Linear(1, 1)
|
||||||
optimizer = SGD(model.parameters(), lr=0.1)
|
optimizer = SGD(model.parameters(), lr=0.1)
|
||||||
optim_wrapper_cfg = dict(
|
optim_wrapper_cfg = dict(
|
||||||
optimizer=optimizer, type='AmpOptimWrapper', accumulative_counts=2)
|
optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2)
|
||||||
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
|
||||||
self.assertIsInstance(optim_wrapper, AmpOptimWrapper)
|
self.assertIsInstance(optim_wrapper, ToyOptimWrapper)
|
||||||
self.assertIs(optim_wrapper.optimizer, optimizer)
|
self.assertIs(optim_wrapper.optimizer, optimizer)
|
||||||
self.assertEqual(optim_wrapper._accumulative_counts, 2)
|
self.assertEqual(optim_wrapper._accumulative_counts, 2)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user