diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 383c985a..1bf00b5a 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -66,6 +66,7 @@ class EpochBasedTrainLoop(BaseLoop): self.run_epoch() if (self.runner.val_loop is not None + and self._epoch >= self.runner.val_loop.begin and self._epoch % self.runner.val_loop.interval == 0): self.runner.val_loop.run() @@ -162,6 +163,7 @@ class IterBasedTrainLoop(BaseLoop): self.run_iter(data_batch) if (self.runner.val_loop is not None + and self._iter >= self.runner.val_begin and self._iter % self.runner.val_interval == 0): self.runner.val_loop.run() @@ -197,13 +199,15 @@ class ValLoop(BaseLoop): build a dataloader. evaluator (Evaluator or dict or list): Used for computing metrics. interval (int): Validation interval. Defaults to 1. + begin (int): The epoch/iteration that begins validating. Defaults to 1. """ def __init__(self, runner, dataloader: Union[DataLoader, Dict], evaluator: Union[Evaluator, Dict, List], - interval: int = 1) -> None: + interval: int = 1, + begin: int = 1) -> None: super().__init__(runner, dataloader) if isinstance(evaluator, dict) or is_list_of(evaluator, dict): @@ -220,6 +224,7 @@ class ValLoop(BaseLoop): 'metainfo. ``dataset_meta`` in evaluator, metric and ' 'visualizer will be None.') self.interval = interval + self.begin = begin def run(self): """Launch validation.""" diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 58eaeea0..bfa50603 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -561,6 +561,12 @@ class Runner: """int: Interval to run validation during training.""" return self.val_loop.interval + @property + def val_begin(self): + """int: The epoch/iteration to start running validation during + training.""" + return self.val_loop.begin + def setup_env(self, env_cfg: Dict) -> None: """Setup environment. diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index eb382bcf..06b12b5a 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -214,7 +214,7 @@ class TestRunner(TestCase): val_evaluator=dict(type='ToyMetric1'), test_evaluator=dict(type='ToyMetric1'), train_cfg=dict(by_epoch=True, max_epochs=3), - val_cfg=dict(interval=1), + val_cfg=dict(interval=1, begin=1), test_cfg=dict(), custom_hooks=[], default_hooks=dict( @@ -403,7 +403,7 @@ class TestRunner(TestCase): train_dataloader=train_dataloader, optimizer=optimizer, param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]), - val_cfg=dict(interval=1), + val_cfg=dict(interval=1, begin=1), val_dataloader=val_dataloader, val_evaluator=ToyMetric1(), test_cfg=dict(), @@ -762,12 +762,12 @@ class TestRunner(TestCase): runner.build_test_loop('invalid-type') # input is a dict and contains type key - cfg = dict(type='ValLoop', interval=1) + cfg = dict(type='ValLoop', interval=1, begin=1) loop = runner.build_test_loop(cfg) self.assertIsInstance(loop, ValLoop) # input is a dict but does not contain type key - cfg = dict(interval=1) + cfg = dict(interval=1, begin=1) loop = runner.build_val_loop(cfg) self.assertIsInstance(loop, ValLoop) @@ -846,13 +846,16 @@ class TestRunner(TestCase): with self.assertRaisesRegex(RuntimeError, 'should not be None'): runner.train() - # 2. test iter and epoch counter of EpochBasedTrainLoop + # 2. test iter and epoch counter of EpochBasedTrainLoop and timing of + # running ValLoop epoch_results = [] epoch_targets = [i for i in range(3)] iter_results = [] iter_targets = [i for i in range(4 * 3)] batch_idx_results = [] batch_idx_targets = [i for i in range(4)] * 3 # train and val + val_epoch_results = [] + val_epoch_targets = [i for i in range(2, 4)] @HOOKS.register_module() class TestEpochHook(Hook): @@ -864,9 +867,13 @@ class TestRunner(TestCase): iter_results.append(runner.iter) batch_idx_results.append(batch_idx) + def before_val_epoch(self, runner): + val_epoch_results.append(runner.epoch) + cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train2' cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)] + cfg.val_cfg = dict(begin=2) runner = Runner.from_cfg(cfg) runner.train() @@ -879,13 +886,20 @@ class TestRunner(TestCase): self.assertEqual(result, target) for result, target, in zip(batch_idx_results, batch_idx_targets): self.assertEqual(result, target) + for result, target, in zip(val_epoch_results, val_epoch_targets): + self.assertEqual(result, target) - # 3. test iter and epoch counter of IterBasedTrainLoop + # 3. test iter and epoch counter of IterBasedTrainLoop and timing of + # running ValLoop epoch_results = [] iter_results = [] batch_idx_results = [] + val_iter_results = [] + val_batch_idx_results = [] iter_targets = [i for i in range(12)] batch_idx_targets = [i for i in range(12)] + val_iter_targets = [i for i in range(4, 12)] + val_batch_idx_targets = [i for i in range(4)] * 2 @HOOKS.register_module() class TestIterHook(Hook): @@ -897,10 +911,14 @@ class TestRunner(TestCase): iter_results.append(runner.iter) batch_idx_results.append(batch_idx) + def before_val_iter(self, runner, batch_idx, data_batch=None): + val_epoch_results.append(runner.iter) + val_batch_idx_results.append(batch_idx) + cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train3' cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.val_cfg = dict(interval=4) + cfg.val_cfg = dict(interval=4, begin=4) runner = Runner.from_cfg(cfg) runner.train() @@ -912,6 +930,11 @@ class TestRunner(TestCase): self.assertEqual(result, target) for result, target, in zip(batch_idx_results, batch_idx_targets): self.assertEqual(result, target) + for result, target, in zip(val_iter_results, val_iter_targets): + self.assertEqual(result, target) + for result, target, in zip(val_batch_idx_results, + val_batch_idx_targets): + self.assertEqual(result, target) def test_val(self): cfg = copy.deepcopy(self.epoch_based_cfg)