mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Support dynamic interval (#342)
* support dynamic interval in iterbasedtrainloop * update typehint * update typehint * add dynamic interval in epochbasedtrainloop * update * fix Co-authored-by: luochunhua.vendor <luochunhua@pjlab.org.cn>
This commit is contained in:
parent
d65350a9da
commit
9c55b4300c
@ -1,7 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import bisect
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@ -11,6 +12,7 @@ from mmengine.registry import LOOPS
|
||||
from mmengine.utils import is_list_of
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
from .utils import calc_dynamic_intervals
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
val_begin (int): The epoch that begins validating.
|
||||
Defaults to 1.
|
||||
val_interval (int): Validation interval. Defaults to 1.
|
||||
dynamic_intervals (List[Tuple[int, int]], optional): The
|
||||
first element in the tuple is a milestone and the second
|
||||
element is a interval. The interval is used after the
|
||||
corresponding milestone. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
max_epochs: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
max_epochs: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1,
|
||||
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
self._max_epochs = max_epochs
|
||||
self._max_iters = max_epochs * len(self.dataloader)
|
||||
@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||
'None.')
|
||||
|
||||
self.dynamic_milestones, self.dynamic_intervals = \
|
||||
calc_dynamic_intervals(
|
||||
self.val_interval, dynamic_intervals)
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
"""int: Total epochs to train model."""
|
||||
@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
while self._epoch < self._max_epochs:
|
||||
self.run_epoch()
|
||||
|
||||
self._decide_current_val_interval()
|
||||
if (self.runner.val_loop is not None
|
||||
and self._epoch >= self.val_begin
|
||||
and self._epoch % self.val_interval == 0):
|
||||
@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
outputs=outputs)
|
||||
self._iter += 1
|
||||
|
||||
def _decide_current_val_interval(self) -> None:
|
||||
"""Dynamically modify the ``val_interval``."""
|
||||
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
|
||||
self.val_interval = self.dynamic_intervals[step - 1]
|
||||
|
||||
|
||||
class _InfiniteDataloaderIterator:
|
||||
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
|
||||
@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
val_begin (int): The iteration that begins validating.
|
||||
Defaults to 1.
|
||||
val_interval (int): Validation interval. Defaults to 1000.
|
||||
dynamic_intervals (List[Tuple[int, int]], optional): The
|
||||
first element in the tuple is a milestone and the second
|
||||
element is a interval. The interval is used after the
|
||||
corresponding milestone. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
max_iters: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1000) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
max_iters: int,
|
||||
val_begin: int = 1,
|
||||
val_interval: int = 1000,
|
||||
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
self._max_iters = max_iters
|
||||
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
||||
@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
# get the iterator of the dataloader
|
||||
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
|
||||
|
||||
self.dynamic_milestones, self.dynamic_intervals = \
|
||||
calc_dynamic_intervals(
|
||||
self.val_interval, dynamic_intervals)
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
"""int: Total epochs to train model."""
|
||||
@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
data_batch = next(self.dataloader_iterator)
|
||||
self.run_iter(data_batch)
|
||||
|
||||
self._decide_current_val_interval()
|
||||
if (self.runner.val_loop is not None
|
||||
and self._iter >= self.val_begin
|
||||
and self._iter % self.val_interval == 0):
|
||||
@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
outputs=outputs)
|
||||
self._iter += 1
|
||||
|
||||
def _decide_current_val_interval(self) -> None:
|
||||
"""Dynamically modify the ``val_interval``."""
|
||||
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
|
||||
self.val_interval = self.dynamic_intervals[step - 1]
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class ValLoop(BaseLoop):
|
||||
|
35
mmengine/runner/utils.py
Normal file
35
mmengine/runner/utils.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mmengine.utils.misc import is_list_of
|
||||
|
||||
|
||||
def calc_dynamic_intervals(
|
||||
start_interval: int,
|
||||
dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
"""Calculate dynamic intervals.
|
||||
|
||||
Args:
|
||||
start_interval (int): The interval used in the beginning.
|
||||
dynamic_interval_list (List[Tuple[int, int]], optional): The
|
||||
first element in the tuple is a milestone and the second
|
||||
element is a interval. The interval is used after the
|
||||
corresponding milestone. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[List[int], List[int]]: a list of milestone and its corresponding
|
||||
intervals.
|
||||
"""
|
||||
if dynamic_interval_list is None:
|
||||
return [0], [start_interval]
|
||||
|
||||
assert is_list_of(dynamic_interval_list, tuple)
|
||||
|
||||
dynamic_milestones = [0]
|
||||
dynamic_milestones.extend(
|
||||
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
|
||||
dynamic_intervals = [start_interval]
|
||||
dynamic_intervals.extend(
|
||||
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
|
||||
return dynamic_milestones, dynamic_intervals
|
@ -1283,6 +1283,80 @@ class TestRunner(TestCase):
|
||||
val_batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 5. test dynamic interval in IterBasedTrainLoop
|
||||
max_iters = 12
|
||||
interval = 5
|
||||
dynamic_intervals = [(11, 2)]
|
||||
iter_results = []
|
||||
iter_targets = [5, 10, 12]
|
||||
val_interval_results = []
|
||||
val_interval_targets = [5] * 10 + [2] * 2
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TestIterDynamicIntervalHook(Hook):
|
||||
|
||||
def before_val(self, runner):
|
||||
iter_results.append(runner.iter)
|
||||
|
||||
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
||||
val_interval_results.append(runner.train_loop.val_interval)
|
||||
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train5'
|
||||
cfg.train_dataloader.sampler = dict(
|
||||
type='DefaultSampler', shuffle=True)
|
||||
cfg.custom_hooks = [
|
||||
dict(type='TestIterDynamicIntervalHook', priority=50)
|
||||
]
|
||||
cfg.train_cfg = dict(
|
||||
by_epoch=False,
|
||||
max_iters=max_iters,
|
||||
val_interval=interval,
|
||||
dynamic_intervals=dynamic_intervals)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
for result, target, in zip(iter_results, iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 6. test dynamic interval in EpochBasedTrainLoop
|
||||
max_epochs = 12
|
||||
interval = 5
|
||||
dynamic_intervals = [(11, 2)]
|
||||
epoch_results = []
|
||||
epoch_targets = [5, 10, 12]
|
||||
val_interval_results = []
|
||||
val_interval_targets = [5] * 10 + [2] * 2
|
||||
|
||||
@HOOKS.register_module()
|
||||
class TestEpochDynamicIntervalHook(Hook):
|
||||
|
||||
def before_val_epoch(self, runner):
|
||||
epoch_results.append(runner.epoch)
|
||||
|
||||
def before_train_epoch(self, runner):
|
||||
val_interval_results.append(runner.train_loop.val_interval)
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_train6'
|
||||
cfg.train_dataloader.sampler = dict(
|
||||
type='DefaultSampler', shuffle=True)
|
||||
cfg.custom_hooks = [
|
||||
dict(type='TestEpochDynamicIntervalHook', priority=50)
|
||||
]
|
||||
cfg.train_cfg = dict(
|
||||
by_epoch=True,
|
||||
max_epochs=max_epochs,
|
||||
val_interval=interval,
|
||||
dynamic_intervals=dynamic_intervals)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
for result, target, in zip(epoch_results, epoch_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val1'
|
||||
|
Loading…
x
Reference in New Issue
Block a user