[Enhancement] Refactor Runner (#139)

* [Enhancement] Rename build_from_cfg to from_cfg

* refactor build_logger and build_message_hub

* remove time.sleep from unit tests

* minor fix

* move set_randomness from setup_env

* improve docstring

* refine comments

* print a warning information

* refine comments

* simplify the interface of build_logger
This commit is contained in:
Zaida Zhou 2022-03-30 14:26:40 +08:00 committed by GitHub
parent 9a61b389e7
commit f1de071cf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 240 additions and 218 deletions

View File

@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
DefaultScope)
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
from mmengine.visualization import ComposedWriter
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict]
class Runner:
"""A training helper for PyTorch.
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
where the ``cfg`` usually contains training, validation, and test-related
configurations to build corresponding components. We usually use the
same config to launch training, testing, and validation tasks. However,
only some of these components are necessary at the same time, e.g.,
testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
lazy initialization to only initialize components when they are going to be
used. Therefore, the model is always initialized at the beginning, and
training, validation, and, testing related components are only initialized
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
respectively.
Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
a dict used for build a model.
@ -114,23 +129,21 @@ class Runner:
non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')).
logger (MMLogger or dict, optional): A MMLogger object or a dict to
build MMLogger object. Defaults to None. If not specified, default
config will be used.
message_hub (MessageHub or dict, optional): A Messagehub object or a
dict to build MessageHub object. Defaults to None. If not
specified, default config will be used.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
dict build ComposedWriter object. Defaults to None. If not
specified, default config will be used.
default_scope (str, optional): Used to reset registries location.
Defaults to None.
seed (int, optional): A number to set random modules. If not specified,
a random number will be set as seed. Defaults to None.
deterministic (bool): Whether cudnn to select deterministic algorithms.
Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
randomness (dict): Some settings to make the experiment as reproducible
as possible like seed and deterministic.
Defaults to ``dict(seed=None)``. If seed is None, a random number
will be generated and it will be broadcasted to all other processes
if in distributed environment. If ``cudnn_benchmarch`` is
``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.
experiment_name (str, optional): Name of current experiment. If not
specified, timestamp will be used as ``experiment_name``.
Defaults to None.
@ -173,13 +186,11 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
logger=dict(log_level='INFO'),
message_hub=None,
writer=dict(
name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
)
>>> runner = Runner.build_from_cfg(cfg)
>>> runner = Runner.from_cfg(cfg)
>>> runner.train()
>>> runner.test()
"""
@ -208,19 +219,17 @@ class Runner:
resume: bool = False,
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
logger: Optional[Union[MMLogger, Dict]] = None,
message_hub: Optional[Union[MessageHub, Dict]] = None,
log_level: str = 'INFO',
writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None,
seed: Optional[int] = None,
deterministic: bool = False,
randomness: Dict = dict(seed=None),
experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None,
):
self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the ``cfg`` because `self.cfg` will be modified
# recursively copy the `cfg` because `self.cfg` will be modified
# everywhere.
if cfg is not None:
self.cfg = copy.deepcopy(cfg)
@ -231,21 +240,25 @@ class Runner:
self._iter = 0
# lazy initialization
training_related = [
train_dataloader, train_cfg, optimizer, param_scheduler
]
training_related = [train_dataloader, train_cfg, optimizer]
if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)):
raise ValueError(
'train_dataloader, train_cfg, optimizer and param_scheduler '
'should be either all None or not None, but got '
'train_dataloader, train_cfg, and optimizer should be either '
'all None or not None, but got '
f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, '
f'optimizer={optimizer}, '
f'param_scheduler={param_scheduler}.')
f'optimizer={optimizer}.')
self.train_dataloader = train_dataloader
self.train_loop = train_cfg
self.optimizer = optimizer
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optimizer is None:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler]
else:
@ -256,7 +269,7 @@ class Runner:
for item in val_related) or all(item is not None
for item in val_related)):
raise ValueError(
'val_dataloader, val_cfg and val_evaluator should be either '
'val_dataloader, val_cfg, and val_evaluator should be either '
'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}')
@ -268,8 +281,8 @@ class Runner:
if not (all(item is None for item in test_related)
or all(item is not None for item in test_related)):
raise ValueError(
'test_dataloader, test_cfg and test_evaluator should be either'
' all None or not None, but got '
'test_dataloader, test_cfg, and test_evaluator should be '
'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}')
self.test_dataloader = test_dataloader
@ -282,10 +295,13 @@ class Runner:
else:
self._distributed = True
# self._deterministic, self._seed and self._timestamp will be set in
# the `setup_env`` method. Besides, it also will initialize
# multi-process and (or) distributed environment.
self.setup_env(env_cfg, seed, deterministic)
# self._timestamp will be set in the `setup_env` method. Besides,
# it also will initialize multi-process and (or) distributed
# environment.
self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self.set_randomness(**randomness)
if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}'
@ -296,9 +312,9 @@ class Runner:
else:
self._experiment_name = self.timestamp
self.logger = self.build_logger(logger)
self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction
self.message_hub = self.build_message_hub(message_hub)
self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for
@ -333,7 +349,7 @@ class Runner:
self.dump_config()
@classmethod
def build_from_cfg(cls, cfg: ConfigType) -> 'Runner':
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config.
Args:
@ -363,12 +379,11 @@ class Runner:
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
logger=cfg.get('log_cfg'),
message_hub=cfg.get('message_hub'),
log_level=cfg.get('log_level', 'INFO'),
writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'),
seed=cfg.get('seed'),
deterministic=cfg.get('deterministic', False),
randomness=cfg.get('randomness', dict(seed=None)),
experiment_name=cfg.get('experiment_name'),
cfg=cfg,
)
@ -439,10 +454,7 @@ class Runner:
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
def setup_env(self,
env_cfg: Dict,
seed: Optional[int],
deterministic: bool = False) -> None:
def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.
An example of ``env_cfg``::
@ -458,17 +470,7 @@ class Runner:
Args:
env_cfg (dict): Config for setting environment.
seed (int, optional): A number to set random modules. If not
specified, a random number will be set as seed.
Defaults to None.
deterministic (bool): Whether cudnn to select deterministic
algorithms. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
"""
self._deterministic = deterministic
self._seed = seed
if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True
@ -490,9 +492,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item()))
# set random seeds
self._set_random_seed()
def _set_multi_processing(self,
mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None:
@ -546,16 +545,20 @@ class Runner:
'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def _set_random_seed(self) -> None:
def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results.
Warning:
Results can not be guaranteed to resproducible if ``self.seed`` is
None because :meth:`_set_random_seed` will generate a random seed
when launching a new experiment.
See https://pytorch.org/docs/stable/notes/randomness.html for details.
Args:
seed (int): A number to set random modules.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
"""
self._deterministic = deterministic
self._seed = seed
if self._seed is None:
self._seed = sync_random_seed()
@ -563,69 +566,62 @@ class Runner:
np.random.seed(self._seed)
torch.manual_seed(self._seed)
torch.cuda.manual_seed_all(self._seed)
if self._deterministic:
torch.backends.cudnn.deterministic = True
if deterministic:
if torch.backends.cudnn.benchmark:
warnings.warn(
'torch.backends.cudnn.benchmark is going to be set as '
'`False` to cause cuDNN to deterministically select an '
'algorithm')
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
torch.use_deterministic_algorithms(True)
def build_logger(self,
logger: Optional[Union[MMLogger,
Dict]] = None) -> MMLogger:
log_level: Union[int, str] = 'INFO',
log_file: str = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger.
Args:
logger (MMLogger or dict, optional): A MMLogger object or a dict to
build MMLogger object. If ``logger`` is a MMLogger object, just
returns itself. If not specified, default config will be used
to build MMLogger object. Defaults to None.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
log_file (str, optional): Path of filename to save log.
Defaults to None.
**kwargs: Remaining parameters passed to ``MMLogger``.
Returns:
MMLogger: A MMLogger object build from ``logger``.
"""
if isinstance(logger, MMLogger):
return logger
elif logger is None:
logger = dict(
name=self._experiment_name,
log_level='INFO',
log_file=osp.join(self.work_dir,
f'{self._experiment_name}.log'))
elif isinstance(logger, dict):
# ensure logger containing name key
logger.setdefault('name', self._experiment_name)
else:
raise TypeError(
'logger should be MMLogger object, a dict or None, '
f'but got {logger}')
if log_file is None:
log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
return MMLogger.get_instance(**logger)
log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
def build_message_hub(
self,
message_hub: Optional[Union[MessageHub,
Dict]] = None) -> MessageHub:
return MMLogger.get_instance(**log_cfg) # type: ignore
def build_message_hub(self,
message_hub: Optional[Dict] = None) -> MessageHub:
"""Build a global asscessable MessageHub.
Args:
message_hub (MessageHub or dict, optional): A MessageHub object or
a dict to build MessageHub object. If ``message_hub`` is a
MessageHub object, just returns itself. If not specified,
default config will be used to build MessageHub object.
Defaults to None.
message_hub (dict, optional): A dict to build MessageHub object.
If not specified, default config will be used to build
MessageHub object. Defaults to None.
Returns:
MessageHub: A MessageHub object build from ``message_hub``.
"""
if isinstance(message_hub, MessageHub):
return message_hub
elif message_hub is None:
if message_hub is None:
message_hub = dict(name=self._experiment_name)
elif isinstance(message_hub, dict):
# ensure message_hub containing name key
message_hub.setdefault('name', self._experiment_name)
else:
raise TypeError(
'message_hub should be MessageHub object, a dict or None, '
f'but got {message_hub}')
f'message_hub should be dict or None, but got {message_hub}')
return MessageHub.get_instance(**message_hub)

View File

@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import os.path as osp
import shutil
import tempfile
import time
from unittest import TestCase
import torch
@ -229,8 +227,6 @@ class TestRunner(TestCase):
optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook'))
time.sleep(1)
def tearDown(self):
shutil.rmtree(self.temp_dir)
@ -238,89 +234,99 @@ class TestRunner(TestCase):
# 1. test arguments
# 1.1 train_dataloader, train_cfg, optimizer and param_scheduler
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init1'
cfg.pop('train_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg)
# all of training related configs are None
# all of training related configs are None and param_scheduler should
# also be None
cfg.experiment_name = 'test_init2'
cfg.pop('train_dataloader')
cfg.pop('optimizer')
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# avoid different runners having same timestamp
time.sleep(1)
# all of training related configs are not None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init3'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# all of training related configs are not None and param_scheduler
# can be None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init4'
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# param_scheduler should be None when optimizer is None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init5'
cfg.pop('train_cfg')
cfg.pop('train_dataloader')
cfg.pop('optimizer')
with self.assertRaisesRegex(ValueError, 'should be None'):
runner = Runner(**cfg)
# 1.2 val_dataloader, val_evaluator, val_cfg
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init6'
cfg.pop('val_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg)
time.sleep(1)
cfg.experiment_name = 'test_init7'
cfg.pop('val_dataloader')
cfg.pop('val_evaluator')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
time.sleep(1)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init8'
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# 1.3 test_dataloader, test_evaluator and test_cfg
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init9'
cfg.pop('test_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
runner = Runner(**cfg)
time.sleep(1)
cfg.experiment_name = 'test_init10'
cfg.pop('test_dataloader')
cfg.pop('test_evaluator')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
time.sleep(1)
# 1.4 test env params
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init11'
runner = Runner(**cfg)
self.assertFalse(runner.distributed)
self.assertFalse(runner.deterministic)
time.sleep(1)
# 1.5 message_hub, logger and writer
# they are all not specified
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init12'
runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertIsInstance(runner.writer, ComposedWriter)
time.sleep(1)
# they are all specified
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.logger = dict(name='test_logger')
cfg.message_hub = dict(name='test_message_hub')
cfg.experiment_name = 'test_init13'
cfg.log_level = 'INFO'
cfg.writer = dict(name='test_writer')
runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.logger.instance_name, 'test_logger')
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name, 'test_message_hub')
self.assertIsInstance(runner.writer, ComposedWriter)
self.assertEqual(runner.writer.instance_name, 'test_writer')
assert runner.distributed is False
assert runner.seed is not None
@ -358,7 +364,6 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
time.sleep(1)
# 4. initialize runner with objects rather than config
model = ToyModel()
optimizer = SGD(
@ -385,84 +390,66 @@ class TestRunner(TestCase):
test_dataloader=test_dataloader,
test_evaluator=ToyEvaluator1(),
default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2])
custom_hooks=[toy_hook2],
experiment_name='test_init14')
runner.train()
runner.test()
# 5. test `dump_config`
# TODO
def test_build_from_cfg(self):
runner = Runner.build_from_cfg(cfg=self.epoch_based_cfg)
def test_from_cfg(self):
runner = Runner.from_cfg(cfg=self.epoch_based_cfg)
self.assertIsInstance(runner, Runner)
def test_setup_env(self):
# TODO
pass
def test_logger(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
def test_build_logger(self):
self.epoch_based_cfg.experiment_name = 'test_build_logger1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.experiment_name, runner.logger.instance_name)
self.assertEqual(runner.logger.level, logging.NOTSET)
# input is a MMLogger object
self.assertEqual(
id(runner.build_logger(runner.logger)), id(runner.logger))
# input is None
runner._experiment_name = 'logger_name1'
logger = runner.build_logger(None)
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name1')
# input is a dict
log_cfg = dict(name='logger_name2')
logger = runner.build_logger(log_cfg)
logger = runner.build_logger(name='test_build_logger2')
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name2')
self.assertEqual(logger.instance_name, 'test_build_logger2')
# input is a dict but does not contain name key
runner._experiment_name = 'logger_name3'
log_cfg = dict()
logger = runner.build_logger(log_cfg)
runner._experiment_name = 'test_build_logger3'
logger = runner.build_logger()
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name3')
# input is not a valid type
with self.assertRaisesRegex(TypeError, 'logger should be'):
runner.build_logger('invalid-type')
self.assertEqual(logger.instance_name, 'test_build_logger3')
def test_build_message_hub(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
self.epoch_based_cfg.experiment_name = 'test_build_message_hub1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name,
runner.experiment_name)
# input is a MessageHub object
self.assertEqual(
id(runner.build_message_hub(runner.message_hub)),
id(runner.message_hub))
# input is a dict
message_hub_cfg = dict(name='message_hub_name1')
message_hub_cfg = dict(name='test_build_message_hub2')
message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'message_hub_name1')
self.assertEqual(message_hub.instance_name, 'test_build_message_hub2')
# input is a dict but does not contain name key
runner._experiment_name = 'message_hub_name2'
runner._experiment_name = 'test_build_message_hub3'
message_hub_cfg = dict()
message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'message_hub_name2')
self.assertEqual(message_hub.instance_name, 'test_build_message_hub3')
# input is not a valid type
with self.assertRaisesRegex(TypeError, 'message_hub should be'):
runner.build_message_hub('invalid-type')
def test_build_writer(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
self.epoch_based_cfg.experiment_name = 'test_build_writer1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.writer, ComposedWriter)
self.assertEqual(runner.experiment_name, runner.writer.instance_name)
@ -471,17 +458,17 @@ class TestRunner(TestCase):
id(runner.build_writer(runner.writer)), id(runner.writer))
# input is a dict
writer_cfg = dict(name='writer_name1')
writer_cfg = dict(name='test_build_writer2')
writer = runner.build_writer(writer_cfg)
self.assertIsInstance(writer, ComposedWriter)
self.assertEqual(writer.instance_name, 'writer_name1')
self.assertEqual(writer.instance_name, 'test_build_writer2')
# input is a dict but does not contain name key
runner._experiment_name = 'writer_name2'
runner._experiment_name = 'test_build_writer3'
writer_cfg = dict()
writer = runner.build_writer(writer_cfg)
self.assertIsInstance(writer, ComposedWriter)
self.assertEqual(writer.instance_name, 'writer_name2')
self.assertEqual(writer.instance_name, 'test_build_writer3')
# input is not a valid type
with self.assertRaisesRegex(TypeError, 'writer should be'):
@ -501,12 +488,16 @@ class TestRunner(TestCase):
type='ToyScheduler', milestones=[1, 2])
self.epoch_based_cfg.default_scope = 'toy'
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_scope'
runner = Runner.from_cfg(cfg)
runner.train()
self.assertIsInstance(runner.param_schedulers[0], ToyScheduler)
def test_build_model(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_model'
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, ToyModel)
# input should be a nn.Module object or dict
@ -526,12 +517,15 @@ class TestRunner(TestCase):
# TODO: test on distributed environment
# custom model wrapper
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_wrap_model'
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
runner = Runner.build_from_cfg(cfg)
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, CustomModelWrapper)
def test_build_optimizer(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_optimizer'
runner = Runner.from_cfg(cfg)
# input should be an Optimizer object or dict
with self.assertRaisesRegex(TypeError, 'optimizer should be'):
@ -547,7 +541,9 @@ class TestRunner(TestCase):
self.assertIsInstance(optimizer, SGD)
def test_build_param_scheduler(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler'
runner = Runner.from_cfg(cfg)
# `build_optimizer` should be called before `build_param_scheduler`
cfg = dict(type='MultiStepLR', milestones=[1, 2])
@ -584,7 +580,9 @@ class TestRunner(TestCase):
self.assertIsInstance(param_schedulers[1], StepLR)
def test_build_evaluator(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_evaluator'
runner = Runner.from_cfg(cfg)
# input is a BaseEvaluator or ComposedEvaluator object
evaluator = ToyEvaluator1()
@ -603,7 +601,9 @@ class TestRunner(TestCase):
runner.build_evaluator(evaluator), ComposedEvaluator)
def test_build_dataloader(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_dataloader'
runner = Runner.from_cfg(cfg)
cfg = dict(
dataset=dict(type='ToyDataset'),
@ -616,8 +616,11 @@ class TestRunner(TestCase):
self.assertIsInstance(dataloader.sampler, DefaultSampler)
def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict
runner = Runner.build_from_cfg(self.epoch_based_cfg)
with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_train_loop('invalid-type')
@ -653,7 +656,9 @@ class TestRunner(TestCase):
self.assertIsInstance(loop, CustomTrainLoop)
def test_build_val_loop(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_val_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'):
@ -678,7 +683,9 @@ class TestRunner(TestCase):
self.assertIsInstance(loop, CustomValLoop)
def test_build_test_loop(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_test_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'):
@ -705,16 +712,15 @@ class TestRunner(TestCase):
def test_train(self):
# 1. test `self.train_loop` is None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train1'
cfg.pop('train_dataloader')
cfg.pop('train_cfg')
cfg.pop('optimizer')
cfg.pop('param_scheduler')
runner = Runner.build_from_cfg(cfg)
runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.train()
time.sleep(1)
# 2. test iter and epoch counter of EpochBasedTrainLoop
epoch_results = []
epoch_targets = [i for i in range(3)]
@ -733,10 +739,10 @@ class TestRunner(TestCase):
iter_results.append(runner.iter)
batch_idx_results.append(batch_idx)
self.epoch_based_cfg.custom_hooks = [
dict(type='TestEpochHook', priority=50)
]
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train2'
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
runner = Runner.from_cfg(cfg)
runner.train()
@ -749,8 +755,6 @@ class TestRunner(TestCase):
for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target)
time.sleep(1)
# 3. test iter and epoch counter of IterBasedTrainLoop
epoch_results = []
iter_results = []
@ -768,11 +772,11 @@ class TestRunner(TestCase):
iter_results.append(runner.iter)
batch_idx_results.append(batch_idx)
self.iter_based_cfg.custom_hooks = [
dict(type='TestIterHook', priority=50)
]
self.iter_based_cfg.val_cfg = dict(interval=4)
runner = Runner.build_from_cfg(self.iter_based_cfg)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train3'
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
cfg.val_cfg = dict(interval=4)
runner = Runner.from_cfg(cfg)
runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop)
@ -786,32 +790,38 @@ class TestRunner(TestCase):
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'
cfg.pop('val_dataloader')
cfg.pop('val_cfg')
cfg.pop('val_evaluator')
runner = Runner.build_from_cfg(cfg)
runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.val()
time.sleep(1)
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val2'
runner = Runner.from_cfg(cfg)
runner.val()
def test_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test1'
cfg.pop('test_dataloader')
cfg.pop('test_cfg')
cfg.pop('test_evaluator')
runner = Runner.build_from_cfg(cfg)
runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.test()
time.sleep(1)
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test2'
runner = Runner.from_cfg(cfg)
runner.test()
def test_register_hook(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hook'
runner = Runner.from_cfg(cfg)
runner._hooks = []
# 1. test `hook` parameter
@ -870,7 +880,9 @@ class TestRunner(TestCase):
get_priority(runner._hooks[3].priority), get_priority('VERY_LOW'))
def test_default_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_hooks'
runner = Runner.from_cfg(cfg)
runner._hooks = []
# register five hooks by default
@ -893,7 +905,10 @@ class TestRunner(TestCase):
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
def test_custom_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_custom_hooks'
runner = Runner.from_cfg(cfg)
self.assertEqual(len(runner._hooks), 5)
custom_hooks = [dict(type='ToyHook')]
runner.register_custom_hooks(custom_hooks)
@ -901,7 +916,10 @@ class TestRunner(TestCase):
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
def test_register_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hooks'
runner = Runner.from_cfg(cfg)
runner._hooks = []
custom_hooks = [dict(type='ToyHook')]
runner.register_hooks(custom_hooks=custom_hooks)
@ -975,7 +993,8 @@ class TestRunner(TestCase):
self.iter_based_cfg.custom_hooks = [
dict(type='TestWarmupHook', priority=50)
]
runner = Runner.build_from_cfg(self.iter_based_cfg)
self.iter_based_cfg.experiment_name = 'test_custom_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.train()
self.assertIsInstance(runner.train_loop, CustomTrainLoop2)
@ -990,7 +1009,9 @@ class TestRunner(TestCase):
def test_checkpoint(self):
# 1. test epoch based
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint1'
runner = Runner.from_cfg(cfg)
runner.train()
# 1.1 test `save_checkpoint` which called by `CheckpointHook`
@ -1006,17 +1027,19 @@ class TestRunner(TestCase):
assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list)
time.sleep(1)
# 1.2 test `load_checkpoint`
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint2'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded)
time.sleep(1)
# 1.3 test `resume`
runner = Runner.build_from_cfg(self.epoch_based_cfg)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint3'
runner = Runner.from_cfg(cfg)
runner.resume(path)
self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12)
@ -1025,8 +1048,9 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2. test iter based
time.sleep(1)
runner = Runner.build_from_cfg(self.iter_based_cfg)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint4'
runner = Runner.from_cfg(cfg)
runner.train()
# 2.1 test `save_checkpoint` which called by `CheckpointHook`
@ -1043,16 +1067,18 @@ class TestRunner(TestCase):
assert isinstance(ckpt['param_schedulers'], list)
# 2.2 test `load_checkpoint`
time.sleep(1)
runner = Runner.build_from_cfg(self.iter_based_cfg)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint5'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded)
time.sleep(1)
# 2.3 test `resume`
runner = Runner.build_from_cfg(self.iter_based_cfg)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
runner = Runner.from_cfg(cfg)
runner.resume(path)
self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12)