102 lines
3.8 KiB
Python
102 lines
3.8 KiB
Python
import math
|
|
from typing import Iterator, Optional, Sized
|
|
|
|
import torch
|
|
from mmengine.dist import get_dist_info, is_main_process, sync_random_seed
|
|
from torch.utils.data import Sampler
|
|
|
|
from mmpretrain.registry import DATA_SAMPLERS
|
|
|
|
|
|
@DATA_SAMPLERS.register_module()
|
|
class RepeatAugSampler(Sampler):
|
|
"""Sampler that restricts data loading to a subset of the dataset for
|
|
distributed, with repeated augmentation. It ensures that different each
|
|
augmented version of a sample will be visible to a different process (GPU).
|
|
Heavily based on torch.utils.data.DistributedSampler.
|
|
|
|
This sampler was taken from
|
|
https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
|
|
Used in
|
|
Copyright (c) 2015-present, Facebook, Inc.
|
|
|
|
Args:
|
|
dataset (Sized): The dataset.
|
|
shuffle (bool): Whether shuffle the dataset or not. Defaults to True.
|
|
num_repeats (int): The repeat times of every sample. Defaults to 3.
|
|
seed (int, optional): Random seed used to shuffle the sampler if
|
|
:attr:`shuffle=True`. This number should be identical across all
|
|
processes in the distributed group. Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
dataset: Sized,
|
|
shuffle: bool = True,
|
|
num_repeats: int = 3,
|
|
seed: Optional[int] = None):
|
|
rank, world_size = get_dist_info()
|
|
self.rank = rank
|
|
self.world_size = world_size
|
|
|
|
self.dataset = dataset
|
|
self.shuffle = shuffle
|
|
if not self.shuffle and is_main_process():
|
|
from mmengine.logging import MMLogger
|
|
logger = MMLogger.get_current_instance()
|
|
logger.warning('The RepeatAugSampler always picks a '
|
|
'fixed part of data if `shuffle=False`.')
|
|
|
|
if seed is None:
|
|
seed = sync_random_seed()
|
|
self.seed = seed
|
|
self.epoch = 0
|
|
self.num_repeats = num_repeats
|
|
|
|
# The number of repeated samples in the rank
|
|
self.num_samples = math.ceil(
|
|
len(self.dataset) * num_repeats / world_size)
|
|
# The total number of repeated samples in all ranks.
|
|
self.total_size = self.num_samples * world_size
|
|
# The number of selected samples in the rank
|
|
self.num_selected_samples = math.ceil(len(self.dataset) / world_size)
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
"""Iterate the indices."""
|
|
# deterministically shuffle based on epoch and seed
|
|
if self.shuffle:
|
|
g = torch.Generator()
|
|
g.manual_seed(self.seed + self.epoch)
|
|
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
|
else:
|
|
indices = list(range(len(self.dataset)))
|
|
|
|
# produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
|
|
indices = [x for x in indices for _ in range(self.num_repeats)]
|
|
# add extra samples to make it evenly divisible
|
|
padding_size = self.total_size - len(indices)
|
|
indices += indices[:padding_size]
|
|
assert len(indices) == self.total_size
|
|
|
|
# subsample per rank
|
|
indices = indices[self.rank:self.total_size:self.world_size]
|
|
assert len(indices) == self.num_samples
|
|
|
|
# return up to num selected samples
|
|
return iter(indices[:self.num_selected_samples])
|
|
|
|
def __len__(self) -> int:
|
|
"""The number of samples in this rank."""
|
|
return self.num_selected_samples
|
|
|
|
def set_epoch(self, epoch: int) -> None:
|
|
"""Sets the epoch for this sampler.
|
|
|
|
When :attr:`shuffle=True`, this ensures all replicas use a different
|
|
random ordering for each epoch. Otherwise, the next iteration of this
|
|
sampler will yield the same ordering.
|
|
|
|
Args:
|
|
epoch (int): Epoch number.
|
|
"""
|
|
self.epoch = epoch
|