[Refactor] Make loop-related attributes to be runner's properties. (#236)

* [Enhance] Make loop related attributes to be runner's properties.

* move iter and epoch to loop

* resolve comments
This commit is contained in:
RangiLyu 2022-05-18 22:35:10 +08:00 committed by GitHub
parent cc8a6b86e1
commit e37f1f905b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 206 additions and 144 deletions

View File

@ -194,7 +194,8 @@ class CheckpointHook(Hook):
# 1. every ``self.interval`` iterations
# 2. reach the last iteration of training
if self.every_n_iters(runner, self.interval) or \
(self.save_last and self.is_last_iter(runner, mode='train')):
(self.save_last and
self.is_last_train_iter(runner)):
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
self._save_checkpoint(runner)

View File

@ -409,25 +409,15 @@ class Hook:
Returns:
bool: Whether reaches the end of training epoch.
"""
return runner.epoch + 1 == runner.train_loop.max_epochs
return runner.epoch + 1 == runner.max_epochs
def is_last_iter(self, runner, mode='train') -> bool:
"""Test whether current iteration is the last iteration.
def is_last_train_iter(self, runner) -> bool:
"""Test whether current iteration is the last train iteration.
Args:
runner (Runner): The runner of the training, validation or testing
process.
mode (str): Current mode of runner. Defaults to 'train'.
runner (Runner): The runner of the training process.
Returns:
bool: Whether current iteration is the last iteration.
bool: Whether current iteration is the last train iteration.
"""
if mode == 'train':
return runner.iter + 1 == runner.train_loop.max_iters
elif mode == 'val':
return runner.iter + 1 == runner.val_loop.max_iters
elif mode == 'test':
return runner.iter + 1 == runner.test_loop.max_iters
else:
raise ValueError('mode should be train, val or test but got'
f'{mode}')
return runner.iter + 1 == runner.max_iters

View File

@ -98,14 +98,13 @@ class IterTimerHook(Hook):
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (
runner.train_loop.max_iters - runner.iter - 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
else:
if mode == 'val':
cur_dataloader = runner.val_loop.dataloader
cur_dataloader = runner.val_dataloader
else:
cur_dataloader = runner.test_loop.dataloader
cur_dataloader = runner.test_dataloader
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)

View File

@ -127,13 +127,13 @@ class LoggerHook(Hook):
# Print experiment name every n iterations.
if self.every_n_iters(runner,
self.interval_exp_name) or (self.end_of_epoch(
runner.train_loop.dataloader, batch_idx)):
runner.train_dataloader, batch_idx)):
exp_info = f'Exp name: {runner.experiment_name}'
runner.logger.info(exp_info)
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
elif (self.end_of_epoch(runner.train_loop.dataloader, batch_idx)
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
and not self.ignore_last):
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will

View File

@ -142,7 +142,7 @@ class LogProcessor:
else:
if mode == 'train':
log_str = (f'Iter({mode}) '
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
f'[{cur_iter}/{runner.max_iters}] ')
else:
log_str = (f'Iter({mode}) [{batch_idx+1}'
f'/{len(current_loop.dataloader)}] ')

View File

