[Enhance] Support dynamic interval (#342)

* support dynamic interval in iterbasedtrainloop

* update typehint

* update typehint

* add dynamic interval in epochbasedtrainloop

* update

* fix

Co-authored-by: luochunhua.vendor <luochunhua@pjlab.org.cn>
This commit is contained in:
Cedric Luo 2022-06-30 15:08:56 +08:00 committed by GitHub
parent d65350a9da
commit 9c55b4300c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 13 deletions

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import bisect
import time import time
import warnings import warnings
from typing import Dict, List, Sequence, Union from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -11,6 +12,7 @@ from mmengine.registry import LOOPS
from mmengine.utils import is_list_of from mmengine.utils import is_list_of
from .amp import autocast from .amp import autocast
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
@LOOPS.register_module() @LOOPS.register_module()
@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop):
val_begin (int): The epoch that begins validating. val_begin (int): The epoch that begins validating.
Defaults to 1. Defaults to 1.
val_interval (int): Validation interval. Defaults to 1. val_interval (int): Validation interval. Defaults to 1.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
""" """
def __init__(self, def __init__(
runner, self,
dataloader: Union[DataLoader, Dict], runner,
max_epochs: int, dataloader: Union[DataLoader, Dict],
val_begin: int = 1, max_epochs: int,
val_interval: int = 1) -> None: val_begin: int = 1,
val_interval: int = 1,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
self._max_epochs = max_epochs self._max_epochs = max_epochs
self._max_iters = max_epochs * len(self.dataloader) self._max_iters = max_epochs * len(self.dataloader)
@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop):
'metainfo. ``dataset_meta`` in visualizer will be ' 'metainfo. ``dataset_meta`` in visualizer will be '
'None.') 'None.')
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property @property
def max_epochs(self): def max_epochs(self):
"""int: Total epochs to train model.""" """int: Total epochs to train model."""
@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop):
while self._epoch < self._max_epochs: while self._epoch < self._max_epochs:
self.run_epoch() self.run_epoch()
self._decide_current_val_interval()
if (self.runner.val_loop is not None if (self.runner.val_loop is not None
and self._epoch >= self.val_begin and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0): and self._epoch % self.val_interval == 0):
@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop):
outputs=outputs) outputs=outputs)
self._iter += 1 self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
self.val_interval = self.dynamic_intervals[step - 1]
class _InfiniteDataloaderIterator: class _InfiniteDataloaderIterator:
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop. """An infinite dataloader iterator wrapper for IterBasedTrainLoop.
@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop):
val_begin (int): The iteration that begins validating. val_begin (int): The iteration that begins validating.
Defaults to 1. Defaults to 1.
val_interval (int): Validation interval. Defaults to 1000. val_interval (int): Validation interval. Defaults to 1000.
dynamic_intervals (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
""" """
def __init__(self, def __init__(
runner, self,
dataloader: Union[DataLoader, Dict], runner,
max_iters: int, dataloader: Union[DataLoader, Dict],
val_begin: int = 1, max_iters: int,
val_interval: int = 1000) -> None: val_begin: int = 1,
val_interval: int = 1000,
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
self._max_iters = max_iters self._max_iters = max_iters
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop):
# get the iterator of the dataloader # get the iterator of the dataloader
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
self.dynamic_milestones, self.dynamic_intervals = \
calc_dynamic_intervals(
self.val_interval, dynamic_intervals)
@property @property
def max_epochs(self): def max_epochs(self):
"""int: Total epochs to train model.""" """int: Total epochs to train model."""
@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop):
data_batch = next(self.dataloader_iterator) data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch) self.run_iter(data_batch)
self._decide_current_val_interval()
if (self.runner.val_loop is not None if (self.runner.val_loop is not None
and self._iter >= self.val_begin and self._iter >= self.val_begin
and self._iter % self.val_interval == 0): and self._iter % self.val_interval == 0):
@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop):
outputs=outputs) outputs=outputs)
self._iter += 1 self._iter += 1
def _decide_current_val_interval(self) -> None:
"""Dynamically modify the ``val_interval``."""
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
self.val_interval = self.dynamic_intervals[step - 1]
@LOOPS.register_module() @LOOPS.register_module()
class ValLoop(BaseLoop): class ValLoop(BaseLoop):

35
mmengine/runner/utils.py Normal file
View File

@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple
from mmengine.utils.misc import is_list_of
def calc_dynamic_intervals(
start_interval: int,
dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
) -> Tuple[List[int], List[int]]:
"""Calculate dynamic intervals.
Args:
start_interval (int): The interval used in the beginning.
dynamic_interval_list (List[Tuple[int, int]], optional): The
first element in the tuple is a milestone and the second
element is a interval. The interval is used after the
corresponding milestone. Defaults to None.
Returns:
Tuple[List[int], List[int]]: a list of milestone and its corresponding
intervals.
"""
if dynamic_interval_list is None:
return [0], [start_interval]
assert is_list_of(dynamic_interval_list, tuple)
dynamic_milestones = [0]
dynamic_milestones.extend(
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
dynamic_intervals = [start_interval]
dynamic_intervals.extend(
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
return dynamic_milestones, dynamic_intervals

View File

@ -1283,6 +1283,80 @@ class TestRunner(TestCase):
val_batch_idx_targets): val_batch_idx_targets):
self.assertEqual(result, target) self.assertEqual(result, target)
# 5. test dynamic interval in IterBasedTrainLoop
max_iters = 12
interval = 5
dynamic_intervals = [(11, 2)]
iter_results = []
iter_targets = [5, 10, 12]
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
class TestIterDynamicIntervalHook(Hook):
def before_val(self, runner):
iter_results.append(runner.iter)
def before_train_iter(self, runner, batch_idx, data_batch=None):
val_interval_results.append(runner.train_loop.val_interval)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train5'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [
dict(type='TestIterDynamicIntervalHook', priority=50)
]
cfg.train_cfg = dict(
by_epoch=False,
max_iters=max_iters,
val_interval=interval,
dynamic_intervals=dynamic_intervals)
runner = Runner.from_cfg(cfg)
runner.train()
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
# 6. test dynamic interval in EpochBasedTrainLoop
max_epochs = 12
interval = 5
dynamic_intervals = [(11, 2)]
epoch_results = []
epoch_targets = [5, 10, 12]
val_interval_results = []
val_interval_targets = [5] * 10 + [2] * 2
@HOOKS.register_module()
class TestEpochDynamicIntervalHook(Hook):
def before_val_epoch(self, runner):
epoch_results.append(runner.epoch)
def before_train_epoch(self, runner):
val_interval_results.append(runner.train_loop.val_interval)
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train6'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [
dict(type='TestEpochDynamicIntervalHook', priority=50)
]
cfg.train_cfg = dict(
by_epoch=True,
max_epochs=max_epochs,
val_interval=interval,
dynamic_intervals=dynamic_intervals)
runner = Runner.from_cfg(cfg)
runner.train()
for result, target, in zip(epoch_results, epoch_targets):
self.assertEqual(result, target)
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
def test_val(self): def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1' cfg.experiment_name = 'test_val1'