mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Support Custom Runner (#258)
* support custom runner * change build_runner_from_cfg * refine docstring * refine docstring
This commit is contained in:
parent
94c7c3be2c
commit
65bc95036c
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .default_scope import DefaultScope
|
from .default_scope import DefaultScope
|
||||||
from .registry import Registry, build_from_cfg
|
from .registry import Registry, build_from_cfg, build_runner_from_cfg
|
||||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
|
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
|
||||||
METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||||
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
|
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
|
||||||
@ -14,5 +14,6 @@ __all__ = [
|
|||||||
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
|
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
|
||||||
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
|
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
|
||||||
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope',
|
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope',
|
||||||
'traverse_registry_tree', 'count_registered_modules'
|
'traverse_registry_tree', 'count_registered_modules',
|
||||||
|
'build_runner_from_cfg'
|
||||||
]
|
]
|
||||||
|
@ -10,6 +10,73 @@ from ..utils import ManagerMixin, is_seq_of
|
|||||||
from .default_scope import DefaultScope
|
from .default_scope import DefaultScope
|
||||||
|
|
||||||
|
|
||||||
|
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
||||||
|
registry: 'Registry') -> Any:
|
||||||
|
"""Build a Runner object.
|
||||||
|
Examples:
|
||||||
|
>>> from mmengine import Registry, build_runner_from_cfg
|
||||||
|
>>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg)
|
||||||
|
>>> @RUNNERS.register_module()
|
||||||
|
>>> class CustomRunner(Runner):
|
||||||
|
>>> def setup_env(env_cfg):
|
||||||
|
>>> pass
|
||||||
|
>>> cfg = dict(runner_type='CustomRunner', ...)
|
||||||
|
>>> custom_runner = RUNNERS.build(cfg)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key
|
||||||
|
exists, it will be used to build a custom runner. Otherwise, it
|
||||||
|
will be used to build a default runner.
|
||||||
|
registry (:obj:`Registry`): The registry to search the type from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The constructed runner object.
|
||||||
|
"""
|
||||||
|
from ..logging.logger import MMLogger
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
cfg,
|
||||||
|
(dict, ConfigDict, Config
|
||||||
|
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
|
||||||
|
assert isinstance(
|
||||||
|
registry, Registry), ('registry should be a mmengine.Registry object',
|
||||||
|
f'but got {type(registry)}')
|
||||||
|
|
||||||
|
args = cfg.copy()
|
||||||
|
obj_type = args.pop('runner_type', 'mmengine.Runner')
|
||||||
|
if isinstance(obj_type, str):
|
||||||
|
runner_cls = registry.get(obj_type)
|
||||||
|
if runner_cls is None:
|
||||||
|
raise KeyError(
|
||||||
|
f'{obj_type} is not in the {registry.name} registry. '
|
||||||
|
f'Please check whether the value of `{obj_type}` is correct or'
|
||||||
|
' it was registered as expected. More details can be found at'
|
||||||
|
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
||||||
|
)
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
runner_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
runner = runner_cls.from_cfg(args) # type: ignore
|
||||||
|
logger: MMLogger = MMLogger.get_current_instance()
|
||||||
|
logger.info(
|
||||||
|
f'An `{runner_cls.__name__}` instance is built ' # type: ignore
|
||||||
|
f'from registry, its implementation can be found in'
|
||||||
|
f'{runner_cls.__module__}') # type: ignore
|
||||||
|
return runner
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Normal TypeError does not print class name.
|
||||||
|
cls_location = '/'.join(
|
||||||
|
runner_cls.__module__.split('.')) # type: ignore
|
||||||
|
raise type(e)(
|
||||||
|
f'class `{runner_cls.__name__}` in ' # type: ignore
|
||||||
|
f'{cls_location}.py: {e}')
|
||||||
|
|
||||||
|
|
||||||
def build_from_cfg(
|
def build_from_cfg(
|
||||||
cfg: Union[dict, ConfigDict, Config],
|
cfg: Union[dict, ConfigDict, Config],
|
||||||
registry: 'Registry',
|
registry: 'Registry',
|
||||||
|
@ -6,10 +6,10 @@ More datails can be found at
|
|||||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .registry import Registry
|
from .registry import Registry, build_runner_from_cfg
|
||||||
|
|
||||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||||
RUNNERS = Registry('runner')
|
RUNNERS = Registry('runner', build_func=build_runner_from_cfg)
|
||||||
# manage runner constructors that define how to initialize runners
|
# manage runner constructors that define how to initialize runners
|
||||||
RUNNER_CONSTRUCTORS = Registry('runner constructor')
|
RUNNER_CONSTRUCTORS = Registry('runner constructor')
|
||||||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||||
|
@ -30,7 +30,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
|||||||
build_optim_wrapper)
|
build_optim_wrapper)
|
||||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||||
VISUALIZERS, DefaultScope,
|
RUNNERS, VISUALIZERS, DefaultScope,
|
||||||
count_registered_modules)
|
count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
from mmengine.registry.root import LOG_PROCESSORS
|
||||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||||
@ -49,6 +49,7 @@ ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
|
|||||||
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
|
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
|
||||||
|
|
||||||
|
|
||||||
|
@RUNNERS.register_module()
|
||||||
class Runner:
|
class Runner:
|
||||||
"""A training helper for PyTorch.
|
"""A training helper for PyTorch.
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
|||||||
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
|
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
|
||||||
MODEL_WRAPPERS, MODELS,
|
MODEL_WRAPPERS, MODELS,
|
||||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
||||||
Registry)
|
RUNNERS, Registry)
|
||||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||||
Runner, TestLoop, ValLoop)
|
Runner, TestLoop, ValLoop)
|
||||||
from mmengine.runner.priority import Priority, get_priority
|
from mmengine.runner.priority import Priority, get_priority
|
||||||
@ -215,6 +215,41 @@ class CustomLogProcessor(LogProcessor):
|
|||||||
self._check_custom_cfg()
|
self._check_custom_cfg()
|
||||||
|
|
||||||
|
|
||||||
|
@RUNNERS.register_module()
|
||||||
|
class CustomRunner(Runner):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model,
|
||||||
|
work_dir,
|
||||||
|
train_dataloader=None,
|
||||||
|
val_dataloader=None,
|
||||||
|
test_dataloader=None,
|
||||||
|
train_cfg=None,
|
||||||
|
val_cfg=None,
|
||||||
|
test_cfg=None,
|
||||||
|
optimizer=None,
|
||||||
|
param_scheduler=None,
|
||||||
|
val_evaluator=None,
|
||||||
|
test_evaluator=None,
|
||||||
|
default_hooks=None,
|
||||||
|
custom_hooks=None,
|
||||||
|
load_from=None,
|
||||||
|
resume=False,
|
||||||
|
launcher='none',
|
||||||
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||||
|
log_processor=None,
|
||||||
|
log_level='INFO',
|
||||||
|
visualizer=None,
|
||||||
|
default_scope=None,
|
||||||
|
randomness=dict(seed=None),
|
||||||
|
experiment_name=None,
|
||||||
|
cfg=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_env(self, env_cfg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(data_batch):
|
def collate_fn(data_batch):
|
||||||
return data_batch
|
return data_batch
|
||||||
|
|
||||||
@ -1511,3 +1546,17 @@ class TestRunner(TestCase):
|
|||||||
self.assertTrue(runner._has_loaded)
|
self.assertTrue(runner._has_loaded)
|
||||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
|
def test_build_runner(self):
|
||||||
|
# No need to test other cases which have been tested in
|
||||||
|
# `test_build_from_cfg`
|
||||||
|
# test custom runner
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_runner1'
|
||||||
|
cfg.runner_type = 'CustomRunner'
|
||||||
|
assert isinstance(RUNNERS.build(cfg), CustomRunner)
|
||||||
|
|
||||||
|
# test default runner
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_runner2'
|
||||||
|
assert isinstance(RUNNERS.build(cfg), Runner)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user