mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Support validation only after some epoch/iteration in ValLoop (#257)
* add the epoch/iter that begins validating * fix lint * add property and fix unit test * minor changes * fix typos and add unit test * add unit test about begin * fix docstring
This commit is contained in:
parent
08a3adb5d7
commit
40daf46a45
@ -66,6 +66,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
self.run_epoch()
|
self.run_epoch()
|
||||||
|
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
|
and self._epoch >= self.runner.val_loop.begin
|
||||||
and self._epoch % self.runner.val_loop.interval == 0):
|
and self._epoch % self.runner.val_loop.interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
@ -162,6 +163,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
self.run_iter(data_batch)
|
self.run_iter(data_batch)
|
||||||
|
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
|
and self._iter >= self.runner.val_begin
|
||||||
and self._iter % self.runner.val_interval == 0):
|
and self._iter % self.runner.val_interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
@ -197,13 +199,15 @@ class ValLoop(BaseLoop):
|
|||||||
build a dataloader.
|
build a dataloader.
|
||||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||||
interval (int): Validation interval. Defaults to 1.
|
interval (int): Validation interval. Defaults to 1.
|
||||||
|
begin (int): The epoch/iteration that begins validating. Defaults to 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
runner,
|
runner,
|
||||||
dataloader: Union[DataLoader, Dict],
|
dataloader: Union[DataLoader, Dict],
|
||||||
evaluator: Union[Evaluator, Dict, List],
|
evaluator: Union[Evaluator, Dict, List],
|
||||||
interval: int = 1) -> None:
|
interval: int = 1,
|
||||||
|
begin: int = 1) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
|
|
||||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||||
@ -220,6 +224,7 @@ class ValLoop(BaseLoop):
|
|||||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||||
'visualizer will be None.')
|
'visualizer will be None.')
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
|
self.begin = begin
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Launch validation."""
|
"""Launch validation."""
|
||||||
|
@ -561,6 +561,12 @@ class Runner:
|
|||||||
"""int: Interval to run validation during training."""
|
"""int: Interval to run validation during training."""
|
||||||
return self.val_loop.interval
|
return self.val_loop.interval
|
||||||
|
|
||||||
|
@property
|
||||||
|
def val_begin(self):
|
||||||
|
"""int: The epoch/iteration to start running validation during
|
||||||
|
training."""
|
||||||
|
return self.val_loop.begin
|
||||||
|
|
||||||
def setup_env(self, env_cfg: Dict) -> None:
|
def setup_env(self, env_cfg: Dict) -> None:
|
||||||
"""Setup environment.
|
"""Setup environment.
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ class TestRunner(TestCase):
|
|||||||
val_evaluator=dict(type='ToyMetric1'),
|
val_evaluator=dict(type='ToyMetric1'),
|
||||||
test_evaluator=dict(type='ToyMetric1'),
|
test_evaluator=dict(type='ToyMetric1'),
|
||||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||||
val_cfg=dict(interval=1),
|
val_cfg=dict(interval=1, begin=1),
|
||||||
test_cfg=dict(),
|
test_cfg=dict(),
|
||||||
custom_hooks=[],
|
custom_hooks=[],
|
||||||
default_hooks=dict(
|
default_hooks=dict(
|
||||||
@ -403,7 +403,7 @@ class TestRunner(TestCase):
|
|||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
|
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
|
||||||
val_cfg=dict(interval=1),
|
val_cfg=dict(interval=1, begin=1),
|
||||||
val_dataloader=val_dataloader,
|
val_dataloader=val_dataloader,
|
||||||
val_evaluator=ToyMetric1(),
|
val_evaluator=ToyMetric1(),
|
||||||
test_cfg=dict(),
|
test_cfg=dict(),
|
||||||
@ -762,12 +762,12 @@ class TestRunner(TestCase):
|
|||||||
runner.build_test_loop('invalid-type')
|
runner.build_test_loop('invalid-type')
|
||||||
|
|
||||||
# input is a dict and contains type key
|
# input is a dict and contains type key
|
||||||
cfg = dict(type='ValLoop', interval=1)
|
cfg = dict(type='ValLoop', interval=1, begin=1)
|
||||||
loop = runner.build_test_loop(cfg)
|
loop = runner.build_test_loop(cfg)
|
||||||
self.assertIsInstance(loop, ValLoop)
|
self.assertIsInstance(loop, ValLoop)
|
||||||
|
|
||||||
# input is a dict but does not contain type key
|
# input is a dict but does not contain type key
|
||||||
cfg = dict(interval=1)
|
cfg = dict(interval=1, begin=1)
|
||||||
loop = runner.build_val_loop(cfg)
|
loop = runner.build_val_loop(cfg)
|
||||||
self.assertIsInstance(loop, ValLoop)
|
self.assertIsInstance(loop, ValLoop)
|
||||||
|
|
||||||
@ -846,13 +846,16 @@ class TestRunner(TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
with self.assertRaisesRegex(RuntimeError, 'should not be None'):
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
# 2. test iter and epoch counter of EpochBasedTrainLoop
|
# 2. test iter and epoch counter of EpochBasedTrainLoop and timing of
|
||||||
|
# running ValLoop
|
||||||
epoch_results = []
|
epoch_results = []
|
||||||
epoch_targets = [i for i in range(3)]
|
epoch_targets = [i for i in range(3)]
|
||||||
iter_results = []
|
iter_results = []
|
||||||
iter_targets = [i for i in range(4 * 3)]
|
iter_targets = [i for i in range(4 * 3)]
|
||||||
batch_idx_results = []
|
batch_idx_results = []
|
||||||
batch_idx_targets = [i for i in range(4)] * 3 # train and val
|
batch_idx_targets = [i for i in range(4)] * 3 # train and val
|
||||||
|
val_epoch_results = []
|
||||||
|
val_epoch_targets = [i for i in range(2, 4)]
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class TestEpochHook(Hook):
|
class TestEpochHook(Hook):
|
||||||
@ -864,9 +867,13 @@ class TestRunner(TestCase):
|
|||||||
iter_results.append(runner.iter)
|
iter_results.append(runner.iter)
|
||||||
batch_idx_results.append(batch_idx)
|
batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
|
def before_val_epoch(self, runner):
|
||||||
|
val_epoch_results.append(runner.epoch)
|
||||||
|
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_train2'
|
cfg.experiment_name = 'test_train2'
|
||||||
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
||||||
|
cfg.val_cfg = dict(begin=2)
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
runner.train()
|
runner.train()
|
||||||
@ -879,13 +886,20 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_epoch_results, val_epoch_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
# 3. test iter and epoch counter of IterBasedTrainLoop
|
# 3. test iter and epoch counter of IterBasedTrainLoop and timing of
|
||||||
|
# running ValLoop
|
||||||
epoch_results = []
|
epoch_results = []
|
||||||
iter_results = []
|
iter_results = []
|
||||||
batch_idx_results = []
|
batch_idx_results = []
|
||||||
|
val_iter_results = []
|
||||||
|
val_batch_idx_results = []
|
||||||
iter_targets = [i for i in range(12)]
|
iter_targets = [i for i in range(12)]
|
||||||
batch_idx_targets = [i for i in range(12)]
|
batch_idx_targets = [i for i in range(12)]
|
||||||
|
val_iter_targets = [i for i in range(4, 12)]
|
||||||
|
val_batch_idx_targets = [i for i in range(4)] * 2
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class TestIterHook(Hook):
|
class TestIterHook(Hook):
|
||||||
@ -897,10 +911,14 @@ class TestRunner(TestCase):
|
|||||||
iter_results.append(runner.iter)
|
iter_results.append(runner.iter)
|
||||||
batch_idx_results.append(batch_idx)
|
batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
|
def before_val_iter(self, runner, batch_idx, data_batch=None):
|
||||||
|
val_epoch_results.append(runner.iter)
|
||||||
|
val_batch_idx_results.append(batch_idx)
|
||||||
|
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
cfg.experiment_name = 'test_train3'
|
cfg.experiment_name = 'test_train3'
|
||||||
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||||
cfg.val_cfg = dict(interval=4)
|
cfg.val_cfg = dict(interval=4, begin=4)
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
@ -912,6 +930,11 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_iter_results, val_iter_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_batch_idx_results,
|
||||||
|
val_batch_idx_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
def test_val(self):
|
def test_val(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user