Support validation only after some epoch/iteration in ValLoop (#257)

* add the epoch/iter that begins validating

* fix lint

* add property and fix unit test

* minor changes

* fix typos and add unit test

* add unit test about begin

* fix docstring
This commit is contained in:
Jingwei Zhang 2022-05-27 15:10:12 +08:00 committed by GitHub
parent 08a3adb5d7
commit 40daf46a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 8 deletions

View File

@ -66,6 +66,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.run_epoch()
if (self.runner.val_loop is not None
and self._epoch >= self.runner.val_loop.begin
and self._epoch % self.runner.val_loop.interval == 0):
self.runner.val_loop.run()
@ -162,6 +163,7 @@ class IterBasedTrainLoop(BaseLoop):
self.run_iter(data_batch)
if (self.runner.val_loop is not None
and self._iter >= self.runner.val_begin
and self._iter % self.runner.val_interval == 0):
self.runner.val_loop.run()
@ -197,13 +199,15 @@ class ValLoop(BaseLoop):
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
interval (int): Validation interval. Defaults to 1.
begin (int): The epoch/iteration that begins validating. Defaults to 1.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
interval: int = 1) -> None:
interval: int = 1,
begin: int = 1) -> None:
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
@ -220,6 +224,7 @@ class ValLoop(BaseLoop):
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.')
self.interval = interval
self.begin = begin
def run(self):
"""Launch validation."""

View File

@ -561,6 +561,12 @@ class Runner:
"""int: Interval to run validation during training."""
return self.val_loop.interval
@property
def val_begin(self):
"""int: The epoch/iteration to start running validation during
training."""
return self.val_loop.begin
def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.

View File

@ -214,7 +214,7 @@ class TestRunner(TestCase):
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(by_epoch=True, max_epochs=3),
val_cfg=dict(interval=1),
val_cfg=dict(interval=1, begin=1),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
@ -403,7 +403,7 @@ class TestRunner(TestCase):
train_dataloader=train_dataloader,
optimizer=optimizer,
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
val_cfg=dict(interval=1),
val_cfg=dict(interval=1, begin=1),
val_dataloader=val_dataloader,
val_evaluator=ToyMetric1(),
test_cfg=dict(),
@ -762,12 +762,12 @@ class TestRunner(TestCase):
runner.build_test_loop('invalid-type')
# input is a dict and contains type key
cfg = dict(type='ValLoop', interval=1)
cfg = dict(type='ValLoop', interval=1, begin=1)
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, ValLoop)
# input is a dict but does not contain type key
cfg = dict(interval=1)
cfg = dict(interval=1, begin=1)
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, ValLoop)
@ -846,13 +846,16 @@ class TestRunner(TestCase):
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
runner.train()
# 2. test iter and epoch counter of EpochBasedTrainLoop
# 2. test iter and epoch counter of EpochBasedTrainLoop and timing of
# running ValLoop
epoch_results = []
epoch_targets = [i for i in range(3)]
iter_results = []
iter_targets = [i for i in range(4 * 3)]
batch_idx_results = []
batch_idx_targets = [i for i in range(4)] * 3 # train and val
val_epoch_results = []
val_epoch_targets = [i for i in range(2, 4)]
@HOOKS.register_module()
class TestEpochHook(Hook):
@ -864,9 +867,13 @@ class TestRunner(TestCase):
iter_results.append(runner.iter)
batch_idx_results.append(batch_idx)
def before_val_epoch(self, runner):
val_epoch_results.append(runner.epoch)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train2'
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
cfg.val_cfg = dict(begin=2)
runner = Runner.from_cfg(cfg)
runner.train()
@ -879,13 +886,20 @@ class TestRunner(TestCase):
self.assertEqual(result, target)
for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target)
for result, target, in zip(val_epoch_results, val_epoch_targets):
self.assertEqual(result, target)
# 3. test iter and epoch counter of IterBasedTrainLoop
# 3. test iter and epoch counter of IterBasedTrainLoop and timing of
# running ValLoop
epoch_results = []
iter_results = []
batch_idx_results = []
val_iter_results = []
val_batch_idx_results = []
iter_targets = [i for i in range(12)]
batch_idx_targets = [i for i in range(12)]
val_iter_targets = [i for i in range(4, 12)]
val_batch_idx_targets = [i for i in range(4)] * 2
@HOOKS.register_module()
class TestIterHook(Hook):
@ -897,10 +911,14 @@ class TestRunner(TestCase):
iter_results.append(runner.iter)
batch_idx_results.append(batch_idx)
def before_val_iter(self, runner, batch_idx, data_batch=None):
val_epoch_results.append(runner.iter)
val_batch_idx_results.append(batch_idx)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train3'
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
cfg.val_cfg = dict(interval=4)
cfg.val_cfg = dict(interval=4, begin=4)
runner = Runner.from_cfg(cfg)
runner.train()
@ -912,6 +930,11 @@ class TestRunner(TestCase):
self.assertEqual(result, target)
for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target)
for result, target, in zip(val_iter_results, val_iter_targets):
self.assertEqual(result, target)
for result, target, in zip(val_batch_idx_results,
val_batch_idx_targets):
self.assertEqual(result, target)
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)