[Enhancement] Refactor Runner (#139)

* [Enhancement] Rename build_from_cfg to from_cfg

* refactor build_logger and build_message_hub

* remove time.sleep from unit tests

* minor fix

* move set_randomness from setup_env

* improve docstring

* refine comments

* print a warning information

* refine comments

* simplify the interface of build_logger
This commit is contained in:
Zaida Zhou 2022-03-30 14:26:40 +08:00 committed by GitHub
parent 9a61b389e7
commit f1de071cf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 240 additions and 218 deletions

View File

@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
DefaultScope) DefaultScope)
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
from mmengine.visualization import ComposedWriter from mmengine.visualization import ComposedWriter
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict]
class Runner: class Runner:
"""A training helper for PyTorch. """A training helper for PyTorch.
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
where the ``cfg`` usually contains training, validation, and test-related
configurations to build corresponding components. We usually use the
same config to launch training, testing, and validation tasks. However,
only some of these components are necessary at the same time, e.g.,
testing a model does not need training or validation-related components.
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
lazy initialization to only initialize components when they are going to be
used. Therefore, the model is always initialized at the beginning, and
training, validation, and, testing related components are only initialized
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
respectively.
Args: Args:
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
a dict used for build a model. a dict used for build a model.
@ -114,23 +129,21 @@ class Runner:
non-distributed environment will be launched. non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')). dict(dist_cfg=dict(backend='nccl')).
logger (MMLogger or dict, optional): A MMLogger object or a dict to log_level (int or str): The log level of MMLogger handlers.
build MMLogger object. Defaults to None. If not specified, default Defaults to 'INFO'.
config will be used.
message_hub (MessageHub or dict, optional): A Messagehub object or a
dict to build MessageHub object. Defaults to None. If not
specified, default config will be used.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a writer (ComposedWriter or dict, optional): A ComposedWriter object or a
dict build ComposedWriter object. Defaults to None. If not dict build ComposedWriter object. Defaults to None. If not
specified, default config will be used. specified, default config will be used.
default_scope (str, optional): Used to reset registries location. default_scope (str, optional): Used to reset registries location.
Defaults to None. Defaults to None.
seed (int, optional): A number to set random modules. If not specified, randomness (dict): Some settings to make the experiment as reproducible
a random number will be set as seed. Defaults to None. as possible like seed and deterministic.
deterministic (bool): Whether cudnn to select deterministic algorithms. Defaults to ``dict(seed=None)``. If seed is None, a random number
Defaults to False. will be generated and it will be broadcasted to all other processes
See https://pytorch.org/docs/stable/notes/randomness.html for if in distributed environment. If ``cudnn_benchmarch`` is
more details. ``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
``randomness``, the value of ``torch.backends.cudnn.benchmark``
will be ``False`` finally.
experiment_name (str, optional): Name of current experiment. If not experiment_name (str, optional): Name of current experiment. If not
specified, timestamp will be used as ``experiment_name``. specified, timestamp will be used as ``experiment_name``.
Defaults to None. Defaults to None.
@ -173,13 +186,11 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')), param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none', launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')), env_cfg=dict(dist_cfg=dict(backend='nccl')),
logger=dict(log_level='INFO'),
message_hub=None,
writer=dict( writer=dict(
name='composed_writer', name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')]) writers=[dict(type='LocalWriter', save_dir='temp_dir')])
) )
>>> runner = Runner.build_from_cfg(cfg) >>> runner = Runner.from_cfg(cfg)
>>> runner.train() >>> runner.train()
>>> runner.test() >>> runner.test()
""" """
@ -208,19 +219,17 @@ class Runner:
resume: bool = False, resume: bool = False,
launcher: str = 'none', launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
logger: Optional[Union[MMLogger, Dict]] = None, log_level: str = 'INFO',
message_hub: Optional[Union[MessageHub, Dict]] = None,
writer: Optional[Union[ComposedWriter, Dict]] = None, writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None, default_scope: Optional[str] = None,
seed: Optional[int] = None, randomness: Dict = dict(seed=None),
deterministic: bool = False,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
cfg: Optional[ConfigType] = None, cfg: Optional[ConfigType] = None,
): ):
self._work_dir = osp.abspath(work_dir) self._work_dir = osp.abspath(work_dir)
mmengine.mkdir_or_exist(self._work_dir) mmengine.mkdir_or_exist(self._work_dir)
# recursively copy the ``cfg`` because `self.cfg` will be modified # recursively copy the `cfg` because `self.cfg` will be modified
# everywhere. # everywhere.
if cfg is not None: if cfg is not None:
self.cfg = copy.deepcopy(cfg) self.cfg = copy.deepcopy(cfg)
@ -231,21 +240,25 @@ class Runner:
self._iter = 0 self._iter = 0
# lazy initialization # lazy initialization
training_related = [ training_related = [train_dataloader, train_cfg, optimizer]
train_dataloader, train_cfg, optimizer, param_scheduler
]
if not (all(item is None for item in training_related) if not (all(item is None for item in training_related)
or all(item is not None for item in training_related)): or all(item is not None for item in training_related)):
raise ValueError( raise ValueError(
'train_dataloader, train_cfg, optimizer and param_scheduler ' 'train_dataloader, train_cfg, and optimizer should be either '
'should be either all None or not None, but got ' 'all None or not None, but got '
f'train_dataloader={train_dataloader}, ' f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, ' f'train_cfg={train_cfg}, '
f'optimizer={optimizer}, ' f'optimizer={optimizer}.')
f'param_scheduler={param_scheduler}.')
self.train_dataloader = train_dataloader self.train_dataloader = train_dataloader
self.train_loop = train_cfg self.train_loop = train_cfg
self.optimizer = optimizer self.optimizer = optimizer
# If there is no need to adjust learning rate, momentum or other
# parameters of optimizer, param_scheduler can be None
if param_scheduler is not None and self.optimizer is None:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence): if not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler] self.param_schedulers = [param_scheduler]
else: else:
@ -256,7 +269,7 @@ class Runner:
for item in val_related) or all(item is not None for item in val_related) or all(item is not None
for item in val_related)): for item in val_related)):
raise ValueError( raise ValueError(
'val_dataloader, val_cfg and val_evaluator should be either ' 'val_dataloader, val_cfg, and val_evaluator should be either '
'all None or not None, but got ' 'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, ' f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}') f'val_evaluator={val_evaluator}')
@ -268,8 +281,8 @@ class Runner:
if not (all(item is None for item in test_related) if not (all(item is None for item in test_related)
or all(item is not None for item in test_related)): or all(item is not None for item in test_related)):
raise ValueError( raise ValueError(
'test_dataloader, test_cfg and test_evaluator should be either' 'test_dataloader, test_cfg, and test_evaluator should be '
' all None or not None, but got ' 'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, ' f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}') f'test_evaluator={test_evaluator}')
self.test_dataloader = test_dataloader self.test_dataloader = test_dataloader
@ -282,10 +295,13 @@ class Runner:
else: else:
self._distributed = True self._distributed = True
# self._deterministic, self._seed and self._timestamp will be set in # self._timestamp will be set in the `setup_env` method. Besides,
# the `setup_env`` method. Besides, it also will initialize # it also will initialize multi-process and (or) distributed
# multi-process and (or) distributed environment. # environment.
self.setup_env(env_cfg, seed, deterministic) self.setup_env(env_cfg)
# self._deterministic and self._seed will be set in the
# `set_randomness`` method
self.set_randomness(**randomness)
if experiment_name is not None: if experiment_name is not None:
self._experiment_name = f'{experiment_name}_{self._timestamp}' self._experiment_name = f'{experiment_name}_{self._timestamp}'
@ -296,9 +312,9 @@ class Runner:
else: else:
self._experiment_name = self.timestamp self._experiment_name = self.timestamp
self.logger = self.build_logger(logger) self.logger = self.build_logger(log_level=log_level)
# message hub used for component interaction # message hub used for component interaction
self.message_hub = self.build_message_hub(message_hub) self.message_hub = self.build_message_hub()
# writer used for writing log or visualizing all kinds of data # writer used for writing log or visualizing all kinds of data
self.writer = self.build_writer(writer) self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for # Used to reset registries location. See :meth:`Registry.build` for
@ -333,7 +349,7 @@ class Runner:
self.dump_config() self.dump_config()
@classmethod @classmethod
def build_from_cfg(cls, cfg: ConfigType) -> 'Runner': def from_cfg(cls, cfg: ConfigType) -> 'Runner':
"""Build a runner from config. """Build a runner from config.
Args: Args:
@ -363,12 +379,11 @@ class Runner:
resume=cfg.get('resume', False), resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'), launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore env_cfg=cfg.get('env_cfg'), # type: ignore
logger=cfg.get('log_cfg'), log_level=cfg.get('log_level', 'INFO'),
message_hub=cfg.get('message_hub'),
writer=cfg.get('writer'), writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'), default_scope=cfg.get('default_scope'),
seed=cfg.get('seed'), randomness=cfg.get('randomness', dict(seed=None)),
deterministic=cfg.get('deterministic', False), experiment_name=cfg.get('experiment_name'),
cfg=cfg, cfg=cfg,
) )
@ -439,10 +454,7 @@ class Runner:
"""list[:obj:`Hook`]: A list of registered hooks.""" """list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks return self._hooks
def setup_env(self, def setup_env(self, env_cfg: Dict) -> None:
env_cfg: Dict,
seed: Optional[int],
deterministic: bool = False) -> None:
"""Setup environment. """Setup environment.
An example of ``env_cfg``:: An example of ``env_cfg``::
@ -458,17 +470,7 @@ class Runner:
Args: Args:
env_cfg (dict): Config for setting environment. env_cfg (dict): Config for setting environment.
seed (int, optional): A number to set random modules. If not
specified, a random number will be set as seed.
Defaults to None.
deterministic (bool): Whether cudnn to select deterministic
algorithms. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
""" """
self._deterministic = deterministic
self._seed = seed
if env_cfg.get('cudnn_benchmark'): if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -490,9 +492,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S', self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item())) time.localtime(timestamp.item()))
# set random seeds
self._set_random_seed()
def _set_multi_processing(self, def _set_multi_processing(self,
mp_start_method: str = 'fork', mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None: opencv_num_threads: int = 0) -> None:
@ -546,16 +545,20 @@ class Runner:
'optimal performance in your application as needed.') 'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def _set_random_seed(self) -> None: def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results. """Set random seed to guarantee reproducible results.
Warning: Args:
Results can not be guaranteed to resproducible if ``self.seed`` is seed (int): A number to set random modules.
None because :meth:`_set_random_seed` will generate a random seed deterministic (bool): Whether to set the deterministic option for
when launching a new experiment. CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
See https://pytorch.org/docs/stable/notes/randomness.html for details. Defaults to False.
See https://pytorch.org/docs/stable/notes/randomness.html for
more details.
""" """
self._deterministic = deterministic
self._seed = seed
if self._seed is None: if self._seed is None:
self._seed = sync_random_seed() self._seed = sync_random_seed()
@ -563,69 +566,62 @@ class Runner:
np.random.seed(self._seed) np.random.seed(self._seed)
torch.manual_seed(self._seed) torch.manual_seed(self._seed)
torch.cuda.manual_seed_all(self._seed) torch.cuda.manual_seed_all(self._seed)
if self._deterministic: if deterministic:
torch.backends.cudnn.deterministic = True if torch.backends.cudnn.benchmark:
warnings.warn(
'torch.backends.cudnn.benchmark is going to be set as '
'`False` to cause cuDNN to deterministically select an '
'algorithm')
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
torch.use_deterministic_algorithms(True)
def build_logger(self, def build_logger(self,
logger: Optional[Union[MMLogger, log_level: Union[int, str] = 'INFO',
Dict]] = None) -> MMLogger: log_file: str = None,
**kwargs) -> MMLogger:
"""Build a global asscessable MMLogger. """Build a global asscessable MMLogger.
Args: Args:
logger (MMLogger or dict, optional): A MMLogger object or a dict to log_level (int or str): The log level of MMLogger handlers.
build MMLogger object. If ``logger`` is a MMLogger object, just Defaults to 'INFO'.
returns itself. If not specified, default config will be used log_file (str, optional): Path of filename to save log.
to build MMLogger object. Defaults to None. Defaults to None.
**kwargs: Remaining parameters passed to ``MMLogger``.
Returns: Returns:
MMLogger: A MMLogger object build from ``logger``. MMLogger: A MMLogger object build from ``logger``.
""" """
if isinstance(logger, MMLogger): if log_file is None:
return logger log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
elif logger is None:
logger = dict(
name=self._experiment_name,
log_level='INFO',
log_file=osp.join(self.work_dir,
f'{self._experiment_name}.log'))
elif isinstance(logger, dict):
# ensure logger containing name key
logger.setdefault('name', self._experiment_name)
else:
raise TypeError(
'logger should be MMLogger object, a dict or None, '
f'but got {logger}')
return MMLogger.get_instance(**logger) log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
log_cfg.setdefault('name', self._experiment_name)
def build_message_hub( return MMLogger.get_instance(**log_cfg) # type: ignore
self,
message_hub: Optional[Union[MessageHub, def build_message_hub(self,
Dict]] = None) -> MessageHub: message_hub: Optional[Dict] = None) -> MessageHub:
"""Build a global asscessable MessageHub. """Build a global asscessable MessageHub.
Args: Args:
message_hub (MessageHub or dict, optional): A MessageHub object or message_hub (dict, optional): A dict to build MessageHub object.
a dict to build MessageHub object. If ``message_hub`` is a If not specified, default config will be used to build
MessageHub object, just returns itself. If not specified, MessageHub object. Defaults to None.
default config will be used to build MessageHub object.
Defaults to None.
Returns: Returns:
MessageHub: A MessageHub object build from ``message_hub``. MessageHub: A MessageHub object build from ``message_hub``.
""" """
if isinstance(message_hub, MessageHub): if message_hub is None:
return message_hub
elif message_hub is None:
message_hub = dict(name=self._experiment_name) message_hub = dict(name=self._experiment_name)
elif isinstance(message_hub, dict): elif isinstance(message_hub, dict):
# ensure message_hub containing name key # ensure message_hub containing name key
message_hub.setdefault('name', self._experiment_name) message_hub.setdefault('name', self._experiment_name)
else: else:
raise TypeError( raise TypeError(
'message_hub should be MessageHub object, a dict or None, ' f'message_hub should be dict or None, but got {message_hub}')
f'but got {message_hub}')
return MessageHub.get_instance(**message_hub) return MessageHub.get_instance(**message_hub)

