mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor]: Modify val_interval and val_begin to be the attributes of TrainLoop. (#274)
* Modify val_interval and val_begin to be the attributes of TrainLoop. * update doc * fix lint * type hint
This commit is contained in:
parent
13606040ac
commit
70c4ea191f
@ -152,8 +152,13 @@ val_evaluator = dict(type='Accuracy')
|
|||||||
test_evaluator = dict(type='Accuracy')
|
test_evaluator = dict(type='Accuracy')
|
||||||
|
|
||||||
# 训练、验证、测试流程配置
|
# 训练、验证、测试流程配置
|
||||||
train_cfg = dict(by_epoch=True, max_epochs=100)
|
train_cfg = dict(
|
||||||
val_cfg = dict(interval=1) # 每隔一个 epoch 进行一次验证
|
by_epoch=True,
|
||||||
|
max_epochs=100,
|
||||||
|
val_begin=20, # 从第 20 个 epoch 开始验证
|
||||||
|
val_interval=1 # 每隔一个 epoch 进行一次验证
|
||||||
|
)
|
||||||
|
val_cfg = dict()
|
||||||
test_cfg = dict()
|
test_cfg = dict()
|
||||||
|
|
||||||
# 自定义钩子
|
# 自定义钩子
|
||||||
|
@ -20,15 +20,24 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
max_epochs (int): Total training epochs.
|
max_epochs (int): Total training epochs.
|
||||||
|
val_begin (int): The epoch that begins validating.
|
||||||
|
Defaults to 1.
|
||||||
|
val_interval (int): Validation interval. Defaults to 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
def __init__(self,
|
||||||
max_epochs: int) -> None:
|
runner,
|
||||||
|
dataloader: Union[DataLoader, Dict],
|
||||||
|
max_epochs: int,
|
||||||
|
val_begin: int = 1,
|
||||||
|
val_interval: int = 1) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._max_epochs = max_epochs
|
self._max_epochs = max_epochs
|
||||||
self._max_iters = max_epochs * len(self.dataloader)
|
self._max_iters = max_epochs * len(self.dataloader)
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
self._iter = 0
|
self._iter = 0
|
||||||
|
self.val_begin = val_begin
|
||||||
|
self.val_interval = val_interval
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.runner.visualizer.dataset_meta = \
|
self.runner.visualizer.dataset_meta = \
|
||||||
self.dataloader.dataset.metainfo
|
self.dataloader.dataset.metainfo
|
||||||
@ -66,8 +75,8 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
self.run_epoch()
|
self.run_epoch()
|
||||||
|
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
and self._epoch >= self.runner.val_loop.begin
|
and self._epoch >= self.val_begin
|
||||||
and self._epoch % self.runner.val_loop.interval == 0):
|
and self._epoch % self.val_interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
self.runner.call_hook('after_train')
|
self.runner.call_hook('after_train')
|
||||||
@ -111,15 +120,24 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
max_iters (int): Total training iterations.
|
max_iters (int): Total training iterations.
|
||||||
|
val_begin (int): The iteration that begins validating.
|
||||||
|
Defaults to 1.
|
||||||
|
val_interval (int): Validation interval. Defaults to 1000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
def __init__(self,
|
||||||
max_iters: int) -> None:
|
runner,
|
||||||
|
dataloader: Union[DataLoader, Dict],
|
||||||
|
max_iters: int,
|
||||||
|
val_begin: int = 1,
|
||||||
|
val_interval: int = 1000) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._max_iters = max_iters
|
self._max_iters = max_iters
|
||||||
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
self._iter = 0
|
self._iter = 0
|
||||||
|
self.val_begin = val_begin
|
||||||
|
self.val_interval = val_interval
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.runner.visualizer.dataset_meta = \
|
self.runner.visualizer.dataset_meta = \
|
||||||
self.dataloader.dataset.metainfo
|
self.dataloader.dataset.metainfo
|
||||||
@ -163,8 +181,8 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
self.run_iter(data_batch)
|
self.run_iter(data_batch)
|
||||||
|
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
and self._iter >= self.runner.val_begin
|
and self._iter >= self.val_begin
|
||||||
and self._iter % self.runner.val_interval == 0):
|
and self._iter % self.val_interval == 0):
|
||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
@ -198,16 +216,10 @@ class ValLoop(BaseLoop):
|
|||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||||
interval (int): Validation interval. Defaults to 1.
|
|
||||||
begin (int): The epoch/iteration that begins validating. Defaults to 1.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||||
runner,
|
evaluator: Union[Evaluator, Dict, List]) -> None:
|
||||||
dataloader: Union[DataLoader, Dict],
|
|
||||||
evaluator: Union[Evaluator, Dict, List],
|
|
||||||
interval: int = 1,
|
|
||||||
begin: int = 1) -> None:
|
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
|
|
||||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||||
@ -223,8 +235,6 @@ class ValLoop(BaseLoop):
|
|||||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||||
'visualizer will be None.')
|
'visualizer will be None.')
|
||||||
self.interval = interval
|
|
||||||
self.begin = begin
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Launch validation."""
|
"""Launch validation."""
|
||||||
|
@ -189,8 +189,8 @@ class Runner:
|
|||||||
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||||
>>> val_evaluator=dict(type='ToyEvaluator'),
|
>>> val_evaluator=dict(type='ToyEvaluator'),
|
||||||
>>> test_evaluator=dict(type='ToyEvaluator'),
|
>>> test_evaluator=dict(type='ToyEvaluator'),
|
||||||
>>> train_cfg=dict(by_epoch=True, max_epochs=3),
|
>>> train_cfg=dict(by_epoch=True, max_epochs=3, val_interval=1),
|
||||||
>>> val_cfg=dict(interval=1),
|
>>> val_cfg=dict(),
|
||||||
>>> test_cfg=dict(),
|
>>> test_cfg=dict(),
|
||||||
>>> custom_hooks=[],
|
>>> custom_hooks=[],
|
||||||
>>> default_hooks=dict(
|
>>> default_hooks=dict(
|
||||||
@ -573,13 +573,13 @@ class Runner:
|
|||||||
@property
|
@property
|
||||||
def val_interval(self):
|
def val_interval(self):
|
||||||
"""int: Interval to run validation during training."""
|
"""int: Interval to run validation during training."""
|
||||||
return self.val_loop.interval
|
return self.train_loop.val_interval
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def val_begin(self):
|
def val_begin(self):
|
||||||
"""int: The epoch/iteration to start running validation during
|
"""int: The epoch/iteration to start running validation during
|
||||||
training."""
|
training."""
|
||||||
return self.val_loop.begin
|
return self.train_loop.val_begin
|
||||||
|
|
||||||
def setup_env(self, env_cfg: Dict) -> None:
|
def setup_env(self, env_cfg: Dict) -> None:
|
||||||
"""Setup environment.
|
"""Setup environment.
|
||||||
@ -1285,10 +1285,10 @@ class Runner:
|
|||||||
Examples of ``loop``:
|
Examples of ``loop``:
|
||||||
|
|
||||||
# `ValLoop` will be used
|
# `ValLoop` will be used
|
||||||
loop = dict(interval=1)
|
loop = dict()
|
||||||
|
|
||||||
# custom validation loop
|
# custom validation loop
|
||||||
loop = dict(type='CustomValLoop', interval=1)
|
loop = dict(type='CustomValLoop')
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loop (BaseLoop or dict): A validation loop or a dict to build
|
loop (BaseLoop or dict): A validation loop or a dict to build
|
||||||
@ -1317,9 +1317,7 @@ class Runner:
|
|||||||
loop = ValLoop(
|
loop = ValLoop(
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self._val_dataloader,
|
dataloader=self._val_dataloader,
|
||||||
evaluator=self._val_evaluator, # type: ignore
|
evaluator=self._val_evaluator) # type: ignore
|
||||||
**loop_cfg,
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
return loop # type: ignore
|
return loop # type: ignore
|
||||||
|
|
||||||
|
@ -82,8 +82,8 @@ class TestEMAHook(TestCase):
|
|||||||
work_dir=self.temp_dir.name,
|
work_dir=self.temp_dir.name,
|
||||||
optim_wrapper=OptimWrapper(
|
optim_wrapper=OptimWrapper(
|
||||||
torch.optim.Adam(ToyModel().parameters())),
|
torch.optim.Adam(ToyModel().parameters())),
|
||||||
train_cfg=dict(by_epoch=True, max_epochs=2),
|
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
|
||||||
val_cfg=dict(interval=1),
|
val_cfg=dict(),
|
||||||
default_hooks=dict(logger=None),
|
default_hooks=dict(logger=None),
|
||||||
custom_hooks=[dict(type='EMAHook', )],
|
custom_hooks=[dict(type='EMAHook', )],
|
||||||
experiment_name='test1')
|
experiment_name='test1')
|
||||||
|
@ -176,7 +176,7 @@ class CustomTrainLoop(BaseLoop):
|
|||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
class CustomValLoop(BaseLoop):
|
class CustomValLoop(BaseLoop):
|
||||||
|
|
||||||
def __init__(self, runner, dataloader, evaluator, interval=1):
|
def __init__(self, runner, dataloader, evaluator):
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._runner = runner
|
self._runner = runner
|
||||||
|
|
||||||
@ -246,8 +246,9 @@ class TestRunner(TestCase):
|
|||||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||||
val_evaluator=dict(type='ToyMetric1'),
|
val_evaluator=dict(type='ToyMetric1'),
|
||||||
test_evaluator=dict(type='ToyMetric1'),
|
test_evaluator=dict(type='ToyMetric1'),
|
||||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
train_cfg=dict(
|
||||||
val_cfg=dict(interval=1, begin=1),
|
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
|
||||||
|
val_cfg=dict(),
|
||||||
test_cfg=dict(),
|
test_cfg=dict(),
|
||||||
custom_hooks=[],
|
custom_hooks=[],
|
||||||
default_hooks=dict(
|
default_hooks=dict(
|
||||||
@ -432,11 +433,12 @@ class TestRunner(TestCase):
|
|||||||
runner = Runner(
|
runner = Runner(
|
||||||
model=model,
|
model=model,
|
||||||
work_dir=self.temp_dir,
|
work_dir=self.temp_dir,
|
||||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
train_cfg=dict(
|
||||||
|
by_epoch=True, max_epochs=3, val_interval=1, val_begin=1),
|
||||||
train_dataloader=train_dataloader,
|
train_dataloader=train_dataloader,
|
||||||
optim_wrapper=optim_wrapper,
|
optim_wrapper=optim_wrapper,
|
||||||
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]),
|
||||||
val_cfg=dict(interval=1, begin=1),
|
val_cfg=dict(),
|
||||||
val_dataloader=val_dataloader,
|
val_dataloader=val_dataloader,
|
||||||
val_evaluator=ToyMetric1(),
|
val_evaluator=ToyMetric1(),
|
||||||
test_cfg=dict(),
|
test_cfg=dict(),
|
||||||
@ -888,12 +890,12 @@ class TestRunner(TestCase):
|
|||||||
runner.build_test_loop('invalid-type')
|
runner.build_test_loop('invalid-type')
|
||||||
|
|
||||||
# input is a dict and contains type key
|
# input is a dict and contains type key
|
||||||
cfg = dict(type='ValLoop', interval=1, begin=1)
|
cfg = dict(type='ValLoop')
|
||||||
loop = runner.build_test_loop(cfg)
|
loop = runner.build_test_loop(cfg)
|
||||||
self.assertIsInstance(loop, ValLoop)
|
self.assertIsInstance(loop, ValLoop)
|
||||||
|
|
||||||
# input is a dict but does not contain type key
|
# input is a dict but does not contain type key
|
||||||
cfg = dict(interval=1, begin=1)
|
cfg = dict()
|
||||||
loop = runner.build_val_loop(cfg)
|
loop = runner.build_val_loop(cfg)
|
||||||
self.assertIsInstance(loop, ValLoop)
|
self.assertIsInstance(loop, ValLoop)
|
||||||
|
|
||||||
@ -901,7 +903,7 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(id(runner.build_val_loop(loop)), id(loop))
|
self.assertEqual(id(runner.build_val_loop(loop)), id(loop))
|
||||||
|
|
||||||
# test custom validation loop
|
# test custom validation loop
|
||||||
cfg = dict(type='CustomValLoop', interval=1)
|
cfg = dict(type='CustomValLoop')
|
||||||
loop = runner.build_val_loop(cfg)
|
loop = runner.build_val_loop(cfg)
|
||||||
self.assertIsInstance(loop, CustomValLoop)
|
self.assertIsInstance(loop, CustomValLoop)
|
||||||
|
|
||||||
@ -999,7 +1001,7 @@ class TestRunner(TestCase):
|
|||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_train2'
|
cfg.experiment_name = 'test_train2'
|
||||||
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
cfg.custom_hooks = [dict(type='TestEpochHook', priority=50)]
|
||||||
cfg.val_cfg = dict(begin=2)
|
cfg.train_cfg = dict(by_epoch=True, max_epochs=3, val_begin=2)
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
runner.train()
|
runner.train()
|
||||||
@ -1044,7 +1046,8 @@ class TestRunner(TestCase):
|
|||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
cfg.experiment_name = 'test_train3'
|
cfg.experiment_name = 'test_train3'
|
||||||
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||||
cfg.val_cfg = dict(interval=4, begin=4)
|
cfg.train_cfg = dict(
|
||||||
|
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
@ -1052,6 +1055,8 @@ class TestRunner(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(len(epoch_results), 1)
|
self.assertEqual(len(epoch_results), 1)
|
||||||
self.assertEqual(epoch_results[0], 0)
|
self.assertEqual(epoch_results[0], 0)
|
||||||
|
self.assertEqual(runner.val_interval, 4)
|
||||||
|
self.assertEqual(runner.val_begin, 4)
|
||||||
for result, target, in zip(iter_results, iter_targets):
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user