From f0451a38f096bb3e92f4df293c35e6503d4ba643 Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Mon, 21 Feb 2022 13:08:55 +0800 Subject: [PATCH] [Feature] Add data sampler (#30) * Add DefaultSampler and InfiniteSampler * Add unit test * Add samplers to API reference * Update docstring * Improve according to comments * Rename `num_replicas` to `world_size` * Update docstring. * Update mmengine/data/sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/data/sampler.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Fix typo in unit test Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- docs/en/api.rst | 5 + docs/zh_cn/api.rst | 5 + mmengine/data/__init__.py | 4 + mmengine/data/sampler.py | 168 ++++++++++++++++++++++++++++++++ tests/test_data/test_sampler.py | 139 ++++++++++++++++++++++++++ 5 files changed, 321 insertions(+) create mode 100644 mmengine/data/__init__.py create mode 100644 mmengine/data/sampler.py create mode 100644 tests/test_data/test_sampler.py diff --git a/docs/en/api.rst b/docs/en/api.rst index 744f2348..fee9eea1 100644 --- a/docs/en/api.rst +++ b/docs/en/api.rst @@ -2,3 +2,8 @@ Registry -------- .. automodule:: mmengine.registry :members: + +Data +-------- +.. automodule:: mmengine.data + :members: diff --git a/docs/zh_cn/api.rst b/docs/zh_cn/api.rst index 744f2348..fee9eea1 100644 --- a/docs/zh_cn/api.rst +++ b/docs/zh_cn/api.rst @@ -2,3 +2,8 @@ Registry -------- .. automodule:: mmengine.registry :members: + +Data +-------- +.. automodule:: mmengine.data + :members: diff --git a/mmengine/data/__init__.py b/mmengine/data/__init__.py new file mode 100644 index 00000000..1c6b205a --- /dev/null +++ b/mmengine/data/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .sampler import DefaultSampler, InfiniteSampler + +__all__ = ['DefaultSampler', 'InfiniteSampler'] diff --git a/mmengine/data/sampler.py b/mmengine/data/sampler.py new file mode 100644 index 00000000..3d891909 --- /dev/null +++ b/mmengine/data/sampler.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import math +from typing import Iterator, Optional, Sized + +import torch +from torch.utils.data import Sampler + +from mmengine.dist import get_dist_info, sync_random_seed +from mmengine.registry import DATA_SAMPLERS + + +@DATA_SAMPLERS.register_module() +class DefaultSampler(Sampler[int]): + """The default data sampler for both distributed and non-distributed + environment. + + It has several differences from the PyTorch ``DistributedSampler`` as + below: + + 1. This sampler supports non-distributed environment. + + 2. The round up behaviors are a little different. + + - If ``round_up=True``, this sampler will add extra samples to make the + number of samples is evenly divisible by the world size. And + this behavior is the same as the ``DistributedSampler`` with + ``drop_last=False``. + - If ``round_up=False``, this sampler won't remove or add any samples + while the ``DistributedSampler`` with ``drop_last=True`` will remove + tail samples. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + round_up (bool): Whether to add extra samples to make the number of + samples evenly divisible by the world size. Defaults to True. + """ + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + seed: Optional[int] = None, + round_up: bool = True) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil(len(self.dataset) / world_size) + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + + # subsample + indices = indices[self.rank:self.total_size:self.world_size] + + return iter(indices) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +@DATA_SAMPLERS.register_module() +class InfiniteSampler(Sampler[int]): + """It's designed for iteration-based runner and yields a mini-batch indices + each time. + + The implementation logic is referred to + https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/distributed_sampler.py + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + seed (int, optional): Random seed. If None, set a random seed. + Defaults to None. + """ # noqa: W605 + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + seed: Optional[int] = None) -> None: + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.world_size = world_size + self.rank = rank + self.shuffle = shuffle + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.size = len(dataset) + self.indices = self._indices_of_rank() + + def _infinite_indices(self) -> Iterator[int]: + """Infinitely yield a sequence of indices.""" + g = torch.Generator() + g.manual_seed(self.seed) + while True: + if self.shuffle: + yield from torch.randperm(self.size, generator=g).tolist() + + else: + yield from torch.arange(self.size).tolist() + + def _indices_of_rank(self) -> Iterator[int]: + """Slice the infinite indices by rank.""" + yield from itertools.islice(self._infinite_indices(), self.rank, None, + self.world_size) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + for idx in self.indices: + yield idx + + def __len__(self) -> int: + """Length of base dataset.""" + return self.size + + def set_epoch(self, epoch: int) -> None: + """Not supported in iteration-based runner.""" + raise NotImplementedError( + 'The `InfiniteSampler` is only used in iteration-based runner, ' + "and doesn't need `set_epoch`") diff --git a/tests/test_data/test_sampler.py b/tests/test_data/test_sampler.py new file mode 100644 index 00000000..c2494d0e --- /dev/null +++ b/tests/test_data/test_sampler.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from functools import partial +from unittest import TestCase +from unittest.mock import patch + +import numpy as np +import torch + +from mmengine.data import DefaultSampler, InfiniteSampler + + +class TestDefaultSampler(TestCase): + + def setUp(self): + self.data_length = 100 + self.dataset = list(range(self.data_length)) + + @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) + def test_non_dist(self, mock): + sampler = DefaultSampler(self.dataset) + self.assertEqual(sampler.world_size, 1) + self.assertEqual(sampler.rank, 0) + + # test round_up=True + sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False) + self.assertEqual(sampler.total_size, self.data_length) + self.assertEqual(sampler.num_samples, self.data_length) + self.assertEqual(list(sampler), list(range(self.data_length))) + + # test round_up=False + sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) + self.assertEqual(sampler.total_size, self.data_length) + self.assertEqual(sampler.num_samples, self.data_length) + self.assertEqual(list(sampler), list(range(self.data_length))) + + @patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3)) + def test_dist(self, mock): + sampler = DefaultSampler(self.dataset) + self.assertEqual(sampler.world_size, 3) + self.assertEqual(sampler.rank, 2) + + # test round_up=True + sampler = DefaultSampler(self.dataset, round_up=True, shuffle=False) + self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3)) + self.assertEqual(sampler.total_size, sampler.num_samples * 3) + self.assertEqual(len(sampler), sampler.num_samples) + self.assertEqual( + list(sampler), + list(range(self.data_length))[2::3] + [1]) + + # test round_up=False + sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) + self.assertEqual(sampler.num_samples, + np.ceil((self.data_length - 2) / 3)) + self.assertEqual(sampler.total_size, self.data_length) + self.assertEqual(len(sampler), sampler.num_samples) + self.assertEqual(list(sampler), list(range(self.data_length))[2::3]) + + @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) + @patch('mmengine.data.sampler.sync_random_seed', return_value=7) + def test_shuffle(self, mock1, mock2): + # test seed=None + sampler = DefaultSampler(self.dataset, seed=None) + self.assertEqual(sampler.seed, 7) + + # test random seed + sampler = DefaultSampler(self.dataset, shuffle=True, seed=0) + sampler.set_epoch(10) + g = torch.Generator() + g.manual_seed(10) + self.assertEqual( + list(sampler), + torch.randperm(len(self.dataset), generator=g).tolist()) + + sampler = DefaultSampler(self.dataset, shuffle=True, seed=42) + sampler.set_epoch(10) + g = torch.Generator() + g.manual_seed(42 + 10) + self.assertEqual( + list(sampler), + torch.randperm(len(self.dataset), generator=g).tolist()) + + +class TestInfiniteSampler(TestCase): + + def setUp(self): + self.data_length = 100 + self.dataset = list(range(self.data_length)) + + @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) + def test_non_dist(self, mock): + sampler = InfiniteSampler(self.dataset) + self.assertEqual(sampler.world_size, 1) + self.assertEqual(sampler.rank, 0) + + # test iteration + sampler = InfiniteSampler(self.dataset, shuffle=False) + self.assertEqual(len(sampler), self.data_length) + self.assertEqual(sampler.size, self.data_length) + items = [next(iter(sampler)) for _ in range(self.data_length * 2)] + self.assertEqual(items, list(range(self.data_length)) * 2) + + @patch('mmengine.data.sampler.get_dist_info', return_value=(2, 3)) + def test_dist(self, mock): + sampler = InfiniteSampler(self.dataset) + self.assertEqual(sampler.world_size, 3) + self.assertEqual(sampler.rank, 2) + + # test iteration + sampler = InfiniteSampler(self.dataset, shuffle=False) + self.assertEqual(len(sampler), self.data_length) + self.assertEqual(sampler.size, self.data_length) + targets = (list(range(self.data_length)) * 2)[2::3] + samples = [next(iter(sampler)) for _ in range(len(targets))] + self.assertEqual(samples, targets) + + @patch('mmengine.data.sampler.get_dist_info', return_value=(0, 1)) + @patch('mmengine.data.sampler.sync_random_seed', return_value=7) + def test_shuffle(self, mock1, mock2): + # test seed=None + sampler = InfiniteSampler(self.dataset, seed=None) + self.assertEqual(sampler.seed, 7) + + # test the random seed + sampler = InfiniteSampler(self.dataset, shuffle=True, seed=42) + + sampler_iter = iter(sampler) + samples = [next(sampler_iter) for _ in range(self.data_length)] + + g = torch.Generator() + g.manual_seed(42) + self.assertEqual( + samples, + torch.randperm(self.data_length, generator=g).tolist()) + + def test_set_epoch(self): + sampler = InfiniteSampler(self.dataset) + self.assertRaises(NotImplementedError, partial(sampler.set_epoch, 10))