mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor]: Modify val_interval and val_begin to be the attributes of TrainLoop. (#274)
* Modify val_interval and val_begin to be the attributes of TrainLoop. * update doc * fix lint * type hint
This commit is contained in:
parent
13606040ac
commit
70c4ea191f
@ -152,8 +152,13 @@ val_evaluator = dict(type='Accuracy')
|
||||
test_evaluator = dict(type='Accuracy')
|
||||
|
||||
# 训练、验证、测试流程配置
|
||||
train_cfg = dict(by_epoch=True, max_epochs=100)
|
||||
val_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
|
||||
train_cfg = dict(
|
||||
by_epoch=True,
|
||||
max_epochs=100,
|
||||
val_begin=20, # 从第 20 个 epoch 开始验证
|
||||
val_interval=1 # 每隔一个 epoch 进行一次验证
|
||||
)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# 自定义钩子
|
||||
|
@ -20,15 +20,24 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
max_epochs (int): Total training epochs.
|
||||
val_begin (int): The epoch that begins validating.
|
||||
Defaults to 1.
|
||||
val_interval (int): Validation interval. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
max_epochs: int) -> None:
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
max_epochs: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
self._max_epochs = max_epochs
|
||||
self._max_iters = max_epochs * len(self.dataloader)
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
self.val_begin = val_begin
|
||||
self.val_interval = val_interval
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
@ -66,8 +75,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
self.run_epoch()
|
||||
|
||||
if (self.runner.val_loop is not None
|
||||
and self._epoch >= self.runner.val_loop.begin
|
||||
and self._epoch % self.runner.val_loop.interval == 0):
|
||||
and self._epoch >= self.val_begin
|
||||
and self._epoch % self.val_interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train')
|
||||
@ -111,15 +120,24 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
max_iters (int): Total training iterations.
|
||||
val_begin (int): The iteration that begins validating.
|
||||
Defaults to 1.
|
||||
val_interval (int): Validation interval. Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
max_iters: int) -> None:
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
max_iters: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1000) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
self._max_iters = max_iters
|
||||
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
self.val_begin = val_begin
|
||||
self.val_interval = val_interval
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.runner.visualizer.dataset_meta = \
|
||||
self.dataloader.dataset.metainfo
|
||||
@ -163,8 +181,8 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
self.run_iter(data_batch)
|
||||
|
||||
if (self.runner.val_loop is not None
|
||||
and self._iter >= self.runner.val_begin
|
||||
and self._iter % self.runner.val_interval == 0):
|
||||
and self._iter >= self.val_begin
|
||||
and self._iter % self.val_interval == 0):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
@ -198,16 +216,10 @@ class ValLoop(BaseLoop):
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
interval (int): Validation interval. Defaults to 1.
|
||||
begin (int): The epoch/iteration that begins validating. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
interval: int = 1,
|
||||
begin: int = 1) -> None:
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List]) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
@ -223,8 +235,6 @@ class ValLoop(BaseLoop):
|
||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||
'visualizer will be None.')
|
||||
self.interval = interval
|
||||
self.begin = begin
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
|
@ -189,8 +189,8 @@ class Runner:
|
||||
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
>>> val_evaluator=dict(type='ToyEvaluator'),
|
||||
>>> test_evaluator=dict(type='ToyEvaluator'),
|
||||
>>> train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
>>> val_cfg=dict(interval=1),
|
||||
>>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),
|
||||
>>> val_cfg=dict(),
|
||||
>>> test_cfg=dict(),
|
||||
>>> custom_hooks=[],
|
||||
>>> default_hooks=dict(
|
||||
@ -573,13 +573,13 @@ class Runner:
|
||||
@property
|
||||
def val_interval(self):
|
||||
"""int: Interval to run validation during training."""
|
||||
return self.val_loop.interval
|
||||
return self.train_loop.val_interval
|
||||
|
||||
@property
|
||||
def val_begin(self):
|
||||
"""int: The epoch/iteration to start running validation during
|
||||
training."""
|
||||
return self.val_loop.begin
|
||||
return self.train_loop.val_begin
|
||||
|
||||
def setup_env(self, env_cfg: Dict) -> None:
|
||||
"""Setup environment.
|
||||
@ -1285,10 +1285,10 @@ class Runner:
|
||||
Examples of ``loop``:
|
||||
|
||||
# `ValLoop` will be used
|
||||
loop = dict(interval=1)
|
||||
loop = dict()
|
||||
|
||||
# custom validation loop
|
||||
loop = dict(type='CustomValLoop', interval=1)
|
||||
loop = dict(type='CustomValLoop')
|
||||
|
||||
Args:
|
||||
loop (BaseLoop or dict): A validation loop or a dict to build
|
||||
@ -1317,9 +1317,7 @@ class Runner:
|
||||
loop = ValLoop(
|
||||
runner=self,
|
||||
dataloader=self._val_dataloader,
|
||||
evaluator=self._val_evaluator, # type: ignore
|
||||
**loop_cfg,
|
||||
) # type: ignore
|
||||
evaluator=self._val_evaluator) # type: ignore
|
||||
|
||||
return loop # type: ignore
|
||||
|
||||
|
@ -82,8 +82,8 @@ class TestEMAHook(TestCase):
|
||||
work_dir=self.temp_dir.name,
|
||||
optim_wrapper=OptimWrapper(
|
||||
torch.optim.Adam(ToyModel().parameters())),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2),
|
||||
val_cfg=dict(interval=1),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||
val_cfg=dict(),
|
||||
default_hooks=dict(logger=None),
|
||||
custom_hooks=[dict(type='EMAHook', )],
|
||||
experiment_name='test1')
|
||||
|
@ -176,7 +176,7 @@ class CustomTrainLoop(BaseLoop):
|
||||
@LOOPS.register_module()
|
||||
class CustomValLoop(BaseLoop):
|
||||
|
||||
def __init__(self, runner, dataloader, evaluator, interval=1):
|
||||
def __init__(self, runner, dataloader, evaluator):
|
||||
super().__init__(runner, dataloader)
|
||||
self._runner = runner
|
||||
|
||||
@ -246,8 +246,9 @@ class TestRunner(TestCase):
|
||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
val_evaluator=dict(type='ToyMetric1'),
|
||||
test_evaluator=dict(type='ToyMetric1'),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
val_cfg=dict(interval=1, begin=1),
|
||||
train_cfg=dict(
|
||||
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
|
||||
val_cfg=dict(),
|
||||
test_cfg=dict(),
|
||||
custom_hooks=[],
|
||||
default_hooks=dict(
|
||||
@ -432,11 +433,12 @@ class TestRunner(TestCase):
|
||||
runner = Runner(
|
||||
model=model,
|
||||
work_dir=self.temp_dir,
|
||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
train_cfg=dict(
|
||||
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
|
||||
train_dataloader=train_dataloader,
|
||||
optim_wrapper=optim_wrapper,
|
||||
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
||||
val_cfg=dict(interval=1, begin=1),
|
||||
val_cfg=dict(),
|
||||
val_dataloader=val_dataloader,
|
||||
val_evaluator=ToyMetric1(),
|
||||
test_cfg=dict(),
|
||||
@ -888,12 +890,12 @@ class TestRunner(TestCase):
|
||||
runner.build_test_loop('invalid-type')
|
||||
|
||||
# input is a dict and contains type key
|
||||
cfg = dict(type='ValLoop', interval=1, begin=1)
|
||||
cfg = dict(type='ValLoop')
|
||||
loop = runner.build_test_loop(cfg)
|
||||
self.assertIsInstance(loop, ValLoop)
|
||||
|
||||
# input is a dict but does not contain type key
|
||||
cfg = dict(interval=1, begin=1)
|
||||
cfg = dict()
|
||||
loop = runner.build_val_loop(cfg)
|
||||
self.assertIsInstance(loop, ValLoop)
|
||||
|
||||
@ -901,7 +903,7 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(id(runner.build_val_loop(loop)), id(loop))
|
||||
|
||||
# test custom validation loop
|
||||
cfg = dict(type='CustomValLoop', interval=1)
|
||||
cfg = dict(type='CustomValLoop')
|
||||
loop = runner.build_val_loop(cfg)
|
||||
self.assertIsInstance(loop, CustomValLoop)
|
||||
|
||||
@ -999,7 +1001,7 @@ class TestRunner(TestCase):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_train2'
|
||||
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
||||
cfg.val_cfg = dict(begin=2)
|
||||
cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
runner.train()
|
||||
@ -1044,7 +1046,8 @@ class TestRunner(TestCase):
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train3'
|
||||
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||
cfg.val_cfg = dict(interval=4, begin=4)
|
||||
cfg.train_cfg = dict(
|
||||
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
@ -1052,6 +1055,8 @@ class TestRunner(TestCase):
|
||||
|
||||
self.assertEqual(len(epoch_results), 1)
|
||||
self.assertEqual(epoch_results[0], 0)
|
||||
self.assertEqual(runner.val_interval, 4)
|
||||
self.assertEqual(runner.val_begin, 4)
|
||||
for result, target, in zip(iter_results, iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||
|
Loading…
x
Reference in New Issue
Block a user