[Refactor]: Modify val_interval and val_begin to be the attributes of TrainLoop. (#274)

* Modify val_interval and val_begin to be the attributes of TrainLoop.

* update doc

* fix lint

* type hint
This commit is contained in:
RangiLyu 2022-06-06 11:13:25 +08:00 committed by GitHub
parent 13606040ac
commit 70c4ea191f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 41 deletions

View File

@ -152,8 +152,13 @@ val_evaluator = dict(type='Accuracy')
test_evaluator = dict(type='Accuracy')
# 训练、验证、测试流程配置
train_cfg = dict(by_epoch=True, max_epochs=100)
val_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
train_cfg = dict(
by_epoch=True,
max_epochs=100,
val_begin=20, # 从第 20 个 epoch 开始验证
val_interval=1 # 每隔一个 epoch 进行一次验证
)
val_cfg = dict()
test_cfg = dict()
# 自定义钩子

View File

@ -20,15 +20,24 @@ class EpochBasedTrainLoop(BaseLoop):
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_epochs (int): Total training epochs.
val_begin (int): The epoch that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
max_epochs: int) -> None:
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1) -> None:
super().__init__(runner, dataloader)
self._max_epochs = max_epochs
self._max_iters = max_epochs * len(self.dataloader)
self._epoch = 0
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
@ -66,8 +75,8 @@ class EpochBasedTrainLoop(BaseLoop):
self.run_epoch()
if (self.runner.val_loop is not None
and self._epoch >= self.runner.val_loop.begin
and self._epoch % self.runner.val_loop.interval == 0):
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
@ -111,15 +120,24 @@ class IterBasedTrainLoop(BaseLoop):
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
max_iters (int): Total training iterations.
val_begin (int): The iteration that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
max_iters: int) -> None:
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000) -> None:
super().__init__(runner, dataloader)
self._max_iters = max_iters
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
self._epoch = 0
self._iter = 0
self.val_begin = val_begin
self.val_interval = val_interval
if hasattr(self.dataloader.dataset, 'metainfo'):
self.runner.visualizer.dataset_meta = \
self.dataloader.dataset.metainfo
@ -163,8 +181,8 @@ class IterBasedTrainLoop(BaseLoop):
self.run_iter(data_batch)
if (self.runner.val_loop is not None
and self._iter >= self.runner.val_begin
and self._iter % self.runner.val_interval == 0):
and self._iter >= self.val_begin
and self._iter % self.val_interval == 0):
self.runner.val_loop.run()
self.runner.call_hook('after_train_epoch')
@ -198,16 +216,10 @@ class ValLoop(BaseLoop):
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
interval (int): Validation interval. Defaults to 1.
begin (int): The epoch/iteration that begins validating. Defaults to 1.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
interval: int = 1,
begin: int = 1) -> None:
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List]) -> None:
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
@ -223,8 +235,6 @@ class ValLoop(BaseLoop):
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.')
self.interval = interval
self.begin = begin
def run(self):
"""Launch validation."""

View File

@ -189,8 +189,8 @@ class Runner:
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
>>> val_evaluator=dict(type='ToyEvaluator'),
>>> test_evaluator=dict(type='ToyEvaluator'),
>>> train_cfg=dict(by_epoch=True, max_epochs=3),
>>> val_cfg=dict(interval=1),
>>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),
>>> val_cfg=dict(),
>>> test_cfg=dict(),
>>> custom_hooks=[],
>>> default_hooks=dict(
@ -573,13 +573,13 @@ class Runner:
@property
def val_interval(self):
"""int: Interval to run validation during training."""
return self.val_loop.interval
return self.train_loop.val_interval
@property
def val_begin(self):
"""int: The epoch/iteration to start running validation during
training."""
return self.val_loop.begin
return self.train_loop.val_begin
def setup_env(self, env_cfg: Dict) -> None:
"""Setup environment.
@ -1285,10 +1285,10 @@ class Runner:
Examples of ``loop``:
# `ValLoop` will be used
loop = dict(interval=1)
loop = dict()
# custom validation loop
loop = dict(type='CustomValLoop', interval=1)
loop = dict(type='CustomValLoop')
Args:
loop (BaseLoop or dict): A validation loop or a dict to build
@ -1317,9 +1317,7 @@ class Runner:
loop = ValLoop(
runner=self,
dataloader=self._val_dataloader,
evaluator=self._val_evaluator, # type: ignore
**loop_cfg,
) # type: ignore
evaluator=self._val_evaluator) # type: ignore
return loop # type: ignore

View File

@ -82,8 +82,8 @@ class TestEMAHook(TestCase):
work_dir=self.temp_dir.name,
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(by_epoch=True, max_epochs=2),
val_cfg=dict(interval=1),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_cfg=dict(),
default_hooks=dict(logger=None),
custom_hooks=[dict(type='EMAHook', )],
experiment_name='test1')

View File

@ -176,7 +176,7 @@ class CustomTrainLoop(BaseLoop):
@LOOPS.register_module()
class CustomValLoop(BaseLoop):
def __init__(self, runner, dataloader, evaluator, interval=1):
def __init__(self, runner, dataloader, evaluator):
super().__init__(runner, dataloader)
self._runner = runner
@ -246,8 +246,9 @@ class TestRunner(TestCase):
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(by_epoch=True, max_epochs=3),
val_cfg=dict(interval=1, begin=1),
train_cfg=dict(
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
val_cfg=dict(),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
@ -432,11 +433,12 @@ class TestRunner(TestCase):
runner = Runner(
model=model,
work_dir=self.temp_dir,
train_cfg=dict(by_epoch=True, max_epochs=3),
train_cfg=dict(
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
train_dataloader=train_dataloader,
optim_wrapper=optim_wrapper,
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
val_cfg=dict(interval=1, begin=1),
val_cfg=dict(),
val_dataloader=val_dataloader,
val_evaluator=ToyMetric1(),
test_cfg=dict(),
@ -888,12 +890,12 @@ class TestRunner(TestCase):
runner.build_test_loop('invalid-type')
# input is a dict and contains type key
cfg = dict(type='ValLoop', interval=1, begin=1)
cfg = dict(type='ValLoop')
loop = runner.build_test_loop(cfg)
self.assertIsInstance(loop, ValLoop)
# input is a dict but does not contain type key
cfg = dict(interval=1, begin=1)
cfg = dict()
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, ValLoop)
@ -901,7 +903,7 @@ class TestRunner(TestCase):
self.assertEqual(id(runner.build_val_loop(loop)), id(loop))
# test custom validation loop
cfg = dict(type='CustomValLoop', interval=1)
cfg = dict(type='CustomValLoop')
loop = runner.build_val_loop(cfg)
self.assertIsInstance(loop, CustomValLoop)
@ -999,7 +1001,7 @@ class TestRunner(TestCase):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train2'
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
cfg.val_cfg = dict(begin=2)
cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
runner = Runner.from_cfg(cfg)
runner.train()
@ -1044,7 +1046,8 @@ class TestRunner(TestCase):
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train3'
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
cfg.val_cfg = dict(interval=4, begin=4)
cfg.train_cfg = dict(
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
runner = Runner.from_cfg(cfg)
runner.train()
@ -1052,6 +1055,8 @@ class TestRunner(TestCase):
self.assertEqual(len(epoch_results), 1)
self.assertEqual(epoch_results[0], 0)
self.assertEqual(runner.val_interval, 4)
self.assertEqual(runner.val_begin, 4)
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(batch_idx_results, batch_idx_targets):