[Refactor] Refactor `RepeatAugSampler`.
parent
e8d69cf2ff
commit
12c982f939
|
@ -1,8 +1,8 @@
|
||||||
import math
|
import math
|
||||||
|
from typing import Iterator, Optional, Sized
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from mmcv.runner import get_dist_info
|
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
|
||||||
from mmengine.dist import sync_random_seed
|
|
||||||
from torch.utils.data import Sampler
|
from torch.utils.data import Sampler
|
||||||
|
|
||||||
from mmcls.registry import DATA_SAMPLERS
|
from mmcls.registry import DATA_SAMPLERS
|
||||||
|
@ -19,69 +19,54 @@ class RepeatAugSampler(Sampler):
|
||||||
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
||||||
Used in
|
Used in
|
||||||
Copyright (c) 2015-present, Facebook, Inc.
|
Copyright (c) 2015-present, Facebook, Inc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (Sized): The dataset.
|
||||||
|
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
||||||
|
num_repeats (int): The repeat times of every sample. Defaults to 3.
|
||||||
|
seed (int, optional): Random seed used to shuffle the sampler if
|
||||||
|
:attr:`shuffle=True`. This number should be identical across all
|
||||||
|
processes in the distributed group. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dataset,
|
dataset: Sized,
|
||||||
num_replicas=None,
|
shuffle: bool = True,
|
||||||
rank=None,
|
num_repeats: int = 3,
|
||||||
shuffle=True,
|
seed: Optional[int] = None):
|
||||||
num_repeats=3,
|
rank, world_size = get_dist_info()
|
||||||
selected_round=256,
|
self.rank = rank
|
||||||
selected_ratio=0,
|
self.world_size = world_size
|
||||||
seed=0):
|
|
||||||
default_rank, default_world_size = get_dist_info()
|
|
||||||
rank = default_rank if rank is None else rank
|
|
||||||
num_replicas = (
|
|
||||||
default_world_size if num_replicas is None else num_replicas)
|
|
||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.num_replicas = num_replicas
|
|
||||||
self.rank = rank
|
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.num_repeats = num_repeats
|
if not self.shuffle and is_main_process():
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
|
logger = MMLogger.get_current_instance()
|
||||||
|
logger.warning('The RepeatAugSampler always picks a '
|
||||||
|
'fixed part of data if `shuffle=False`.')
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = sync_random_seed()
|
||||||
|
self.seed = seed
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.num_samples = int(
|
self.num_repeats = num_repeats
|
||||||
math.ceil(len(self.dataset) * num_repeats / self.num_replicas))
|
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
|
||||||
# Determine the number of samples to select per epoch for each rank.
|
|
||||||
# num_selected logic defaults to be the same as original RASampler
|
|
||||||
# impl, but this one can be tweaked
|
|
||||||
# via selected_ratio and selected_round args.
|
|
||||||
selected_ratio = selected_ratio or num_replicas # ratio to reduce
|
|
||||||
# selected samples by, num_replicas if 0
|
|
||||||
if selected_round:
|
|
||||||
self.num_selected_samples = int(
|
|
||||||
math.floor(
|
|
||||||
len(self.dataset) // selected_round * selected_round /
|
|
||||||
selected_ratio))
|
|
||||||
else:
|
|
||||||
self.num_selected_samples = int(
|
|
||||||
math.ceil(len(self.dataset) / selected_ratio))
|
|
||||||
|
|
||||||
# In distributed sampling, different ranks should sample
|
# The number of repeated samples in the rank
|
||||||
# non-overlapped data in the dataset. Therefore, this function
|
self.num_samples = math.ceil(
|
||||||
# is used to make sure that each rank shuffles the data indices
|
len(self.dataset) * num_repeats / world_size)
|
||||||
# in the same order based on the same seed. Then different ranks
|
# The total number of repeated samples in all ranks.
|
||||||
# could use different indices to select non-overlapped data from the
|
self.total_size = self.num_samples * world_size
|
||||||
# same data list.
|
# The number of selected samples in the rank
|
||||||
self.seed = sync_random_seed(seed)
|
self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[int]:
|
||||||
# deterministically shuffle based on epoch
|
"""Iterate the indices."""
|
||||||
|
# deterministically shuffle based on epoch and seed
|
||||||
if self.shuffle:
|
if self.shuffle:
|
||||||
if self.num_replicas > 1: # In distributed environment
|
|
||||||
# deterministically shuffle based on epoch
|
|
||||||
g = torch.Generator()
|
g = torch.Generator()
|
||||||
# When :attr:`shuffle=True`, this ensures all replicas
|
g.manual_seed(self.seed + self.epoch)
|
||||||
# use a different random ordering for each epoch.
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||||
# Otherwise, the next iteration of this sampler will
|
|
||||||
# yield the same ordering.
|
|
||||||
g.manual_seed(self.epoch + self.seed)
|
|
||||||
indices = torch.randperm(
|
|
||||||
len(self.dataset), generator=g).tolist()
|
|
||||||
else:
|
|
||||||
indices = torch.randperm(len(self.dataset)).tolist()
|
|
||||||
else:
|
else:
|
||||||
indices = list(range(len(self.dataset)))
|
indices = list(range(len(self.dataset)))
|
||||||
|
|
||||||
|
@ -93,14 +78,24 @@ class RepeatAugSampler(Sampler):
|
||||||
assert len(indices) == self.total_size
|
assert len(indices) == self.total_size
|
||||||
|
|
||||||
# subsample per rank
|
# subsample per rank
|
||||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
indices = indices[self.rank:self.total_size:self.world_size]
|
||||||
assert len(indices) == self.num_samples
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
# return up to num selected samples
|
# return up to num selected samples
|
||||||
return iter(indices[:self.num_selected_samples])
|
return iter(indices[:self.num_selected_samples])
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
|
"""The number of samples in this rank."""
|
||||||
return self.num_selected_samples
|
return self.num_selected_samples
|
||||||
|
|
||||||
def set_epoch(self, epoch):
|
def set_epoch(self, epoch: int) -> None:
|
||||||
|
"""Sets the epoch for this sampler.
|
||||||
|
|
||||||
|
When :attr:`shuffle=True`, this ensures all replicas use a different
|
||||||
|
random ordering for each epoch. Otherwise, the next iteration of this
|
||||||
|
sampler will yield the same ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): Epoch number.
|
||||||
|
"""
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
import math
|
||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mmengine.logging import MMLogger
|
||||||
|
|
||||||
|
from mmcls.datasets import RepeatAugSampler
|
||||||
|
|
||||||
|
file = 'mmcls.datasets.samplers.repeat_aug.'
|
||||||
|
|
||||||
|
|
||||||
|
class MockDist:
|
||||||
|
|
||||||
|
def __init__(self, dist_info=(0, 1), seed=7):
|
||||||
|
self.dist_info = dist_info
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def get_dist_info(self):
|
||||||
|
return self.dist_info
|
||||||
|
|
||||||
|
def sync_random_seed(self):
|
||||||
|
return self.seed
|
||||||
|
|
||||||
|
def is_main_process(self):
|
||||||
|
return self.dist_info[0] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestRepeatAugSampler(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.data_length = 100
|
||||||
|
self.dataset = list(range(self.data_length))
|
||||||
|
|
||||||
|
@patch(file + 'get_dist_info', return_value=(0, 1))
|
||||||
|
def test_non_dist(self, mock):
|
||||||
|
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||||
|
self.assertEqual(sampler.world_size, 1)
|
||||||
|
self.assertEqual(sampler.rank, 0)
|
||||||
|
self.assertEqual(sampler.total_size, self.data_length * 3)
|
||||||
|
self.assertEqual(sampler.num_samples, self.data_length * 3)
|
||||||
|
self.assertEqual(sampler.num_selected_samples, self.data_length)
|
||||||
|
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
||||||
|
indices = [x for x in range(self.data_length) for _ in range(3)]
|
||||||
|
self.assertEqual(list(sampler), indices[:self.data_length])
|
||||||
|
|
||||||
|
logger = MMLogger.get_current_instance()
|
||||||
|
with self.assertLogs(logger, 'WARN') as log:
|
||||||
|
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
||||||
|
self.assertIn('always picks a fixed part', log.output[0])
|
||||||
|
|
||||||
|
@patch(file + 'get_dist_info', return_value=(2, 3))
|
||||||
|
@patch(file + 'is_main_process', return_value=False)
|
||||||
|
def test_dist(self, mock1, mock2):
|
||||||
|
sampler = RepeatAugSampler(self.dataset, num_repeats=3, shuffle=False)
|
||||||
|
self.assertEqual(sampler.world_size, 3)
|
||||||
|
self.assertEqual(sampler.rank, 2)
|
||||||
|
self.assertEqual(sampler.num_samples, self.data_length)
|
||||||
|
self.assertEqual(sampler.total_size, self.data_length * 3)
|
||||||
|
self.assertEqual(sampler.num_selected_samples,
|
||||||
|
math.ceil(self.data_length / 3))
|
||||||
|
self.assertEqual(len(sampler), sampler.num_selected_samples)
|
||||||
|
indices = [x for x in range(self.data_length) for _ in range(3)]
|
||||||
|
self.assertEqual(
|
||||||
|
list(sampler), indices[2::3][:sampler.num_selected_samples])
|
||||||
|
|
||||||
|
logger = MMLogger.get_current_instance()
|
||||||
|
with patch.object(logger, 'warning') as mock_log:
|
||||||
|
sampler = RepeatAugSampler(self.dataset, shuffle=False)
|
||||||
|
mock_log.assert_not_called()
|
||||||
|
|
||||||
|
@patch(file + 'get_dist_info', return_value=(0, 1))
|
||||||
|
@patch(file + 'sync_random_seed', return_value=7)
|
||||||
|
def test_shuffle(self, mock1, mock2):
|
||||||
|
# test seed=None
|
||||||
|
sampler = RepeatAugSampler(self.dataset, seed=None)
|
||||||
|
self.assertEqual(sampler.seed, 7)
|
||||||
|
|
||||||
|
# test random seed
|
||||||
|
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=0)
|
||||||
|
sampler.set_epoch(10)
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(10)
|
||||||
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||||
|
indices = [x for x in indices
|
||||||
|
for _ in range(3)][:sampler.num_selected_samples]
|
||||||
|
self.assertEqual(list(sampler), indices)
|
||||||
|
|
||||||
|
sampler = RepeatAugSampler(self.dataset, shuffle=True, seed=42)
|
||||||
|
sampler.set_epoch(10)
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(42 + 10)
|
||||||
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||||
|
indices = [x for x in indices
|
||||||
|
for _ in range(3)][:sampler.num_selected_samples]
|
||||||
|
self.assertEqual(list(sampler), indices)
|
Loading…
Reference in New Issue