From 65bc95036c2a9ea3a27d3428fcc04f44ef2fb9cd Mon Sep 17 00:00:00 2001 From: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Date: Mon, 6 Jun 2022 14:33:32 +0800 Subject: [PATCH] [Enhance] Support Custom Runner (#258) * support custom runner * change build_runner_from_cfg * refine docstring * refine docstring --- mmengine/registry/__init__.py | 5 ++- mmengine/registry/registry.py | 67 ++++++++++++++++++++++++++++++++ mmengine/registry/root.py | 4 +- mmengine/runner/runner.py | 3 +- tests/test_runner/test_runner.py | 51 +++++++++++++++++++++++- 5 files changed, 124 insertions(+), 6 deletions(-) diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index b43fbb2b..bf301e2f 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .default_scope import DefaultScope -from .registry import Registry, build_from_cfg +from .registry import Registry, build_from_cfg, build_runner_from_cfg from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS, PARAM_SCHEDULERS, @@ -14,5 +14,6 @@ __all__ = [ 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS', 'OPTIM_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS', 'LOG_PROCESSORS', 'DefaultScope', - 'traverse_registry_tree', 'count_registered_modules' + 'traverse_registry_tree', 'count_registered_modules', + 'build_runner_from_cfg' ] diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index ac648b9e..9bfd25a4 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -10,6 +10,73 @@ from ..utils import ManagerMixin, is_seq_of from .default_scope import DefaultScope +def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], + registry: 'Registry') -> Any: + """Build a Runner object. + Examples: + >>> from mmengine import Registry, build_runner_from_cfg + >>> RUNNERS = Registry('runners', build_func=build_runner_from_cfg) + >>> @RUNNERS.register_module() + >>> class CustomRunner(Runner): + >>> def setup_env(env_cfg): + >>> pass + >>> cfg = dict(runner_type='CustomRunner', ...) + >>> custom_runner = RUNNERS.build(cfg) + + Args: + cfg (dict or ConfigDict or Config): Config dict. If "runner_type" key + exists, it will be used to build a custom runner. Otherwise, it + will be used to build a default runner. + registry (:obj:`Registry`): The registry to search the type from. + + Returns: + object: The constructed runner object. + """ + from ..logging.logger import MMLogger + + assert isinstance( + cfg, + (dict, ConfigDict, Config + )), f'cfg should be a dict, ConfigDict or Config, but got {type(cfg)}' + assert isinstance( + registry, Registry), ('registry should be a mmengine.Registry object', + f'but got {type(registry)}') + + args = cfg.copy() + obj_type = args.pop('runner_type', 'mmengine.Runner') + if isinstance(obj_type, str): + runner_cls = registry.get(obj_type) + if runner_cls is None: + raise KeyError( + f'{obj_type} is not in the {registry.name} registry. ' + f'Please check whether the value of `{obj_type}` is correct or' + ' it was registered as expected. More details can be found at' + ' https://mmengine.readthedocs.io/en/latest/tutorials/config.html#import-custom-python-modules' # noqa: E501 + ) + elif inspect.isclass(obj_type): + runner_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + try: + runner = runner_cls.from_cfg(args) # type: ignore + logger: MMLogger = MMLogger.get_current_instance() + logger.info( + f'An `{runner_cls.__name__}` instance is built ' # type: ignore + f'from registry, its implementation can be found in' + f'{runner_cls.__module__}') # type: ignore + return runner + + except Exception as e: + # Normal TypeError does not print class name. + cls_location = '/'.join( + runner_cls.__module__.split('.')) # type: ignore + raise type(e)( + f'class `{runner_cls.__name__}` in ' # type: ignore + f'{cls_location}.py: {e}') + + def build_from_cfg( cfg: Union[dict, ConfigDict, Config], registry: 'Registry', diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 5860167a..0abef9bc 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -6,10 +6,10 @@ More datails can be found at https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. """ -from .registry import Registry +from .registry import Registry, build_runner_from_cfg # manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` -RUNNERS = Registry('runner') +RUNNERS = Registry('runner', build_func=build_runner_from_cfg) # manage runner constructors that define how to initialize runners RUNNER_CONSTRUCTORS = Registry('runner constructor') # manage all kinds of loops like `EpochBasedTrainLoop` diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index d7e0d5a7..156b708e 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -30,7 +30,7 @@ from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler, build_optim_wrapper) from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, - VISUALIZERS, DefaultScope, + RUNNERS, VISUALIZERS, DefaultScope, count_registered_modules) from mmengine.registry.root import LOG_PROCESSORS from mmengine.utils import (TORCH_VERSION, digit_version, @@ -49,6 +49,7 @@ ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, OptimWrapperType = Union[OptimWrapper, OptimWrapperDict] +@RUNNERS.register_module() class Runner: """A training helper for PyTorch. diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e43fa711..dcb2bddf 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -24,7 +24,7 @@ from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, PARAM_SCHEDULERS, - Registry) + RUNNERS, Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.priority import Priority, get_priority @@ -215,6 +215,41 @@ class CustomLogProcessor(LogProcessor): self._check_custom_cfg() +@RUNNERS.register_module() +class CustomRunner(Runner): + + def __init__(self, + model, + work_dir, + train_dataloader=None, + val_dataloader=None, + test_dataloader=None, + train_cfg=None, + val_cfg=None, + test_cfg=None, + optimizer=None, + param_scheduler=None, + val_evaluator=None, + test_evaluator=None, + default_hooks=None, + custom_hooks=None, + load_from=None, + resume=False, + launcher='none', + env_cfg=dict(dist_cfg=dict(backend='nccl')), + log_processor=None, + log_level='INFO', + visualizer=None, + default_scope=None, + randomness=dict(seed=None), + experiment_name=None, + cfg=None): + pass + + def setup_env(self, env_cfg): + pass + + def collate_fn(data_batch): return data_batch @@ -1511,3 +1546,17 @@ class TestRunner(TestCase): self.assertTrue(runner._has_loaded) self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) + + def test_build_runner(self): + # No need to test other cases which have been tested in + # `test_build_from_cfg` + # test custom runner + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_build_runner1' + cfg.runner_type = 'CustomRunner' + assert isinstance(RUNNERS.build(cfg), CustomRunner) + + # test default runner + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_build_runner2' + assert isinstance(RUNNERS.build(cfg), Runner)