@ -19,7 +19,7 @@ class EpochBasedTrainLoop(BaseLoop):
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_epoch (int): Total training epochs.
max_epochs (int): Total training epochs.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
@ -27,6 +27,8 @@ class EpochBasedTrainLoop(BaseLoop):
super().__init__(runner, dataloader)
self._max_epochs = max_epochs
self._max_iters = max_epochs * len(self.dataloader)
self._epoch = 0
self._iter = 0
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
@ -46,15 +48,25 @@ class EpochBasedTrainLoop(BaseLoop):
"""int: Total iterations to train model."""
return self._max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
def run(self) -> None:
"""Launch training."""
self.runner.call_hook('before_train')
while self.runner._epoch < self._max_epochs:
while self._epoch < self._max_epochs:
self.run_epoch()
if (self.runner.val_loop is not None and
self.runner._epoch % self.runner.val_loop.interval == 0):
if (self.runner.val_loop is not None
and self._epoch % self.runner.val_loop.interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
@ -67,7 +79,9 @@ class EpochBasedTrainLoop(BaseLoop):
self.run_iter(idx, data_batch)
self.runner.call_hook('after_train_epoch')
self.runner.epoch += 1
self._epoch += 1
# To allow components that cannot access runner to get current epoch.
self.runner.message_hub.update_info('epoch', self._epoch)
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
"""Iterate one min-batch.
@ -90,7 +104,10 @@ class EpochBasedTrainLoop(BaseLoop):
data_batch=data_batch,
outputs=self.runner.outputs)
self.runner.iter += 1
self._iter += 1
# To allow components that cannot access runner to get current
# iteration.
self.runner.message_hub.update_info('iter', self._iter)
@LOOPS.register_module()
@ -101,13 +118,16 @@ class IterBasedTrainLoop(BaseLoop):
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_iter (int): Total training iterations.
max_iters (int): Total training iterations.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
max_iters: int) -> None:
super().__init__(runner, dataloader)
self._max_iters = max_iters
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
self._epoch = 0
self._iter = 0
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
@ -118,25 +138,40 @@ class IterBasedTrainLoop(BaseLoop):
'None.')
self.dataloader = iter(self.dataloader)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
return self._max_epochs
@property
def max_iters(self):
"""int: Total iterations to train model."""
return self._max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
def run(self) -> None:
"""Launch training."""
self.runner.call_hook('before_train')
# In iteration-based training loop, we treat the whole training process
# as a big epoch and execute the corresponding hook.
self.runner.call_hook('before_train_epoch')
while self.runner._iter < self._max_iters:
while self._iter < self._max_iters:
self.runner.model.train()
data_batch = next(self.dataloader)
self.run_iter(data_batch)
if (self.runner.val_loop is not None and
self.runner._iter % self.runner.val_loop.interval == 0):
if (self.runner.val_loop is not None
and self._iter % self.runner.val_interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train_epoch')
@ -149,9 +184,7 @@ class IterBasedTrainLoop(BaseLoop):
data_batch (Sequence[dict]): Batch of data from dataloader.
"""
self.runner.call_hook(
'before_train_iter',
batch_idx=self.runner._iter,
data_batch=data_batch)
'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
# outputs should be a dict containing loss tensor
self.runner.outputs = self.runner.model(data_batch, return_loss=True)
@ -161,10 +194,13 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.call_hook(
'after_train_iter',
batch_idx=self.runner._iter,
batch_idx=self._iter,
data_batch=data_batch,
outputs=self.runner.outputs)
self.runner.iter += 1
self._iter += 1
# To allow components that cannot access runner to get current
# iteration.
self.runner.message_hub.update_info('iter', self._iter)
@LOOPS.register_module()

View File

@ -199,9 +199,9 @@ class Runner:
>>> runner.test()
"""
cfg: Config
train_loop: Optional[Union[BaseLoop, Dict]]
val_loop: Optional[Union[BaseLoop, Dict]]
test_loop: Optional[Union[BaseLoop, Dict]]
_train_loop: Optional[Union[BaseLoop, Dict]]
_val_loop: Optional[Union[BaseLoop, Dict]]
_test_loop: Optional[Union[BaseLoop, Dict]]
def __init__(
self,
@ -244,9 +244,6 @@ class Runner:
else:
self.cfg = Config(dict())
self._epoch = 0
self._iter = 0
# lazy initialization
training_related = [train_dataloader, train_cfg, optimizer]
if not (all(item is None for item in training_related)
@ -257,8 +254,8 @@ class Runner:
f'train_dataloader={train_dataloader}, '
f'train_cfg={train_cfg}, '
f'optimizer={optimizer}.')
self.train_dataloader = train_dataloader
self.train_loop = train_cfg
self._train_dataloader = train_dataloader
self._train_loop = train_cfg
self.optimizer = optimizer
# If there is no need to adjust learning rate, momentum or other
@ -284,9 +281,9 @@ class Runner:
'all None or not None, but got '
f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, '
f'val_evaluator={val_evaluator}')
self.val_dataloader = val_dataloader
self.val_loop = val_cfg
self.val_evaluator = val_evaluator
self._val_dataloader = val_dataloader
self._val_loop = val_cfg
self._val_evaluator = val_evaluator
test_related = [test_dataloader, test_cfg, test_evaluator]
if not (all(item is None for item in test_related)
@ -296,9 +293,9 @@ class Runner:
'either all None or not None, but got '
f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, '
f'test_evaluator={test_evaluator}')
self.test_dataloader = test_dataloader
self.test_loop = test_cfg
self.test_evaluator = test_evaluator
self._test_dataloader = test_dataloader
self._test_loop = test_cfg
self._test_evaluator = test_evaluator
self._launcher = launcher
if self._launcher == 'none':
@ -435,30 +432,25 @@ class Runner:
"""str: The working directory to save checkpoints and logs."""
return self._work_dir
@property
def max_epochs(self):
"""int: Total epochs to train model."""
return self.train_loop.max_epochs
@property
def max_iters(self):
"""int: Total iterations to train model."""
return self.train_loop.max_iters
@property
def epoch(self):
"""int: Current epoch."""
return self._epoch
@epoch.setter
def epoch(self, epoch: int):
"""Update epoch and synchronize epoch in :attr:`message_hub`."""
self._epoch = epoch
# To allow components that cannot access runner to get current epoch.
self.message_hub.update_info('epoch', epoch)
return self.train_loop.epoch
@property
def iter(self):
"""int: Current iteration."""
return self._iter
@iter.setter
def iter(self, iter: int):
"""Update iter and synchronize iter in :attr:`message_hub`."""
self._iter = iter
# To allow components that cannot access runner to get current
# iteration.
self.message_hub.update_info('iter', iter)
return self.train_loop.iter
@property
def launcher(self):
@ -500,6 +492,63 @@ class Runner:
"""list[:obj:`Hook`]: A list of registered hooks."""
return self._hooks
@property
def train_loop(self):
""":obj:`BaseLoop`: A loop to run training."""
if isinstance(self._train_loop, BaseLoop) or self._train_loop is None:
return self._train_loop
else:
self._train_loop = self.build_train_loop(self._train_loop)
return self._train_loop
@property
def val_loop(self):
""":obj:`BaseLoop`: A loop to run validation."""
if isinstance(self._val_loop, BaseLoop) or self._val_loop is None:
return self._val_loop
else:
self._val_loop = self.build_val_loop(self._val_loop)
return self._val_loop
@property
def test_loop(self):
""":obj:`BaseLoop`: A loop to run testing."""
if isinstance(self._test_loop, BaseLoop) or self._test_loop is None:
return self._test_loop
else:
self._test_loop = self.build_test_loop(self._test_loop)
return self._test_loop
@property
def train_dataloader(self):
"""The data loader for training."""
return self.train_loop.dataloader
@property
def val_dataloader(self):
"""The data loader for validation."""
return self.val_loop.dataloader
@property
def test_dataloader(self):
"""The data loader for testing."""
return self.test_loop.dataloader
@property
def val_evaluator(self):
""":obj:`Evaluator`: An evaluator for validation."""
return self.val_loop.evaluator
@property
def test_evaluator(self):
""":obj:`Evaluator`: An evaluator for testing."""
return self.test_loop.evaluator
@property
def val_interval(self):
"""int: Interval to run validation during training."""
return self.val_loop.interval
def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.
@ -857,7 +906,7 @@ class Runner:
'by_epoch', True
), 'only epoch-based parameter scheduler can be ' \
'converted to iter-based'
assert isinstance(self.train_loop, BaseLoop), \
assert isinstance(self._train_loop, BaseLoop), \
'Scheduler can only be converted to iter-based ' \
'when train loop is built.'
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
@ -866,7 +915,7 @@ class Runner:
optimizer=self.optimizer,
**_scheduler,
epoch_length=len(
self.train_loop.dataloader), # type: ignore
self.train_dataloader), # type: ignore
))
else:
param_schedulers.append(
@ -1043,15 +1092,15 @@ class Runner:
loop = LOOPS.build(
loop_cfg,
default_args=dict(
runner=self, dataloader=self.train_dataloader))
runner=self, dataloader=self._train_dataloader))
else:
by_epoch = loop_cfg.pop('by_epoch')
if by_epoch:
loop = EpochBasedTrainLoop(
**loop_cfg, runner=self, dataloader=self.train_dataloader)
**loop_cfg, runner=self, dataloader=self._train_dataloader)
else:
loop = IterBasedTrainLoop(
**loop_cfg, runner=self, dataloader=self.train_dataloader)
**loop_cfg, runner=self, dataloader=self._train_dataloader)
# `build_optimizer` should be called before `build_param_scheduler`
# because the latter depends on the former
@ -1095,13 +1144,13 @@ class Runner:
loop_cfg,
default_args=dict(
runner=self,
dataloader=self.val_dataloader,
evaluator=self.val_evaluator))
dataloader=self._val_dataloader,
evaluator=self._val_evaluator))
else:
loop = ValLoop(
runner=self,
dataloader=self.val_dataloader,
evaluator=self.val_evaluator, # type: ignore
dataloader=self._val_dataloader,
evaluator=self._val_evaluator, # type: ignore
**loop_cfg,
) # type: ignore
@ -1122,9 +1171,6 @@ class Runner:
loop (BaseLoop or dict): A test loop or a dict to build test loop.
If ``loop`` is a test loop object, just returns itself.
Args:
loop_cfg (dict): Config to build test loop.
Returns:
:obj:`BaseLoop`: Test loop object build from ``loop_cfg``.
"""
@ -1141,13 +1187,13 @@ class Runner:
loop_cfg,
default_args=dict(
runner=self,
dataloader=self.test_dataloader,
evaluator=self.test_evaluator))
dataloader=self._test_dataloader,
evaluator=self._test_evaluator))
else:
loop = TestLoop(
runner=self,
dataloader=self.test_dataloader,
evaluator=self.test_evaluator) # type: ignore
dataloader=self._test_dataloader,
evaluator=self._test_evaluator) # type: ignore
return loop # type: ignore
@ -1176,18 +1222,19 @@ class Runner:
def train(self) -> None:
"""Launch training."""
if self.train_loop is None:
if self._train_loop is None:
raise RuntimeError(
'`self.train_loop` should not be None when calling train '
'`self._train_loop` should not be None when calling train '
'method. Please provide `train_dataloader`, `train_cfg`, '
'`optimizer` and `param_scheduler` arguments when '
'initializing runner.')
self.train_loop = self.build_train_loop(
self.train_loop) # type: ignore
self._train_loop = self.build_train_loop(
self._train_loop) # type: ignore
if self.val_loop is not None:
self.val_loop = self.build_val_loop(self.val_loop) # type: ignore
if self._val_loop is not None:
self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore
self.load_or_resume()
@ -1198,13 +1245,13 @@ class Runner:
def val(self) -> None:
"""Launch validation."""
if self.val_loop is None:
if self._val_loop is None:
raise RuntimeError(
'`self.val_loop` should not be None when calling val method.'
'`self._val_loop` should not be None when calling val method.'
'Please provide `val_dataloader`, `val_cfg` and '
'`val_evaluator` arguments when initializing runner.')
self.val_loop = self.build_val_loop(self.val_loop) # type: ignore
self._val_loop = self.build_val_loop(self._val_loop) # type: ignore
self.load_or_resume()
@ -1214,13 +1261,13 @@ class Runner:
def test(self) -> None:
"""Launch test."""
if self.test_loop is None:
if self._test_loop is None:
raise RuntimeError(
'`self.test_loop` should not be None when calling test method.'
'Please provide `test_dataloader`, `test_cfg` and '
'`self._test_loop` should not be None when calling test '
'method. Please provide `test_dataloader`, `test_cfg` and '
'`test_evaluator` arguments when initializing runner.')
self.test_loop = self.build_test_loop(self.test_loop) # type: ignore
self._test_loop = self.build_test_loop(self._test_loop) # type: ignore
self.load_or_resume()
@ -1424,8 +1471,8 @@ class Runner:
checkpoint = self.load_checkpoint(
filename, map_location=map_location)
self._epoch = checkpoint['meta']['epoch']
self._iter = checkpoint['meta']['iter']
self.train_loop._epoch = checkpoint['meta']['epoch']
self.train_loop._iter = checkpoint['meta']['iter']
if self.meta is None:
self.meta = {}
@ -1466,7 +1513,7 @@ class Runner:
self._has_loaded = True
self.logger.info(f'resumed epoch: {self._epoch}, iter: {self._iter}')
self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}')
def load_checkpoint(self,
filename: str,
@ -1544,13 +1591,13 @@ class Runner:
meta.update(self.meta)
if by_epoch:
# self._epoch increments 1 after
# self.epoch increments 1 after
# `self.call_hook('after_train_epoch)` but `save_checkpoint` is
# called by `after_train_epoch`` method of `CheckpointHook` so
# `epoch` should be `self_epoch + 1`
meta.update(epoch=self._epoch + 1, iter=self._iter)
# `epoch` should be `self.epoch + 1`
meta.update(epoch=self.epoch + 1, iter=self.iter)
else:
meta.update(epoch=self._epoch, iter=self._iter + 1)
meta.update(epoch=self.epoch, iter=self.iter + 1)
filepath = osp.join(out_dir, filename)

View File

@ -1,8 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
import pytest
from mmengine.hooks import Hook
@ -176,33 +174,21 @@ class TestHook:
# last epoch
runner.epoch = 1
runner.train_loop.max_epochs = 2
runner.max_epochs = 2
return_val = hook.is_last_train_epoch(runner)
assert return_val
# not the last epoch
runner.train_loop.max_epochs = 0
runner.max_epochs = 0
return_val = hook.is_last_train_epoch(runner)
assert not return_val
def test_is_last_iter(self):
def test_is_last_train_iter(self):
hook = Hook()
runner = Mock()
# last iter
runner.iter = 1
runner.train_loop.max_iters = 2
return_val = hook.is_last_iter(runner)
runner.max_iters = 2
return_val = hook.is_last_train_iter(runner)
assert return_val
# not the last iter
runner.val_loop.max_iters = 0
return_val = hook.is_last_iter(runner, mode='val')
assert not return_val
runner.test_loop.max_iters = 0
return_val = hook.is_last_iter(runner, mode='test')
assert not return_val
with pytest.raises(ValueError):
hook.is_last_iter(runner, mode='error_mode')

View File

@ -49,10 +49,10 @@ class TestIterTimerHook(TestCase):
runner = MagicMock()
runner.log_buffer = dict()
runner.log_processor.window_size = 10
runner.train_loop.max_iters = 100
runner.max_iters = 100
runner.iter = 0
runner.test_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 20
runner.test_dataloader = [0] * 20
runner.val_dataloader = [0] * 20
self.hook._before_epoch(runner)
self.hook.before_run(runner)
self.hook._after_iter(runner, batch_idx=1)

View File

@ -95,7 +95,7 @@ class TestLoggerHook:
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook = LoggerHook(ignore_last=False)
runner.train_loop.dataloader = [0] * 5
runner.train_dataloader = [0] * 5
logger_hook.after_train_iter(runner, batch_idx=4)
runner.log_processor.get_log_after_iter.assert_called()

View File

@ -110,7 +110,7 @@ class TestLogProcessor:
assert out == log_str
else:
if mode == 'train':
max_iters = self.runner.train_loop.max_iters
max_iters = self.runner.max_iters
log_str = f'Iter({mode}) [11/{max_iters}] '
else:
max_iters = len(cur_loop.dataloader)
@ -227,7 +227,10 @@ class TestLogProcessor:
runner = MagicMock()
runner.epoch = 1
runner.iter = 10
runner.train_loop.max_iters = 50
runner.max_iters = 50
runner.train_dataloader = [0] * 20
runner.val_dataloader = [0] * 10
runner.test_dataloader = [0] * 5
runner.train_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 10
runner.test_loop.dataloader = [0] * 5

View File

@ -341,9 +341,9 @@ class TestRunner(TestCase):
self.assertEqual(runner.model_name, 'ToyModel')
# 3. test lazy initialization
self.assertIsInstance(runner.train_dataloader, dict)
self.assertIsInstance(runner.val_dataloader, dict)
self.assertIsInstance(runner.test_dataloader, dict)
self.assertIsInstance(runner._train_dataloader, dict)
self.assertIsInstance(runner._val_dataloader, dict)
self.assertIsInstance(runner._test_dataloader, dict)
self.assertIsInstance(runner.optimizer, dict)
self.assertIsInstance(runner.param_schedulers[0], dict)
@ -352,20 +352,20 @@ class TestRunner(TestCase):
# test_dataloader should also be dict
runner.train()
self.assertIsInstance(runner.train_loop, BaseLoop)
self.assertIsInstance(runner.train_loop.dataloader, DataLoader)
self.assertIsInstance(runner._train_loop, BaseLoop)
self.assertIsInstance(runner.train_dataloader, DataLoader)
self.assertIsInstance(runner.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertIsInstance(runner.val_loop, BaseLoop)
self.assertIsInstance(runner.val_loop.dataloader, DataLoader)
self.assertIsInstance(runner.val_loop.evaluator, Evaluator)
self.assertIsInstance(runner._val_loop, BaseLoop)
self.assertIsInstance(runner._val_loop.dataloader, DataLoader)
self.assertIsInstance(runner._val_loop.evaluator, Evaluator)
# After calling runner.test(), test_dataloader should be initialized
self.assertIsInstance(runner.test_loop, dict)
self.assertIsInstance(runner._test_loop, dict)
runner.test()
self.assertIsInstance(runner.test_loop, BaseLoop)
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
self.assertIsInstance(runner.test_loop.evaluator, Evaluator)
self.assertIsInstance(runner._test_loop, BaseLoop)
self.assertIsInstance(runner._test_loop.dataloader, DataLoader)
self.assertIsInstance(runner._test_loop.evaluator, Evaluator)
# 4. initialize runner with objects rather than config
model = ToyModel()
@ -637,7 +637,7 @@ class TestRunner(TestCase):
begin=1,
end=7,
convert_to_iter_based=True)
runner.train_loop = runner.build_train_loop(runner.train_loop)
runner._train_loop = runner.build_train_loop(runner.train_loop)
param_schedulers = runner.build_param_scheduler(cfg)
self.assertFalse(param_schedulers[0].by_epoch)
self.assertEqual(param_schedulers[0].begin, 4)