mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Support Custom Runner (#258)
* support custom runner * change build_runner_from_cfg * refine docstring * refine docstring
This commit is contained in:
parent
94c7c3be2c
commit
65bc95036c
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .default_scope import DefaultScope
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .registry import Registry, build_from_cfg, build_runner_from_cfg
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS,
|
||||
METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS,
|
||||
@ -14,5 +14,6 @@ __all__ = [
|
||||
'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS',
|
||||
'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS',
|
||||
'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope',
|
||||
'traverse_registry_tree', 'count_registered_modules'
|
||||
'traverse_registry_tree', 'count_registered_modules',
|
||||
'build_runner_from_cfg'
|
||||
]
|
||||
|
@ -10,6 +10,73 @@ from ..utils import ManagerMixin, is_seq_of
|
||||
from .default_scope import DefaultScope
|
||||
|
||||
|
||||
def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config],
|
||||
registry: 'Registry') -> Any:
|
||||
"""Build a Runner object.
|
||||
Examples:
|
||||
>>> from mmengine import Registry, build_runner_from_cfg
|
||||
>>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg)
|
||||
>>> @RUNNERS.register_module()
|
||||
>>> class CustomRunner(Runner):
|
||||
>>> def setup_env(env_cfg):
|
||||
>>> pass
|
||||
>>> cfg = dict(runner_type='CustomRunner', ...)
|
||||
>>> custom_runner = RUNNERS.build(cfg)
|
||||
|
||||
Args:
|
||||
cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key
|
||||
exists, it will be used to build a custom runner. Otherwise, it
|
||||
will be used to build a default runner.
|
||||
registry (:obj:`Registry`): The registry to search the type from.
|
||||
|
||||
Returns:
|
||||
object: The constructed runner object.
|
||||
"""
|
||||
from ..logging.logger import MMLogger
|
||||
|
||||
assert isinstance(
|
||||
cfg,
|
||||
(dict, ConfigDict, Config
|
||||
)), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}'
|
||||
assert isinstance(
|
||||
registry, Registry), ('registry should be a mmengine.Registry object',
|
||||
f'but got {type(registry)}')
|
||||
|
||||
args = cfg.copy()
|
||||
obj_type = args.pop('runner_type', 'mmengine.Runner')
|
||||
if isinstance(obj_type, str):
|
||||
runner_cls = registry.get(obj_type)
|
||||
if runner_cls is None:
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name} registry. '
|
||||
f'Please check whether the value of `{obj_type}` is correct or'
|
||||
' it was registered as expected. More details can be found at'
|
||||
' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501
|
||||
)
|
||||
elif inspect.isclass(obj_type):
|
||||
runner_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
try:
|
||||
runner = runner_cls.from_cfg(args) # type: ignore
|
||||
logger: MMLogger = MMLogger.get_current_instance()
|
||||
logger.info(
|
||||
f'An `{runner_cls.__name__}` instance is built ' # type: ignore
|
||||
f'from registry, its implementation can be found in'
|
||||
f'{runner_cls.__module__}') # type: ignore
|
||||
return runner
|
||||
|
||||
except Exception as e:
|
||||
# Normal TypeError does not print class name.
|
||||
cls_location = '/'.join(
|
||||
runner_cls.__module__.split('.')) # type: ignore
|
||||
raise type(e)(
|
||||
f'class `{runner_cls.__name__}` in ' # type: ignore
|
||||
f'{cls_location}.py: {e}')
|
||||
|
||||
|
||||
def build_from_cfg(
|
||||
cfg: Union[dict, ConfigDict, Config],
|
||||
registry: 'Registry',
|
||||
|
@ -6,10 +6,10 @@ More datails can be found at
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
"""
|
||||
|
||||
from .registry import Registry
|
||||
from .registry import Registry, build_runner_from_cfg
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
RUNNERS = Registry('runner')
|
||||
RUNNERS = Registry('runner', build_func=build_runner_from_cfg)
|
||||
# manage runner constructors that define how to initialize runners
|
||||
RUNNER_CONSTRUCTORS = Registry('runner constructor')
|
||||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||
|
@ -30,7 +30,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
|
||||
build_optim_wrapper)
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||
VISUALIZERS, DefaultScope,
|
||||
RUNNERS, VISUALIZERS, DefaultScope,
|
||||
count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||
@ -49,6 +49,7 @@ ParamSchedulerType = Union[List[_ParamScheduler], Dict[str,
|
||||
OptimWrapperType = Union[OptimWrapper, OptimWrapperDict]
|
||||
|
||||
|
||||
@RUNNERS.register_module()
|
||||
class Runner:
|
||||
"""A training helper for PyTorch.
|
||||
|
||||
|
@ -24,7 +24,7 @@ from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||
from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS,
|
||||
Registry)
|
||||
RUNNERS, Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.priority import Priority, get_priority
|
||||
@ -215,6 +215,41 @@ class CustomLogProcessor(LogProcessor):
|
||||
self._check_custom_cfg()
|
||||
|
||||
|
||||
@RUNNERS.register_module()
|
||||
class CustomRunner(Runner):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
work_dir,
|
||||
train_dataloader=None,
|
||||
val_dataloader=None,
|
||||
test_dataloader=None,
|
||||
train_cfg=None,
|
||||
val_cfg=None,
|
||||
test_cfg=None,
|
||||
optimizer=None,
|
||||
param_scheduler=None,
|
||||
val_evaluator=None,
|
||||
test_evaluator=None,
|
||||
default_hooks=None,
|
||||
custom_hooks=None,
|
||||
load_from=None,
|
||||
resume=False,
|
||||
launcher='none',
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
log_processor=None,
|
||||
log_level='INFO',
|
||||
visualizer=None,
|
||||
default_scope=None,
|
||||
randomness=dict(seed=None),
|
||||
experiment_name=None,
|
||||
cfg=None):
|
||||
pass
|
||||
|
||||
def setup_env(self, env_cfg):
|
||||
pass
|
||||
|
||||
|
||||
def collate_fn(data_batch):
|
||||
return data_batch
|
||||
|
||||
@ -1511,3 +1546,17 @@ class TestRunner(TestCase):
|
||||
self.assertTrue(runner._has_loaded)
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
def test_build_runner(self):
|
||||
# No need to test other cases which have been tested in
|
||||
# `test_build_from_cfg`
|
||||
# test custom runner
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_runner1'
|
||||
cfg.runner_type = 'CustomRunner'
|
||||
assert isinstance(RUNNERS.build(cfg), CustomRunner)
|
||||
|
||||
# test default runner
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_runner2'
|
||||
assert isinstance(RUNNERS.build(cfg), Runner)
|
||||
|
Loading…
x
Reference in New Issue
Block a user