mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Support dynamic interval (#342)
* support dynamic interval in iterbasedtrainloop * update typehint * update typehint * add dynamic interval in epochbasedtrainloop * update * fix Co-authored-by: luochunhua.vendor <luochunhua@pjlab.org.cn>
This commit is contained in:
parent
d65350a9da
commit
9c55b4300c
@ -1,7 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import bisect
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Sequence, Union
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -11,6 +12,7 @@ from mmengine.registry import LOOPS
|
|||||||
from mmengine.utils import is_list_of
|
from mmengine.utils import is_list_of
|
||||||
from .amp import autocast
|
from .amp import autocast
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
|
from .utils import calc_dynamic_intervals
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
@ -25,14 +27,20 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
val_begin (int): The epoch that begins validating.
|
val_begin (int): The epoch that begins validating.
|
||||||
Defaults to 1.
|
Defaults to 1.
|
||||||
val_interval (int): Validation interval. Defaults to 1.
|
val_interval (int): Validation interval. Defaults to 1.
|
||||||
|
dynamic_intervals (List[Tuple[int, int]], optional): The
|
||||||
|
first element in the tuple is a milestone and the second
|
||||||
|
element is a interval. The interval is used after the
|
||||||
|
corresponding milestone. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
runner,
|
self,
|
||||||
dataloader: Union[DataLoader, Dict],
|
runner,
|
||||||
max_epochs: int,
|
dataloader: Union[DataLoader, Dict],
|
||||||
val_begin: int = 1,
|
max_epochs: int,
|
||||||
val_interval: int = 1) -> None:
|
val_begin: int = 1,
|
||||||
|
val_interval: int = 1,
|
||||||
|
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._max_epochs = max_epochs
|
self._max_epochs = max_epochs
|
||||||
self._max_iters = max_epochs * len(self.dataloader)
|
self._max_iters = max_epochs * len(self.dataloader)
|
||||||
@ -49,6 +57,10 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||||
'None.')
|
'None.')
|
||||||
|
|
||||||
|
self.dynamic_milestones, self.dynamic_intervals = \
|
||||||
|
calc_dynamic_intervals(
|
||||||
|
self.val_interval, dynamic_intervals)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_epochs(self):
|
def max_epochs(self):
|
||||||
"""int: Total epochs to train model."""
|
"""int: Total epochs to train model."""
|
||||||
@ -76,6 +88,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
while self._epoch < self._max_epochs:
|
while self._epoch < self._max_epochs:
|
||||||
self.run_epoch()
|
self.run_epoch()
|
||||||
|
|
||||||
|
self._decide_current_val_interval()
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
and self._epoch >= self.val_begin
|
and self._epoch >= self.val_begin
|
||||||
and self._epoch % self.val_interval == 0):
|
and self._epoch % self.val_interval == 0):
|
||||||
@ -114,6 +127,11 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
outputs=outputs)
|
outputs=outputs)
|
||||||
self._iter += 1
|
self._iter += 1
|
||||||
|
|
||||||
|
def _decide_current_val_interval(self) -> None:
|
||||||
|
"""Dynamically modify the ``val_interval``."""
|
||||||
|
step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1))
|
||||||
|
self.val_interval = self.dynamic_intervals[step - 1]
|
||||||
|
|
||||||
|
|
||||||
class _InfiniteDataloaderIterator:
|
class _InfiniteDataloaderIterator:
|
||||||
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
|
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
|
||||||
@ -172,14 +190,20 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
val_begin (int): The iteration that begins validating.
|
val_begin (int): The iteration that begins validating.
|
||||||
Defaults to 1.
|
Defaults to 1.
|
||||||
val_interval (int): Validation interval. Defaults to 1000.
|
val_interval (int): Validation interval. Defaults to 1000.
|
||||||
|
dynamic_intervals (List[Tuple[int, int]], optional): The
|
||||||
|
first element in the tuple is a milestone and the second
|
||||||
|
element is a interval. The interval is used after the
|
||||||
|
corresponding milestone. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
runner,
|
self,
|
||||||
dataloader: Union[DataLoader, Dict],
|
runner,
|
||||||
max_iters: int,
|
dataloader: Union[DataLoader, Dict],
|
||||||
val_begin: int = 1,
|
max_iters: int,
|
||||||
val_interval: int = 1000) -> None:
|
val_begin: int = 1,
|
||||||
|
val_interval: int = 1000,
|
||||||
|
dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
self._max_iters = max_iters
|
self._max_iters = max_iters
|
||||||
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop
|
||||||
@ -198,6 +222,10 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
# get the iterator of the dataloader
|
# get the iterator of the dataloader
|
||||||
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
|
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
|
||||||
|
|
||||||
|
self.dynamic_milestones, self.dynamic_intervals = \
|
||||||
|
calc_dynamic_intervals(
|
||||||
|
self.val_interval, dynamic_intervals)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_epochs(self):
|
def max_epochs(self):
|
||||||
"""int: Total epochs to train model."""
|
"""int: Total epochs to train model."""
|
||||||
@ -230,6 +258,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
data_batch = next(self.dataloader_iterator)
|
data_batch = next(self.dataloader_iterator)
|
||||||
self.run_iter(data_batch)
|
self.run_iter(data_batch)
|
||||||
|
|
||||||
|
self._decide_current_val_interval()
|
||||||
if (self.runner.val_loop is not None
|
if (self.runner.val_loop is not None
|
||||||
and self._iter >= self.val_begin
|
and self._iter >= self.val_begin
|
||||||
and self._iter % self.val_interval == 0):
|
and self._iter % self.val_interval == 0):
|
||||||
@ -260,6 +289,11 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
outputs=outputs)
|
outputs=outputs)
|
||||||
self._iter += 1
|
self._iter += 1
|
||||||
|
|
||||||
|
def _decide_current_val_interval(self) -> None:
|
||||||
|
"""Dynamically modify the ``val_interval``."""
|
||||||
|
step = bisect.bisect(self.dynamic_milestones, (self._iter + 1))
|
||||||
|
self.val_interval = self.dynamic_intervals[step - 1]
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
class ValLoop(BaseLoop):
|
class ValLoop(BaseLoop):
|
||||||
|
35
mmengine/runner/utils.py
Normal file
35
mmengine/runner/utils.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from mmengine.utils.misc import is_list_of
|
||||||
|
|
||||||
|
|
||||||
|
def calc_dynamic_intervals(
|
||||||
|
start_interval: int,
|
||||||
|
dynamic_interval_list: Optional[List[Tuple[int, int]]] = None
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
"""Calculate dynamic intervals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_interval (int): The interval used in the beginning.
|
||||||
|
dynamic_interval_list (List[Tuple[int, int]], optional): The
|
||||||
|
first element in the tuple is a milestone and the second
|
||||||
|
element is a interval. The interval is used after the
|
||||||
|
corresponding milestone. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[int], List[int]]: a list of milestone and its corresponding
|
||||||
|
intervals.
|
||||||
|
"""
|
||||||
|
if dynamic_interval_list is None:
|
||||||
|
return [0], [start_interval]
|
||||||
|
|
||||||
|
assert is_list_of(dynamic_interval_list, tuple)
|
||||||
|
|
||||||
|
dynamic_milestones = [0]
|
||||||
|
dynamic_milestones.extend(
|
||||||
|
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
|
||||||
|
dynamic_intervals = [start_interval]
|
||||||
|
dynamic_intervals.extend(
|
||||||
|
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
|
||||||
|
return dynamic_milestones, dynamic_intervals
|
@ -1283,6 +1283,80 @@ class TestRunner(TestCase):
|
|||||||
val_batch_idx_targets):
|
val_batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
|
# 5. test dynamic interval in IterBasedTrainLoop
|
||||||
|
max_iters = 12
|
||||||
|
interval = 5
|
||||||
|
dynamic_intervals = [(11, 2)]
|
||||||
|
iter_results = []
|
||||||
|
iter_targets = [5, 10, 12]
|
||||||
|
val_interval_results = []
|
||||||
|
val_interval_targets = [5] * 10 + [2] * 2
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class TestIterDynamicIntervalHook(Hook):
|
||||||
|
|
||||||
|
def before_val(self, runner):
|
||||||
|
iter_results.append(runner.iter)
|
||||||
|
|
||||||
|
def before_train_iter(self, runner, batch_idx, data_batch=None):
|
||||||
|
val_interval_results.append(runner.train_loop.val_interval)
|
||||||
|
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_train5'
|
||||||
|
cfg.train_dataloader.sampler = dict(
|
||||||
|
type='DefaultSampler', shuffle=True)
|
||||||
|
cfg.custom_hooks = [
|
||||||
|
dict(type='TestIterDynamicIntervalHook', priority=50)
|
||||||
|
]
|
||||||
|
cfg.train_cfg = dict(
|
||||||
|
by_epoch=False,
|
||||||
|
max_iters=max_iters,
|
||||||
|
val_interval=interval,
|
||||||
|
dynamic_intervals=dynamic_intervals)
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.train()
|
||||||
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
|
# 6. test dynamic interval in EpochBasedTrainLoop
|
||||||
|
max_epochs = 12
|
||||||
|
interval = 5
|
||||||
|
dynamic_intervals = [(11, 2)]
|
||||||
|
epoch_results = []
|
||||||
|
epoch_targets = [5, 10, 12]
|
||||||
|
val_interval_results = []
|
||||||
|
val_interval_targets = [5] * 10 + [2] * 2
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class TestEpochDynamicIntervalHook(Hook):
|
||||||
|
|
||||||
|
def before_val_epoch(self, runner):
|
||||||
|
epoch_results.append(runner.epoch)
|
||||||
|
|
||||||
|
def before_train_epoch(self, runner):
|
||||||
|
val_interval_results.append(runner.train_loop.val_interval)
|
||||||
|
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_train6'
|
||||||
|
cfg.train_dataloader.sampler = dict(
|
||||||
|
type='DefaultSampler', shuffle=True)
|
||||||
|
cfg.custom_hooks = [
|
||||||
|
dict(type='TestEpochDynamicIntervalHook', priority=50)
|
||||||
|
]
|
||||||
|
cfg.train_cfg = dict(
|
||||||
|
by_epoch=True,
|
||||||
|
max_epochs=max_epochs,
|
||||||
|
val_interval=interval,
|
||||||
|
dynamic_intervals=dynamic_intervals)
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.train()
|
||||||
|
for result, target, in zip(epoch_results, epoch_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
def test_val(self):
|
def test_val(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_val1'
|
cfg.experiment_name = 'test_val1'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user