fast-reid/fastreid/data/samplers/imbalance_sampler.py

68 lines
2.2 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# based on:
# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
import itertools
from typing import Optional, List, Callable
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
from fastreid.utils import comm
class ImbalancedDatasetSampler(Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
data_source: a list of data items
size: number of samples to draw
"""
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
callback_get_label: Callable = None):
self.data_source = data_source
# consider all elements in the dataset
self.indices = list(range(len(data_source)))
# if num_samples is not provided, draw `len(indices)` samples in each iteration
self._size = len(self.indices) if size is None else size
self.callback_get_label = callback_get_label
# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(data_source, idx)
label_to_count[label] = label_to_count.get(label, 0) + 1
# weight for each sample
weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
self.weights = torch.DoubleTensor(weights)
if seed is None:
seed = comm.shared_random_seed()
self._seed = int(seed)
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
def _get_label(self, dataset, idx):
if self.callback_get_label:
return self.callback_get_label(dataset, idx)
else:
return dataset[idx][1]
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
np.random.seed(self._seed)
while True:
for i in torch.multinomial(self.weights, self._size, replacement=True):
yield self.indices[i]