diff --git a/mmcls/datasets/samplers/repeat_aug.py b/mmcls/datasets/samplers/repeat_aug.py index 58dd20249..d4b7e1e95 100644 --- a/mmcls/datasets/samplers/repeat_aug.py +++ b/mmcls/datasets/samplers/repeat_aug.py @@ -1,8 +1,8 @@ import math +from typing import Iterator, Optional, Sized import torch -from mmcv.runner import get_dist_info -from mmengine.dist import sync_random_seed +from mmengine.dist import get_dist_info, is_main_process, sync_random_seed from torch.utils.data import Sampler from mmcls.registry import DATA_SAMPLERS @@ -19,69 +19,54 @@ class RepeatAugSampler(Sampler): https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py Used in Copyright (c) 2015-present, Facebook, Inc. + + Args: + dataset (Sized): The dataset. + shuffle (bool): Whether shuffle the dataset or not. Defaults to True. + num_repeats (int): The repeat times of every sample. Defaults to 3. + seed (int, optional): Random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Defaults to None. """ def __init__(self, - dataset, - num_replicas=None, - rank=None, - shuffle=True, - num_repeats=3, - selected_round=256, - selected_ratio=0, - seed=0): - default_rank, default_world_size = get_dist_info() - rank = default_rank if rank is None else rank - num_replicas = ( - default_world_size if num_replicas is None else num_replicas) + dataset: Sized, + shuffle: bool = True, + num_repeats: int = 3, + seed: Optional[int] = None): + rank, world_size = get_dist_info() + self.rank = rank + self.world_size = world_size self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank self.shuffle = shuffle - self.num_repeats = num_repeats + if not self.shuffle and is_main_process(): + from mmengine.logging import MMLogger + logger = MMLogger.get_current_instance() + logger.warning('The RepeatAugSampler always picks a ' + 'fixed part of data if `shuffle=False`.') + + if seed is None: + seed = sync_random_seed() + self.seed = seed self.epoch = 0 - self.num_samples = int( - math.ceil(len(self.dataset) * num_repeats / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - # Determine the number of samples to select per epoch for each rank. - # num_selected logic defaults to be the same as original RASampler - # impl, but this one can be tweaked - # via selected_ratio and selected_round args. - selected_ratio = selected_ratio or num_replicas # ratio to reduce - # selected samples by, num_replicas if 0 - if selected_round: - self.num_selected_samples = int( - math.floor( - len(self.dataset) // selected_round * selected_round / - selected_ratio)) - else: - self.num_selected_samples = int( - math.ceil(len(self.dataset) / selected_ratio)) + self.num_repeats = num_repeats - # In distributed sampling, different ranks should sample - # non-overlapped data in the dataset. Therefore, this function - # is used to make sure that each rank shuffles the data indices - # in the same order based on the same seed. Then different ranks - # could use different indices to select non-overlapped data from the - # same data list. - self.seed = sync_random_seed(seed) + # The number of repeated samples in the rank + self.num_samples = math.ceil( + len(self.dataset) * num_repeats / world_size) + # The total number of repeated samples in all ranks. + self.total_size = self.num_samples * world_size + # The number of selected samples in the rank + self.num_selected_samples = math.ceil(len(self.dataset) / world_size) - def __iter__(self): - # deterministically shuffle based on epoch + def __iter__(self) -> Iterator[int]: + """Iterate the indices.""" + # deterministically shuffle based on epoch and seed if self.shuffle: - if self.num_replicas > 1: # In distributed environment - # deterministically shuffle based on epoch - g = torch.Generator() - # When :attr:`shuffle=True`, this ensures all replicas - # use a different random ordering for each epoch. - # Otherwise, the next iteration of this sampler will - # yield the same ordering. - g.manual_seed(self.epoch + self.seed) - indices = torch.randperm( - len(self.dataset), generator=g).tolist() - else: - indices = torch.randperm(len(self.dataset)).tolist() + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = list(range(len(self.dataset))) @@ -93,14 +78,24 @@ class RepeatAugSampler(Sampler): assert len(indices) == self.total_size # subsample per rank - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank:self.total_size:self.world_size] assert len(indices) == self.num_samples # return up to num selected samples return iter(indices[:self.num_selected_samples]) - def __len__(self): + def __len__(self) -> int: + """The number of samples in this rank.""" return self.num_selected_samples - def set_epoch(self, epoch): + def set_epoch(self, epoch: int) -> None: + """Sets the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas use a different + random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ self.epoch = epoch diff --git a/tests/test_data/test_samplers/test_repeat_aug.py b/tests/test_data/test_samplers/test_repeat_aug.py new file mode 100644 index 000000000..1fce3510c --- /dev/null +++ b/tests/test_data/test_samplers/test_repeat_aug.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import math +from unittest import TestCase +from unittest.mock import patch + +import torch +from mmengine.logging import MMLogger + +from mmcls.datasets import RepeatAugSampler + +file = 'mmcls.datasets.samplers.repeat_aug.' + + +class MockDist: + + def __init__(self, dist_info=(0, 1), seed=7): + self.dist_info = dist_info + self.seed = seed + + def get_dist_info(self): + return self.dist_info + + def sync_random_seed(self): + return self.seed + + def is_main_process(self): + return self.dist_info[0] == 0 + + +class TestRepeatAugSampler(TestCase): + + def setUp(self): + self.data_length = 100 + self.dataset = list(range(self.data_length)) + + @patch(file + 'get_dist_info', return_value=(0, 1)) + def test_non_dist(self, mock): + sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) + self.assertEqual(sampler.world_size, 1) + self.assertEqual(sampler.rank, 0) + self.assertEqual(sampler.total_size, self.data_length * 3) + self.assertEqual(sampler.num_samples, self.data_length * 3) + self.assertEqual(sampler.num_selected_samples, self.data_length) + self.assertEqual(len(sampler), sampler.num_selected_samples) + indices = [x for x in range(self.data_length) for _ in range(3)] + self.assertEqual(list(sampler), indices[:self.data_length]) + + logger = MMLogger.get_current_instance() + with self.assertLogs(logger, 'WARN') as log: + sampler = RepeatAugSampler(self.dataset, shuffle=False) + self.assertIn('always picks a fixed part', log.output[0]) + + @patch(file + 'get_dist_info', return_value=(2, 3)) + @patch(file + 'is_main_process', return_value=False) + def test_dist(self, mock1, mock2): + sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False) + self.assertEqual(sampler.world_size, 3) + self.assertEqual(sampler.rank, 2) + self.assertEqual(sampler.num_samples, self.data_length) + self.assertEqual(sampler.total_size, self.data_length * 3) + self.assertEqual(sampler.num_selected_samples, + math.ceil(self.data_length / 3)) + self.assertEqual(len(sampler), sampler.num_selected_samples) + indices = [x for x in range(self.data_length) for _ in range(3)] + self.assertEqual( + list(sampler), indices[2::3][:sampler.num_selected_samples]) + + logger = MMLogger.get_current_instance() + with patch.object(logger, 'warning') as mock_log: + sampler = RepeatAugSampler(self.dataset, shuffle=False) + mock_log.assert_not_called() + + @patch(file + 'get_dist_info', return_value=(0, 1)) + @patch(file + 'sync_random_seed', return_value=7) + def test_shuffle(self, mock1, mock2): + # test seed=None + sampler = RepeatAugSampler(self.dataset, seed=None) + self.assertEqual(sampler.seed, 7) + + # test random seed + sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0) + sampler.set_epoch(10) + g = torch.Generator() + g.manual_seed(10) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = [x for x in indices + for _ in range(3)][:sampler.num_selected_samples] + self.assertEqual(list(sampler), indices) + + sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42) + sampler.set_epoch(10) + g = torch.Generator() + g.manual_seed(42 + 10) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = [x for x in indices + for _ in range(3)][:sampler.num_selected_samples] + self.assertEqual(list(sampler), indices)