mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Make loop-related attributes to be runner's properties. (#236)
* [Enhance] Make loop related attributes to be runner's properties. * move iter and epoch to loop * resolve comments
This commit is contained in:
parent
cc8a6b86e1
commit
e37f1f905b
@ -194,7 +194,8 @@ class CheckpointHook(Hook):
|
||||
# 1. every ``self.interval`` iterations
|
||||
# 2. reach the last iteration of training
|
||||
if self.every_n_iters(runner, self.interval) or \
|
||||
(self.save_last and self.is_last_iter(runner, mode='train')):
|
||||
(self.save_last and
|
||||
self.is_last_train_iter(runner)):
|
||||
runner.logger.info(
|
||||
f'Saving checkpoint at {runner.iter + 1} iterations')
|
||||
self._save_checkpoint(runner)
|
||||
|
@ -409,25 +409,15 @@ class Hook:
|
||||
Returns:
|
||||
bool: Whether reaches the end of training epoch.
|
||||
"""
|
||||
return runner.epoch + 1 == runner.train_loop.max_epochs
|
||||
return runner.epoch + 1 == runner.max_epochs
|
||||
|
||||
def is_last_iter(self, runner, mode='train') -> bool:
|
||||
"""Test whether current iteration is the last iteration.
|
||||
def is_last_train_iter(self, runner) -> bool:
|
||||
"""Test whether current iteration is the last train iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
runner (Runner): The runner of the training process.
|
||||
|
||||
Returns:
|
||||
bool: Whether current iteration is the last iteration.
|
||||
bool: Whether current iteration is the last train iteration.
|
||||
"""
|
||||
if mode == 'train':
|
||||
return runner.iter + 1 == runner.train_loop.max_iters
|
||||
elif mode == 'val':
|
||||
return runner.iter + 1 == runner.val_loop.max_iters
|
||||
elif mode == 'test':
|
||||
return runner.iter + 1 == runner.test_loop.max_iters
|
||||
else:
|
||||
raise ValueError('mode should be train, val or test but got'
|
||||
f'{mode}')
|
||||
return runner.iter + 1 == runner.max_iters
|
||||
|
@ -98,14 +98,13 @@ class IterTimerHook(Hook):
|
||||
time_sec_avg = self.time_sec_tot / (
|
||||
runner.iter - self.start_iter + 1)
|
||||
# Calculate eta.
|
||||
eta_sec = time_sec_avg * (
|
||||
runner.train_loop.max_iters - runner.iter - 1)
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
else:
|
||||
if mode == 'val':
|
||||
cur_dataloader = runner.val_loop.dataloader
|
||||
cur_dataloader = runner.val_dataloader
|
||||
else:
|
||||
cur_dataloader = runner.test_loop.dataloader
|
||||
cur_dataloader = runner.test_dataloader
|
||||
|
||||
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
|
@ -127,13 +127,13 @@ class LoggerHook(Hook):
|
||||
# Print experiment name every n iterations.
|
||||
if self.every_n_iters(runner,
|
||||
self.interval_exp_name) or (self.end_of_epoch(
|
||||
runner.train_loop.dataloader, batch_idx)):
|
||||
runner.train_dataloader, batch_idx)):
|
||||
exp_info = f'Exp name: {runner.experiment_name}'
|
||||
runner.logger.info(exp_info)
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||
runner, batch_idx, 'train')
|
||||
elif (self.end_of_epoch(runner.train_loop.dataloader, batch_idx)
|
||||
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
|
||||
and not self.ignore_last):
|
||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||
# `self.ignore_last==True`, the log of remaining iterations will
|
||||
|
@ -142,7 +142,7 @@ class LogProcessor:
|
||||
else:
|
||||
if mode == 'train':
|
||||
log_str = (f'Iter({mode}) '
|
||||
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
|
||||
f'[{cur_iter}/{runner.max_iters}] ')
|
||||
else:
|
||||
log_str = (f'Iter({mode}) [{batch_idx+1}'
|
||||
f'/{len(current_loop.dataloader)}] ')
|
||||
|
@ -19,7 +19,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
max_epoch (int): Total training epochs.
|
||||
max_epochs (int): Total training epochs.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
@ -27,6 +27,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
super().__init__(runner, dataloader)
|
||||
self._max_epochs = max_epochs
|
||||
self._max_iters = max_epochs * len(self.dataloader)
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
@ -46,15 +48,25 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
"""int: Total iterations to train model."""
|
||||
return self._max_iters
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
"""int: Current epoch."""
|
||||
return self._epoch
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
"""int: Current iteration."""
|
||||
return self._iter
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
self.runner.call_hook('before_train')
|
||||
|
||||
while self.runner._epoch < self._max_epochs:
|
||||
while self._epoch < self._max_epochs:
|
||||
self.run_epoch()
|
||||
|
||||
if (self.runner.val_loop is not None and
|
||||
self.runner._epoch % self.runner.val_loop.interval == 0):
|
||||
if (self.runner.val_loop is not None
|
||||
and self._epoch % self.runner.val_loop.interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train')
|
||||
@ -67,7 +79,9 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
self.run_iter(idx, data_batch)
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
self.runner.epoch += 1
|
||||
self._epoch += 1
|
||||
# To allow components that cannot access runner to get current epoch.
|
||||
self.runner.message_hub.update_info('epoch', self._epoch)
|
||||
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one min-batch.
|
||||
@ -90,7 +104,10 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
data_batch=data_batch,
|
||||
outputs=self.runner.outputs)
|
||||
|
||||
self.runner.iter += 1
|
||||
self._iter += 1
|
||||
# To allow components that cannot access runner to get current
|
||||
# iteration.
|
||||
self.runner.message_hub.update_info('iter', self._iter)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
@ -101,13 +118,16 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
max_iter (int): Total training iterations.
|
||||
max_iters (int): Total training iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
max_iters: int) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
self._max_iters = max_iters
|
||||
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
@ -118,25 +138,40 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
'None.')
|
||||
self.dataloader = iter(self.dataloader)
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
"""int: Total epochs to train model."""
|
||||
return self._max_epochs
|
||||
|
||||
@property
|
||||
def max_iters(self):
|
||||
"""int: Total iterations to train model."""
|
||||
return self._max_iters
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
"""int: Current epoch."""
|
||||
return self._epoch
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
"""int: Current iteration."""
|
||||
return self._iter
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
self.runner.call_hook('before_train')
|
||||
# In iteration-based training loop, we treat the whole training process
|
||||
# as a big epoch and execute the corresponding hook.
|
||||
self.runner.call_hook('before_train_epoch')
|
||||
while self.runner._iter < self._max_iters:
|
||||
while self._iter < self._max_iters:
|
||||
self.runner.model.train()
|
||||
|
||||
data_batch = next(self.dataloader)
|
||||
self.run_iter(data_batch)
|
||||
|
||||
if (self.runner.val_loop is not None and
|
||||
self.runner._iter % self.runner.val_loop.interval == 0):
|
||||
if (self.runner.val_loop is not None
|
||||
and self._iter % self.runner.val_interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
@ -149,9 +184,7 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
data_batch (Sequence[dict]): Batch of data from dataloader.
|
||||
"""
|
||||
self.runner.call_hook(
|
||||
'before_train_iter',
|
||||
batch_idx=self.runner._iter,
|
||||
data_batch=data_batch)
|
||||
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
|
||||
# outputs should be a dict containing loss tensor
|
||||
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
|
||||
|
||||
@ -161,10 +194,13 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
batch_idx=self.runner._iter,
|
||||
batch_idx=self._iter,
|
||||
data_batch=data_batch,
|
||||
outputs=self.runner.outputs)
|
||||
self.runner.iter += 1
|
||||
self._iter += 1
|
||||
# To allow components that cannot access runner to get current
|
||||
# iteration.
|
||||
self.runner.message_hub.update_info('iter', self._iter)
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -199,9 +199,9 @@ class Runner:
|
||||
>>> runner.test()
|
||||
"""
|
||||
cfg: Config
|
||||
train_loop: Optional[Union[BaseLoop, Dict]]
|
||||
val_loop: Optional[Union[BaseLoop, Dict]]
|
||||
test_loop: Optional[Union[BaseLoop, Dict]]
|
||||
_train_loop: Optional[Union[BaseLoop, Dict]]
|
||||
_val_loop: Optional[Union[BaseLoop, Dict]]
|
||||
_test_loop: Optional[Union[BaseLoop, Dict]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -244,9 +244,6 @@ class Runner:
|
||||
else:
|
||||
self.cfg = Config(dict())
|
||||
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
|
||||
# lazy initialization
|
||||
training_related = [train_dataloader, train_cfg, optimizer]
|
||||
if not (all(item is None for item in training_related)
|
||||
@ -257,8 +254,8 @@ class Runner:
|
||||
f'train_dataloader={train_dataloader}, '
|
||||
f'train_cfg={train_cfg}, '
|
||||
f'optimizer={optimizer}.')
|
||||
self.train_dataloader = train_dataloader
|
||||
self.train_loop = train_cfg
|
||||
self._train_dataloader = train_dataloader
|
||||
self._train_loop = train_cfg
|
||||
self.optimizer = optimizer
|
||||
|
||||
# If there is no need to adjust learning rate, momentum or other
|
||||
@ -284,9 +281,9 @@ class Runner:
|
||||
'all None or not None, but got '
|
||||
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
|
||||
f'val_evaluator={val_evaluator}')
|
||||
self.val_dataloader = val_dataloader
|
||||
self.val_loop = val_cfg
|
||||
self.val_evaluator = val_evaluator
|
||||
self._val_dataloader = val_dataloader
|
||||
self._val_loop = val_cfg
|
||||
self._val_evaluator = val_evaluator
|
||||
|
||||
test_related = [test_dataloader, test_cfg, test_evaluator]
|
||||
if not (all(item is None for item in test_related)
|
||||
@ -296,9 +293,9 @@ class Runner:
|
||||
'either all None or not None, but got '
|
||||
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
|
||||
f'test_evaluator={test_evaluator}')
|
||||
self.test_dataloader = test_dataloader
|
||||
self.test_loop = test_cfg
|
||||
self.test_evaluator = test_evaluator
|
||||
self._test_dataloader = test_dataloader
|
||||
self._test_loop = test_cfg
|
||||
self._test_evaluator = test_evaluator
|
||||
|
||||
self._launcher = launcher
|
||||
if self._launcher == 'none':
|
||||
@ -435,30 +432,25 @@ class Runner:
|
||||
"""str: The working directory to save checkpoints and logs."""
|
||||
return self._work_dir
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
"""int: Total epochs to train model."""
|
||||
return self.train_loop.max_epochs
|
||||
|
||||
@property
|
||||
def max_iters(self):
|
||||
"""int: Total iterations to train model."""
|
||||
return self.train_loop.max_iters
|
||||
|
||||
@property
|
||||
def epoch(self):
|
||||
"""int: Current epoch."""
|
||||
return self._epoch
|
||||
|
||||
@epoch.setter
|
||||
def epoch(self, epoch: int):
|
||||
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
|
||||
self._epoch = epoch
|
||||
# To allow components that cannot access runner to get current epoch.
|
||||
self.message_hub.update_info('epoch', epoch)
|
||||
return self.train_loop.epoch
|
||||
|
||||
@property
|
||||
def iter(self):
|
||||
"""int: Current iteration."""
|
||||
return self._iter
|
||||
|
||||
@iter.setter
|
||||
def iter(self, iter: int):
|
||||
"""Update iter and synchronize iter in :attr:`message_hub`."""
|
||||
self._iter = iter
|
||||
# To allow components that cannot access runner to get current
|
||||
# iteration.
|
||||
self.message_hub.update_info('iter', iter)
|
||||
return self.train_loop.iter
|
||||
|
||||
@property
|
||||
def launcher(self):
|
||||
@ -500,6 +492,63 @@ class Runner:
|
||||
"""list[:obj:`Hook`]: A list of registered hooks."""
|
||||
return self._hooks
|
||||
|
||||
@property
|
||||
def train_loop(self):
|
||||
""":obj:`BaseLoop`: A loop to run training."""
|
||||
if isinstance(self._train_loop, BaseLoop) or self._train_loop is None:
|
||||
return self._train_loop
|
||||
else:
|
||||
self._train_loop = self.build_train_loop(self._train_loop)
|
||||
return self._train_loop
|
||||
|
||||
@property
|
||||
def val_loop(self):
|
||||
""":obj:`BaseLoop`: A loop to run validation."""
|
||||
if isinstance(self._val_loop, BaseLoop) or self._val_loop is None:
|
||||
return self._val_loop
|
||||
else:
|
||||
self._val_loop = self.build_val_loop(self._val_loop)
|
||||
return self._val_loop
|
||||
|
||||
@property
|
||||
def test_loop(self):
|
||||
""":obj:`BaseLoop`: A loop to run testing."""
|
||||
if isinstance(self._test_loop, BaseLoop) or self._test_loop is None:
|
||||
return self._test_loop
|
||||
else:
|
||||
self._test_loop = self.build_test_loop(self._test_loop)
|
||||
return self._test_loop
|
||||
|
||||
@property
|
||||
def train_dataloader(self):
|
||||
"""The data loader for training."""
|
||||
return self.train_loop.dataloader
|
||||
|
||||
@property
|
||||
def val_dataloader(self):
|
||||
"""The data loader for validation."""
|
||||
return self.val_loop.dataloader
|
||||
|
||||
@property
|
||||
def test_dataloader(self):
|
||||
"""The data loader for testing."""
|
||||
return self.test_loop.dataloader
|
||||
|
||||
@property
|
||||
def val_evaluator(self):
|
||||
""":obj:`Evaluator`: An evaluator for validation."""
|
||||
return self.val_loop.evaluator
|
||||
|
||||
@property
|
||||
def test_evaluator(self):
|
||||
""":obj:`Evaluator`: An evaluator for testing."""
|
||||
return self.test_loop.evaluator
|
||||
|
||||
@property
|
||||
def val_interval(self):
|
||||
"""int: Interval to run validation during training."""
|
||||
return self.val_loop.interval
|
||||
|
||||
def setup_env(self, env_cfg: Dict) -> None:
|
||||
"""Setup environment.
|
||||
|
||||
@ -857,7 +906,7 @@ class Runner:
|
||||
'by_epoch', True
|
||||
), 'only epoch-based parameter scheduler can be ' \
|
||||
'converted to iter-based'
|
||||
assert isinstance(self.train_loop, BaseLoop), \
|
||||
assert isinstance(self._train_loop, BaseLoop), \
|
||||
'Scheduler can only be converted to iter-based ' \
|
||||
'when train loop is built.'
|
||||
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
|
||||
@ -866,7 +915,7 @@ class Runner:
|
||||
optimizer=self.optimizer,
|
||||
**_scheduler,
|
||||
epoch_length=len(
|
||||
self.train_loop.dataloader), # type: ignore
|
||||
self.train_dataloader), # type: ignore
|
||||
))
|
||||
else:
|
||||
param_schedulers.append(
|
||||
@ -1043,15 +1092,15 @@ class Runner:
|
||||
loop = LOOPS.build(
|
||||
loop_cfg,
|
||||
default_args=dict(
|
||||
runner=self, dataloader=self.train_dataloader))
|
||||
runner=self, dataloader=self._train_dataloader))
|
||||
else:
|
||||
by_epoch = loop_cfg.pop('by_epoch')
|
||||
if by_epoch:
|
||||
loop = EpochBasedTrainLoop(
|
||||
**loop_cfg, runner=self, dataloader=self.train_dataloader)
|
||||
**loop_cfg, runner=self, dataloader=self._train_dataloader)
|
||||
else:
|
||||
loop = IterBasedTrainLoop(
|
||||
**loop_cfg, runner=self, dataloader=self.train_dataloader)
|
||||
**loop_cfg, runner=self, dataloader=self._train_dataloader)
|
||||
|
||||
# `build_optimizer` should be called before `build_param_scheduler`
|
||||
# because the latter depends on the former
|
||||
@ -1095,13 +1144,13 @@ class Runner:
|
||||
loop_cfg,
|
||||
default_args=dict(
|
||||
runner=self,
|
||||
dataloader=self.val_dataloader,
|
||||
evaluator=self.val_evaluator))
|
||||
dataloader=self._val_dataloader,
|
||||
evaluator=self._val_evaluator))
|
||||
else:
|
||||
loop = ValLoop(
|
||||
runner=self,
|
||||
dataloader=self.val_dataloader,
|
||||
evaluator=self.val_evaluator, # type: ignore
|
||||
dataloader=self._val_dataloader,
|
||||
evaluator=self._val_evaluator, # type: ignore
|
||||
**loop_cfg,
|
||||
) # type: ignore
|
||||
|
||||
@ -1122,9 +1171,6 @@ class Runner:
|
||||
loop (BaseLoop or dict): A test loop or a dict to build test loop.
|
||||
If ``loop`` is a test loop object, just returns itself.
|
||||
|
||||
Args:
|
||||
loop_cfg (dict): Config to build test loop.
|
||||
|
||||
Returns:
|
||||
:obj:`BaseLoop`: Test loop object build from ``loop_cfg``.
|
||||
"""
|
||||
@ -1141,13 +1187,13 @@ class Runner:
|
||||
loop_cfg,
|
||||
default_args=dict(
|
||||
runner=self,
|
||||
dataloader=self.test_dataloader,
|
||||
evaluator=self.test_evaluator))
|
||||
dataloader=self._test_dataloader,
|
||||
evaluator=self._test_evaluator))
|
||||
else:
|
||||
loop = TestLoop(
|
||||
runner=self,
|
||||
dataloader=self.test_dataloader,
|
||||
evaluator=self.test_evaluator) # type: ignore
|
||||
dataloader=self._test_dataloader,
|
||||
evaluator=self._test_evaluator) # type: ignore
|
||||
|
||||
return loop # type: ignore
|
||||
|
||||
@ -1176,18 +1222,19 @@ class Runner:
|
||||
|
||||
def train(self) -> None:
|
||||
"""Launch training."""
|
||||
if self.train_loop is None:
|
||||
if self._train_loop is None:
|
||||
raise RuntimeError(
|
||||
'`self.train_loop` should not be None when calling train '
|
||||
'`self._train_loop` should not be None when calling train '
|
||||
'method. Please provide `train_dataloader`, `train_cfg`, '
|
||||
'`optimizer` and `param_scheduler` arguments when '
|
||||
'initializing runner.')
|
||||
|
||||
self.train_loop = self.build_train_loop(
|
||||
self.train_loop) # type: ignore
|
||||
self._train_loop = self.build_train_loop(
|
||||
self._train_loop) # type: ignore
|
||||
|
||||
if self.val_loop is not None:
|
||||
self.val_loop = self.build_val_loop(self.val_loop) # type: ignore
|
||||
if self._val_loop is not None:
|
||||
self._val_loop = self.build_val_loop(
|
||||
self._val_loop) # type: ignore
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
@ -1198,13 +1245,13 @@ class Runner:
|
||||
|
||||
def val(self) -> None:
|
||||
"""Launch validation."""
|
||||
if self.val_loop is None:
|
||||
if self._val_loop is None:
|
||||
raise RuntimeError(
|
||||
'`self.val_loop` should not be None when calling val method.'
|
||||
'`self._val_loop` should not be None when calling val method.'
|
||||
'Please provide `val_dataloader`, `val_cfg` and '
|
||||
'`val_evaluator` arguments when initializing runner.')
|
||||
|
||||
self.val_loop = self.build_val_loop(self.val_loop) # type: ignore
|
||||
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
@ -1214,13 +1261,13 @@ class Runner:
|
||||
|
||||
def test(self) -> None:
|
||||
"""Launch test."""
|
||||
if self.test_loop is None:
|
||||
if self._test_loop is None:
|
||||
raise RuntimeError(
|
||||
'`self.test_loop` should not be None when calling test method.'
|
||||
'Please provide `test_dataloader`, `test_cfg` and '
|
||||
'`self._test_loop` should not be None when calling test '
|
||||
'method. Please provide `test_dataloader`, `test_cfg` and '
|
||||
'`test_evaluator` arguments when initializing runner.')
|
||||
|
||||
self.test_loop = self.build_test_loop(self.test_loop) # type: ignore
|
||||
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
|
||||
|
||||
self.load_or_resume()
|
||||
|
||||
@ -1424,8 +1471,8 @@ class Runner:
|
||||
checkpoint = self.load_checkpoint(
|
||||
filename, map_location=map_location)
|
||||
|
||||
self._epoch = checkpoint['meta']['epoch']
|
||||
self._iter = checkpoint['meta']['iter']
|
||||
self.train_loop._epoch = checkpoint['meta']['epoch']
|
||||
self.train_loop._iter = checkpoint['meta']['iter']
|
||||
|
||||
if self.meta is None:
|
||||
self.meta = {}
|
||||
@ -1466,7 +1513,7 @@ class Runner:
|
||||
|
||||
self._has_loaded = True
|
||||
|
||||
self.logger.info(f'resumed epoch: {self._epoch}, iter: {self._iter}')
|
||||
self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}')
|
||||
|
||||
def load_checkpoint(self,
|
||||
filename: str,
|
||||
@ -1544,13 +1591,13 @@ class Runner:
|
||||
meta.update(self.meta)
|
||||
|
||||
if by_epoch:
|
||||
# self._epoch increments 1 after
|
||||
# self.epoch increments 1 after
|
||||
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
|
||||
# called by `after_train_epoch`` method of `CheckpointHook` so
|
||||
# `epoch` should be `self_epoch + 1`
|
||||
meta.update(epoch=self._epoch + 1, iter=self._iter)
|
||||
# `epoch` should be `self.epoch + 1`
|
||||
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
||||
else:
|
||||
meta.update(epoch=self._epoch, iter=self._iter + 1)
|
||||
meta.update(epoch=self.epoch, iter=self.iter + 1)
|
||||
|
||||
filepath = osp.join(out_dir, filename)
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
|
||||
@ -176,33 +174,21 @@ class TestHook:
|
||||
|
||||
# last epoch
|
||||
runner.epoch = 1
|
||||
runner.train_loop.max_epochs = 2
|
||||
runner.max_epochs = 2
|
||||
return_val = hook.is_last_train_epoch(runner)
|
||||
assert return_val
|
||||
|
||||
# not the last epoch
|
||||
runner.train_loop.max_epochs = 0
|
||||
runner.max_epochs = 0
|
||||
return_val = hook.is_last_train_epoch(runner)
|
||||
assert not return_val
|
||||
|
||||
def test_is_last_iter(self):
|
||||
def test_is_last_train_iter(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
|
||||
# last iter
|
||||
runner.iter = 1
|
||||
runner.train_loop.max_iters = 2
|
||||
return_val = hook.is_last_iter(runner)
|
||||
runner.max_iters = 2
|
||||
return_val = hook.is_last_train_iter(runner)
|
||||
assert return_val
|
||||
|
||||
# not the last iter
|
||||
runner.val_loop.max_iters = 0
|
||||
return_val = hook.is_last_iter(runner, mode='val')
|
||||
assert not return_val
|
||||
|
||||
runner.test_loop.max_iters = 0
|
||||
return_val = hook.is_last_iter(runner, mode='test')
|
||||
assert not return_val
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
hook.is_last_iter(runner, mode='error_mode')
|
||||
|
@ -49,10 +49,10 @@ class TestIterTimerHook(TestCase):
|
||||
runner = MagicMock()
|
||||
runner.log_buffer = dict()
|
||||
runner.log_processor.window_size = 10
|
||||
runner.train_loop.max_iters = 100
|
||||
runner.max_iters = 100
|
||||
runner.iter = 0
|
||||
runner.test_loop.dataloader = [0] * 20
|
||||
runner.val_loop.dataloader = [0] * 20
|
||||
runner.test_dataloader = [0] * 20
|
||||
runner.val_dataloader = [0] * 20
|
||||
self.hook._before_epoch(runner)
|
||||
self.hook.before_run(runner)
|
||||
self.hook._after_iter(runner, batch_idx=1)
|
||||
|
@ -95,7 +95,7 @@ class TestLoggerHook:
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
logger_hook = LoggerHook(ignore_last=False)
|
||||
runner.train_loop.dataloader = [0] * 5
|
||||
runner.train_dataloader = [0] * 5
|
||||
logger_hook.after_train_iter(runner, batch_idx=4)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
|
||||
|
@ -110,7 +110,7 @@ class TestLogProcessor:
|
||||
assert out == log_str
|
||||
else:
|
||||
if mode == 'train':
|
||||
max_iters = self.runner.train_loop.max_iters
|
||||
max_iters = self.runner.max_iters
|
||||
log_str = f'Iter({mode}) [11/{max_iters}] '
|
||||
else:
|
||||
max_iters = len(cur_loop.dataloader)
|
||||
@ -227,7 +227,10 @@ class TestLogProcessor:
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.iter = 10
|
||||
runner.train_loop.max_iters = 50
|
||||
runner.max_iters = 50
|
||||
runner.train_dataloader = [0] * 20
|
||||
runner.val_dataloader = [0] * 10
|
||||
runner.test_dataloader = [0] * 5
|
||||
runner.train_loop.dataloader = [0] * 20
|
||||
runner.val_loop.dataloader = [0] * 10
|
||||
runner.test_loop.dataloader = [0] * 5
|
||||
|
@ -341,9 +341,9 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(runner.model_name, 'ToyModel')
|
||||
|
||||
# 3. test lazy initialization
|
||||
self.assertIsInstance(runner.train_dataloader, dict)
|
||||
self.assertIsInstance(runner.val_dataloader, dict)
|
||||
self.assertIsInstance(runner.test_dataloader, dict)
|
||||
self.assertIsInstance(runner._train_dataloader, dict)
|
||||
self.assertIsInstance(runner._val_dataloader, dict)
|
||||
self.assertIsInstance(runner._test_dataloader, dict)
|
||||
self.assertIsInstance(runner.optimizer, dict)
|
||||
self.assertIsInstance(runner.param_schedulers[0], dict)
|
||||
|
||||
@ -352,20 +352,20 @@ class TestRunner(TestCase):
|
||||
# test_dataloader should also be dict
|
||||
runner.train()
|
||||
|
||||
self.assertIsInstance(runner.train_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.train_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner._train_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.train_dataloader, DataLoader)
|
||||
self.assertIsInstance(runner.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
self.assertIsInstance(runner.val_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.val_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner.val_loop.evaluator, Evaluator)
|
||||
self.assertIsInstance(runner._val_loop, BaseLoop)
|
||||
self.assertIsInstance(runner._val_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner._val_loop.evaluator, Evaluator)
|
||||
|
||||
# After calling runner.test(), test_dataloader should be initialized
|
||||
self.assertIsInstance(runner.test_loop, dict)
|
||||
self.assertIsInstance(runner._test_loop, dict)
|
||||
runner.test()
|
||||
self.assertIsInstance(runner.test_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner.test_loop.evaluator, Evaluator)
|
||||
self.assertIsInstance(runner._test_loop, BaseLoop)
|
||||
self.assertIsInstance(runner._test_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner._test_loop.evaluator, Evaluator)
|
||||
|
||||
# 4. initialize runner with objects rather than config
|
||||
model = ToyModel()
|
||||
@ -637,7 +637,7 @@ class TestRunner(TestCase):
|
||||
begin=1,
|
||||
end=7,
|
||||
convert_to_iter_based=True)
|
||||
runner.train_loop = runner.build_train_loop(runner.train_loop)
|
||||
runner._train_loop = runner.build_train_loop(runner.train_loop)
|
||||
param_schedulers = runner.build_param_scheduler(cfg)
|
||||
self.assertFalse(param_schedulers[0].by_epoch)
|
||||
self.assertEqual(param_schedulers[0].begin, 4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user