[Enhance] Support infinite dataloader iterator wrapper for IterBasedTrainLoop. (#289)
parent
5016332588
commit
1c18f30854
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
import warnings
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
|
@ -113,6 +114,51 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
self._iter += 1
|
||||
|
||||
|
||||
class _InfiniteDataloaderIterator:
|
||||
"""An infinite dataloader iterator wrapper for IterBasedTrainLoop.
|
||||
|
||||
It resets the dataloader to continue iterating when the iterator has
|
||||
iterated over all the data. However, this approach is not efficient, as the
|
||||
workers need to be restarted every time the dataloader is reset. It is
|
||||
recommended to use `mmengine.data.InfiniteSampler` to enable the dataloader
|
||||
to iterate infinitely.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: DataLoader) -> None:
|
||||
self._dataloader = dataloader
|
||||
self._iterator = iter(self._dataloader)
|
||||
self._epoch = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> Sequence[dict]:
|
||||
try:
|
||||
data = next(self._iterator)
|
||||
except StopIteration:
|
||||
warnings.warn('Reach the end of the dataloader, it will be '
|
||||
'restarted and continue to iterate. It is '
|
||||
'recommended to use `mmengine.data.InfiniteSampler` '
|
||||
'to enable the dataloader to iterate infinitely.')
|
||||
self._epoch += 1
|
||||
if hasattr(self._dataloader, 'sampler') and hasattr(
|
||||
self._dataloader.sampler, 'set_epoch'):
|
||||
# In case the` _SingleProcessDataLoaderIter` has no sampler,
|
||||
# or data loader uses `SequentialSampler` in Pytorch.
|
||||
self._dataloader.sampler.set_epoch(self._epoch)
|
||||
|
||||
elif hasattr(self._dataloader, 'batch_sampler') and hasattr(
|
||||
self._dataloader.batch_sampler.sampler, 'set_epoch'):
|
||||
# In case the` _SingleProcessDataLoaderIter` has no batch
|
||||
# sampler. batch sampler in pytorch warps the sampler as its
|
||||
# attributes.
|
||||
self._dataloader.batch_sampler.sampler.set_epoch(self._epoch)
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
self._iterator = iter(self._dataloader)
|
||||
data = next(self._iterator)
|
||||
return data
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class IterBasedTrainLoop(BaseLoop):
|
||||
"""Loop for iter-based training.
|
||||
|
@ -149,7 +195,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
'metainfo. ``dataset_meta`` in visualizer will be '
|
||||
'None.')
|
||||
# get the iterator of the dataloader
|
||||
self.dataloader_iterator = iter(self.dataloader)
|
||||
self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader)
|
||||
|
||||
@property
|
||||
def max_epochs(self):
|
||||
|
|
|
@ -29,6 +29,7 @@ from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS,
|
|||
RUNNERS, Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||
from mmengine.runner.priority import Priority, get_priority
|
||||
from mmengine.utils import is_list_of
|
||||
from mmengine.visualization import Visualizer
|
||||
|
@ -1202,6 +1203,50 @@ class TestRunner(TestCase):
|
|||
val_batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 4. test iter and epoch counter of IterBasedTrainLoop and timing of
|
||||
# running ValLoop without InfiniteSampler
|
||||
epoch_results = []
|
||||
iter_results = []
|
||||
batch_idx_results = []
|
||||
val_iter_results = []
|
||||
val_batch_idx_results = []
|
||||
iter_targets = [i for i in range(12)]
|
||||
batch_idx_targets = [i for i in range(12)]
|
||||
val_iter_targets = [i for i in range(4, 12)]
|
||||
val_batch_idx_targets = [i for i in range(4)] * 2
|
||||
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train4'
|
||||
cfg.train_dataloader.sampler = dict(
|
||||
type='DefaultSampler', shuffle=True)
|
||||
cfg.custom_hooks = [dict(type='TestIterHook', priority=50)]
|
||||
cfg.train_cfg = dict(
|
||||
by_epoch=False, max_iters=12, val_interval=4, val_begin=4)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
with self.assertWarnsRegex(
|
||||
Warning,
|
||||
'Reach the end of the dataloader, it will be restarted and '
|
||||
'continue to iterate.'):
|
||||
runner.train()
|
||||
|
||||
assert isinstance(runner.train_loop, IterBasedTrainLoop)
|
||||
assert isinstance(runner.train_loop.dataloader_iterator,
|
||||
_InfiniteDataloaderIterator)
|
||||
|
||||
self.assertEqual(len(epoch_results), 1)
|
||||
self.assertEqual(epoch_results[0], 0)
|
||||
self.assertEqual(runner.val_interval, 4)
|
||||
self.assertEqual(runner.val_begin, 4)
|
||||
for result, target, in zip(iter_results, iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(batch_idx_results, batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_iter_results, val_iter_targets):
|
||||
self.assertEqual(result, target)
|
||||
for result, target, in zip(val_batch_idx_results,
|
||||
val_batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val1'
|
||||
|
|
Loading…
Reference in New Issue