mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Refactor Runner (#139)
* [Enhancement] Rename build_from_cfg to from_cfg * refactor build_logger and build_message_hub * remove time.sleep from unit tests * minor fix * move set_randomness from setup_env * improve docstring * refine comments * print a warning information * refine comments * simplify the interface of build_logger
This commit is contained in:
parent
9a61b389e7
commit
f1de071cf0
@ -32,7 +32,8 @@ from mmengine.optim import _ParamScheduler, build_optimizer
|
|||||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||||
DefaultScope)
|
DefaultScope)
|
||||||
from mmengine.utils import find_latest_checkpoint, is_list_of, symlink
|
from mmengine.utils import (TORCH_VERSION, digit_version,
|
||||||
|
find_latest_checkpoint, is_list_of, symlink)
|
||||||
from mmengine.visualization import ComposedWriter
|
from mmengine.visualization import ComposedWriter
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||||
@ -47,6 +48,20 @@ ConfigType = Union[Dict, Config, ConfigDict]
|
|||||||
class Runner:
|
class Runner:
|
||||||
"""A training helper for PyTorch.
|
"""A training helper for PyTorch.
|
||||||
|
|
||||||
|
Runner object can be built from config by ``runner = Runner.from_cfg(cfg)``
|
||||||
|
where the ``cfg`` usually contains training, validation, and test-related
|
||||||
|
configurations to build corresponding components. We usually use the
|
||||||
|
same config to launch training, testing, and validation tasks. However,
|
||||||
|
only some of these components are necessary at the same time, e.g.,
|
||||||
|
testing a model does not need training or validation-related components.
|
||||||
|
|
||||||
|
To avoid repeatedly modifying config, the construction of ``Runner`` adopts
|
||||||
|
lazy initialization to only initialize components when they are going to be
|
||||||
|
used. Therefore, the model is always initialized at the beginning, and
|
||||||
|
training, validation, and, testing related components are only initialized
|
||||||
|
when calling ``runner.train()``, ``runner.val()``, and ``runner.test()``,
|
||||||
|
respectively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
|
model (:obj:`torch.nn.Module` or dict): The model to be run. It can be
|
||||||
a dict used for build a model.
|
a dict used for build a model.
|
||||||
@ -114,23 +129,21 @@ class Runner:
|
|||||||
non-distributed environment will be launched.
|
non-distributed environment will be launched.
|
||||||
env_cfg (dict): A dict used for setting environment. Defaults to
|
env_cfg (dict): A dict used for setting environment. Defaults to
|
||||||
dict(dist_cfg=dict(backend='nccl')).
|
dict(dist_cfg=dict(backend='nccl')).
|
||||||
logger (MMLogger or dict, optional): A MMLogger object or a dict to
|
log_level (int or str): The log level of MMLogger handlers.
|
||||||
build MMLogger object. Defaults to None. If not specified, default
|
Defaults to 'INFO'.
|
||||||
config will be used.
|
|
||||||
message_hub (MessageHub or dict, optional): A Messagehub object or a
|
|
||||||
dict to build MessageHub object. Defaults to None. If not
|
|
||||||
specified, default config will be used.
|
|
||||||
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
|
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
|
||||||
dict build ComposedWriter object. Defaults to None. If not
|
dict build ComposedWriter object. Defaults to None. If not
|
||||||
specified, default config will be used.
|
specified, default config will be used.
|
||||||
default_scope (str, optional): Used to reset registries location.
|
default_scope (str, optional): Used to reset registries location.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
seed (int, optional): A number to set random modules. If not specified,
|
randomness (dict): Some settings to make the experiment as reproducible
|
||||||
a random number will be set as seed. Defaults to None.
|
as possible like seed and deterministic.
|
||||||
deterministic (bool): Whether cudnn to select deterministic algorithms.
|
Defaults to ``dict(seed=None)``. If seed is None, a random number
|
||||||
Defaults to False.
|
will be generated and it will be broadcasted to all other processes
|
||||||
See https://pytorch.org/docs/stable/notes/randomness.html for
|
if in distributed environment. If ``cudnn_benchmarch`` is
|
||||||
more details.
|
``True`` in ``env_cfg`` but ``deterministic`` is ``True`` in
|
||||||
|
``randomness``, the value of ``torch.backends.cudnn.benchmark``
|
||||||
|
will be ``False`` finally.
|
||||||
experiment_name (str, optional): Name of current experiment. If not
|
experiment_name (str, optional): Name of current experiment. If not
|
||||||
specified, timestamp will be used as ``experiment_name``.
|
specified, timestamp will be used as ``experiment_name``.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
@ -173,13 +186,11 @@ class Runner:
|
|||||||
param_scheduler=dict(type='ParamSchedulerHook')),
|
param_scheduler=dict(type='ParamSchedulerHook')),
|
||||||
launcher='none',
|
launcher='none',
|
||||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||||
logger=dict(log_level='INFO'),
|
|
||||||
message_hub=None,
|
|
||||||
writer=dict(
|
writer=dict(
|
||||||
name='composed_writer',
|
name='composed_writer',
|
||||||
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
|
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
|
||||||
)
|
)
|
||||||
>>> runner = Runner.build_from_cfg(cfg)
|
>>> runner = Runner.from_cfg(cfg)
|
||||||
>>> runner.train()
|
>>> runner.train()
|
||||||
>>> runner.test()
|
>>> runner.test()
|
||||||
"""
|
"""
|
||||||
@ -208,19 +219,17 @@ class Runner:
|
|||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
launcher: str = 'none',
|
launcher: str = 'none',
|
||||||
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
|
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
|
||||||
logger: Optional[Union[MMLogger, Dict]] = None,
|
log_level: str = 'INFO',
|
||||||
message_hub: Optional[Union[MessageHub, Dict]] = None,
|
|
||||||
writer: Optional[Union[ComposedWriter, Dict]] = None,
|
writer: Optional[Union[ComposedWriter, Dict]] = None,
|
||||||
default_scope: Optional[str] = None,
|
default_scope: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
randomness: Dict = dict(seed=None),
|
||||||
deterministic: bool = False,
|
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
cfg: Optional[ConfigType] = None,
|
cfg: Optional[ConfigType] = None,
|
||||||
):
|
):
|
||||||
self._work_dir = osp.abspath(work_dir)
|
self._work_dir = osp.abspath(work_dir)
|
||||||
mmengine.mkdir_or_exist(self._work_dir)
|
mmengine.mkdir_or_exist(self._work_dir)
|
||||||
|
|
||||||
# recursively copy the ``cfg`` because `self.cfg` will be modified
|
# recursively copy the `cfg` because `self.cfg` will be modified
|
||||||
# everywhere.
|
# everywhere.
|
||||||
if cfg is not None:
|
if cfg is not None:
|
||||||
self.cfg = copy.deepcopy(cfg)
|
self.cfg = copy.deepcopy(cfg)
|
||||||
@ -231,21 +240,25 @@ class Runner:
|
|||||||
self._iter = 0
|
self._iter = 0
|
||||||
|
|
||||||
# lazy initialization
|
# lazy initialization
|
||||||
training_related = [
|
training_related = [train_dataloader, train_cfg, optimizer]
|
||||||
train_dataloader, train_cfg, optimizer, param_scheduler
|
|
||||||
]
|
|
||||||
if not (all(item is None for item in training_related)
|
if not (all(item is None for item in training_related)
|
||||||
or all(item is not None for item in training_related)):
|
or all(item is not None for item in training_related)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'train_dataloader, train_cfg, optimizer and param_scheduler '
|
'train_dataloader, train_cfg, and optimizer should be either '
|
||||||
'should be either all None or not None, but got '
|
'all None or not None, but got '
|
||||||
f'train_dataloader={train_dataloader}, '
|
f'train_dataloader={train_dataloader}, '
|
||||||
f'train_cfg={train_cfg}, '
|
f'train_cfg={train_cfg}, '
|
||||||
f'optimizer={optimizer}, '
|
f'optimizer={optimizer}.')
|
||||||
f'param_scheduler={param_scheduler}.')
|
|
||||||
self.train_dataloader = train_dataloader
|
self.train_dataloader = train_dataloader
|
||||||
self.train_loop = train_cfg
|
self.train_loop = train_cfg
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
# If there is no need to adjust learning rate, momentum or other
|
||||||
|
# parameters of optimizer, param_scheduler can be None
|
||||||
|
if param_scheduler is not None and self.optimizer is None:
|
||||||
|
raise ValueError(
|
||||||
|
'param_scheduler should be None when optimizer is None, '
|
||||||
|
f'but got {param_scheduler}')
|
||||||
if not isinstance(param_scheduler, Sequence):
|
if not isinstance(param_scheduler, Sequence):
|
||||||
self.param_schedulers = [param_scheduler]
|
self.param_schedulers = [param_scheduler]
|
||||||
else:
|
else:
|
||||||
@ -256,7 +269,7 @@ class Runner:
|
|||||||
for item in val_related) or all(item is not None
|
for item in val_related) or all(item is not None
|
||||||
for item in val_related)):
|
for item in val_related)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'val_dataloader, val_cfg and val_evaluator should be either '
|
'val_dataloader, val_cfg, and val_evaluator should be either '
|
||||||
'all None or not None, but got '
|
'all None or not None, but got '
|
||||||
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
|
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
|
||||||
f'val_evaluator={val_evaluator}')
|
f'val_evaluator={val_evaluator}')
|
||||||
@ -268,8 +281,8 @@ class Runner:
|
|||||||
if not (all(item is None for item in test_related)
|
if not (all(item is None for item in test_related)
|
||||||
or all(item is not None for item in test_related)):
|
or all(item is not None for item in test_related)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'test_dataloader, test_cfg and test_evaluator should be either'
|
'test_dataloader, test_cfg, and test_evaluator should be '
|
||||||
' all None or not None, but got '
|
'either all None or not None, but got '
|
||||||
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
|
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
|
||||||
f'test_evaluator={test_evaluator}')
|
f'test_evaluator={test_evaluator}')
|
||||||
self.test_dataloader = test_dataloader
|
self.test_dataloader = test_dataloader
|
||||||
@ -282,10 +295,13 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
self._distributed = True
|
self._distributed = True
|
||||||
|
|
||||||
# self._deterministic, self._seed and self._timestamp will be set in
|
# self._timestamp will be set in the `setup_env` method. Besides,
|
||||||
# the `setup_env`` method. Besides, it also will initialize
|
# it also will initialize multi-process and (or) distributed
|
||||||
# multi-process and (or) distributed environment.
|
# environment.
|
||||||
self.setup_env(env_cfg, seed, deterministic)
|
self.setup_env(env_cfg)
|
||||||
|
# self._deterministic and self._seed will be set in the
|
||||||
|
# `set_randomness`` method
|
||||||
|
self.set_randomness(**randomness)
|
||||||
|
|
||||||
if experiment_name is not None:
|
if experiment_name is not None:
|
||||||
self._experiment_name = f'{experiment_name}_{self._timestamp}'
|
self._experiment_name = f'{experiment_name}_{self._timestamp}'
|
||||||
@ -296,9 +312,9 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
self._experiment_name = self.timestamp
|
self._experiment_name = self.timestamp
|
||||||
|
|
||||||
self.logger = self.build_logger(logger)
|
self.logger = self.build_logger(log_level=log_level)
|
||||||
# message hub used for component interaction
|
# message hub used for component interaction
|
||||||
self.message_hub = self.build_message_hub(message_hub)
|
self.message_hub = self.build_message_hub()
|
||||||
# writer used for writing log or visualizing all kinds of data
|
# writer used for writing log or visualizing all kinds of data
|
||||||
self.writer = self.build_writer(writer)
|
self.writer = self.build_writer(writer)
|
||||||
# Used to reset registries location. See :meth:`Registry.build` for
|
# Used to reset registries location. See :meth:`Registry.build` for
|
||||||
@ -333,7 +349,7 @@ class Runner:
|
|||||||
self.dump_config()
|
self.dump_config()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_from_cfg(cls, cfg: ConfigType) -> 'Runner':
|
def from_cfg(cls, cfg: ConfigType) -> 'Runner':
|
||||||
"""Build a runner from config.
|
"""Build a runner from config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -363,12 +379,11 @@ class Runner:
|
|||||||
resume=cfg.get('resume', False),
|
resume=cfg.get('resume', False),
|
||||||
launcher=cfg.get('launcher', 'none'),
|
launcher=cfg.get('launcher', 'none'),
|
||||||
env_cfg=cfg.get('env_cfg'), # type: ignore
|
env_cfg=cfg.get('env_cfg'), # type: ignore
|
||||||
logger=cfg.get('log_cfg'),
|
log_level=cfg.get('log_level', 'INFO'),
|
||||||
message_hub=cfg.get('message_hub'),
|
|
||||||
writer=cfg.get('writer'),
|
writer=cfg.get('writer'),
|
||||||
default_scope=cfg.get('default_scope'),
|
default_scope=cfg.get('default_scope'),
|
||||||
seed=cfg.get('seed'),
|
randomness=cfg.get('randomness', dict(seed=None)),
|
||||||
deterministic=cfg.get('deterministic', False),
|
experiment_name=cfg.get('experiment_name'),
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -439,10 +454,7 @@ class Runner:
|
|||||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||||
return self._hooks
|
return self._hooks
|
||||||
|
|
||||||
def setup_env(self,
|
def setup_env(self, env_cfg: Dict) -> None:
|
||||||
env_cfg: Dict,
|
|
||||||
seed: Optional[int],
|
|
||||||
deterministic: bool = False) -> None:
|
|
||||||
"""Setup environment.
|
"""Setup environment.
|
||||||
|
|
||||||
An example of ``env_cfg``::
|
An example of ``env_cfg``::
|
||||||
@ -458,17 +470,7 @@ class Runner:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
env_cfg (dict): Config for setting environment.
|
env_cfg (dict): Config for setting environment.
|
||||||
seed (int, optional): A number to set random modules. If not
|
|
||||||
specified, a random number will be set as seed.
|
|
||||||
Defaults to None.
|
|
||||||
deterministic (bool): Whether cudnn to select deterministic
|
|
||||||
algorithms. Defaults to False.
|
|
||||||
See https://pytorch.org/docs/stable/notes/randomness.html for
|
|
||||||
more details.
|
|
||||||
"""
|
"""
|
||||||
self._deterministic = deterministic
|
|
||||||
self._seed = seed
|
|
||||||
|
|
||||||
if env_cfg.get('cudnn_benchmark'):
|
if env_cfg.get('cudnn_benchmark'):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
@ -490,9 +492,6 @@ class Runner:
|
|||||||
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
|
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
|
||||||
time.localtime(timestamp.item()))
|
time.localtime(timestamp.item()))
|
||||||
|
|
||||||
# set random seeds
|
|
||||||
self._set_random_seed()
|
|
||||||
|
|
||||||
def _set_multi_processing(self,
|
def _set_multi_processing(self,
|
||||||
mp_start_method: str = 'fork',
|
mp_start_method: str = 'fork',
|
||||||
opencv_num_threads: int = 0) -> None:
|
opencv_num_threads: int = 0) -> None:
|
||||||
@ -546,16 +545,20 @@ class Runner:
|
|||||||
'optimal performance in your application as needed.')
|
'optimal performance in your application as needed.')
|
||||||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
|
||||||
|
|
||||||
def _set_random_seed(self) -> None:
|
def set_randomness(self, seed, deterministic: bool = False) -> None:
|
||||||
"""Set random seed to guarantee reproducible results.
|
"""Set random seed to guarantee reproducible results.
|
||||||
|
|
||||||
Warning:
|
Args:
|
||||||
Results can not be guaranteed to resproducible if ``self.seed`` is
|
seed (int): A number to set random modules.
|
||||||
None because :meth:`_set_random_seed` will generate a random seed
|
deterministic (bool): Whether to set the deterministic option for
|
||||||
when launching a new experiment.
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
||||||
|
to True and `torch.backends.cudnn.benchmark` to False.
|
||||||
See https://pytorch.org/docs/stable/notes/randomness.html for details.
|
Defaults to False.
|
||||||
|
See https://pytorch.org/docs/stable/notes/randomness.html for
|
||||||
|
more details.
|
||||||
"""
|
"""
|
||||||
|
self._deterministic = deterministic
|
||||||
|
self._seed = seed
|
||||||
if self._seed is None:
|
if self._seed is None:
|
||||||
self._seed = sync_random_seed()
|
self._seed = sync_random_seed()
|
||||||
|
|
||||||
@ -563,69 +566,62 @@ class Runner:
|
|||||||
np.random.seed(self._seed)
|
np.random.seed(self._seed)
|
||||||
torch.manual_seed(self._seed)
|
torch.manual_seed(self._seed)
|
||||||
torch.cuda.manual_seed_all(self._seed)
|
torch.cuda.manual_seed_all(self._seed)
|
||||||
if self._deterministic:
|
if deterministic:
|
||||||
torch.backends.cudnn.deterministic = True
|
if torch.backends.cudnn.benchmark:
|
||||||
|
warnings.warn(
|
||||||
|
'torch.backends.cudnn.benchmark is going to be set as '
|
||||||
|
'`False` to cause cuDNN to deterministically select an '
|
||||||
|
'algorithm')
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
if digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
|
||||||
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
def build_logger(self,
|
def build_logger(self,
|
||||||
logger: Optional[Union[MMLogger,
|
log_level: Union[int, str] = 'INFO',
|
||||||
Dict]] = None) -> MMLogger:
|
log_file: str = None,
|
||||||
|
**kwargs) -> MMLogger:
|
||||||
"""Build a global asscessable MMLogger.
|
"""Build a global asscessable MMLogger.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logger (MMLogger or dict, optional): A MMLogger object or a dict to
|
log_level (int or str): The log level of MMLogger handlers.
|
||||||
build MMLogger object. If ``logger`` is a MMLogger object, just
|
Defaults to 'INFO'.
|
||||||
returns itself. If not specified, default config will be used
|
log_file (str, optional): Path of filename to save log.
|
||||||
to build MMLogger object. Defaults to None.
|
Defaults to None.
|
||||||
|
**kwargs: Remaining parameters passed to ``MMLogger``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MMLogger: A MMLogger object build from ``logger``.
|
MMLogger: A MMLogger object build from ``logger``.
|
||||||
"""
|
"""
|
||||||
if isinstance(logger, MMLogger):
|
if log_file is None:
|
||||||
return logger
|
log_file = osp.join(self.work_dir, f'{self._experiment_name}.log')
|
||||||
elif logger is None:
|
|
||||||
logger = dict(
|
|
||||||
name=self._experiment_name,
|
|
||||||
log_level='INFO',
|
|
||||||
log_file=osp.join(self.work_dir,
|
|
||||||
f'{self._experiment_name}.log'))
|
|
||||||
elif isinstance(logger, dict):
|
|
||||||
# ensure logger containing name key
|
|
||||||
logger.setdefault('name', self._experiment_name)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
'logger should be MMLogger object, a dict or None, '
|
|
||||||
f'but got {logger}')
|
|
||||||
|
|
||||||
return MMLogger.get_instance(**logger)
|
log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs)
|
||||||
|
log_cfg.setdefault('name', self._experiment_name)
|
||||||
|
|
||||||
def build_message_hub(
|
return MMLogger.get_instance(**log_cfg) # type: ignore
|
||||||
self,
|
|
||||||
message_hub: Optional[Union[MessageHub,
|
def build_message_hub(self,
|
||||||
Dict]] = None) -> MessageHub:
|
message_hub: Optional[Dict] = None) -> MessageHub:
|
||||||
"""Build a global asscessable MessageHub.
|
"""Build a global asscessable MessageHub.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_hub (MessageHub or dict, optional): A MessageHub object or
|
message_hub (dict, optional): A dict to build MessageHub object.
|
||||||
a dict to build MessageHub object. If ``message_hub`` is a
|
If not specified, default config will be used to build
|
||||||
MessageHub object, just returns itself. If not specified,
|
MessageHub object. Defaults to None.
|
||||||
default config will be used to build MessageHub object.
|
|
||||||
Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MessageHub: A MessageHub object build from ``message_hub``.
|
MessageHub: A MessageHub object build from ``message_hub``.
|
||||||
"""
|
"""
|
||||||
if isinstance(message_hub, MessageHub):
|
if message_hub is None:
|
||||||
return message_hub
|
|
||||||
elif message_hub is None:
|
|
||||||
message_hub = dict(name=self._experiment_name)
|
message_hub = dict(name=self._experiment_name)
|
||||||
elif isinstance(message_hub, dict):
|
elif isinstance(message_hub, dict):
|
||||||
# ensure message_hub containing name key
|
# ensure message_hub containing name key
|
||||||
message_hub.setdefault('name', self._experiment_name)
|
message_hub.setdefault('name', self._experiment_name)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'message_hub should be MessageHub object, a dict or None, '
|
f'message_hub should be dict or None, but got {message_hub}')
|
||||||
f'but got {message_hub}')
|
|
||||||
|
|
||||||
return MessageHub.get_instance(**message_hub)
|
return MessageHub.get_instance(**message_hub)
|
||||||
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -229,8 +227,6 @@ class TestRunner(TestCase):
|
|||||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
param_scheduler=dict(type='ParamSchedulerHook'))
|
param_scheduler=dict(type='ParamSchedulerHook'))
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.temp_dir)
|
shutil.rmtree(self.temp_dir)
|
||||||
|
|
||||||
@ -238,89 +234,99 @@ class TestRunner(TestCase):
|
|||||||
# 1. test arguments
|
# 1. test arguments
|
||||||
# 1.1 train_dataloader, train_cfg, optimizer and param_scheduler
|
# 1.1 train_dataloader, train_cfg, optimizer and param_scheduler
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init1'
|
||||||
cfg.pop('train_cfg')
|
cfg.pop('train_cfg')
|
||||||
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
|
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
|
||||||
Runner(**cfg)
|
Runner(**cfg)
|
||||||
|
|
||||||
# all of training related configs are None
|
# all of training related configs are None and param_scheduler should
|
||||||
|
# also be None
|
||||||
|
cfg.experiment_name = 'test_init2'
|
||||||
cfg.pop('train_dataloader')
|
cfg.pop('train_dataloader')
|
||||||
cfg.pop('optimizer')
|
cfg.pop('optimizer')
|
||||||
cfg.pop('param_scheduler')
|
cfg.pop('param_scheduler')
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
# avoid different runners having same timestamp
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# all of training related configs are not None
|
# all of training related configs are not None
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init3'
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
|
# all of training related configs are not None and param_scheduler
|
||||||
|
# can be None
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init4'
|
||||||
|
cfg.pop('param_scheduler')
|
||||||
|
runner = Runner(**cfg)
|
||||||
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
|
# param_scheduler should be None when optimizer is None
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init5'
|
||||||
|
cfg.pop('train_cfg')
|
||||||
|
cfg.pop('train_dataloader')
|
||||||
|
cfg.pop('optimizer')
|
||||||
|
with self.assertRaisesRegex(ValueError, 'should be None'):
|
||||||
|
runner = Runner(**cfg)
|
||||||
|
|
||||||
# 1.2 val_dataloader, val_evaluator, val_cfg
|
# 1.2 val_dataloader, val_evaluator, val_cfg
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init6'
|
||||||
cfg.pop('val_cfg')
|
cfg.pop('val_cfg')
|
||||||
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
|
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
|
||||||
Runner(**cfg)
|
Runner(**cfg)
|
||||||
|
|
||||||
time.sleep(1)
|
cfg.experiment_name = 'test_init7'
|
||||||
|
|
||||||
cfg.pop('val_dataloader')
|
cfg.pop('val_dataloader')
|
||||||
cfg.pop('val_evaluator')
|
cfg.pop('val_evaluator')
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init8'
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
# 1.3 test_dataloader, test_evaluator and test_cfg
|
# 1.3 test_dataloader, test_evaluator and test_cfg
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init9'
|
||||||
cfg.pop('test_cfg')
|
cfg.pop('test_cfg')
|
||||||
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
|
with self.assertRaisesRegex(ValueError, 'either all None or not None'):
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
|
|
||||||
time.sleep(1)
|
cfg.experiment_name = 'test_init10'
|
||||||
|
|
||||||
cfg.pop('test_dataloader')
|
cfg.pop('test_dataloader')
|
||||||
cfg.pop('test_evaluator')
|
cfg.pop('test_evaluator')
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 1.4 test env params
|
# 1.4 test env params
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init11'
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertFalse(runner.distributed)
|
self.assertFalse(runner.distributed)
|
||||||
self.assertFalse(runner.deterministic)
|
self.assertFalse(runner.deterministic)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 1.5 message_hub, logger and writer
|
# 1.5 message_hub, logger and writer
|
||||||
# they are all not specified
|
# they are all not specified
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_init12'
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner.logger, MMLogger)
|
self.assertIsInstance(runner.logger, MMLogger)
|
||||||
self.assertIsInstance(runner.message_hub, MessageHub)
|
self.assertIsInstance(runner.message_hub, MessageHub)
|
||||||
self.assertIsInstance(runner.writer, ComposedWriter)
|
self.assertIsInstance(runner.writer, ComposedWriter)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# they are all specified
|
# they are all specified
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.logger = dict(name='test_logger')
|
cfg.experiment_name = 'test_init13'
|
||||||
cfg.message_hub = dict(name='test_message_hub')
|
cfg.log_level = 'INFO'
|
||||||
cfg.writer = dict(name='test_writer')
|
cfg.writer = dict(name='test_writer')
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner.logger, MMLogger)
|
self.assertIsInstance(runner.logger, MMLogger)
|
||||||
self.assertEqual(runner.logger.instance_name, 'test_logger')
|
|
||||||
self.assertIsInstance(runner.message_hub, MessageHub)
|
self.assertIsInstance(runner.message_hub, MessageHub)
|
||||||
self.assertEqual(runner.message_hub.instance_name, 'test_message_hub')
|
|
||||||
self.assertIsInstance(runner.writer, ComposedWriter)
|
self.assertIsInstance(runner.writer, ComposedWriter)
|
||||||
self.assertEqual(runner.writer.instance_name, 'test_writer')
|
|
||||||
|
|
||||||
assert runner.distributed is False
|
assert runner.distributed is False
|
||||||
assert runner.seed is not None
|
assert runner.seed is not None
|
||||||
@ -358,7 +364,6 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
|
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
|
||||||
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
|
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
# 4. initialize runner with objects rather than config
|
# 4. initialize runner with objects rather than config
|
||||||
model = ToyModel()
|
model = ToyModel()
|
||||||
optimizer = SGD(
|
optimizer = SGD(
|
||||||
@ -385,84 +390,66 @@ class TestRunner(TestCase):
|
|||||||
test_dataloader=test_dataloader,
|
test_dataloader=test_dataloader,
|
||||||
test_evaluator=ToyEvaluator1(),
|
test_evaluator=ToyEvaluator1(),
|
||||||
default_hooks=dict(param_scheduler=toy_hook),
|
default_hooks=dict(param_scheduler=toy_hook),
|
||||||
custom_hooks=[toy_hook2])
|
custom_hooks=[toy_hook2],
|
||||||
|
experiment_name='test_init14')
|
||||||
runner.train()
|
runner.train()
|
||||||
runner.test()
|
runner.test()
|
||||||
|
|
||||||
# 5. test `dump_config`
|
# 5. test `dump_config`
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
def test_build_from_cfg(self):
|
def test_from_cfg(self):
|
||||||
runner = Runner.build_from_cfg(cfg=self.epoch_based_cfg)
|
runner = Runner.from_cfg(cfg=self.epoch_based_cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
|
||||||
def test_setup_env(self):
|
def test_setup_env(self):
|
||||||
# TODO
|
# TODO
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_logger(self):
|
def test_build_logger(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
self.epoch_based_cfg.experiment_name = 'test_build_logger1'
|
||||||
|
runner = Runner.from_cfg(self.epoch_based_cfg)
|
||||||
self.assertIsInstance(runner.logger, MMLogger)
|
self.assertIsInstance(runner.logger, MMLogger)
|
||||||
self.assertEqual(runner.experiment_name, runner.logger.instance_name)
|
self.assertEqual(runner.experiment_name, runner.logger.instance_name)
|
||||||
self.assertEqual(runner.logger.level, logging.NOTSET)
|
|
||||||
|
|
||||||
# input is a MMLogger object
|
|
||||||
self.assertEqual(
|
|
||||||
id(runner.build_logger(runner.logger)), id(runner.logger))
|
|
||||||
|
|
||||||
# input is None
|
|
||||||
runner._experiment_name = 'logger_name1'
|
|
||||||
logger = runner.build_logger(None)
|
|
||||||
self.assertIsInstance(logger, MMLogger)
|
|
||||||
self.assertEqual(logger.instance_name, 'logger_name1')
|
|
||||||
|
|
||||||
# input is a dict
|
# input is a dict
|
||||||
log_cfg = dict(name='logger_name2')
|
logger = runner.build_logger(name='test_build_logger2')
|
||||||
logger = runner.build_logger(log_cfg)
|
|
||||||
self.assertIsInstance(logger, MMLogger)
|
self.assertIsInstance(logger, MMLogger)
|
||||||
self.assertEqual(logger.instance_name, 'logger_name2')
|
self.assertEqual(logger.instance_name, 'test_build_logger2')
|
||||||
|
|
||||||
# input is a dict but does not contain name key
|
# input is a dict but does not contain name key
|
||||||
runner._experiment_name = 'logger_name3'
|
runner._experiment_name = 'test_build_logger3'
|
||||||
log_cfg = dict()
|
logger = runner.build_logger()
|
||||||
logger = runner.build_logger(log_cfg)
|
|
||||||
self.assertIsInstance(logger, MMLogger)
|
self.assertIsInstance(logger, MMLogger)
|
||||||
self.assertEqual(logger.instance_name, 'logger_name3')
|
self.assertEqual(logger.instance_name, 'test_build_logger3')
|
||||||
|
|
||||||
# input is not a valid type
|
|
||||||
with self.assertRaisesRegex(TypeError, 'logger should be'):
|
|
||||||
runner.build_logger('invalid-type')
|
|
||||||
|
|
||||||
def test_build_message_hub(self):
|
def test_build_message_hub(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
self.epoch_based_cfg.experiment_name = 'test_build_message_hub1'
|
||||||
|
runner = Runner.from_cfg(self.epoch_based_cfg)
|
||||||
self.assertIsInstance(runner.message_hub, MessageHub)
|
self.assertIsInstance(runner.message_hub, MessageHub)
|
||||||
self.assertEqual(runner.message_hub.instance_name,
|
self.assertEqual(runner.message_hub.instance_name,
|
||||||
runner.experiment_name)
|
runner.experiment_name)
|
||||||
|
|
||||||
# input is a MessageHub object
|
|
||||||
self.assertEqual(
|
|
||||||
id(runner.build_message_hub(runner.message_hub)),
|
|
||||||
id(runner.message_hub))
|
|
||||||
|
|
||||||
# input is a dict
|
# input is a dict
|
||||||
message_hub_cfg = dict(name='message_hub_name1')
|
message_hub_cfg = dict(name='test_build_message_hub2')
|
||||||
message_hub = runner.build_message_hub(message_hub_cfg)
|
message_hub = runner.build_message_hub(message_hub_cfg)
|
||||||
self.assertIsInstance(message_hub, MessageHub)
|
self.assertIsInstance(message_hub, MessageHub)
|
||||||
self.assertEqual(message_hub.instance_name, 'message_hub_name1')
|
self.assertEqual(message_hub.instance_name, 'test_build_message_hub2')
|
||||||
|
|
||||||
# input is a dict but does not contain name key
|
# input is a dict but does not contain name key
|
||||||
runner._experiment_name = 'message_hub_name2'
|
runner._experiment_name = 'test_build_message_hub3'
|
||||||
message_hub_cfg = dict()
|
message_hub_cfg = dict()
|
||||||
message_hub = runner.build_message_hub(message_hub_cfg)
|
message_hub = runner.build_message_hub(message_hub_cfg)
|
||||||
self.assertIsInstance(message_hub, MessageHub)
|
self.assertIsInstance(message_hub, MessageHub)
|
||||||
self.assertEqual(message_hub.instance_name, 'message_hub_name2')
|
self.assertEqual(message_hub.instance_name, 'test_build_message_hub3')
|
||||||
|
|
||||||
# input is not a valid type
|
# input is not a valid type
|
||||||
with self.assertRaisesRegex(TypeError, 'message_hub should be'):
|
with self.assertRaisesRegex(TypeError, 'message_hub should be'):
|
||||||
runner.build_message_hub('invalid-type')
|
runner.build_message_hub('invalid-type')
|
||||||
|
|
||||||
def test_build_writer(self):
|
def test_build_writer(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
self.epoch_based_cfg.experiment_name = 'test_build_writer1'
|
||||||
|
runner = Runner.from_cfg(self.epoch_based_cfg)
|
||||||
self.assertIsInstance(runner.writer, ComposedWriter)
|
self.assertIsInstance(runner.writer, ComposedWriter)
|
||||||
self.assertEqual(runner.experiment_name, runner.writer.instance_name)
|
self.assertEqual(runner.experiment_name, runner.writer.instance_name)
|
||||||
|
|
||||||
@ -471,17 +458,17 @@ class TestRunner(TestCase):
|
|||||||
id(runner.build_writer(runner.writer)), id(runner.writer))
|
id(runner.build_writer(runner.writer)), id(runner.writer))
|
||||||
|
|
||||||
# input is a dict
|
# input is a dict
|
||||||
writer_cfg = dict(name='writer_name1')
|
writer_cfg = dict(name='test_build_writer2')
|
||||||
writer = runner.build_writer(writer_cfg)
|
writer = runner.build_writer(writer_cfg)
|
||||||
self.assertIsInstance(writer, ComposedWriter)
|
self.assertIsInstance(writer, ComposedWriter)
|
||||||
self.assertEqual(writer.instance_name, 'writer_name1')
|
self.assertEqual(writer.instance_name, 'test_build_writer2')
|
||||||
|
|
||||||
# input is a dict but does not contain name key
|
# input is a dict but does not contain name key
|
||||||
runner._experiment_name = 'writer_name2'
|
runner._experiment_name = 'test_build_writer3'
|
||||||
writer_cfg = dict()
|
writer_cfg = dict()
|
||||||
writer = runner.build_writer(writer_cfg)
|
writer = runner.build_writer(writer_cfg)
|
||||||
self.assertIsInstance(writer, ComposedWriter)
|
self.assertIsInstance(writer, ComposedWriter)
|
||||||
self.assertEqual(writer.instance_name, 'writer_name2')
|
self.assertEqual(writer.instance_name, 'test_build_writer3')
|
||||||
|
|
||||||
# input is not a valid type
|
# input is not a valid type
|
||||||
with self.assertRaisesRegex(TypeError, 'writer should be'):
|
with self.assertRaisesRegex(TypeError, 'writer should be'):
|
||||||
@ -501,12 +488,16 @@ class TestRunner(TestCase):
|
|||||||
type='ToyScheduler', milestones=[1, 2])
|
type='ToyScheduler', milestones=[1, 2])
|
||||||
self.epoch_based_cfg.default_scope = 'toy'
|
self.epoch_based_cfg.default_scope = 'toy'
|
||||||
|
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_default_scope'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
self.assertIsInstance(runner.param_schedulers[0], ToyScheduler)
|
self.assertIsInstance(runner.param_schedulers[0], ToyScheduler)
|
||||||
|
|
||||||
def test_build_model(self):
|
def test_build_model(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_model'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
self.assertIsInstance(runner.model, ToyModel)
|
self.assertIsInstance(runner.model, ToyModel)
|
||||||
|
|
||||||
# input should be a nn.Module object or dict
|
# input should be a nn.Module object or dict
|
||||||
@ -526,12 +517,15 @@ class TestRunner(TestCase):
|
|||||||
# TODO: test on distributed environment
|
# TODO: test on distributed environment
|
||||||
# custom model wrapper
|
# custom model wrapper
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_wrap_model'
|
||||||
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
|
cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
|
||||||
runner = Runner.build_from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
self.assertIsInstance(runner.model, CustomModelWrapper)
|
self.assertIsInstance(runner.model, CustomModelWrapper)
|
||||||
|
|
||||||
def test_build_optimizer(self):
|
def test_build_optimizer(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_optimizer'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
# input should be an Optimizer object or dict
|
# input should be an Optimizer object or dict
|
||||||
with self.assertRaisesRegex(TypeError, 'optimizer should be'):
|
with self.assertRaisesRegex(TypeError, 'optimizer should be'):
|
||||||
@ -547,7 +541,9 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(optimizer, SGD)
|
self.assertIsInstance(optimizer, SGD)
|
||||||
|
|
||||||
def test_build_param_scheduler(self):
|
def test_build_param_scheduler(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_param_scheduler'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
# `build_optimizer` should be called before `build_param_scheduler`
|
# `build_optimizer` should be called before `build_param_scheduler`
|
||||||
cfg = dict(type='MultiStepLR', milestones=[1, 2])
|
cfg = dict(type='MultiStepLR', milestones=[1, 2])
|
||||||
@ -584,7 +580,9 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(param_schedulers[1], StepLR)
|
self.assertIsInstance(param_schedulers[1], StepLR)
|
||||||
|
|
||||||
def test_build_evaluator(self):
|
def test_build_evaluator(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_evaluator'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
# input is a BaseEvaluator or ComposedEvaluator object
|
# input is a BaseEvaluator or ComposedEvaluator object
|
||||||
evaluator = ToyEvaluator1()
|
evaluator = ToyEvaluator1()
|
||||||
@ -603,7 +601,9 @@ class TestRunner(TestCase):
|
|||||||
runner.build_evaluator(evaluator), ComposedEvaluator)
|
runner.build_evaluator(evaluator), ComposedEvaluator)
|
||||||
|
|
||||||
def test_build_dataloader(self):
|
def test_build_dataloader(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_dataloader'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
cfg = dict(
|
cfg = dict(
|
||||||
dataset=dict(type='ToyDataset'),
|
dataset=dict(type='ToyDataset'),
|
||||||
@ -616,8 +616,11 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(dataloader.sampler, DefaultSampler)
|
self.assertIsInstance(dataloader.sampler, DefaultSampler)
|
||||||
|
|
||||||
def test_build_train_loop(self):
|
def test_build_train_loop(self):
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_train_loop'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
# input should be a Loop object or dict
|
# input should be a Loop object or dict
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
|
||||||
with self.assertRaisesRegex(TypeError, 'should be'):
|
with self.assertRaisesRegex(TypeError, 'should be'):
|
||||||
runner.build_train_loop('invalid-type')
|
runner.build_train_loop('invalid-type')
|
||||||
|
|
||||||
@ -653,7 +656,9 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(loop, CustomTrainLoop)
|
self.assertIsInstance(loop, CustomTrainLoop)
|
||||||
|
|
||||||
def test_build_val_loop(self):
|
def test_build_val_loop(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_val_loop'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
# input should be a Loop object or dict
|
# input should be a Loop object or dict
|
||||||
with self.assertRaisesRegex(TypeError, 'should be'):
|
with self.assertRaisesRegex(TypeError, 'should be'):
|
||||||
@ -678,7 +683,9 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(loop, CustomValLoop)
|
self.assertIsInstance(loop, CustomValLoop)
|
||||||
|
|
||||||
def test_build_test_loop(self):
|
def test_build_test_loop(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_build_test_loop'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
# input should be a Loop object or dict
|
# input should be a Loop object or dict
|
||||||
with self.assertRaisesRegex(TypeError, 'should be'):
|
with self.assertRaisesRegex(TypeError, 'should be'):
|
||||||
@ -705,16 +712,15 @@ class TestRunner(TestCase):
|
|||||||
def test_train(self):
|
def test_train(self):
|
||||||
# 1. test `self.train_loop` is None
|
# 1. test `self.train_loop` is None
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_train1'
|
||||||
cfg.pop('train_dataloader')
|
cfg.pop('train_dataloader')
|
||||||
cfg.pop('train_cfg')
|
cfg.pop('train_cfg')
|
||||||
cfg.pop('optimizer')
|
cfg.pop('optimizer')
|
||||||
cfg.pop('param_scheduler')
|
cfg.pop('param_scheduler')
|
||||||
runner = Runner.build_from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 2. test iter and epoch counter of EpochBasedTrainLoop
|
# 2. test iter and epoch counter of EpochBasedTrainLoop
|
||||||
epoch_results = []
|
epoch_results = []
|
||||||
epoch_targets = [i for i in range(3)]
|
epoch_targets = [i for i in range(3)]
|
||||||
@ -733,10 +739,10 @@ class TestRunner(TestCase):
|
|||||||
iter_results.append(runner.iter)
|
iter_results.append(runner.iter)
|
||||||
batch_idx_results.append(batch_idx)
|
batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
self.epoch_based_cfg.custom_hooks = [
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
dict(type='TestEpochHook', priority=50)
|
cfg.experiment_name = 'test_train2'
|
||||||
]
|
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
@ -749,8 +755,6 @@ class TestRunner(TestCase):
|
|||||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
# 3. test iter and epoch counter of IterBasedTrainLoop
|
# 3. test iter and epoch counter of IterBasedTrainLoop
|
||||||
epoch_results = []
|
epoch_results = []
|
||||||
iter_results = []
|
iter_results = []
|
||||||
@ -768,11 +772,11 @@ class TestRunner(TestCase):
|
|||||||
iter_results.append(runner.iter)
|
iter_results.append(runner.iter)
|
||||||
batch_idx_results.append(batch_idx)
|
batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
self.iter_based_cfg.custom_hooks = [
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
dict(type='TestIterHook', priority=50)
|
cfg.experiment_name = 'test_train3'
|
||||||
]
|
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||||
self.iter_based_cfg.val_cfg = dict(interval=4)
|
cfg.val_cfg = dict(interval=4)
|
||||||
runner = Runner.build_from_cfg(self.iter_based_cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||||
@ -786,32 +790,38 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
def test_val(self):
|
def test_val(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_val1'
|
||||||
cfg.pop('val_dataloader')
|
cfg.pop('val_dataloader')
|
||||||
cfg.pop('val_cfg')
|
cfg.pop('val_cfg')
|
||||||
cfg.pop('val_evaluator')
|
cfg.pop('val_evaluator')
|
||||||
runner = Runner.build_from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
||||||
runner.val()
|
runner.val()
|
||||||
|
|
||||||
time.sleep(1)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg.experiment_name = 'test_val2'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.val()
|
runner.val()
|
||||||
|
|
||||||
def test_test(self):
|
def test_test(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_test1'
|
||||||
cfg.pop('test_dataloader')
|
cfg.pop('test_dataloader')
|
||||||
cfg.pop('test_cfg')
|
cfg.pop('test_cfg')
|
||||||
cfg.pop('test_evaluator')
|
cfg.pop('test_evaluator')
|
||||||
runner = Runner.build_from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
||||||
runner.test()
|
runner.test()
|
||||||
|
|
||||||
time.sleep(1)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg.experiment_name = 'test_test2'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.test()
|
runner.test()
|
||||||
|
|
||||||
def test_register_hook(self):
|
def test_register_hook(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_register_hook'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
|
|
||||||
# 1. test `hook` parameter
|
# 1. test `hook` parameter
|
||||||
@ -870,7 +880,9 @@ class TestRunner(TestCase):
|
|||||||
get_priority(runner._hooks[3].priority), get_priority('VERY_LOW'))
|
get_priority(runner._hooks[3].priority), get_priority('VERY_LOW'))
|
||||||
|
|
||||||
def test_default_hooks(self):
|
def test_default_hooks(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_default_hooks'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
|
|
||||||
# register five hooks by default
|
# register five hooks by default
|
||||||
@ -893,7 +905,10 @@ class TestRunner(TestCase):
|
|||||||
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
||||||
|
|
||||||
def test_custom_hooks(self):
|
def test_custom_hooks(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_custom_hooks'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
self.assertEqual(len(runner._hooks), 5)
|
self.assertEqual(len(runner._hooks), 5)
|
||||||
custom_hooks = [dict(type='ToyHook')]
|
custom_hooks = [dict(type='ToyHook')]
|
||||||
runner.register_custom_hooks(custom_hooks)
|
runner.register_custom_hooks(custom_hooks)
|
||||||
@ -901,7 +916,10 @@ class TestRunner(TestCase):
|
|||||||
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
self.assertTrue(isinstance(runner._hooks[5], ToyHook))
|
||||||
|
|
||||||
def test_register_hooks(self):
|
def test_register_hooks(self):
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_register_hooks'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
runner._hooks = []
|
runner._hooks = []
|
||||||
custom_hooks = [dict(type='ToyHook')]
|
custom_hooks = [dict(type='ToyHook')]
|
||||||
runner.register_hooks(custom_hooks=custom_hooks)
|
runner.register_hooks(custom_hooks=custom_hooks)
|
||||||
@ -975,7 +993,8 @@ class TestRunner(TestCase):
|
|||||||
self.iter_based_cfg.custom_hooks = [
|
self.iter_based_cfg.custom_hooks = [
|
||||||
dict(type='TestWarmupHook', priority=50)
|
dict(type='TestWarmupHook', priority=50)
|
||||||
]
|
]
|
||||||
runner = Runner.build_from_cfg(self.iter_based_cfg)
|
self.iter_based_cfg.experiment_name = 'test_custom_loop'
|
||||||
|
runner = Runner.from_cfg(self.iter_based_cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
self.assertIsInstance(runner.train_loop, CustomTrainLoop2)
|
self.assertIsInstance(runner.train_loop, CustomTrainLoop2)
|
||||||
@ -990,7 +1009,9 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
def test_checkpoint(self):
|
def test_checkpoint(self):
|
||||||
# 1. test epoch based
|
# 1. test epoch based
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint1'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
# 1.1 test `save_checkpoint` which called by `CheckpointHook`
|
# 1.1 test `save_checkpoint` which called by `CheckpointHook`
|
||||||
@ -1006,17 +1027,19 @@ class TestRunner(TestCase):
|
|||||||
assert isinstance(ckpt['optimizer'], dict)
|
assert isinstance(ckpt['optimizer'], dict)
|
||||||
assert isinstance(ckpt['param_schedulers'], list)
|
assert isinstance(ckpt['param_schedulers'], list)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
# 1.2 test `load_checkpoint`
|
# 1.2 test `load_checkpoint`
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint2'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.load_checkpoint(path)
|
runner.load_checkpoint(path)
|
||||||
self.assertEqual(runner.epoch, 0)
|
self.assertEqual(runner.epoch, 0)
|
||||||
self.assertEqual(runner.iter, 0)
|
self.assertEqual(runner.iter, 0)
|
||||||
self.assertTrue(runner._has_loaded)
|
self.assertTrue(runner._has_loaded)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
# 1.3 test `resume`
|
# 1.3 test `resume`
|
||||||
runner = Runner.build_from_cfg(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint3'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.resume(path)
|
runner.resume(path)
|
||||||
self.assertEqual(runner.epoch, 3)
|
self.assertEqual(runner.epoch, 3)
|
||||||
self.assertEqual(runner.iter, 12)
|
self.assertEqual(runner.iter, 12)
|
||||||
@ -1025,8 +1048,9 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
# 2. test iter based
|
# 2. test iter based
|
||||||
time.sleep(1)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
runner = Runner.build_from_cfg(self.iter_based_cfg)
|
cfg.experiment_name = 'test_checkpoint4'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
# 2.1 test `save_checkpoint` which called by `CheckpointHook`
|
# 2.1 test `save_checkpoint` which called by `CheckpointHook`
|
||||||
@ -1043,16 +1067,18 @@ class TestRunner(TestCase):
|
|||||||
assert isinstance(ckpt['param_schedulers'], list)
|
assert isinstance(ckpt['param_schedulers'], list)
|
||||||
|
|
||||||
# 2.2 test `load_checkpoint`
|
# 2.2 test `load_checkpoint`
|
||||||
time.sleep(1)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
runner = Runner.build_from_cfg(self.iter_based_cfg)
|
cfg.experiment_name = 'test_checkpoint5'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.load_checkpoint(path)
|
runner.load_checkpoint(path)
|
||||||
self.assertEqual(runner.epoch, 0)
|
self.assertEqual(runner.epoch, 0)
|
||||||
self.assertEqual(runner.iter, 0)
|
self.assertEqual(runner.iter, 0)
|
||||||
self.assertTrue(runner._has_loaded)
|
self.assertTrue(runner._has_loaded)
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
# 2.3 test `resume`
|
# 2.3 test `resume`
|
||||||
runner = Runner.build_from_cfg(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint6'
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.resume(path)
|
runner.resume(path)
|
||||||
self.assertEqual(runner.epoch, 0)
|
self.assertEqual(runner.epoch, 0)
|
||||||
self.assertEqual(runner.iter, 12)
|
self.assertEqual(runner.iter, 12)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user