mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Make loop-related attributes to be runner's properties. (#236)
* [Enhance] Make loop related attributes to be runner's properties. * move iter and epoch to loop * resolve comments
This commit is contained in:
parent
cc8a6b86e1
commit
e37f1f905b
@ -194,7 +194,8 @@ class CheckpointHook(Hook):
|
|||||||
# 1. every ``self.interval`` iterations
|
# 1. every ``self.interval`` iterations
|
||||||
# 2. reach the last iteration of training
|
# 2. reach the last iteration of training
|
||||||
if self.every_n_iters(runner, self.interval) or \
|
if self.every_n_iters(runner, self.interval) or \
|
||||||
(self.save_last and self.is_last_iter(runner, mode='train')):
|
(self.save_last and
|
||||||
|
self.is_last_train_iter(runner)):
|
||||||
runner.logger.info(
|
runner.logger.info(
|
||||||
f'Saving checkpoint at {runner.iter + 1} iterations')
|
f'Saving checkpoint at {runner.iter + 1} iterations')
|
||||||
self._save_checkpoint(runner)
|
self._save_checkpoint(runner)
|
||||||
|
@ -409,25 +409,15 @@ class Hook:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: Whether reaches the end of training epoch.
|
bool: Whether reaches the end of training epoch.
|
||||||
"""
|
"""
|
||||||
return runner.epoch + 1 == runner.train_loop.max_epochs
|
return runner.epoch + 1 == runner.max_epochs
|
||||||
|
|
||||||
def is_last_iter(self, runner, mode='train') -> bool:
|
def is_last_train_iter(self, runner) -> bool:
|
||||||
"""Test whether current iteration is the last iteration.
|
"""Test whether current iteration is the last train iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training process.
|
||||||
process.
|
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether current iteration is the last iteration.
|
bool: Whether current iteration is the last train iteration.
|
||||||
"""
|
"""
|
||||||
if mode == 'train':
|
return runner.iter + 1 == runner.max_iters
|
||||||
return runner.iter + 1 == runner.train_loop.max_iters
|
|
||||||
elif mode == 'val':
|
|
||||||
return runner.iter + 1 == runner.val_loop.max_iters
|
|
||||||
elif mode == 'test':
|
|
||||||
return runner.iter + 1 == runner.test_loop.max_iters
|
|
||||||
else:
|
|
||||||
raise ValueError('mode should be train, val or test but got'
|
|
||||||
f'{mode}')
|
|
||||||
|
@ -98,14 +98,13 @@ class IterTimerHook(Hook):
|
|||||||
time_sec_avg = self.time_sec_tot / (
|
time_sec_avg = self.time_sec_tot / (
|
||||||
runner.iter - self.start_iter + 1)
|
runner.iter - self.start_iter + 1)
|
||||||
# Calculate eta.
|
# Calculate eta.
|
||||||
eta_sec = time_sec_avg * (
|
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||||
runner.train_loop.max_iters - runner.iter - 1)
|
|
||||||
runner.message_hub.update_info('eta', eta_sec)
|
runner.message_hub.update_info('eta', eta_sec)
|
||||||
else:
|
else:
|
||||||
if mode == 'val':
|
if mode == 'val':
|
||||||
cur_dataloader = runner.val_loop.dataloader
|
cur_dataloader = runner.val_dataloader
|
||||||
else:
|
else:
|
||||||
cur_dataloader = runner.test_loop.dataloader
|
cur_dataloader = runner.test_dataloader
|
||||||
|
|
||||||
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
|
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
|
||||||
runner.message_hub.update_info('eta', eta_sec)
|
runner.message_hub.update_info('eta', eta_sec)
|
||||||
|
@ -127,13 +127,13 @@ class LoggerHook(Hook):
|
|||||||
# Print experiment name every n iterations.
|
# Print experiment name every n iterations.
|
||||||
if self.every_n_iters(runner,
|
if self.every_n_iters(runner,
|
||||||
self.interval_exp_name) or (self.end_of_epoch(
|
self.interval_exp_name) or (self.end_of_epoch(
|
||||||
runner.train_loop.dataloader, batch_idx)):
|
runner.train_dataloader, batch_idx)):
|
||||||
exp_info = f'Exp name: {runner.experiment_name}'
|
exp_info = f'Exp name: {runner.experiment_name}'
|
||||||
runner.logger.info(exp_info)
|
runner.logger.info(exp_info)
|
||||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||||
tag, log_str = runner.log_processor.get_log_after_iter(
|
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||||
runner, batch_idx, 'train')
|
runner, batch_idx, 'train')
|
||||||
elif (self.end_of_epoch(runner.train_loop.dataloader, batch_idx)
|
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
|
||||||
and not self.ignore_last):
|
and not self.ignore_last):
|
||||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||||
# `self.ignore_last==True`, the log of remaining iterations will
|
# `self.ignore_last==True`, the log of remaining iterations will
|
||||||
|
@ -142,7 +142,7 @@ class LogProcessor:
|
|||||||
else:
|
else:
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
log_str = (f'Iter({mode}) '
|
log_str = (f'Iter({mode}) '
|
||||||
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
|
f'[{cur_iter}/{runner.max_iters}] ')
|
||||||
else:
|
else:
|
||||||
log_str = (f'Iter({mode}) [{batch_idx+1}'
|
log_str = (f'Iter({mode}) [{batch_idx+1}'
|
||||||
f'/{len(current_loop.dataloader)}] ')
|
f'/{len(current_loop.dataloader)}] ')
|
||||||
|
@ -19,7 +19,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
runner (Runner): A reference of runner.
|
runner (Runner): A reference of runner.
|
||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
max_epoch (int): Total training epochs.
|
max_epochs (int): Total training epochs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||||
@ -27,6 +27,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._max_epochs = max_epochs
|
self._max_epochs = max_epochs
|
||||||
self._max_iters = max_epochs * len(self.dataloader)
|
self._max_iters = max_epochs * len(self.dataloader)
|
||||||
|
self._epoch = 0
|
||||||
|
self._iter = 0
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.runner.visualizer.dataset_meta = \
|
self.runner.visualizer.dataset_meta = \
|
||||||
self.dataloader.dataset.metainfo
|
self.dataloader.dataset.metainfo
|
||||||
@ -46,15 +48,25 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
"""int: Total iterations to train model."""
|
"""int: Total iterations to train model."""
|
||||||
return self._max_iters
|
return self._max_iters
|
||||||
|
|
||||||
|
@property
|
||||||
|
def epoch(self):
|
||||||
|
"""int: Current epoch."""
|
||||||
|
return self._epoch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iter(self):
|
||||||
|
"""int: Current iteration."""
|
||||||
|
return self._iter
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Launch training."""
|
"""Launch training."""
|
||||||
self.runner.call_hook('before_train')
|
self.runner.call_hook('before_train')
|
||||||
|
|
||||||
while self.runner._epoch < self._max_epochs:
|
while self._epoch < self._max_epochs:
|
||||||
self.run_epoch()
|
self.run_epoch()
|
||||||
|
|
||||||
if (self.runner.val_loop is not None and
|
if (self.runner.val_loop is not None
|
||||||
self.runner._epoch % self.runner.val_loop.interval == 0):
|
and self._epoch % self.runner.val_loop.interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
self.runner.call_hook('after_train')
|
self.runner.call_hook('after_train')
|
||||||
@ -67,7 +79,9 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
self.run_iter(idx, data_batch)
|
self.run_iter(idx, data_batch)
|
||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
self.runner.epoch += 1
|
self._epoch += 1
|
||||||
|
# To allow components that cannot access runner to get current epoch.
|
||||||
|
self.runner.message_hub.update_info('epoch', self._epoch)
|
||||||
|
|
||||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||||
"""Iterate one min-batch.
|
"""Iterate one min-batch.
|
||||||
@ -90,7 +104,10 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
|
|
||||||
self.runner.iter += 1
|
self._iter += 1
|
||||||
|
# To allow components that cannot access runner to get current
|
||||||
|
# iteration.
|
||||||
|
self.runner.message_hub.update_info('iter', self._iter)
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
@ -101,13 +118,16 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
runner (Runner): A reference of runner.
|
runner (Runner): A reference of runner.
|
||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
max_iter (int): Total training iterations.
|
max_iters (int): Total training iterations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||||
max_iters: int) -> None:
|
max_iters: int) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._max_iters = max_iters
|
self._max_iters = max_iters
|
||||||
|
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
||||||
|
self._epoch = 0
|
||||||
|
self._iter = 0
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.runner.visualizer.dataset_meta = \
|
self.runner.visualizer.dataset_meta = \
|
||||||
self.dataloader.dataset.metainfo
|
self.dataloader.dataset.metainfo
|
||||||
@ -118,25 +138,40 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
'None.')
|
'None.')
|
||||||
self.dataloader = iter(self.dataloader)
|
self.dataloader = iter(self.dataloader)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_epochs(self):
|
||||||
|
"""int: Total epochs to train model."""
|
||||||
|
return self._max_epochs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_iters(self):
|
def max_iters(self):
|
||||||
"""int: Total iterations to train model."""
|
"""int: Total iterations to train model."""
|
||||||
return self._max_iters
|
return self._max_iters
|
||||||
|
|
||||||
|
@property
|
||||||
|
def epoch(self):
|
||||||
|
"""int: Current epoch."""
|
||||||
|
return self._epoch
|
||||||
|
|
||||||
|
@property
|
||||||
|
def iter(self):
|
||||||
|
"""int: Current iteration."""
|
||||||
|
return self._iter
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Launch training."""
|
"""Launch training."""
|
||||||
self.runner.call_hook('before_train')
|
self.runner.call_hook('before_train')
|
||||||
# In iteration-based training loop, we treat the whole training process
|
# In iteration-based training loop, we treat the whole training process
|
||||||
# as a big epoch and execute the corresponding hook.
|
# as a big epoch and execute the corresponding hook.
|
||||||
self.runner.call_hook('before_train_epoch')
|
self.runner.call_hook('before_train_epoch')
|
||||||
while self.runner._iter < self._max_iters:
|
while self._iter < self._max_iters:
|
||||||
self.runner.model.train()
|
self.runner.model.train()
|
||||||
|
|
||||||
data_batch = next(self.dataloader)
|
data_batch = next(self.dataloader)
|
||||||
self.run_iter(data_batch)
|
self.run_iter(data_batch)
|
||||||
|
|
||||||
if (self.runner.val_loop is not None and
|
if (self.runner.val_loop is not None
|
||||||
self.runner._iter % self.runner.val_loop.interval == 0):
|
and self._iter % self.runner.val_interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
@ -149,9 +184,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||||
"""
|
"""
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'before_train_iter',
|
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
|
||||||
batch_idx=self.runner._iter,
|
|
||||||
data_batch=data_batch)
|
|
||||||
# outputs should be a dict containing loss tensor
|
# outputs should be a dict containing loss tensor
|
||||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||||
|
|
||||||
@ -161,10 +194,13 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
|
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'after_train_iter',
|
'after_train_iter',
|
||||||
batch_idx=self.runner._iter,
|
batch_idx=self._iter,
|
||||||
data_batch=data_batch,
|
data_batch=data_batch,
|
||||||
outputs=self.runner.outputs)
|
outputs=self.runner.outputs)
|
||||||
self.runner.iter += 1
|
self._iter += 1
|
||||||
|
# To allow components that cannot access runner to get current
|
||||||
|
# iteration.
|
||||||
|
self.runner.message_hub.update_info('iter', self._iter)
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
|
@ -199,9 +199,9 @@ class Runner:
|
|||||||
>>> runner.test()
|
>>> runner.test()
|
||||||
"""
|
"""
|
||||||
cfg: Config
|
cfg: Config
|
||||||
train_loop: Optional[Union[BaseLoop, Dict]]
|
_train_loop: Optional[Union[BaseLoop, Dict]]
|
||||||
val_loop: Optional[Union[BaseLoop, Dict]]
|
_val_loop: Optional[Union[BaseLoop, Dict]]
|
||||||
test_loop: Optional[Union[BaseLoop, Dict]]
|
_test_loop: Optional[Union[BaseLoop, Dict]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -244,9 +244,6 @@ class Runner:
|
|||||||
else:
|
else:
|
||||||
self.cfg = Config(dict())
|
self.cfg = Config(dict())
|
||||||
|
|
||||||
self._epoch = 0
|
|
||||||
self._iter = 0
|
|
||||||
|
|
||||||
# lazy initialization
|
# lazy initialization
|
||||||
training_related = [train_dataloader, train_cfg, optimizer]
|
training_related = [train_dataloader, train_cfg, optimizer]
|
||||||
if not (all(item is None for item in training_related)
|
if not (all(item is None for item in training_related)
|
||||||
@ -257,8 +254,8 @@ class Runner:
|
|||||||
f'train_dataloader={train_dataloader}, '
|
f'train_dataloader={train_dataloader}, '
|
||||||
f'train_cfg={train_cfg}, '
|
f'train_cfg={train_cfg}, '
|
||||||
f'optimizer={optimizer}.')
|
f'optimizer={optimizer}.')
|
||||||
self.train_dataloader = train_dataloader
|
self._train_dataloader = train_dataloader
|
||||||
self.train_loop = train_cfg
|
self._train_loop = train_cfg
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
# If there is no need to adjust learning rate, momentum or other
|
# If there is no need to adjust learning rate, momentum or other
|
||||||
@ -284,9 +281,9 @@ class Runner:
|
|||||||
'all None or not None, but got '
|
'all None or not None, but got '
|
||||||
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
|
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
|
||||||
f'val_evaluator={val_evaluator}')
|
f'val_evaluator={val_evaluator}')
|
||||||
self.val_dataloader = val_dataloader
|
self._val_dataloader = val_dataloader
|
||||||
self.val_loop = val_cfg
|
self._val_loop = val_cfg
|
||||||
self.val_evaluator = val_evaluator
|
self._val_evaluator = val_evaluator
|
||||||
|
|
||||||
test_related = [test_dataloader, test_cfg, test_evaluator]
|
test_related = [test_dataloader, test_cfg, test_evaluator]
|
||||||
if not (all(item is None for item in test_related)
|
if not (all(item is None for item in test_related)
|
||||||
@ -296,9 +293,9 @@ class Runner:
|
|||||||
'either all None or not None, but got '
|
'either all None or not None, but got '
|
||||||
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
|
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
|
||||||
f'test_evaluator={test_evaluator}')
|
f'test_evaluator={test_evaluator}')
|
||||||
self.test_dataloader = test_dataloader
|
self._test_dataloader = test_dataloader
|
||||||
self.test_loop = test_cfg
|
self._test_loop = test_cfg
|
||||||
self.test_evaluator = test_evaluator
|
self._test_evaluator = test_evaluator
|
||||||
|
|
||||||
self._launcher = launcher
|
self._launcher = launcher
|
||||||
if self._launcher == 'none':
|
if self._launcher == 'none':
|
||||||
@ -435,30 +432,25 @@ class Runner:
|
|||||||
"""str: The working directory to save checkpoints and logs."""
|
"""str: The working directory to save checkpoints and logs."""
|
||||||
return self._work_dir
|
return self._work_dir
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_epochs(self):
|
||||||
|
"""int: Total epochs to train model."""
|
||||||
|
return self.train_loop.max_epochs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_iters(self):
|
||||||
|
"""int: Total iterations to train model."""
|
||||||
|
return self.train_loop.max_iters
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def epoch(self):
|
def epoch(self):
|
||||||
"""int: Current epoch."""
|
"""int: Current epoch."""
|
||||||
return self._epoch
|
return self.train_loop.epoch
|
||||||
|
|
||||||
@epoch.setter
|
|
||||||
def epoch(self, epoch: int):
|
|
||||||
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
|
|
||||||
self._epoch = epoch
|
|
||||||
# To allow components that cannot access runner to get current epoch.
|
|
||||||
self.message_hub.update_info('epoch', epoch)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def iter(self):
|
def iter(self):
|
||||||
"""int: Current iteration."""
|
"""int: Current iteration."""
|
||||||
return self._iter
|
return self.train_loop.iter
|
||||||
|
|
||||||
@iter.setter
|
|
||||||
def iter(self, iter: int):
|
|
||||||
"""Update iter and synchronize iter in :attr:`message_hub`."""
|
|
||||||
self._iter = iter
|
|
||||||
# To allow components that cannot access runner to get current
|
|
||||||
# iteration.
|
|
||||||
self.message_hub.update_info('iter', iter)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def launcher(self):
|
def launcher(self):
|
||||||
@ -500,6 +492,63 @@ class Runner:
|
|||||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||||
return self._hooks
|
return self._hooks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_loop(self):
|
||||||
|
""":obj:`BaseLoop`: A loop to run training."""
|
||||||
|
if isinstance(self._train_loop, BaseLoop) or self._train_loop is None:
|
||||||
|
return self._train_loop
|
||||||
|
else:
|
||||||
|
self._train_loop = self.build_train_loop(self._train_loop)
|
||||||
|
return self._train_loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_loop(self):
|
||||||
|
""":obj:`BaseLoop`: A loop to run validation."""
|
||||||
|
if isinstance(self._val_loop, BaseLoop) or self._val_loop is None:
|
||||||
|
return self._val_loop
|
||||||
|
else:
|
||||||
|
self._val_loop = self.build_val_loop(self._val_loop)
|
||||||
|
return self._val_loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_loop(self):
|
||||||
|
""":obj:`BaseLoop`: A loop to run testing."""
|
||||||
|
if isinstance(self._test_loop, BaseLoop) or self._test_loop is None:
|
||||||
|
return self._test_loop
|
||||||
|
else:
|
||||||
|
self._test_loop = self.build_test_loop(self._test_loop)
|
||||||
|
return self._test_loop
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_dataloader(self):
|
||||||
|
"""The data loader for training."""
|
||||||
|
return self.train_loop.dataloader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_dataloader(self):
|
||||||
|
"""The data loader for validation."""
|
||||||
|
return self.val_loop.dataloader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_dataloader(self):
|
||||||
|
"""The data loader for testing."""
|
||||||
|
return self.test_loop.dataloader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_evaluator(self):
|
||||||
|
""":obj:`Evaluator`: An evaluator for validation."""
|
||||||
|
return self.val_loop.evaluator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_evaluator(self):
|
||||||
|
""":obj:`Evaluator`: An evaluator for testing."""
|
||||||
|
return self.test_loop.evaluator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_interval(self):
|
||||||
|
"""int: Interval to run validation during training."""
|
||||||
|
return self.val_loop.interval
|
||||||
|
|
||||||
def setup_env(self, env_cfg: Dict) -> None:
|
def setup_env(self, env_cfg: Dict) -> None:
|
||||||
"""Setup environment.
|
"""Setup environment.
|
||||||
|
|
||||||
@ -857,7 +906,7 @@ class Runner:
|
|||||||
'by_epoch', True
|
'by_epoch', True
|
||||||
), 'only epoch-based parameter scheduler can be ' \
|
), 'only epoch-based parameter scheduler can be ' \
|
||||||
'converted to iter-based'
|
'converted to iter-based'
|
||||||
assert isinstance(self.train_loop, BaseLoop), \
|
assert isinstance(self._train_loop, BaseLoop), \
|
||||||
'Scheduler can only be converted to iter-based ' \
|
'Scheduler can only be converted to iter-based ' \
|
||||||
'when train loop is built.'
|
'when train loop is built.'
|
||||||
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
|
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
|
||||||
@ -866,7 +915,7 @@ class Runner:
|
|||||||
optimizer=self.optimizer,
|
optimizer=self.optimizer,
|
||||||
**_scheduler,
|
**_scheduler,
|
||||||
epoch_length=len(
|
epoch_length=len(
|
||||||
self.train_loop.dataloader), # type: ignore
|
self.train_dataloader), # type: ignore
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
param_schedulers.append(
|
param_schedulers.append(
|
||||||
@ -1043,15 +1092,15 @@ class Runner:
|
|||||||
loop = LOOPS.build(
|
loop = LOOPS.build(
|
||||||
loop_cfg,
|
loop_cfg,
|
||||||
default_args=dict(
|
default_args=dict(
|
||||||
runner=self, dataloader=self.train_dataloader))
|
runner=self, dataloader=self._train_dataloader))
|
||||||
else:
|
else:
|
||||||
by_epoch = loop_cfg.pop('by_epoch')
|
by_epoch = loop_cfg.pop('by_epoch')
|
||||||
if by_epoch:
|
if by_epoch:
|
||||||
loop = EpochBasedTrainLoop(
|
loop = EpochBasedTrainLoop(
|
||||||
**loop_cfg, runner=self, dataloader=self.train_dataloader)
|
**loop_cfg, runner=self, dataloader=self._train_dataloader)
|
||||||
else:
|
else:
|
||||||
loop = IterBasedTrainLoop(
|
loop = IterBasedTrainLoop(
|
||||||
**loop_cfg, runner=self, dataloader=self.train_dataloader)
|
**loop_cfg, runner=self, dataloader=self._train_dataloader)
|
||||||
|
|
||||||
# `build_optimizer` should be called before `build_param_scheduler`
|
# `build_optimizer` should be called before `build_param_scheduler`
|
||||||
# because the latter depends on the former
|
# because the latter depends on the former
|
||||||
@ -1095,13 +1144,13 @@ class Runner:
|
|||||||
loop_cfg,
|
loop_cfg,
|
||||||
default_args=dict(
|
default_args=dict(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self.val_dataloader,
|
dataloader=self._val_dataloader,
|
||||||
evaluator=self.val_evaluator))
|
evaluator=self._val_evaluator))
|
||||||
else:
|
else:
|
||||||
loop = ValLoop(
|
loop = ValLoop(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self.val_dataloader,
|
dataloader=self._val_dataloader,
|
||||||
evaluator=self.val_evaluator, # type: ignore
|
evaluator=self._val_evaluator, # type: ignore
|
||||||
**loop_cfg,
|
**loop_cfg,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
@ -1122,9 +1171,6 @@ class Runner:
|
|||||||
loop (BaseLoop or dict): A test loop or a dict to build test loop.
|
loop (BaseLoop or dict): A test loop or a dict to build test loop.
|
||||||
If ``loop`` is a test loop object, just returns itself.
|
If ``loop`` is a test loop object, just returns itself.
|
||||||
|
|
||||||
Args:
|
|
||||||
loop_cfg (dict): Config to build test loop.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`BaseLoop`: Test loop object build from ``loop_cfg``.
|
:obj:`BaseLoop`: Test loop object build from ``loop_cfg``.
|
||||||
"""
|
"""
|
||||||
@ -1141,13 +1187,13 @@ class Runner:
|
|||||||
loop_cfg,
|
loop_cfg,
|
||||||
default_args=dict(
|
default_args=dict(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self.test_dataloader,
|
dataloader=self._test_dataloader,
|
||||||
evaluator=self.test_evaluator))
|
evaluator=self._test_evaluator))
|
||||||
else:
|
else:
|
||||||
loop = TestLoop(
|
loop = TestLoop(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self.test_dataloader,
|
dataloader=self._test_dataloader,
|
||||||
evaluator=self.test_evaluator) # type: ignore
|
evaluator=self._test_evaluator) # type: ignore
|
||||||
|
|
||||||
return loop # type: ignore
|
return loop # type: ignore
|
||||||
|
|
||||||
@ -1176,18 +1222,19 @@ class Runner:
|
|||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
"""Launch training."""
|
"""Launch training."""
|
||||||
if self.train_loop is None:
|
if self._train_loop is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'`self.train_loop` should not be None when calling train '
|
'`self._train_loop` should not be None when calling train '
|
||||||
'method. Please provide `train_dataloader`, `train_cfg`, '
|
'method. Please provide `train_dataloader`, `train_cfg`, '
|
||||||
'`optimizer` and `param_scheduler` arguments when '
|
'`optimizer` and `param_scheduler` arguments when '
|
||||||
'initializing runner.')
|
'initializing runner.')
|
||||||
|
|
||||||
self.train_loop = self.build_train_loop(
|
self._train_loop = self.build_train_loop(
|
||||||
self.train_loop) # type: ignore
|
self._train_loop) # type: ignore
|
||||||
|
|
||||||
if self.val_loop is not None:
|
if self._val_loop is not None:
|
||||||
self.val_loop = self.build_val_loop(self.val_loop) # type: ignore
|
self._val_loop = self.build_val_loop(
|
||||||
|
self._val_loop) # type: ignore
|
||||||
|
|
||||||
self.load_or_resume()
|
self.load_or_resume()
|
||||||
|
|
||||||
@ -1198,13 +1245,13 @@ class Runner:
|
|||||||
|
|
||||||
def val(self) -> None:
|
def val(self) -> None:
|
||||||
"""Launch validation."""
|
"""Launch validation."""
|
||||||
if self.val_loop is None:
|
if self._val_loop is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'`self.val_loop` should not be None when calling val method.'
|
'`self._val_loop` should not be None when calling val method.'
|
||||||
'Please provide `val_dataloader`, `val_cfg` and '
|
'Please provide `val_dataloader`, `val_cfg` and '
|
||||||
'`val_evaluator` arguments when initializing runner.')
|
'`val_evaluator` arguments when initializing runner.')
|
||||||
|
|
||||||
self.val_loop = self.build_val_loop(self.val_loop) # type: ignore
|
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
|
||||||
|
|
||||||
self.load_or_resume()
|
self.load_or_resume()
|
||||||
|
|
||||||
@ -1214,13 +1261,13 @@ class Runner:
|
|||||||
|
|
||||||
def test(self) -> None:
|
def test(self) -> None:
|
||||||
"""Launch test."""
|
"""Launch test."""
|
||||||
if self.test_loop is None:
|
if self._test_loop is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'`self.test_loop` should not be None when calling test method.'
|
'`self._test_loop` should not be None when calling test '
|
||||||
'Please provide `test_dataloader`, `test_cfg` and '
|
'method. Please provide `test_dataloader`, `test_cfg` and '
|
||||||
'`test_evaluator` arguments when initializing runner.')
|
'`test_evaluator` arguments when initializing runner.')
|
||||||
|
|
||||||
self.test_loop = self.build_test_loop(self.test_loop) # type: ignore
|
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
|
||||||
|
|
||||||
self.load_or_resume()
|
self.load_or_resume()
|
||||||
|
|
||||||
@ -1424,8 +1471,8 @@ class Runner:
|
|||||||
checkpoint = self.load_checkpoint(
|
checkpoint = self.load_checkpoint(
|
||||||
filename, map_location=map_location)
|
filename, map_location=map_location)
|
||||||
|
|
||||||
self._epoch = checkpoint['meta']['epoch']
|
self.train_loop._epoch = checkpoint['meta']['epoch']
|
||||||
self._iter = checkpoint['meta']['iter']
|
self.train_loop._iter = checkpoint['meta']['iter']
|
||||||
|
|
||||||
if self.meta is None:
|
if self.meta is None:
|
||||||
self.meta = {}
|
self.meta = {}
|
||||||
@ -1466,7 +1513,7 @@ class Runner:
|
|||||||
|
|
||||||
self._has_loaded = True
|
self._has_loaded = True
|
||||||
|
|
||||||
self.logger.info(f'resumed epoch: {self._epoch}, iter: {self._iter}')
|
self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}')
|
||||||
|
|
||||||
def load_checkpoint(self,
|
def load_checkpoint(self,
|
||||||
filename: str,
|
filename: str,
|
||||||
@ -1544,13 +1591,13 @@ class Runner:
|
|||||||
meta.update(self.meta)
|
meta.update(self.meta)
|
||||||
|
|
||||||
if by_epoch:
|
if by_epoch:
|
||||||
# self._epoch increments 1 after
|
# self.epoch increments 1 after
|
||||||
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
||||||
# called by `after_train_epoch`` method of `CheckpointHook` so
|
# called by `after_train_epoch`` method of `CheckpointHook` so
|
||||||
# `epoch` should be `self_epoch + 1`
|
# `epoch` should be `self.epoch + 1`
|
||||||
meta.update(epoch=self._epoch + 1, iter=self._iter)
|
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
||||||
else:
|
else:
|
||||||
meta.update(epoch=self._epoch, iter=self._iter + 1)
|
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
||||||
|
|
||||||
filepath = osp.join(out_dir, filename)
|
filepath = osp.join(out_dir, filename)
|
||||||
|
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
|
|
||||||
|
|
||||||
@ -176,33 +174,21 @@ class TestHook:
|
|||||||
|
|
||||||
# last epoch
|
# last epoch
|
||||||
runner.epoch = 1
|
runner.epoch = 1
|
||||||
runner.train_loop.max_epochs = 2
|
runner.max_epochs = 2
|
||||||
return_val = hook.is_last_train_epoch(runner)
|
return_val = hook.is_last_train_epoch(runner)
|
||||||
assert return_val
|
assert return_val
|
||||||
|
|
||||||
# not the last epoch
|
# not the last epoch
|
||||||
runner.train_loop.max_epochs = 0
|
runner.max_epochs = 0
|
||||||
return_val = hook.is_last_train_epoch(runner)
|
return_val = hook.is_last_train_epoch(runner)
|
||||||
assert not return_val
|
assert not return_val
|
||||||
|
|
||||||
def test_is_last_iter(self):
|
def test_is_last_train_iter(self):
|
||||||
hook = Hook()
|
hook = Hook()
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
|
|
||||||
# last iter
|
# last iter
|
||||||
runner.iter = 1
|
runner.iter = 1
|
||||||
runner.train_loop.max_iters = 2
|
runner.max_iters = 2
|
||||||
return_val = hook.is_last_iter(runner)
|
return_val = hook.is_last_train_iter(runner)
|
||||||
assert return_val
|
assert return_val
|
||||||
|
|
||||||
# not the last iter
|
|
||||||
runner.val_loop.max_iters = 0
|
|
||||||
return_val = hook.is_last_iter(runner, mode='val')
|
|
||||||
assert not return_val
|
|
||||||
|
|
||||||
runner.test_loop.max_iters = 0
|
|
||||||
return_val = hook.is_last_iter(runner, mode='test')
|
|
||||||
assert not return_val
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
hook.is_last_iter(runner, mode='error_mode')
|
|
||||||
|
@ -49,10 +49,10 @@ class TestIterTimerHook(TestCase):
|
|||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.log_buffer = dict()
|
runner.log_buffer = dict()
|
||||||
runner.log_processor.window_size = 10
|
runner.log_processor.window_size = 10
|
||||||
runner.train_loop.max_iters = 100
|
runner.max_iters = 100
|
||||||
runner.iter = 0
|
runner.iter = 0
|
||||||
runner.test_loop.dataloader = [0] * 20
|
runner.test_dataloader = [0] * 20
|
||||||
runner.val_loop.dataloader = [0] * 20
|
runner.val_dataloader = [0] * 20
|
||||||
self.hook._before_epoch(runner)
|
self.hook._before_epoch(runner)
|
||||||
self.hook.before_run(runner)
|
self.hook.before_run(runner)
|
||||||
self.hook._after_iter(runner, batch_idx=1)
|
self.hook._after_iter(runner, batch_idx=1)
|
||||||
|
@ -95,7 +95,7 @@ class TestLoggerHook:
|
|||||||
runner.log_processor.get_log_after_iter = MagicMock(
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
return_value=(dict(), 'log_str'))
|
return_value=(dict(), 'log_str'))
|
||||||
logger_hook = LoggerHook(ignore_last=False)
|
logger_hook = LoggerHook(ignore_last=False)
|
||||||
runner.train_loop.dataloader = [0] * 5
|
runner.train_dataloader = [0] * 5
|
||||||
logger_hook.after_train_iter(runner, batch_idx=4)
|
logger_hook.after_train_iter(runner, batch_idx=4)
|
||||||
runner.log_processor.get_log_after_iter.assert_called()
|
runner.log_processor.get_log_after_iter.assert_called()
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class TestLogProcessor:
|
|||||||
assert out == log_str
|
assert out == log_str
|
||||||
else:
|
else:
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
max_iters = self.runner.train_loop.max_iters
|
max_iters = self.runner.max_iters
|
||||||
log_str = f'Iter({mode}) [11/{max_iters}] '
|
log_str = f'Iter({mode}) [11/{max_iters}] '
|
||||||
else:
|
else:
|
||||||
max_iters = len(cur_loop.dataloader)
|
max_iters = len(cur_loop.dataloader)
|
||||||
@ -227,7 +227,10 @@ class TestLogProcessor:
|
|||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.epoch = 1
|
runner.epoch = 1
|
||||||
runner.iter = 10
|
runner.iter = 10
|
||||||
runner.train_loop.max_iters = 50
|
runner.max_iters = 50
|
||||||
|
runner.train_dataloader = [0] * 20
|
||||||
|
runner.val_dataloader = [0] * 10
|
||||||
|
runner.test_dataloader = [0] * 5
|
||||||
runner.train_loop.dataloader = [0] * 20
|
runner.train_loop.dataloader = [0] * 20
|
||||||
runner.val_loop.dataloader = [0] * 10
|
runner.val_loop.dataloader = [0] * 10
|
||||||
runner.test_loop.dataloader = [0] * 5
|
runner.test_loop.dataloader = [0] * 5
|
||||||
|
@ -341,9 +341,9 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(runner.model_name, 'ToyModel')
|
self.assertEqual(runner.model_name, 'ToyModel')
|
||||||
|
|
||||||
# 3. test lazy initialization
|
# 3. test lazy initialization
|
||||||
self.assertIsInstance(runner.train_dataloader, dict)
|
self.assertIsInstance(runner._train_dataloader, dict)
|
||||||
self.assertIsInstance(runner.val_dataloader, dict)
|
self.assertIsInstance(runner._val_dataloader, dict)
|
||||||
self.assertIsInstance(runner.test_dataloader, dict)
|
self.assertIsInstance(runner._test_dataloader, dict)
|
||||||
self.assertIsInstance(runner.optimizer, dict)
|
self.assertIsInstance(runner.optimizer, dict)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], dict)
|
self.assertIsInstance(runner.param_schedulers[0], dict)
|
||||||
|
|
||||||
@ -352,20 +352,20 @@ class TestRunner(TestCase):
|
|||||||
# test_dataloader should also be dict
|
# test_dataloader should also be dict
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
self.assertIsInstance(runner.train_loop, BaseLoop)
|
self.assertIsInstance(runner._train_loop, BaseLoop)
|
||||||
self.assertIsInstance(runner.train_loop.dataloader, DataLoader)
|
self.assertIsInstance(runner.train_dataloader, DataLoader)
|
||||||
self.assertIsInstance(runner.optimizer, SGD)
|
self.assertIsInstance(runner.optimizer, SGD)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
self.assertIsInstance(runner.val_loop, BaseLoop)
|
self.assertIsInstance(runner._val_loop, BaseLoop)
|
||||||
self.assertIsInstance(runner.val_loop.dataloader, DataLoader)
|
self.assertIsInstance(runner._val_loop.dataloader, DataLoader)
|
||||||
self.assertIsInstance(runner.val_loop.evaluator, Evaluator)
|
self.assertIsInstance(runner._val_loop.evaluator, Evaluator)
|
||||||
|
|
||||||
# After calling runner.test(), test_dataloader should be initialized
|
# After calling runner.test(), test_dataloader should be initialized
|
||||||
self.assertIsInstance(runner.test_loop, dict)
|
self.assertIsInstance(runner._test_loop, dict)
|
||||||
runner.test()
|
runner.test()
|
||||||
self.assertIsInstance(runner.test_loop, BaseLoop)
|
self.assertIsInstance(runner._test_loop, BaseLoop)
|
||||||
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
|
self.assertIsInstance(runner._test_loop.dataloader, DataLoader)
|
||||||
self.assertIsInstance(runner.test_loop.evaluator, Evaluator)
|
self.assertIsInstance(runner._test_loop.evaluator, Evaluator)
|
||||||
|
|
||||||
# 4. initialize runner with objects rather than config
|
# 4. initialize runner with objects rather than config
|
||||||
model = ToyModel()
|
model = ToyModel()
|
||||||
@ -637,7 +637,7 @@ class TestRunner(TestCase):
|
|||||||
begin=1,
|
begin=1,
|
||||||
end=7,
|
end=7,
|
||||||
convert_to_iter_based=True)
|
convert_to_iter_based=True)
|
||||||
runner.train_loop = runner.build_train_loop(runner.train_loop)
|
runner._train_loop = runner.build_train_loop(runner.train_loop)
|
||||||
param_schedulers = runner.build_param_scheduler(cfg)
|
param_schedulers = runner.build_param_scheduler(cfg)
|
||||||
self.assertFalse(param_schedulers[0].by_epoch)
|
self.assertFalse(param_schedulers[0].by_epoch)
|
||||||
self.assertEqual(param_schedulers[0].begin, 4)
|
self.assertEqual(param_schedulers[0].begin, 4)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user