[Enhance] Support infinite dataloader iterator wrapper for IterBasedTrainLoop. (#289)
parent
5016332588
commit
1c18f30854
|
@ -1,4 +1,5 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Sequence, Union
|
from typing import Dict, List, Sequence, Union
|
||||||
|
|
||||||
|
@ -113,6 +114,51 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||||
self._iter += 1
|
self._iter += 1
|
||||||
|
|
||||||
|
|
||||||
|
class _InfiniteDataloaderIterator:
|
||||||
|
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
|
||||||
|
|
||||||
|
It resets the dataloader to continue iterating when the iterator has
|
||||||
|
iterated over all the data. However, this approach is not efficient, as the
|
||||||
|
workers need to be restarted every time the dataloader is reset. It is
|
||||||
|
recommended to use `mmengine.data.InfiniteSampler` to enable the dataloader
|
||||||
|
to iterate infinitely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader: DataLoader) -> None:
|
||||||
|
self._dataloader = dataloader
|
||||||
|
self._iterator = iter(self._dataloader)
|
||||||
|
self._epoch = 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> Sequence[dict]:
|
||||||
|
try:
|
||||||
|
data = next(self._iterator)
|
||||||
|
except StopIteration:
|
||||||
|
warnings.warn('Reach the end of the dataloader, it will be '
|
||||||
|
'restarted and continue to iterate. It is '
|
||||||
|
'recommended to use `mmengine.data.InfiniteSampler` '
|
||||||
|
'to enable the dataloader to iterate infinitely.')
|
||||||
|
self._epoch += 1
|
||||||
|
if hasattr(self._dataloader, 'sampler') and hasattr(
|
||||||
|
self._dataloader.sampler, 'set_epoch'):
|
||||||
|
# In case the` _SingleProcessDataLoaderIter` has no sampler,
|
||||||
|
# or data loader uses `SequentialSampler` in Pytorch.
|
||||||
|
self._dataloader.sampler.set_epoch(self._epoch)
|
||||||
|
|
||||||
|
elif hasattr(self._dataloader, 'batch_sampler') and hasattr(
|
||||||
|
self._dataloader.batch_sampler.sampler, 'set_epoch'):
|
||||||
|
# In case the` _SingleProcessDataLoaderIter` has no batch
|
||||||
|
# sampler. batch sampler in pytorch warps the sampler as its
|
||||||
|
# attributes.
|
||||||
|
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch)
|
||||||
|
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||||
|
self._iterator = iter(self._dataloader)
|
||||||
|
data = next(self._iterator)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
@LOOPS.register_module()
|
@LOOPS.register_module()
|
||||||
class IterBasedTrainLoop(BaseLoop):
|
class IterBasedTrainLoop(BaseLoop):
|
||||||
"""Loop for iter-based training.
|
"""Loop for iter-based training.
|
||||||
|
@ -149,7 +195,7 @@ class IterBasedTrainLoop(BaseLoop):
|
||||||
'metainfo. ``dataset_meta`` in visualizer will be '
|
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||||
'None.')
|
'None.')
|
||||||
# get the iterator of the dataloader
|
# get the iterator of the dataloader
|
||||||
self.dataloader_iterator = iter(self.dataloader)
|
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_epochs(self):
|
def max_epochs(self):
|
||||||
|
|
|
@ -29,6 +29,7 @@ from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
||||||
RUNNERS, Registry)
|
RUNNERS, Registry)
|
||||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||||
Runner, TestLoop, ValLoop)
|
Runner, TestLoop, ValLoop)
|
||||||
|
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||||
from mmengine.runner.priority import Priority, get_priority
|
from mmengine.runner.priority import Priority, get_priority
|
||||||
from mmengine.utils import is_list_of
|
from mmengine.utils import is_list_of
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
|
@ -1202,6 +1203,50 @@ class TestRunner(TestCase):
|
||||||
val_batch_idx_targets):
|
val_batch_idx_targets):
|
||||||
self.assertEqual(result, target)
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
|
# 4. test iter and epoch counter of IterBasedTrainLoop and timing of
|
||||||
|
# running ValLoop without InfiniteSampler
|
||||||
|
epoch_results = []
|
||||||
|
iter_results = []
|
||||||
|
batch_idx_results = []
|
||||||
|
val_iter_results = []
|
||||||
|
val_batch_idx_results = []
|
||||||
|
iter_targets = [i for i in range(12)]
|
||||||
|
batch_idx_targets = [i for i in range(12)]
|
||||||
|
val_iter_targets = [i for i in range(4, 12)]
|
||||||
|
val_batch_idx_targets = [i for i in range(4)] * 2
|
||||||
|
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_train4'
|
||||||
|
cfg.train_dataloader.sampler = dict(
|
||||||
|
type='DefaultSampler', shuffle=True)
|
||||||
|
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||||
|
cfg.train_cfg = dict(
|
||||||
|
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
with self.assertWarnsRegex(
|
||||||
|
Warning,
|
||||||
|
'Reach the end of the dataloader, it will be restarted and '
|
||||||
|
'continue to iterate.'):
|
||||||
|
runner.train()
|
||||||
|
|
||||||
|
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||||
|
assert isinstance(runner.train_loop.dataloader_iterator,
|
||||||
|
_InfiniteDataloaderIterator)
|
||||||
|
|
||||||
|
self.assertEqual(len(epoch_results), 1)
|
||||||
|
self.assertEqual(epoch_results[0], 0)
|
||||||
|
self.assertEqual(runner.val_interval, 4)
|
||||||
|
self.assertEqual(runner.val_begin, 4)
|
||||||
|
for result, target, in zip(iter_results, iter_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_iter_results, val_iter_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
for result, target, in zip(val_batch_idx_results,
|
||||||
|
val_batch_idx_targets):
|
||||||
|
self.assertEqual(result, target)
|
||||||
|
|
||||||
def test_val(self):
|
def test_val(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_val1'
|
cfg.experiment_name = 'test_val1'
|
||||||
|
|
Loading…
Reference in New Issue