From e37f1f905b09802c4abd45e920a49d2bc8bdb49c Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 18 May 2022 22:35:10 +0800 Subject: [PATCH] [Refactor] Make loop-related attributes to be runner's properties. (#236) * [Enhance] Make loop related attributes to be runner's properties. * move iter and epoch to loop * resolve comments --- mmengine/hooks/checkpoint_hook.py | 3 +- mmengine/hooks/hook.py | 22 +-- mmengine/hooks/iter_timer_hook.py | 7 +- mmengine/hooks/logger_hook.py | 4 +- mmengine/logging/log_processor.py | 2 +- mmengine/runner/loops.py | 66 +++++++-- mmengine/runner/runner.py | 181 ++++++++++++++--------- tests/test_hook/test_hook.py | 24 +-- tests/test_hook/test_iter_timer_hook.py | 6 +- tests/test_hook/test_logger_hook.py | 2 +- tests/test_logging/test_log_processor.py | 7 +- tests/test_runner/test_runner.py | 26 ++-- 12 files changed, 206 insertions(+), 144 deletions(-) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 3ed332bd..9eb94eb1 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -194,7 +194,8 @@ class CheckpointHook(Hook): # 1. every ``self.interval`` iterations # 2. reach the last iteration of training if self.every_n_iters(runner, self.interval) or \ - (self.save_last and self.is_last_iter(runner, mode='train')): + (self.save_last and + self.is_last_train_iter(runner)): runner.logger.info( f'Saving checkpoint at {runner.iter + 1} iterations') self._save_checkpoint(runner) diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 5e96bcd8..9fde8174 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -409,25 +409,15 @@ class Hook: Returns: bool: Whether reaches the end of training epoch. """ - return runner.epoch + 1 == runner.train_loop.max_epochs + return runner.epoch + 1 == runner.max_epochs - def is_last_iter(self, runner, mode='train') -> bool: - """Test whether current iteration is the last iteration. + def is_last_train_iter(self, runner) -> bool: + """Test whether current iteration is the last train iteration. Args: - runner (Runner): The runner of the training, validation or testing - process. - mode (str): Current mode of runner. Defaults to 'train'. + runner (Runner): The runner of the training process. Returns: - bool: Whether current iteration is the last iteration. + bool: Whether current iteration is the last train iteration. """ - if mode == 'train': - return runner.iter + 1 == runner.train_loop.max_iters - elif mode == 'val': - return runner.iter + 1 == runner.val_loop.max_iters - elif mode == 'test': - return runner.iter + 1 == runner.test_loop.max_iters - else: - raise ValueError('mode should be train, val or test but got' - f'{mode}') + return runner.iter + 1 == runner.max_iters diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 56c9cee9..ee5be0a7 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -98,14 +98,13 @@ class IterTimerHook(Hook): time_sec_avg = self.time_sec_tot / ( runner.iter - self.start_iter + 1) # Calculate eta. - eta_sec = time_sec_avg * ( - runner.train_loop.max_iters - runner.iter - 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) runner.message_hub.update_info('eta', eta_sec) else: if mode == 'val': - cur_dataloader = runner.val_loop.dataloader + cur_dataloader = runner.val_dataloader else: - cur_dataloader = runner.test_loop.dataloader + cur_dataloader = runner.test_dataloader eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1) runner.message_hub.update_info('eta', eta_sec) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 5b529648..b5e15014 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -127,13 +127,13 @@ class LoggerHook(Hook): # Print experiment name every n iterations. if self.every_n_iters(runner, self.interval_exp_name) or (self.end_of_epoch( - runner.train_loop.dataloader, batch_idx)): + runner.train_dataloader, batch_idx)): exp_info = f'Exp name: {runner.experiment_name}' runner.logger.info(exp_info) if self.every_n_inner_iters(batch_idx, self.interval): tag, log_str = runner.log_processor.get_log_after_iter( runner, batch_idx, 'train') - elif (self.end_of_epoch(runner.train_loop.dataloader, batch_idx) + elif (self.end_of_epoch(runner.train_dataloader, batch_idx) and not self.ignore_last): # `runner.max_iters` may not be divisible by `self.interval`. if # `self.ignore_last==True`, the log of remaining iterations will diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py index 84541123..bcfd539d 100644 --- a/mmengine/logging/log_processor.py +++ b/mmengine/logging/log_processor.py @@ -142,7 +142,7 @@ class LogProcessor: else: if mode == 'train': log_str = (f'Iter({mode}) ' - f'[{cur_iter}/{runner.train_loop.max_iters}] ') + f'[{cur_iter}/{runner.max_iters}] ') else: log_str = (f'Iter({mode}) [{batch_idx+1}' f'/{len(current_loop.dataloader)}] ') diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 116560a9..2b670486 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -19,7 +19,7 @@ class EpochBasedTrainLoop(BaseLoop): runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. - max_epoch (int): Total training epochs. + max_epochs (int): Total training epochs. """ def __init__(self, runner, dataloader: Union[DataLoader, Dict], @@ -27,6 +27,8 @@ class EpochBasedTrainLoop(BaseLoop): super().__init__(runner, dataloader) self._max_epochs = max_epochs self._max_iters = max_epochs * len(self.dataloader) + self._epoch = 0 + self._iter = 0 if hasattr(self.dataloader.dataset, 'metainfo'): self.runner.visualizer.dataset_meta = \ self.dataloader.dataset.metainfo @@ -46,15 +48,25 @@ class EpochBasedTrainLoop(BaseLoop): """int: Total iterations to train model.""" return self._max_iters + @property + def epoch(self): + """int: Current epoch.""" + return self._epoch + + @property + def iter(self): + """int: Current iteration.""" + return self._iter + def run(self) -> None: """Launch training.""" self.runner.call_hook('before_train') - while self.runner._epoch < self._max_epochs: + while self._epoch < self._max_epochs: self.run_epoch() - if (self.runner.val_loop is not None and - self.runner._epoch % self.runner.val_loop.interval == 0): + if (self.runner.val_loop is not None + and self._epoch % self.runner.val_loop.interval == 0): self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -67,7 +79,9 @@ class EpochBasedTrainLoop(BaseLoop): self.run_iter(idx, data_batch) self.runner.call_hook('after_train_epoch') - self.runner.epoch += 1 + self._epoch += 1 + # To allow components that cannot access runner to get current epoch. + self.runner.message_hub.update_info('epoch', self._epoch) def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """Iterate one min-batch. @@ -90,7 +104,10 @@ class EpochBasedTrainLoop(BaseLoop): data_batch=data_batch, outputs=self.runner.outputs) - self.runner.iter += 1 + self._iter += 1 + # To allow components that cannot access runner to get current + # iteration. + self.runner.message_hub.update_info('iter', self._iter) @LOOPS.register_module() @@ -101,13 +118,16 @@ class IterBasedTrainLoop(BaseLoop): runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. - max_iter (int): Total training iterations. + max_iters (int): Total training iterations. """ def __init__(self, runner, dataloader: Union[DataLoader, Dict], max_iters: int) -> None: super().__init__(runner, dataloader) self._max_iters = max_iters + self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop + self._epoch = 0 + self._iter = 0 if hasattr(self.dataloader.dataset, 'metainfo'): self.runner.visualizer.dataset_meta = \ self.dataloader.dataset.metainfo @@ -118,25 +138,40 @@ class IterBasedTrainLoop(BaseLoop): 'None.') self.dataloader = iter(self.dataloader) + @property + def max_epochs(self): + """int: Total epochs to train model.""" + return self._max_epochs + @property def max_iters(self): """int: Total iterations to train model.""" return self._max_iters + @property + def epoch(self): + """int: Current epoch.""" + return self._epoch + + @property + def iter(self): + """int: Current iteration.""" + return self._iter + def run(self) -> None: """Launch training.""" self.runner.call_hook('before_train') # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - while self.runner._iter < self._max_iters: + while self._iter < self._max_iters: self.runner.model.train() data_batch = next(self.dataloader) self.run_iter(data_batch) - if (self.runner.val_loop is not None and - self.runner._iter % self.runner.val_loop.interval == 0): + if (self.runner.val_loop is not None + and self._iter % self.runner.val_interval == 0): self.runner.val_loop.run() self.runner.call_hook('after_train_epoch') @@ -149,9 +184,7 @@ class IterBasedTrainLoop(BaseLoop): data_batch (Sequence[dict]): Batch of data from dataloader. """ self.runner.call_hook( - 'before_train_iter', - batch_idx=self.runner._iter, - data_batch=data_batch) + 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) # outputs should be a dict containing loss tensor self.runner.outputs = self.runner.model(data_batch, return_loss=True) @@ -161,10 +194,13 @@ class IterBasedTrainLoop(BaseLoop): self.runner.call_hook( 'after_train_iter', - batch_idx=self.runner._iter, + batch_idx=self._iter, data_batch=data_batch, outputs=self.runner.outputs) - self.runner.iter += 1 + self._iter += 1 + # To allow components that cannot access runner to get current + # iteration. + self.runner.message_hub.update_info('iter', self._iter) @LOOPS.register_module() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 2c65a72b..4f4a1296 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -199,9 +199,9 @@ class Runner: >>> runner.test() """ cfg: Config - train_loop: Optional[Union[BaseLoop, Dict]] - val_loop: Optional[Union[BaseLoop, Dict]] - test_loop: Optional[Union[BaseLoop, Dict]] + _train_loop: Optional[Union[BaseLoop, Dict]] + _val_loop: Optional[Union[BaseLoop, Dict]] + _test_loop: Optional[Union[BaseLoop, Dict]] def __init__( self, @@ -244,9 +244,6 @@ class Runner: else: self.cfg = Config(dict()) - self._epoch = 0 - self._iter = 0 - # lazy initialization training_related = [train_dataloader, train_cfg, optimizer] if not (all(item is None for item in training_related) @@ -257,8 +254,8 @@ class Runner: f'train_dataloader={train_dataloader}, ' f'train_cfg={train_cfg}, ' f'optimizer={optimizer}.') - self.train_dataloader = train_dataloader - self.train_loop = train_cfg + self._train_dataloader = train_dataloader + self._train_loop = train_cfg self.optimizer = optimizer # If there is no need to adjust learning rate, momentum or other @@ -284,9 +281,9 @@ class Runner: 'all None or not None, but got ' f'val_dataloader={val_dataloader}, val_cfg={val_cfg}, ' f'val_evaluator={val_evaluator}') - self.val_dataloader = val_dataloader - self.val_loop = val_cfg - self.val_evaluator = val_evaluator + self._val_dataloader = val_dataloader + self._val_loop = val_cfg + self._val_evaluator = val_evaluator test_related = [test_dataloader, test_cfg, test_evaluator] if not (all(item is None for item in test_related) @@ -296,9 +293,9 @@ class Runner: 'either all None or not None, but got ' f'test_dataloader={test_dataloader}, test_cfg={test_cfg}, ' f'test_evaluator={test_evaluator}') - self.test_dataloader = test_dataloader - self.test_loop = test_cfg - self.test_evaluator = test_evaluator + self._test_dataloader = test_dataloader + self._test_loop = test_cfg + self._test_evaluator = test_evaluator self._launcher = launcher if self._launcher == 'none': @@ -435,30 +432,25 @@ class Runner: """str: The working directory to save checkpoints and logs.""" return self._work_dir + @property + def max_epochs(self): + """int: Total epochs to train model.""" + return self.train_loop.max_epochs + + @property + def max_iters(self): + """int: Total iterations to train model.""" + return self.train_loop.max_iters + @property def epoch(self): """int: Current epoch.""" - return self._epoch - - @epoch.setter - def epoch(self, epoch: int): - """Update epoch and synchronize epoch in :attr:`message_hub`.""" - self._epoch = epoch - # To allow components that cannot access runner to get current epoch. - self.message_hub.update_info('epoch', epoch) + return self.train_loop.epoch @property def iter(self): """int: Current iteration.""" - return self._iter - - @iter.setter - def iter(self, iter: int): - """Update iter and synchronize iter in :attr:`message_hub`.""" - self._iter = iter - # To allow components that cannot access runner to get current - # iteration. - self.message_hub.update_info('iter', iter) + return self.train_loop.iter @property def launcher(self): @@ -500,6 +492,63 @@ class Runner: """list[:obj:`Hook`]: A list of registered hooks.""" return self._hooks + @property + def train_loop(self): + """:obj:`BaseLoop`: A loop to run training.""" + if isinstance(self._train_loop, BaseLoop) or self._train_loop is None: + return self._train_loop + else: + self._train_loop = self.build_train_loop(self._train_loop) + return self._train_loop + + @property + def val_loop(self): + """:obj:`BaseLoop`: A loop to run validation.""" + if isinstance(self._val_loop, BaseLoop) or self._val_loop is None: + return self._val_loop + else: + self._val_loop = self.build_val_loop(self._val_loop) + return self._val_loop + + @property + def test_loop(self): + """:obj:`BaseLoop`: A loop to run testing.""" + if isinstance(self._test_loop, BaseLoop) or self._test_loop is None: + return self._test_loop + else: + self._test_loop = self.build_test_loop(self._test_loop) + return self._test_loop + + @property + def train_dataloader(self): + """The data loader for training.""" + return self.train_loop.dataloader + + @property + def val_dataloader(self): + """The data loader for validation.""" + return self.val_loop.dataloader + + @property + def test_dataloader(self): + """The data loader for testing.""" + return self.test_loop.dataloader + + @property + def val_evaluator(self): + """:obj:`Evaluator`: An evaluator for validation.""" + return self.val_loop.evaluator + + @property + def test_evaluator(self): + """:obj:`Evaluator`: An evaluator for testing.""" + return self.test_loop.evaluator + + @property + def val_interval(self): + """int: Interval to run validation during training.""" + return self.val_loop.interval + def setup_env(self, env_cfg: Dict) -> None: """Setup environment. @@ -857,7 +906,7 @@ class Runner: 'by_epoch', True ), 'only epoch-based parameter scheduler can be ' \ 'converted to iter-based' - assert isinstance(self.train_loop, BaseLoop), \ + assert isinstance(self._train_loop, BaseLoop), \ 'Scheduler can only be converted to iter-based ' \ 'when train loop is built.' cls = PARAM_SCHEDULERS.get(_scheduler.pop('type')) @@ -866,7 +915,7 @@ class Runner: optimizer=self.optimizer, **_scheduler, epoch_length=len( - self.train_loop.dataloader), # type: ignore + self.train_dataloader), # type: ignore )) else: param_schedulers.append( @@ -1043,15 +1092,15 @@ class Runner: loop = LOOPS.build( loop_cfg, default_args=dict( - runner=self, dataloader=self.train_dataloader)) + runner=self, dataloader=self._train_dataloader)) else: by_epoch = loop_cfg.pop('by_epoch') if by_epoch: loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self.train_dataloader) + **loop_cfg, runner=self, dataloader=self._train_dataloader) else: loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self.train_dataloader) + **loop_cfg, runner=self, dataloader=self._train_dataloader) # `build_optimizer` should be called before `build_param_scheduler` # because the latter depends on the former @@ -1095,13 +1144,13 @@ class Runner: loop_cfg, default_args=dict( runner=self, - dataloader=self.val_dataloader, - evaluator=self.val_evaluator)) + dataloader=self._val_dataloader, + evaluator=self._val_evaluator)) else: loop = ValLoop( runner=self, - dataloader=self.val_dataloader, - evaluator=self.val_evaluator, # type: ignore + dataloader=self._val_dataloader, + evaluator=self._val_evaluator, # type: ignore **loop_cfg, ) # type: ignore @@ -1122,9 +1171,6 @@ class Runner: loop (BaseLoop or dict): A test loop or a dict to build test loop. If ``loop`` is a test loop object, just returns itself. - Args: - loop_cfg (dict): Config to build test loop. - Returns: :obj:`BaseLoop`: Test loop object build from ``loop_cfg``. """ @@ -1141,13 +1187,13 @@ class Runner: loop_cfg, default_args=dict( runner=self, - dataloader=self.test_dataloader, - evaluator=self.test_evaluator)) + dataloader=self._test_dataloader, + evaluator=self._test_evaluator)) else: loop = TestLoop( runner=self, - dataloader=self.test_dataloader, - evaluator=self.test_evaluator) # type: ignore + dataloader=self._test_dataloader, + evaluator=self._test_evaluator) # type: ignore return loop # type: ignore @@ -1176,18 +1222,19 @@ class Runner: def train(self) -> None: """Launch training.""" - if self.train_loop is None: + if self._train_loop is None: raise RuntimeError( - '`self.train_loop` should not be None when calling train ' + '`self._train_loop` should not be None when calling train ' 'method. Please provide `train_dataloader`, `train_cfg`, ' '`optimizer` and `param_scheduler` arguments when ' 'initializing runner.') - self.train_loop = self.build_train_loop( - self.train_loop) # type: ignore + self._train_loop = self.build_train_loop( + self._train_loop) # type: ignore - if self.val_loop is not None: - self.val_loop = self.build_val_loop(self.val_loop) # type: ignore + if self._val_loop is not None: + self._val_loop = self.build_val_loop( + self._val_loop) # type: ignore self.load_or_resume() @@ -1198,13 +1245,13 @@ class Runner: def val(self) -> None: """Launch validation.""" - if self.val_loop is None: + if self._val_loop is None: raise RuntimeError( - '`self.val_loop` should not be None when calling val method.' + '`self._val_loop` should not be None when calling val method.' 'Please provide `val_dataloader`, `val_cfg` and ' '`val_evaluator` arguments when initializing runner.') - self.val_loop = self.build_val_loop(self.val_loop) # type: ignore + self._val_loop = self.build_val_loop(self._val_loop) # type: ignore self.load_or_resume() @@ -1214,13 +1261,13 @@ class Runner: def test(self) -> None: """Launch test.""" - if self.test_loop is None: + if self._test_loop is None: raise RuntimeError( - '`self.test_loop` should not be None when calling test method.' - 'Please provide `test_dataloader`, `test_cfg` and ' + '`self._test_loop` should not be None when calling test ' + 'method. Please provide `test_dataloader`, `test_cfg` and ' '`test_evaluator` arguments when initializing runner.') - self.test_loop = self.build_test_loop(self.test_loop) # type: ignore + self._test_loop = self.build_test_loop(self._test_loop) # type: ignore self.load_or_resume() @@ -1424,8 +1471,8 @@ class Runner: checkpoint = self.load_checkpoint( filename, map_location=map_location) - self._epoch = checkpoint['meta']['epoch'] - self._iter = checkpoint['meta']['iter'] + self.train_loop._epoch = checkpoint['meta']['epoch'] + self.train_loop._iter = checkpoint['meta']['iter'] if self.meta is None: self.meta = {} @@ -1466,7 +1513,7 @@ class Runner: self._has_loaded = True - self.logger.info(f'resumed epoch: {self._epoch}, iter: {self._iter}') + self.logger.info(f'resumed epoch: {self.epoch}, iter: {self.iter}') def load_checkpoint(self, filename: str, @@ -1544,13 +1591,13 @@ class Runner: meta.update(self.meta) if by_epoch: - # self._epoch increments 1 after + # self.epoch increments 1 after # `self.call_hook('after_train_epoch)` but `save_checkpoint` is # called by `after_train_epoch`` method of `CheckpointHook` so - # `epoch` should be `self_epoch + 1` - meta.update(epoch=self._epoch + 1, iter=self._iter) + # `epoch` should be `self.epoch + 1` + meta.update(epoch=self.epoch + 1, iter=self.iter) else: - meta.update(epoch=self._epoch, iter=self._iter + 1) + meta.update(epoch=self.epoch, iter=self.iter + 1) filepath = osp.join(out_dir, filename) diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index 771c54f6..7cdfe6c5 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -1,8 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import Mock -import pytest - from mmengine.hooks import Hook @@ -176,33 +174,21 @@ class TestHook: # last epoch runner.epoch = 1 - runner.train_loop.max_epochs = 2 + runner.max_epochs = 2 return_val = hook.is_last_train_epoch(runner) assert return_val # not the last epoch - runner.train_loop.max_epochs = 0 + runner.max_epochs = 0 return_val = hook.is_last_train_epoch(runner) assert not return_val - def test_is_last_iter(self): + def test_is_last_train_iter(self): hook = Hook() runner = Mock() # last iter runner.iter = 1 - runner.train_loop.max_iters = 2 - return_val = hook.is_last_iter(runner) + runner.max_iters = 2 + return_val = hook.is_last_train_iter(runner) assert return_val - - # not the last iter - runner.val_loop.max_iters = 0 - return_val = hook.is_last_iter(runner, mode='val') - assert not return_val - - runner.test_loop.max_iters = 0 - return_val = hook.is_last_iter(runner, mode='test') - assert not return_val - - with pytest.raises(ValueError): - hook.is_last_iter(runner, mode='error_mode') diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index 8d3dfb9d..60365072 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -49,10 +49,10 @@ class TestIterTimerHook(TestCase): runner = MagicMock() runner.log_buffer = dict() runner.log_processor.window_size = 10 - runner.train_loop.max_iters = 100 + runner.max_iters = 100 runner.iter = 0 - runner.test_loop.dataloader = [0] * 20 - runner.val_loop.dataloader = [0] * 20 + runner.test_dataloader = [0] * 20 + runner.val_dataloader = [0] * 20 self.hook._before_epoch(runner) self.hook.before_run(runner) self.hook._after_iter(runner, batch_idx=1) diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 2e05b84f..2cf75dc4 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -95,7 +95,7 @@ class TestLoggerHook: runner.log_processor.get_log_after_iter = MagicMock( return_value=(dict(), 'log_str')) logger_hook = LoggerHook(ignore_last=False) - runner.train_loop.dataloader = [0] * 5 + runner.train_dataloader = [0] * 5 logger_hook.after_train_iter(runner, batch_idx=4) runner.log_processor.get_log_after_iter.assert_called() diff --git a/tests/test_logging/test_log_processor.py b/tests/test_logging/test_log_processor.py index b10cac48..44c22018 100644 --- a/tests/test_logging/test_log_processor.py +++ b/tests/test_logging/test_log_processor.py @@ -110,7 +110,7 @@ class TestLogProcessor: assert out == log_str else: if mode == 'train': - max_iters = self.runner.train_loop.max_iters + max_iters = self.runner.max_iters log_str = f'Iter({mode}) [11/{max_iters}] ' else: max_iters = len(cur_loop.dataloader) @@ -227,7 +227,10 @@ class TestLogProcessor: runner = MagicMock() runner.epoch = 1 runner.iter = 10 - runner.train_loop.max_iters = 50 + runner.max_iters = 50 + runner.train_dataloader = [0] * 20 + runner.val_dataloader = [0] * 10 + runner.test_dataloader = [0] * 5 runner.train_loop.dataloader = [0] * 20 runner.val_loop.dataloader = [0] * 10 runner.test_loop.dataloader = [0] * 5 diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 192264a0..7b50db55 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -341,9 +341,9 @@ class TestRunner(TestCase): self.assertEqual(runner.model_name, 'ToyModel') # 3. test lazy initialization - self.assertIsInstance(runner.train_dataloader, dict) - self.assertIsInstance(runner.val_dataloader, dict) - self.assertIsInstance(runner.test_dataloader, dict) + self.assertIsInstance(runner._train_dataloader, dict) + self.assertIsInstance(runner._val_dataloader, dict) + self.assertIsInstance(runner._test_dataloader, dict) self.assertIsInstance(runner.optimizer, dict) self.assertIsInstance(runner.param_schedulers[0], dict) @@ -352,20 +352,20 @@ class TestRunner(TestCase): # test_dataloader should also be dict runner.train() - self.assertIsInstance(runner.train_loop, BaseLoop) - self.assertIsInstance(runner.train_loop.dataloader, DataLoader) + self.assertIsInstance(runner._train_loop, BaseLoop) + self.assertIsInstance(runner.train_dataloader, DataLoader) self.assertIsInstance(runner.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) - self.assertIsInstance(runner.val_loop, BaseLoop) - self.assertIsInstance(runner.val_loop.dataloader, DataLoader) - self.assertIsInstance(runner.val_loop.evaluator, Evaluator) + self.assertIsInstance(runner._val_loop, BaseLoop) + self.assertIsInstance(runner._val_loop.dataloader, DataLoader) + self.assertIsInstance(runner._val_loop.evaluator, Evaluator) # After calling runner.test(), test_dataloader should be initialized - self.assertIsInstance(runner.test_loop, dict) + self.assertIsInstance(runner._test_loop, dict) runner.test() - self.assertIsInstance(runner.test_loop, BaseLoop) - self.assertIsInstance(runner.test_loop.dataloader, DataLoader) - self.assertIsInstance(runner.test_loop.evaluator, Evaluator) + self.assertIsInstance(runner._test_loop, BaseLoop) + self.assertIsInstance(runner._test_loop.dataloader, DataLoader) + self.assertIsInstance(runner._test_loop.evaluator, Evaluator) # 4. initialize runner with objects rather than config model = ToyModel() @@ -637,7 +637,7 @@ class TestRunner(TestCase): begin=1, end=7, convert_to_iter_based=True) - runner.train_loop = runner.build_train_loop(runner.train_loop) + runner._train_loop = runner.build_train_loop(runner.train_loop) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) self.assertEqual(param_schedulers[0].begin, 4)