[Enhance] Support dynamic interval (#342)

* support dynamic interval in iterbasedtrainloop

* update typehint

* update typehint

* add dynamic interval in epochbasedtrainloop

* update

* fix

Co-authored-by: luochunhua.vendor <luochunhua@pjlab.org.cn>
This commit is contained in:
Cedric Luo 2022-06-30 15:08:56 +08:00 committed by GitHub
parent d65350a9da
commit 9c55b4300c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 13 deletions

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import time
import warnings
from typing import Dict, List, Sequence, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from torch.utils.data import DataLoader
@ -11,6 +12,7 @@ from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
@LOOPS.register_module()
@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop):
val_begin (int): The epoch that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1) -> None:
def __init__(
self,
runner,
dataloader: Union[DataLoader, Dict],
max_epochs: int,
val_begin: int = 1,
val_interval: int = 1,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_epochs = max_epochs
self._max_iters = max_epochs * len(self.dataloader)
@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop):
'metainfo. ``dataset_meta`` in visualizer will be '
'None.')
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop):
while self._epoch < self._max_epochs:
self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop):
outputs=outputs)
self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
self.val_interval = self.dynamic_intervals[step - 1]
class _InfiniteDataloaderIterator:
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop):
val_begin (int): The iteration that begins validating.
Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000) -> None:
def __init__(
self,
runner,
dataloader: Union[DataLoader, Dict],
max_iters: int,
val_begin: int = 1,
val_interval: int = 1000,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader)
self._max_iters = max_iters
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop):
# get the iterator of the dataloader
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property
def max_epochs(self):
"""int: Total epochs to train model."""
@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop):
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
self._decide_current_val_interval()
if (self.runner.val_loop is not None
and self._iter >= self.val_begin
and self._iter % self.val_interval == 0):
@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop):
outputs=outputs)
self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
self.val_interval = self.dynamic_intervals[step - 1]
@LOOPS.register_module()
class ValLoop(BaseLoop):

35
mmengine/runner/utils.py Normal file
View File

@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
from mmengine.utils.misc import is_list_of
def calc_dynamic_intervals(
start_interval: int,
dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[int], List[int]]:
"""Calculate dynamic intervals.
Args:
start_interval (int): The interval used in the beginning.
dynamic_interval_list (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
Returns:
Tuple[List[int], List[int]]: a list of milestone and its corresponding
intervals.
"""
if dynamic_interval_list is None:
return [0], [start_interval]
assert is_list_of(dynamic_interval_list, tuple)
dynamic_milestones = [0]
dynamic_milestones.extend(
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
dynamic_intervals = [start_interval]
dynamic_intervals.extend(
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
return dynamic_milestones, dynamic_intervals

View File

@ -1283,6 +1283,80 @@ class TestRunner(TestCase):
val_batch_idx_targets):
self.assertEqual(result, target)
# 5. test dynamic interval in IterBasedTrainLoop
max_iters = 12
interval = 5
dynamic_intervals = [(11, 2)]
iter_results = []
iter_targets = [5, 10, 12]
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
class TestIterDynamicIntervalHook(Hook):
def before_val(self, runner):
iter_results.append(runner.iter)
def before_train_iter(self, runner, batch_idx, data_batch=None):
val_interval_results.append(runner.train_loop.val_interval)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train5'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [
dict(type='TestIterDynamicIntervalHook', priority=50)
]
cfg.train_cfg = dict(
by_epoch=False,
max_iters=max_iters,
val_interval=interval,
dynamic_intervals=dynamic_intervals)
runner = Runner.from_cfg(cfg)
runner.train()
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
# 6. test dynamic interval in EpochBasedTrainLoop
max_epochs = 12
interval = 5
dynamic_intervals = [(11, 2)]
epoch_results = []
epoch_targets = [5, 10, 12]
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
class TestEpochDynamicIntervalHook(Hook):
def before_val_epoch(self, runner):
epoch_results.append(runner.epoch)
def before_train_epoch(self, runner):
val_interval_results.append(runner.train_loop.val_interval)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train6'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [
dict(type='TestEpochDynamicIntervalHook', priority=50)
]
cfg.train_cfg = dict(
by_epoch=True,
max_epochs=max_epochs,
val_interval=interval,
dynamic_intervals=dynamic_intervals)
runner = Runner.from_cfg(cfg)
runner.train()
for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target)
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'