[Enhance] Support infinite dataloader iterator wrapper for IterBasedTrainLoop. (#289)

pull/306/head
RangiLyu 2022-06-14 14:52:59 +08:00 committed by GitHub
parent 5016332588
commit 1c18f30854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 1 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import time
import warnings
from typing import Dict, List, Sequence, Union
@ -113,6 +114,51 @@ class EpochBasedTrainLoop(BaseLoop):
self._iter += 1
class _InfiniteDataloaderIterator:
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
It resets the dataloader to continue iterating when the iterator has
iterated over all the data. However, this approach is not efficient, as the
workers need to be restarted every time the dataloader is reset. It is
recommended to use `mmengine.data.InfiniteSampler` to enable the dataloader
to iterate infinitely.
"""
def __init__(self, dataloader: DataLoader) -> None:
self._dataloader = dataloader
self._iterator = iter(self._dataloader)
self._epoch = 0
def __iter__(self):
return self
def __next__(self) -> Sequence[dict]:
try:
data = next(self._iterator)
except StopIteration:
warnings.warn('Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use `mmengine.data.InfiniteSampler` '
'to enable the dataloader to iterate infinitely.')
self._epoch += 1
if hasattr(self._dataloader, 'sampler') and hasattr(
self._dataloader.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
self._dataloader.sampler.set_epoch(self._epoch)
elif hasattr(self._dataloader, 'batch_sampler') and hasattr(
self._dataloader.batch_sampler.sampler, 'set_epoch'):
# In case the` _SingleProcessDataLoaderIter` has no batch
# sampler. batch sampler in pytorch warps the sampler as its
# attributes.
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self._iterator = iter(self._dataloader)
data = next(self._iterator)
return data
@LOOPS.register_module()
class IterBasedTrainLoop(BaseLoop):
"""Loop for iter-based training.
@ -149,7 +195,7 @@ class IterBasedTrainLoop(BaseLoop):
'metainfo. ``dataset_meta`` in visualizer will be '
'None.')
# get the iterator of the dataloader
self.dataloader_iterator = iter(self.dataloader)
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
@property
def max_epochs(self):

View File

@ -29,6 +29,7 @@ from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
RUNNERS, Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.loops import _InfiniteDataloaderIterator
from mmengine.runner.priority import Priority, get_priority
from mmengine.utils import is_list_of
from mmengine.visualization import Visualizer
@ -1202,6 +1203,50 @@ class TestRunner(TestCase):
val_batch_idx_targets):
self.assertEqual(result, target)
# 4. test iter and epoch counter of IterBasedTrainLoop and timing of
# running ValLoop without InfiniteSampler
epoch_results = []
iter_results = []
batch_idx_results = []
val_iter_results = []
val_batch_idx_results = []
iter_targets = [i for i in range(12)]
batch_idx_targets = [i for i in range(12)]
val_iter_targets = [i for i in range(4, 12)]
val_batch_idx_targets = [i for i in range(4)] * 2
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train4'
cfg.train_dataloader.sampler = dict(
type='DefaultSampler', shuffle=True)
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
cfg.train_cfg = dict(
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
runner = Runner.from_cfg(cfg)
with self.assertWarnsRegex(
Warning,
'Reach the end of the dataloader, it will be restarted and '
'continue to iterate.'):
runner.train()
assert isinstance(runner.train_loop, IterBasedTrainLoop)
assert isinstance(runner.train_loop.dataloader_iterator,
_InfiniteDataloaderIterator)
self.assertEqual(len(epoch_results), 1)
self.assertEqual(epoch_results[0], 0)
self.assertEqual(runner.val_interval, 4)
self.assertEqual(runner.val_begin, 4)
for result, target, in zip(iter_results, iter_targets):
self.assertEqual(result, target)
for result, target, in zip(batch_idx_results, batch_idx_targets):
self.assertEqual(result, target)
for result, target, in zip(val_iter_results, val_iter_targets):
self.assertEqual(result, target)
for result, target, in zip(val_batch_idx_results,
val_batch_idx_targets):
self.assertEqual(result, target)
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'