fast-reid/fastreid/data/samplers/training_sampler.py

50 lines
1.7 KiB
Python
Raw Normal View History

# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import itertools
from typing import Optional
import numpy as np
from torch.utils.data import Sampler
class TrainingSampler(Sampler):
"""
In training, we only care about the "infinite stream" of training data.
So this sampler produces an infinite stream of indices and
all workers cooperate to correctly shuffle the indices and sample different indices.
The samplers in each worker effectively produces `indices[worker_id::num_workers]`
where `indices` is an infinite stream of indices consisting of
`shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
or `range(size) + range(size) + ...` (if shuffle is False)
"""
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
"""
Args:
size (int): the total number of data of the underlying dataset to sample from
shuffle (bool): whether to shuffle the indices or not
seed (int): the initial seed of the shuffle. Must be the same
across all workers. If None, will use a random seed shared
among workers (require synchronization among all workers).
"""
self._size = size
assert size > 0
self._shuffle = shuffle
if seed is None:
seed = np.random.randint(2 ** 31)
self._seed = int(seed)
def __iter__(self):
yield from itertools.islice(self._infinite_indices(), 0, None, 1)
def _infinite_indices(self):
np.random.seed(self._seed)
while True:
if self._shuffle:
yield from np.random.permutation(self._size)
else:
yield from np.arange(self._size)