[Feature] Add data sampler (#30)

* Add DefaultSampler and InfiniteSampler

* Add unit test

* Add samplers to API reference

* Update docstring

* Improve according to comments

* Rename `num_replicas` to `world_size`

* Update docstring.

* Update mmengine/data/sampler.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/data/sampler.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Fix typo in unit test

Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/40/head
Ma Zerun 2022-02-21 13:08:55 +08:00 committed by GitHub
parent 7353778b7c
commit f0451a38f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 321 additions and 0 deletions

View File

@ -2,3 +2,8 @@ Registry
--------
.. automodule:: mmengine.registry
:members:
Data
--------
.. automodule:: mmengine.data
:members:

View File

@ -2,3 +2,8 @@ Registry
--------
.. automodule:: mmengine.registry
:members:
Data
--------
.. automodule:: mmengine.data
:members:

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .sampler import DefaultSampler, InfiniteSampler
__all__ = ['DefaultSampler', 'InfiniteSampler']

View File

@ -0,0 +1,168 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import math
from typing import Iterator, Optional, Sized
import torch
from torch.utils.data import Sampler
from mmengine.dist import get_dist_info, sync_random_seed
from mmengine.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class DefaultSampler(Sampler[int]):
"""The default data sampler for both distributed and non-distributed
environment.
It has several differences from the PyTorch ``DistributedSampler`` as
below:
1. This sampler supports non-distributed environment.
2. The round up behaviors are a little different.
- If ``round_up=True``, this sampler will add extra samples to make the
number of samples is evenly divisible by the world size. And
this behavior is the same as the ``DistributedSampler`` with
``drop_last=False``.
- If ``round_up=False``, this sampler won't remove or add any samples
while the ``DistributedSampler`` with ``drop_last=True`` will remove
tail samples.
Args:
dataset (Sized): The dataset.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Defaults to None.
round_up (bool): Whether to add extra samples to make the number of
samples evenly divisible by the world size. Defaults to True.
"""
def __init__(self,
dataset: Sized,
shuffle: bool = True,
seed: Optional[int] = None,
round_up: bool = True) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.shuffle = shuffle
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.epoch = 0
self.round_up = round_up
if self.round_up:
self.num_samples = math.ceil(len(self.dataset) / world_size)
self.total_size = self.num_samples * self.world_size
else:
self.num_samples = math.ceil(
(len(self.dataset) - rank) / world_size)
self.total_size = len(self.dataset)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
# deterministically shuffle based on epoch and seed
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
if self.round_up:
indices = (
indices *
int(self.total_size / len(indices) + 1))[:self.total_size]
# subsample
indices = indices[self.rank:self.total_size:self.world_size]
return iter(indices)
def __len__(self) -> int:
"""The number of samples in this rank."""
return self.num_samples
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
When :attr:`shuffle=True`, this ensures all replicas use a different
random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
@DATA_SAMPLERS.register_module()
class InfiniteSampler(Sampler[int]):
"""It's designed for iteration-based runner and yields a mini-batch indices
each time.
The implementation logic is referred to
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/distributed_sampler.py
Args:
dataset (Sized): The dataset.
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
seed (int, optional): Random seed. If None, set a random seed.
Defaults to None.
""" # noqa: W605
def __init__(self,
dataset: Sized,
shuffle: bool = True,
seed: Optional[int] = None) -> None:
rank, world_size = get_dist_info()
self.rank = rank
self.world_size = world_size
self.dataset = dataset
self.world_size = world_size
self.rank = rank
self.shuffle = shuffle
if seed is None:
seed = sync_random_seed()
self.seed = seed
self.size = len(dataset)
self.indices = self._indices_of_rank()
def _infinite_indices(self) -> Iterator[int]:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
while True:
if self.shuffle:
yield from torch.randperm(self.size, generator=g).tolist()
else:
yield from torch.arange(self.size).tolist()
def _indices_of_rank(self) -> Iterator[int]:
"""Slice the infinite indices by rank."""
yield from itertools.islice(self._infinite_indices(), self.rank, None,
self.world_size)
def __iter__(self) -> Iterator[int]:
"""Iterate the indices."""
for idx in self.indices:
yield idx
def __len__(self) -> int:
"""Length of base dataset."""
return self.size
def set_epoch(self, epoch: int) -> None:
"""Not supported in iteration-based runner."""
raise NotImplementedError(
'The `InfiniteSampler` is only used in iteration-based runner, '
"and doesn't need `set_epoch`")

View File

@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
from functools import partial
from unittest import TestCase
from unittest.mock import patch
import numpy as np
import torch
from mmengine.data import DefaultSampler, InfiniteSampler
class TestDefaultSampler(TestCase):
def setUp(self):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = DefaultSampler(self.dataset)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
# test round_up=True
sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False)
self.assertEqual(sampler.total_size, self.data_length)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(list(sampler), list(range(self.data_length)))
# test round_up=False
sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False)
self.assertEqual(sampler.total_size, self.data_length)
self.assertEqual(sampler.num_samples, self.data_length)
self.assertEqual(list(sampler), list(range(self.data_length)))
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
def test_dist(self, mock):
sampler = DefaultSampler(self.dataset)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
# test round_up=True
sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False)
self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3))
self.assertEqual(sampler.total_size, sampler.num_samples * 3)
self.assertEqual(len(sampler), sampler.num_samples)
self.assertEqual(
list(sampler),
list(range(self.data_length))[2::3] + [1])
# test round_up=False
sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False)
self.assertEqual(sampler.num_samples,
np.ceil((self.data_length - 2) / 3))
self.assertEqual(sampler.total_size, self.data_length)
self.assertEqual(len(sampler), sampler.num_samples)
self.assertEqual(list(sampler), list(range(self.data_length))[2::3])
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = DefaultSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)
# test random seed
sampler = DefaultSampler(self.dataset, shuffle=True, seed=0)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(10)
self.assertEqual(
list(sampler),
torch.randperm(len(self.dataset), generator=g).tolist())
sampler = DefaultSampler(self.dataset, shuffle=True, seed=42)
sampler.set_epoch(10)
g = torch.Generator()
g.manual_seed(42 + 10)
self.assertEqual(
list(sampler),
torch.randperm(len(self.dataset), generator=g).tolist())
class TestInfiniteSampler(TestCase):
def setUp(self):
self.data_length = 100
self.dataset = list(range(self.data_length))
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
def test_non_dist(self, mock):
sampler = InfiniteSampler(self.dataset)
self.assertEqual(sampler.world_size, 1)
self.assertEqual(sampler.rank, 0)
# test iteration
sampler = InfiniteSampler(self.dataset, shuffle=False)
self.assertEqual(len(sampler), self.data_length)
self.assertEqual(sampler.size, self.data_length)
items = [next(iter(sampler)) for _ in range(self.data_length * 2)]
self.assertEqual(items, list(range(self.data_length)) * 2)
@patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3))
def test_dist(self, mock):
sampler = InfiniteSampler(self.dataset)
self.assertEqual(sampler.world_size, 3)
self.assertEqual(sampler.rank, 2)
# test iteration
sampler = InfiniteSampler(self.dataset, shuffle=False)
self.assertEqual(len(sampler), self.data_length)
self.assertEqual(sampler.size, self.data_length)
targets = (list(range(self.data_length)) * 2)[2::3]
samples = [next(iter(sampler)) for _ in range(len(targets))]
self.assertEqual(samples, targets)
@patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1))
@patch('mmengine.data.sampler.sync_random_seed', return_value=7)
def test_shuffle(self, mock1, mock2):
# test seed=None
sampler = InfiniteSampler(self.dataset, seed=None)
self.assertEqual(sampler.seed, 7)
# test the random seed
sampler = InfiniteSampler(self.dataset, shuffle=True, seed=42)
sampler_iter = iter(sampler)
samples = [next(sampler_iter) for _ in range(self.data_length)]
g = torch.Generator()
g.manual_seed(42)
self.assertEqual(
samples,
torch.randperm(self.data_length, generator=g).tolist())
def test_set_epoch(self):
sampler = InfiniteSampler(self.dataset)
self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))