diff --git a/mmocr/datasets/__init__.py b/mmocr/datasets/__init__.py index 97cdb16b..54a9ea7f 100644 --- a/mmocr/datasets/__init__.py +++ b/mmocr/datasets/__init__.py @@ -4,6 +4,7 @@ from .icdar_dataset import IcdarDataset from .ocr_dataset import OCRDataset from .recog_lmdb_dataset import RecogLMDBDataset from .recog_text_dataset import RecogTextDataset +from .samplers import * # NOQA from .transforms import * # NOQA from .wildreceipt_dataset import WildReceiptDataset diff --git a/mmocr/datasets/samplers/__init__.py b/mmocr/datasets/samplers/__init__.py new file mode 100644 index 00000000..9b99ca07 --- /dev/null +++ b/mmocr/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .repeat_aug import RepeatAugSampler + +__all__ = ['RepeatAugSampler'] diff --git a/mmocr/datasets/samplers/repeat_aug.py b/mmocr/datasets/samplers/repeat_aug.py new file mode 100644 index 00000000..767b58eb --- /dev/null +++ b/mmocr/datasets/samplers/repeat_aug.py @@ -0,0 +1,100 @@ +import math +from typing import Iterator, Optional, Sized + +import torch +from mmcls.registry import DATA_SAMPLERS +from mmengine.dist import get_dist_info, is_main_process, sync_random_seed +from torch.utils.data import Sampler + + +@DATA_SAMPLERS.register_module() +class RepeatAugSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset for + distributed, with repeated augmentation. It ensures that different each + augmented version of a sample will be visible to a different process (GPU). + Heavily based on torch.utils.data.DistributedSampler. + + This sampler was taken from + https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py + Used in + Copyright (c) 2015-present, Facebook, Inc. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + num_repeats (int): The repeat times of every sample. Defaults to 3. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. + """ + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + num_repeats: int = 3, + seed: Optional[int] = None): + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if not self.shuffle and is_main_process(): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning('The RepeatAugSampler always picks a ' + 'fixed part of data if `shuffle=False`.') + + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.num_repeats = num_repeats + + # The number of repeated samples in the rank + self.num_samples = math.ceil( + len(self.dataset) * num_repeats / world_size) + # The total number of repeated samples in all ranks. + self.total_size = self.num_samples * world_size + # The number of selected samples in the rank + self.num_selected_samples = math.ceil(len(self.dataset) / world_size) + + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] + indices = [x for x in indices for _ in range(self.num_repeats)] + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + indices += indices[:padding_size] + assert len(indices) == self.total_size + + # subsample per rank + indices = indices[self.rank:self.total_size:self.world_size] + assert len(indices) == self.num_samples + + # return up to num selected samples + return iter(indices[:self.num_selected_samples]) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_selected_samples + + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/tests/test_datasets/test_samplers/test_repeat_aug.py b/tests/test_datasets/test_samplers/test_repeat_aug.py new file mode 100644 index 00000000..4fb973fe --- /dev/null +++ b/tests/test_datasets/test_samplers/test_repeat_aug.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math +from unittest import TestCase +from unittest.mock import patch + +import torch +from mmengine.logging import MMLogger + +from mmocr.datasets import RepeatAugSampler + +file = 'mmocr.datasets.samplers.repeat_aug.' + + +class MockDist: + + def __init__(self, dist_info=(0, 1), seed=7): + self.dist_info = dist_info + self.seed = seed + + def get_dist_info(self): + return self.dist_info + + def sync_random_seed(self): + return self.seed + + def is_main_process(self): + return self.dist_info[0] == 0 + + +class TestRepeatAugSampler(TestCase): + + def setUp(self): + self.data_length = 100 + self.dataset = list(range(self.data_length)) + + @patch(file + 'get_dist_info', return_value=(0, 1)) + def test_non_dist(self, mock): + sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) + self.assertEqual(sampler.world_size, 1) + self.assertEqual(sampler.rank, 0) + self.assertEqual(sampler.total_size, self.data_length * 3) + self.assertEqual(sampler.num_samples, self.data_length * 3) + self.assertEqual(sampler.num_selected_samples, self.data_length) + self.assertEqual(len(sampler), sampler.num_selected_samples) + indices = [x for x in range(self.data_length) for _ in range(3)] + self.assertEqual(list(sampler), indices[:self.data_length]) + + logger = MMLogger.get_current_instance() + with self.assertLogs(logger, 'WARN') as log: + sampler = RepeatAugSampler(self.dataset, shuffle=False) + self.assertIn('always picks a fixed part', log.output[0]) + + @patch(file + 'get_dist_info', return_value=(2, 3)) + @patch(file + 'is_main_process', return_value=False) + def test_dist(self, mock1, mock2): + sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) + self.assertEqual(sampler.world_size, 3) + self.assertEqual(sampler.rank, 2) + self.assertEqual(sampler.num_samples, self.data_length) + self.assertEqual(sampler.total_size, self.data_length * 3) + self.assertEqual(sampler.num_selected_samples, + math.ceil(self.data_length / 3)) + self.assertEqual(len(sampler), sampler.num_selected_samples) + indices = [x for x in range(self.data_length) for _ in range(3)] + self.assertEqual( + list(sampler), indices[2::3][:sampler.num_selected_samples]) + + logger = MMLogger.get_current_instance() + with patch.object(logger, 'warning') as mock_log: + sampler = RepeatAugSampler(self.dataset, shuffle=False) + mock_log.assert_not_called() + + @patch(file + 'get_dist_info', return_value=(0, 1)) + @patch(file + 'sync_random_seed', return_value=7) + def test_shuffle(self, mock1, mock2): + # test seed=None + sampler = RepeatAugSampler(self.dataset, seed=None) + self.assertEqual(sampler.seed, 7) + + # test random seed + sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0) + sampler.set_epoch(10) + g = torch.Generator() + g.manual_seed(10) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = [x for x in indices + for _ in range(3)][:sampler.num_selected_samples] + self.assertEqual(list(sampler), indices) + + sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42) + sampler.set_epoch(10) + g = torch.Generator() + g.manual_seed(42 + 10) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = [x for x in indices + for _ in range(3)][:sampler.num_selected_samples] + self.assertEqual(list(sampler), indices)