mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Support validation only after some epoch/iteration in ValLoop (#257)
* add the epoch/iter that begins validating * fix lint * add property and fix unit test * minor changes * fix typos and add unit test * add unit test about begin * fix docstring
This commit is contained in:
parent
08a3adb5d7
commit
40daf46a45
@ -66,6 +66,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
self.run_epoch()
|
||||
|
||||
if (self.runner.val_loop is not None
|
||||
and self._epoch >= self.runner.val_loop.begin
|
||||
and self._epoch % self.runner.val_loop.interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
@ -162,6 +163,7 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
self.run_iter(data_batch)
|
||||
|
||||
if (self.runner.val_loop is not None
|
||||
and self._iter >= self.runner.val_begin
|
||||
and self._iter % self.runner.val_interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
@ -197,13 +199,15 @@ class ValLoop(BaseLoop):
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
interval (int): Validation interval. Defaults to 1.
|
||||
begin (int): The epoch/iteration that begins validating. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
interval: int = 1) -> None:
|
||||
interval: int = 1,
|
||||
begin: int = 1) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
@ -220,6 +224,7 @@ class ValLoop(BaseLoop):
|
||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||
'visualizer will be None.')
|
||||
self.interval = interval
|
||||
self.begin = begin
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
|
@ -561,6 +561,12 @@ class Runner:
|
||||
"""int: Interval to run validation during training."""
|
||||
return self.val_loop.interval
|
||||
|
||||
@property
|
||||
def val_begin(self):
|
||||
"""int: The epoch/iteration to start running validation during
|
||||
training."""
|
||||
return self.val_loop.begin
|
||||
|
||||
def setup_env(self, env_cfg: Dict) -> None:
|
||||
"""Setup environment.
|
||||
|
||||
|
@ -214,7 +214,7 @@ class TestRunner(TestCase):
|
||||
val_evaluator=dict(type='ToyMetric1'),
|
||||
test_evaluator=dict(type='ToyMetric1'),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
val_cfg=dict(interval=1),
|
||||
val_cfg=dict(interval=1, begin=1),
|
||||
test_cfg=dict(),
|
||||
custom_hooks=[],
|
||||
default_hooks=dict(
|
||||
@ -403,7 +403,7 @@ class TestRunner(TestCase):
|
||||
train_dataloader=train_dataloader,
|
||||
optimizer=optimizer,
|
||||
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
|
||||
val_cfg=dict(interval=1),
|
||||
val_cfg=dict(interval=1, begin=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_evaluator=ToyMetric1(),
|
||||
test_cfg=dict(),
|
||||
@ -762,12 +762,12 @@ class TestRunner(TestCase):
|
||||
runner.build_test_loop('invalid-type')
|
||||
|
||||
# input is a dict and contains type key
|
||||
cfg = dict(type='ValLoop', interval=1)
|
||||
cfg = dict(type='ValLoop', interval=1, begin=1)
|
||||
loop = runner.build_test_loop(cfg)
|
||||
self.assertIsInstance(loop, ValLoop)
|
||||
|
||||
# input is a dict but does not contain type key
|
||||
cfg = dict(interval=1)
|
||||
cfg = dict(interval=1, begin=1)
|
||||
loop = runner.build_val_loop(cfg)
|
||||
self.assertIsInstance(loop, ValLoop)
|
||||
|
||||
@ -846,13 +846,16 @@ class TestRunner(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
||||
runner.train()
|
||||
|
||||
# 2. test iter and epoch counter of EpochBasedTrainLoop
|
||||
# 2. test iter and epoch counter of EpochBasedTrainLoop and timing of
|
||||
# running ValLoop
|
||||
epoch_results = []
|
||||
epoch_targets = [i for i in range(3)]
|
||||
iter_results = []
|
||||
iter_targets = [i for i in range(4 * 3)]
|
||||
batch_idx_results = []
|
||||
batch_idx_targets = [i for i in range(4)] * 3 # train and val
|
||||
val_epoch_results = []
|
||||
val_epoch_targets = [i for i in range(2, 4)]
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TestEpochHook(Hook):
|
||||
@ -864,9 +867,13 @@ class TestRunner(TestCase):
|
||||
iter_results.append(runner.iter)
|
||||
batch_idx_results.append(batch_idx)
|
||||
|
||||
def before_val_epoch(self, runner):
|
||||
val_epoch_results.append(runner.epoch)
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_train2'
|
||||
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
||||
cfg.val_cfg = dict(begin=2)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
runner.train()
|
||||
@ -879,13 +886,20 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_epoch_results, val_epoch_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 3. test iter and epoch counter of IterBasedTrainLoop
|
||||
# 3. test iter and epoch counter of IterBasedTrainLoop and timing of
|
||||
# running ValLoop
|
||||
epoch_results = []
|
||||
iter_results = []
|
||||
batch_idx_results = []
|
||||
val_iter_results = []
|
||||
val_batch_idx_results = []
|
||||
iter_targets = [i for i in range(12)]
|
||||
batch_idx_targets = [i for i in range(12)]
|
||||
val_iter_targets = [i for i in range(4, 12)]
|
||||
val_batch_idx_targets = [i for i in range(4)] * 2
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TestIterHook(Hook):
|
||||
@ -897,10 +911,14 @@ class TestRunner(TestCase):
|
||||
iter_results.append(runner.iter)
|
||||
batch_idx_results.append(batch_idx)
|
||||
|
||||
def before_val_iter(self, runner, batch_idx, data_batch=None):
|
||||
val_epoch_results.append(runner.iter)
|
||||
val_batch_idx_results.append(batch_idx)
|
||||
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train3'
|
||||
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||
cfg.val_cfg = dict(interval=4)
|
||||
cfg.val_cfg = dict(interval=4, begin=4)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
@ -912,6 +930,11 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_iter_results, val_iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_batch_idx_results,
|
||||
val_batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
Loading…
x
Reference in New Issue
Block a user