mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add data sampler (#30)
* Add DefaultSampler and InfiniteSampler * Add unit test * Add samplers to API reference * Update docstring * Improve according to comments * Rename `num_replicas` to `world_size` * Update docstring. * Update mmengine/data/sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/data/sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Fix typo in unit test Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
7353778b7c
commit
f0451a38f0
@ -2,3 +2,8 @@ Registry
|
|||||||
--------
|
--------
|
||||||
.. automodule:: mmengine.registry
|
.. automodule:: mmengine.registry
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
Data
|
||||||
|
--------
|
||||||
|
.. automodule:: mmengine.data
|
||||||
|
:members:
|
||||||
|
@ -2,3 +2,8 @@ Registry
|
|||||||
--------
|
--------
|
||||||
.. automodule:: mmengine.registry
|
.. automodule:: mmengine.registry
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
Data
|
||||||
|
--------
|
||||||
|
.. automodule:: mmengine.data
|
||||||
|
:members:
|
||||||
|
4
mmengine/data/__init__.py
Normal file
4
mmengine/data/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .sampler import DefaultSampler, InfiniteSampler
|
||||||
|
|
||||||
|
__all__ = ['DefaultSampler', 'InfiniteSampler']
|
168
mmengine/data/sampler.py
Normal file
168
mmengine/data/sampler.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
from typing import Iterator, Optional, Sized
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
|
from mmengine.dist import get_dist_info, sync_random_seed
|
||||||
|
from mmengine.registry import DATA_SAMPLERS
|
||||||
|
|
||||||
|
|
||||||
|
@DATA_SAMPLERS.register_module()
|
||||||
|
class DefaultSampler(Sampler[int]):
|
||||||
|
"""The default data sampler for both distributed and non-distributed
|
||||||
|
environment.
|
||||||
|
|
||||||
|
It has several differences from the PyTorch ``DistributedSampler`` as
|
||||||
|
below:
|
||||||
|
|
||||||
|
1. This sampler supports non-distributed environment.
|
||||||
|
|
||||||
|
2. The round up behaviors are a little different.
|
||||||
|
|
||||||
|
- If ``round_up=True``, this sampler will add extra samples to make the
|
||||||
|
number of samples is evenly divisible by the world size. And
|
||||||
|
this behavior is the same as the ``DistributedSampler`` with
|
||||||
|
``drop_last=False``.
|
||||||
|
- If ``round_up=False``, this sampler won't remove or add any samples
|
||||||
|
while the ``DistributedSampler`` with ``drop_last=True`` will remove
|
||||||
|
tail samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Sized): The dataset.
|
||||||
|
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
||||||
|
seed (int, optional): Random seed used to shuffle the sampler if
|
||||||
|
:attr:`shuffle=True`. This number should be identical across all
|
||||||
|
processes in the distributed group. Defaults to None.
|
||||||
|
round_up (bool): Whether to add extra samples to make the number of
|
||||||
|
samples evenly divisible by the world size. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset: Sized,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
round_up: bool = True) -> None:
|
||||||
|
rank, world_size = get_dist_info()
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
|
||||||
|
self.dataset = dataset
|
||||||
|
self.shuffle = shuffle
|
||||||
|
if seed is None:
|
||||||
|
seed = sync_random_seed()
|
||||||
|
self.seed = seed
|
||||||
|
self.epoch = 0
|
||||||
|
self.round_up = round_up
|
||||||
|
|
||||||
|
if self.round_up:
|
||||||
|
self.num_samples = math.ceil(len(self.dataset) / world_size)
|
||||||
|
self.total_size = self.num_samples * self.world_size
|
||||||
|
else:
|
||||||
|
self.num_samples = math.ceil(
|
||||||
|
(len(self.dataset) - rank) / world_size)
|
||||||
|
self.total_size = len(self.dataset)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
"""Iterate the indices."""
|
||||||
|
# deterministically shuffle based on epoch and seed
|
||||||
|
if self.shuffle:
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed + self.epoch)
|
||||||
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||||
|
else:
|
||||||
|
indices = torch.arange(len(self.dataset)).tolist()
|
||||||
|
|
||||||
|
# add extra samples to make it evenly divisible
|
||||||
|
if self.round_up:
|
||||||
|
indices = (
|
||||||
|
indices *
|
||||||
|
int(self.total_size / len(indices) + 1))[:self.total_size]
|
||||||
|
|
||||||
|
# subsample
|
||||||
|
indices = indices[self.rank:self.total_size:self.world_size]
|
||||||
|
|
||||||
|
return iter(indices)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""The number of samples in this rank."""
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int) -> None:
|
||||||
|
"""Sets the epoch for this sampler.
|
||||||
|
|
||||||
|
When :attr:`shuffle=True`, this ensures all replicas use a different
|
||||||
|
random ordering for each epoch. Otherwise, the next iteration of this
|
||||||
|
sampler will yield the same ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): Epoch number.
|
||||||
|
"""
|
||||||
|
self.epoch = epoch
|
||||||
|
|
||||||
|
|
||||||
|
@DATA_SAMPLERS.register_module()
|
||||||
|
class InfiniteSampler(Sampler[int]):
|
||||||
|
"""It's designed for iteration-based runner and yields a mini-batch indices
|
||||||
|
each time.
|
||||||
|
|
||||||
|
The implementation logic is referred to
|
||||||
|
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/distributed_sampler.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Sized): The dataset.
|
||||||
|
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
||||||
|
seed (int, optional): Random seed. If None, set a random seed.
|
||||||
|
Defaults to None.
|
||||||
|
""" # noqa: W605
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset: Sized,
|
||||||
|
shuffle: bool = True,
|
||||||
|
seed: Optional[int] = None) -> None:
|
||||||
|
rank, world_size = get_dist_info()
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
|
||||||
|
self.dataset = dataset
|
||||||
|
self.world_size = world_size
|
||||||
|
self.rank = rank
|
||||||
|
self.shuffle = shuffle
|
||||||
|
if seed is None:
|
||||||
|
seed = sync_random_seed()
|
||||||
|
self.seed = seed
|
||||||
|
self.size = len(dataset)
|
||||||
|
self.indices = self._indices_of_rank()
|
||||||
|
|
||||||
|
def _infinite_indices(self) -> Iterator[int]:
|
||||||
|
"""Infinitely yield a sequence of indices."""
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.seed)
|
||||||
|
while True:
|
||||||
|
if self.shuffle:
|
||||||
|
yield from torch.randperm(self.size, generator=g).tolist()
|
||||||
|
|
||||||
|
else:
|
||||||
|
yield from torch.arange(self.size).tolist()
|
||||||
|
|
||||||
|
def _indices_of_rank(self) -> Iterator[int]:
|
||||||
|
"""Slice the infinite indices by rank."""
|
||||||
|
yield from itertools.islice(self._infinite_indices(), self.rank, None,
|
||||||
|
self.world_size)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
"""Iterate the indices."""
|
||||||
|
for idx in self.indices:
|
||||||
|
yield idx
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Length of base dataset."""
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def set_epoch(self, epoch: int) -> None:
|
||||||
|
"""Not supported in iteration-based runner."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
'The `InfiniteSampler` is only used in iteration-based runner, '
|
||||||
|
"and doesn't need `set_epoch`")
|
139
tests/test_data/test_sampler.py
Normal file
139
tests/test_data/test_sampler.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.data import DefaultSampler, InfiniteSampler
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultSampler(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.data_length = 100
|
||||||
|
self.dataset = list(range(self.data_length))
|
||||||
|
|
||||||
|
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||||
|
def test_non_dist(self, mock):
|
||||||
|
sampler = DefaultSampler(self.dataset)
|
||||||
|
self.assertEqual(sampler.world_size, 1)
|
||||||
|
self.assertEqual(sampler.rank, 0)
|
||||||
|
|
||||||
|
# test round_up=True
|
||||||
|
sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False)
|
||||||
|
self.assertEqual(sampler.total_size, self.data_length)
|
||||||
|
self.assertEqual(sampler.num_samples, self.data_length)
|
||||||
|
self.assertEqual(list(sampler), list(range(self.data_length)))
|
||||||
|
|
||||||
|
# test round_up=False
|
||||||
|
sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False)
|
||||||
|
self.assertEqual(sampler.total_size, self.data_length)
|
||||||
|
self.assertEqual(sampler.num_samples, self.data_length)
|
||||||
|
self.assertEqual(list(sampler), list(range(self.data_length)))
|
||||||
|
|
||||||
|
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
|
||||||
|
def test_dist(self, mock):
|
||||||
|
sampler = DefaultSampler(self.dataset)
|
||||||
|
self.assertEqual(sampler.world_size, 3)
|
||||||
|
self.assertEqual(sampler.rank, 2)
|
||||||
|
|
||||||
|
# test round_up=True
|
||||||
|
sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False)
|
||||||
|
self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3))
|
||||||
|
self.assertEqual(sampler.total_size, sampler.num_samples * 3)
|
||||||
|
self.assertEqual(len(sampler), sampler.num_samples)
|
||||||
|
self.assertEqual(
|
||||||
|
list(sampler),
|
||||||
|
list(range(self.data_length))[2::3] + [1])
|
||||||
|
|
||||||
|
# test round_up=False
|
||||||
|
sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False)
|
||||||
|
self.assertEqual(sampler.num_samples,
|
||||||
|
np.ceil((self.data_length - 2) / 3))
|
||||||
|
self.assertEqual(sampler.total_size, self.data_length)
|
||||||
|
self.assertEqual(len(sampler), sampler.num_samples)
|
||||||
|
self.assertEqual(list(sampler), list(range(self.data_length))[2::3])
|
||||||
|
|
||||||
|
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||||
|
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
|
||||||
|
def test_shuffle(self, mock1, mock2):
|
||||||
|
# test seed=None
|
||||||
|
sampler = DefaultSampler(self.dataset, seed=None)
|
||||||
|
self.assertEqual(sampler.seed, 7)
|
||||||
|
|
||||||
|
# test random seed
|
||||||
|
sampler = DefaultSampler(self.dataset, shuffle=True, seed=0)
|
||||||
|
sampler.set_epoch(10)
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(10)
|
||||||
|
self.assertEqual(
|
||||||
|
list(sampler),
|
||||||
|
torch.randperm(len(self.dataset), generator=g).tolist())
|
||||||
|
|
||||||
|
sampler = DefaultSampler(self.dataset, shuffle=True, seed=42)
|
||||||
|
sampler.set_epoch(10)
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(42 + 10)
|
||||||
|
self.assertEqual(
|
||||||
|
list(sampler),
|
||||||
|
torch.randperm(len(self.dataset), generator=g).tolist())
|
||||||
|
|
||||||
|
|
||||||
|
class TestInfiniteSampler(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.data_length = 100
|
||||||
|
self.dataset = list(range(self.data_length))
|
||||||
|
|
||||||
|
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||||
|
def test_non_dist(self, mock):
|
||||||
|
sampler = InfiniteSampler(self.dataset)
|
||||||
|
self.assertEqual(sampler.world_size, 1)
|
||||||
|
self.assertEqual(sampler.rank, 0)
|
||||||
|
|
||||||
|
# test iteration
|
||||||
|
sampler = InfiniteSampler(self.dataset, shuffle=False)
|
||||||
|
self.assertEqual(len(sampler), self.data_length)
|
||||||
|
self.assertEqual(sampler.size, self.data_length)
|
||||||
|
items = [next(iter(sampler)) for _ in range(self.data_length * 2)]
|
||||||
|
self.assertEqual(items, list(range(self.data_length)) * 2)
|
||||||
|
|
||||||
|
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
|
||||||
|
def test_dist(self, mock):
|
||||||
|
sampler = InfiniteSampler(self.dataset)
|
||||||
|
self.assertEqual(sampler.world_size, 3)
|
||||||
|
self.assertEqual(sampler.rank, 2)
|
||||||
|
|
||||||
|
# test iteration
|
||||||
|
sampler = InfiniteSampler(self.dataset, shuffle=False)
|
||||||
|
self.assertEqual(len(sampler), self.data_length)
|
||||||
|
self.assertEqual(sampler.size, self.data_length)
|
||||||
|
targets = (list(range(self.data_length)) * 2)[2::3]
|
||||||
|
samples = [next(iter(sampler)) for _ in range(len(targets))]
|
||||||
|
self.assertEqual(samples, targets)
|
||||||
|
|
||||||
|
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
|
||||||
|
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
|
||||||
|
def test_shuffle(self, mock1, mock2):
|
||||||
|
# test seed=None
|
||||||
|
sampler = InfiniteSampler(self.dataset, seed=None)
|
||||||
|
self.assertEqual(sampler.seed, 7)
|
||||||
|
|
||||||
|
# test the random seed
|
||||||
|
sampler = InfiniteSampler(self.dataset, shuffle=True, seed=42)
|
||||||
|
|
||||||
|
sampler_iter = iter(sampler)
|
||||||
|
samples = [next(sampler_iter) for _ in range(self.data_length)]
|
||||||
|
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(42)
|
||||||
|
self.assertEqual(
|
||||||
|
samples,
|
||||||
|
torch.randperm(self.data_length, generator=g).tolist())
|
||||||
|
|
||||||
|
def test_set_epoch(self):
|
||||||
|
sampler = InfiniteSampler(self.dataset)
|
||||||
|
self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))
|
Loading…
x
Reference in New Issue
Block a user