View File

@ -1,10 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import logging
import os.path as osp import os.path as osp
import shutil import shutil
import tempfile import tempfile
import time
from unittest import TestCase from unittest import TestCase
import torch import torch
@ -229,8 +227,6 @@ class TestRunner(TestCase):
optimizer=dict(type='OptimizerHook', grad_clip=None), optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook')) param_scheduler=dict(type='ParamSchedulerHook'))
time.sleep(1)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.temp_dir) shutil.rmtree(self.temp_dir)
@ -238,89 +234,99 @@ class TestRunner(TestCase):
# 1. test arguments # 1. test arguments
# 1.1 train_dataloader, train_cfg, optimizer and param_scheduler # 1.1 train_dataloader, train_cfg, optimizer and param_scheduler
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init1'
cfg.pop('train_cfg') cfg.pop('train_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'): with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg) Runner(**cfg)
# all of training related configs are None # all of training related configs are None and param_scheduler should
# also be None
cfg.experiment_name = 'test_init2'
cfg.pop('train_dataloader') cfg.pop('train_dataloader')
cfg.pop('optimizer') cfg.pop('optimizer')
cfg.pop('param_scheduler') cfg.pop('param_scheduler')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
# avoid different runners having same timestamp
time.sleep(1)
# all of training related configs are not None # all of training related configs are not None
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init3'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
# all of training related configs are not None and param_scheduler
# can be None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init4'
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
# param_scheduler should be None when optimizer is None
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init5'
cfg.pop('train_cfg')
cfg.pop('train_dataloader')
cfg.pop('optimizer')
with self.assertRaisesRegex(ValueError, 'should be None'):
runner = Runner(**cfg)
# 1.2 val_dataloader, val_evaluator, val_cfg # 1.2 val_dataloader, val_evaluator, val_cfg
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init6'
cfg.pop('val_cfg') cfg.pop('val_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'): with self.assertRaisesRegex(ValueError, 'either all None or not None'):
Runner(**cfg) Runner(**cfg)
time.sleep(1) cfg.experiment_name = 'test_init7'
cfg.pop('val_dataloader') cfg.pop('val_dataloader')
cfg.pop('val_evaluator') cfg.pop('val_evaluator')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
time.sleep(1)
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init8'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
# 1.3 test_dataloader, test_evaluator and test_cfg # 1.3 test_dataloader, test_evaluator and test_cfg
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init9'
cfg.pop('test_cfg') cfg.pop('test_cfg')
with self.assertRaisesRegex(ValueError, 'either all None or not None'): with self.assertRaisesRegex(ValueError, 'either all None or not None'):
runner = Runner(**cfg) runner = Runner(**cfg)
time.sleep(1) cfg.experiment_name = 'test_init10'
cfg.pop('test_dataloader') cfg.pop('test_dataloader')
cfg.pop('test_evaluator') cfg.pop('test_evaluator')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
time.sleep(1)
# 1.4 test env params # 1.4 test env params
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init11'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertFalse(runner.distributed) self.assertFalse(runner.distributed)
self.assertFalse(runner.deterministic) self.assertFalse(runner.deterministic)
time.sleep(1)
# 1.5 message_hub, logger and writer # 1.5 message_hub, logger and writer
# they are all not specified # they are all not specified
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init12'
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub) self.assertIsInstance(runner.message_hub, MessageHub)
self.assertIsInstance(runner.writer, ComposedWriter) self.assertIsInstance(runner.writer, ComposedWriter)
time.sleep(1)
# they are all specified # they are all specified
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.logger = dict(name='test_logger') cfg.experiment_name = 'test_init13'
cfg.message_hub = dict(name='test_message_hub') cfg.log_level = 'INFO'
cfg.writer = dict(name='test_writer') cfg.writer = dict(name='test_writer')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.logger.instance_name, 'test_logger')
self.assertIsInstance(runner.message_hub, MessageHub) self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name, 'test_message_hub')
self.assertIsInstance(runner.writer, ComposedWriter) self.assertIsInstance(runner.writer, ComposedWriter)
self.assertEqual(runner.writer.instance_name, 'test_writer')
assert runner.distributed is False assert runner.distributed is False
assert runner.seed is not None assert runner.seed is not None
@ -358,7 +364,6 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.test_loop.dataloader, DataLoader) self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1) self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
time.sleep(1)
# 4. initialize runner with objects rather than config # 4. initialize runner with objects rather than config
model = ToyModel() model = ToyModel()
optimizer = SGD( optimizer = SGD(
@ -385,84 +390,66 @@ class TestRunner(TestCase):
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
test_evaluator=ToyEvaluator1(), test_evaluator=ToyEvaluator1(),
default_hooks=dict(param_scheduler=toy_hook), default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2]) custom_hooks=[toy_hook2],
experiment_name='test_init14')
runner.train() runner.train()
runner.test() runner.test()
# 5. test `dump_config` # 5. test `dump_config`
# TODO # TODO
def test_build_from_cfg(self): def test_from_cfg(self):
runner = Runner.build_from_cfg(cfg=self.epoch_based_cfg) runner = Runner.from_cfg(cfg=self.epoch_based_cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
def test_setup_env(self): def test_setup_env(self):
# TODO # TODO
pass pass
def test_logger(self): def test_build_logger(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) self.epoch_based_cfg.experiment_name = 'test_build_logger1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.logger, MMLogger) self.assertIsInstance(runner.logger, MMLogger)
self.assertEqual(runner.experiment_name, runner.logger.instance_name) self.assertEqual(runner.experiment_name, runner.logger.instance_name)
self.assertEqual(runner.logger.level, logging.NOTSET)
# input is a MMLogger object
self.assertEqual(
id(runner.build_logger(runner.logger)), id(runner.logger))
# input is None
runner._experiment_name = 'logger_name1'
logger = runner.build_logger(None)
self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name1')
# input is a dict # input is a dict
log_cfg = dict(name='logger_name2') logger = runner.build_logger(name='test_build_logger2')
logger = runner.build_logger(log_cfg)
self.assertIsInstance(logger, MMLogger) self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name2') self.assertEqual(logger.instance_name, 'test_build_logger2')
# input is a dict but does not contain name key # input is a dict but does not contain name key
runner._experiment_name = 'logger_name3' runner._experiment_name = 'test_build_logger3'
log_cfg = dict() logger = runner.build_logger()
logger = runner.build_logger(log_cfg)
self.assertIsInstance(logger, MMLogger) self.assertIsInstance(logger, MMLogger)
self.assertEqual(logger.instance_name, 'logger_name3') self.assertEqual(logger.instance_name, 'test_build_logger3')
# input is not a valid type
with self.assertRaisesRegex(TypeError, 'logger should be'):
runner.build_logger('invalid-type')
def test_build_message_hub(self): def test_build_message_hub(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) self.epoch_based_cfg.experiment_name = 'test_build_message_hub1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.message_hub, MessageHub) self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.instance_name, self.assertEqual(runner.message_hub.instance_name,
runner.experiment_name) runner.experiment_name)
# input is a MessageHub object
self.assertEqual(
id(runner.build_message_hub(runner.message_hub)),
id(runner.message_hub))
# input is a dict # input is a dict
message_hub_cfg = dict(name='message_hub_name1') message_hub_cfg = dict(name='test_build_message_hub2')
message_hub = runner.build_message_hub(message_hub_cfg) message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub) self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'message_hub_name1') self.assertEqual(message_hub.instance_name, 'test_build_message_hub2')
# input is a dict but does not contain name key # input is a dict but does not contain name key
runner._experiment_name = 'message_hub_name2' runner._experiment_name = 'test_build_message_hub3'
message_hub_cfg = dict() message_hub_cfg = dict()
message_hub = runner.build_message_hub(message_hub_cfg) message_hub = runner.build_message_hub(message_hub_cfg)
self.assertIsInstance(message_hub, MessageHub) self.assertIsInstance(message_hub, MessageHub)
self.assertEqual(message_hub.instance_name, 'message_hub_name2') self.assertEqual(message_hub.instance_name, 'test_build_message_hub3')
# input is not a valid type # input is not a valid type
with self.assertRaisesRegex(TypeError, 'message_hub should be'): with self.assertRaisesRegex(TypeError, 'message_hub should be'):
runner.build_message_hub('invalid-type') runner.build_message_hub('invalid-type')
def test_build_writer(self): def test_build_writer(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) self.epoch_based_cfg.experiment_name = 'test_build_writer1'
runner = Runner.from_cfg(self.epoch_based_cfg)
self.assertIsInstance(runner.writer, ComposedWriter) self.assertIsInstance(runner.writer, ComposedWriter)
self.assertEqual(runner.experiment_name, runner.writer.instance_name) self.assertEqual(runner.experiment_name, runner.writer.instance_name)
@ -471,17 +458,17 @@ class TestRunner(TestCase):
id(runner.build_writer(runner.writer)), id(runner.writer)) id(runner.build_writer(runner.writer)), id(runner.writer))
# input is a dict # input is a dict
writer_cfg = dict(name='writer_name1') writer_cfg = dict(name='test_build_writer2')
writer = runner.build_writer(writer_cfg) writer = runner.build_writer(writer_cfg)
self.assertIsInstance(writer, ComposedWriter) self.assertIsInstance(writer, ComposedWriter)
self.assertEqual(writer.instance_name, 'writer_name1') self.assertEqual(writer.instance_name, 'test_build_writer2')
# input is a dict but does not contain name key # input is a dict but does not contain name key
runner._experiment_name = 'writer_name2' runner._experiment_name = 'test_build_writer3'
writer_cfg = dict() writer_cfg = dict()
writer = runner.build_writer(writer_cfg) writer = runner.build_writer(writer_cfg)
self.assertIsInstance(writer, ComposedWriter) self.assertIsInstance(writer, ComposedWriter)
self.assertEqual(writer.instance_name, 'writer_name2') self.assertEqual(writer.instance_name, 'test_build_writer3')
# input is not a valid type # input is not a valid type
with self.assertRaisesRegex(TypeError, 'writer should be'): with self.assertRaisesRegex(TypeError, 'writer should be'):
@ -501,12 +488,16 @@ class TestRunner(TestCase):
type='ToyScheduler', milestones=[1, 2]) type='ToyScheduler', milestones=[1, 2])
self.epoch_based_cfg.default_scope = 'toy' self.epoch_based_cfg.default_scope = 'toy'
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_scope'
runner = Runner.from_cfg(cfg)
runner.train() runner.train()
self.assertIsInstance(runner.param_schedulers[0], ToyScheduler) self.assertIsInstance(runner.param_schedulers[0], ToyScheduler)
def test_build_model(self): def test_build_model(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_model'
runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, ToyModel) self.assertIsInstance(runner.model, ToyModel)
# input should be a nn.Module object or dict # input should be a nn.Module object or dict
@ -526,12 +517,15 @@ class TestRunner(TestCase):
# TODO: test on distributed environment # TODO: test on distributed environment
# custom model wrapper # custom model wrapper
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_wrap_model'
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper') cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
self.assertIsInstance(runner.model, CustomModelWrapper) self.assertIsInstance(runner.model, CustomModelWrapper)
def test_build_optimizer(self): def test_build_optimizer(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_optimizer'
runner = Runner.from_cfg(cfg)
# input should be an Optimizer object or dict # input should be an Optimizer object or dict
with self.assertRaisesRegex(TypeError, 'optimizer should be'): with self.assertRaisesRegex(TypeError, 'optimizer should be'):
@ -547,7 +541,9 @@ class TestRunner(TestCase):
self.assertIsInstance(optimizer, SGD) self.assertIsInstance(optimizer, SGD)
def test_build_param_scheduler(self): def test_build_param_scheduler(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_param_scheduler'
runner = Runner.from_cfg(cfg)
# `build_optimizer` should be called before `build_param_scheduler` # `build_optimizer` should be called before `build_param_scheduler`
cfg = dict(type='MultiStepLR', milestones=[1, 2]) cfg = dict(type='MultiStepLR', milestones=[1, 2])
@ -584,7 +580,9 @@ class TestRunner(TestCase):
self.assertIsInstance(param_schedulers[1], StepLR) self.assertIsInstance(param_schedulers[1], StepLR)
def test_build_evaluator(self): def test_build_evaluator(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_evaluator'
runner = Runner.from_cfg(cfg)
# input is a BaseEvaluator or ComposedEvaluator object # input is a BaseEvaluator or ComposedEvaluator object
evaluator = ToyEvaluator1() evaluator = ToyEvaluator1()
@ -603,7 +601,9 @@ class TestRunner(TestCase):
runner.build_evaluator(evaluator), ComposedEvaluator) runner.build_evaluator(evaluator), ComposedEvaluator)
def test_build_dataloader(self): def test_build_dataloader(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_dataloader'
runner = Runner.from_cfg(cfg)
cfg = dict( cfg = dict(
dataset=dict(type='ToyDataset'), dataset=dict(type='ToyDataset'),
@ -616,8 +616,11 @@ class TestRunner(TestCase):
self.assertIsInstance(dataloader.sampler, DefaultSampler) self.assertIsInstance(dataloader.sampler, DefaultSampler)
def test_build_train_loop(self): def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict # input should be a Loop object or dict
runner = Runner.build_from_cfg(self.epoch_based_cfg)
with self.assertRaisesRegex(TypeError, 'should be'): with self.assertRaisesRegex(TypeError, 'should be'):
runner.build_train_loop('invalid-type') runner.build_train_loop('invalid-type')
@ -653,7 +656,9 @@ class TestRunner(TestCase):
self.assertIsInstance(loop, CustomTrainLoop) self.assertIsInstance(loop, CustomTrainLoop)
def test_build_val_loop(self): def test_build_val_loop(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_val_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict # input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'): with self.assertRaisesRegex(TypeError, 'should be'):
@ -678,7 +683,9 @@ class TestRunner(TestCase):
self.assertIsInstance(loop, CustomValLoop) self.assertIsInstance(loop, CustomValLoop)
def test_build_test_loop(self): def test_build_test_loop(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_test_loop'
runner = Runner.from_cfg(cfg)
# input should be a Loop object or dict # input should be a Loop object or dict
with self.assertRaisesRegex(TypeError, 'should be'): with self.assertRaisesRegex(TypeError, 'should be'):
@ -705,16 +712,15 @@ class TestRunner(TestCase):
def test_train(self): def test_train(self):
# 1. test `self.train_loop` is None # 1. test `self.train_loop` is None
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train1'
cfg.pop('train_dataloader') cfg.pop('train_dataloader')
cfg.pop('train_cfg') cfg.pop('train_cfg')
cfg.pop('optimizer') cfg.pop('optimizer')
cfg.pop('param_scheduler') cfg.pop('param_scheduler')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'): with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.train() runner.train()
time.sleep(1)
# 2. test iter and epoch counter of EpochBasedTrainLoop # 2. test iter and epoch counter of EpochBasedTrainLoop
epoch_results = [] epoch_results = []
epoch_targets = [i for i in range(3)] epoch_targets = [i for i in range(3)]
@ -733,10 +739,10 @@ class TestRunner(TestCase):
iter_results.append(runner.iter) iter_results.append(runner.iter)
batch_idx_results.append(batch_idx) batch_idx_results.append(batch_idx)
self.epoch_based_cfg.custom_hooks = [ cfg = copy.deepcopy(self.epoch_based_cfg)
dict(type='TestEpochHook', priority=50) cfg.experiment_name = 'test_train2'
] cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
runner = Runner.build_from_cfg(self.epoch_based_cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
@ -749,8 +755,6 @@ class TestRunner(TestCase):
for result, target, in zip(batch_idx_results, batch_idx_targets): for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
time.sleep(1)
# 3. test iter and epoch counter of IterBasedTrainLoop # 3. test iter and epoch counter of IterBasedTrainLoop
epoch_results = [] epoch_results = []
iter_results = [] iter_results = []
@ -768,11 +772,11 @@ class TestRunner(TestCase):
iter_results.append(runner.iter) iter_results.append(runner.iter)
batch_idx_results.append(batch_idx) batch_idx_results.append(batch_idx)
self.iter_based_cfg.custom_hooks = [ cfg = copy.deepcopy(self.iter_based_cfg)
dict(type='TestIterHook', priority=50) cfg.experiment_name = 'test_train3'
] cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
self.iter_based_cfg.val_cfg = dict(interval=4) cfg.val_cfg = dict(interval=4)
runner = Runner.build_from_cfg(self.iter_based_cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop) assert isinstance(runner.train_loop, IterBasedTrainLoop)
@ -786,32 +790,38 @@ class TestRunner(TestCase):
def test_val(self): def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'
cfg.pop('val_dataloader') cfg.pop('val_dataloader')
cfg.pop('val_cfg') cfg.pop('val_cfg')
cfg.pop('val_evaluator') cfg.pop('val_evaluator')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'): with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.val() runner.val()
time.sleep(1) cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg.experiment_name = 'test_val2'
runner = Runner.from_cfg(cfg)
runner.val() runner.val()
def test_test(self): def test_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test1'
cfg.pop('test_dataloader') cfg.pop('test_dataloader')
cfg.pop('test_cfg') cfg.pop('test_cfg')
cfg.pop('test_evaluator') cfg.pop('test_evaluator')
runner = Runner.build_from_cfg(cfg) runner = Runner.from_cfg(cfg)
with self.assertRaisesRegex(RuntimeError, 'should not be None'): with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.test() runner.test()
time.sleep(1) cfg = copy.deepcopy(self.epoch_based_cfg)
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg.experiment_name = 'test_test2'
runner = Runner.from_cfg(cfg)
runner.test() runner.test()
def test_register_hook(self): def test_register_hook(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hook'
runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
# 1. test `hook` parameter # 1. test `hook` parameter
@ -870,7 +880,9 @@ class TestRunner(TestCase):
get_priority(runner._hooks[3].priority), get_priority('VERY_LOW')) get_priority(runner._hooks[3].priority), get_priority('VERY_LOW'))
def test_default_hooks(self): def test_default_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_default_hooks'
runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
# register five hooks by default # register five hooks by default
@ -893,7 +905,10 @@ class TestRunner(TestCase):
self.assertTrue(isinstance(runner._hooks[5], ToyHook)) self.assertTrue(isinstance(runner._hooks[5], ToyHook))
def test_custom_hooks(self): def test_custom_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_custom_hooks'
runner = Runner.from_cfg(cfg)
self.assertEqual(len(runner._hooks), 5) self.assertEqual(len(runner._hooks), 5)
custom_hooks = [dict(type='ToyHook')] custom_hooks = [dict(type='ToyHook')]
runner.register_custom_hooks(custom_hooks) runner.register_custom_hooks(custom_hooks)
@ -901,7 +916,10 @@ class TestRunner(TestCase):
self.assertTrue(isinstance(runner._hooks[5], ToyHook)) self.assertTrue(isinstance(runner._hooks[5], ToyHook))
def test_register_hooks(self): def test_register_hooks(self):
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hooks'
runner = Runner.from_cfg(cfg)
runner._hooks = [] runner._hooks = []
custom_hooks = [dict(type='ToyHook')] custom_hooks = [dict(type='ToyHook')]
runner.register_hooks(custom_hooks=custom_hooks) runner.register_hooks(custom_hooks=custom_hooks)
@ -975,7 +993,8 @@ class TestRunner(TestCase):
self.iter_based_cfg.custom_hooks = [ self.iter_based_cfg.custom_hooks = [
dict(type='TestWarmupHook', priority=50) dict(type='TestWarmupHook', priority=50)
] ]
runner = Runner.build_from_cfg(self.iter_based_cfg) self.iter_based_cfg.experiment_name = 'test_custom_loop'
runner = Runner.from_cfg(self.iter_based_cfg)
runner.train() runner.train()
self.assertIsInstance(runner.train_loop, CustomTrainLoop2) self.assertIsInstance(runner.train_loop, CustomTrainLoop2)
@ -990,7 +1009,9 @@ class TestRunner(TestCase):
def test_checkpoint(self): def test_checkpoint(self):
# 1. test epoch based # 1. test epoch based
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint1'
runner = Runner.from_cfg(cfg)
runner.train() runner.train()
# 1.1 test `save_checkpoint` which called by `CheckpointHook` # 1.1 test `save_checkpoint` which called by `CheckpointHook`
@ -1006,17 +1027,19 @@ class TestRunner(TestCase):
assert isinstance(ckpt['optimizer'], dict) assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list) assert isinstance(ckpt['param_schedulers'], list)
time.sleep(1)
# 1.2 test `load_checkpoint` # 1.2 test `load_checkpoint`
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint2'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path) runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0) self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded) self.assertTrue(runner._has_loaded)
time.sleep(1)
# 1.3 test `resume` # 1.3 test `resume`
runner = Runner.build_from_cfg(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_checkpoint3'
runner = Runner.from_cfg(cfg)
runner.resume(path) runner.resume(path)
self.assertEqual(runner.epoch, 3) self.assertEqual(runner.epoch, 3)
self.assertEqual(runner.iter, 12) self.assertEqual(runner.iter, 12)
@ -1025,8 +1048,9 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2. test iter based # 2. test iter based
time.sleep(1) cfg = copy.deepcopy(self.iter_based_cfg)
runner = Runner.build_from_cfg(self.iter_based_cfg) cfg.experiment_name = 'test_checkpoint4'
runner = Runner.from_cfg(cfg)
runner.train() runner.train()
# 2.1 test `save_checkpoint` which called by `CheckpointHook` # 2.1 test `save_checkpoint` which called by `CheckpointHook`
@ -1043,16 +1067,18 @@ class TestRunner(TestCase):
assert isinstance(ckpt['param_schedulers'], list) assert isinstance(ckpt['param_schedulers'], list)
# 2.2 test `load_checkpoint` # 2.2 test `load_checkpoint`
time.sleep(1) cfg = copy.deepcopy(self.iter_based_cfg)
runner = Runner.build_from_cfg(self.iter_based_cfg) cfg.experiment_name = 'test_checkpoint5'
runner = Runner.from_cfg(cfg)
runner.load_checkpoint(path) runner.load_checkpoint(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 0) self.assertEqual(runner.iter, 0)
self.assertTrue(runner._has_loaded) self.assertTrue(runner._has_loaded)
time.sleep(1)
# 2.3 test `resume` # 2.3 test `resume`
runner = Runner.build_from_cfg(self.iter_based_cfg) cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint6'
runner = Runner.from_cfg(cfg)
runner.resume(path) runner.resume(path)
self.assertEqual(runner.epoch, 0) self.assertEqual(runner.epoch, 0)
self.assertEqual(runner.iter, 12) self.assertEqual(runner.iter, 12)