update imbalanced sampler

Summary: add a new sampler, which is useful for imbalanced or long-tail dataset. This refers to ufoym/imbalanced-dataset-sampler.
pull/470/head
liaoxingyu 2021-04-21 17:05:10 +08:00
parent bb6ddbf8b1
commit 0c8e3d9805
8 changed files with 79 additions and 9 deletions

View File

@ -56,6 +56,8 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
elif sampler_name == "SetReWeightSampler":
set_weight = cfg.DATALOADER.SET_WEIGHT
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
elif sampler_name == "ImbalancedDatasetSampler":
sampler = samplers.ImbalancedDatasetSampler(train_set.img_items)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))

View File

@ -6,11 +6,13 @@
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
from .data_sampler import TrainingSampler, InferenceSampler
from .imbalance_sampler import ImbalancedDatasetSampler
__all__ = [
"BalancedIdentitySampler",
"NaiveIdentitySampler",
"SetReWeightSampler",
"TrainingSampler",
"InferenceSampler"
"InferenceSampler",
"ImbalancedDatasetSampler",
]

View File

@ -0,0 +1,67 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# based on:
# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
import itertools
from typing import Optional, List, Callable
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
from fastreid.utils import comm
class ImbalancedDatasetSampler(Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
data_source: a list of data items
size: number of samples to draw
"""
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
callback_get_label: Callable = None):
self.data_source = data_source
# consider all elements in the dataset
self.indices = list(range(len(data_source)))
# if num_samples is not provided, draw `len(indices)` samples in each iteration
self._size = len(self.indices) if size is None else size
self.callback_get_label = callback_get_label
# distribution of classes in the dataset
label_to_count = {}
for idx in self.indices:
label = self._get_label(data_source, idx)
label_to_count[label] = label_to_count.get(label, 0) + 1
# weight for each sample
weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
self.weights = torch.DoubleTensor(weights)
if seed is None:
seed = comm.shared_random_seed()
self._seed = int(seed)
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
def _get_label(self, dataset, idx):
if self.callback_get_label:
return self.callback_get_label(dataset, idx)
else:
return dataset[idx][1]
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
np.random.seed(self._seed)
while True:
for i in torch.multinomial(self.weights, self._size, replacement=True):
yield self.indices[i]

View File

@ -7,7 +7,7 @@
import copy
import itertools
from collections import defaultdict
from typing import Optional
from typing import Optional, List
import numpy as np
from torch.utils.data.sampler import Sampler
@ -39,7 +39,7 @@ def reorder_index(batch_indices, world_size):
class BalancedIdentitySampler(Sampler):
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
def __init__(self, data_source: List, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
self.data_source = data_source
self.num_instances = num_instances
self.num_pids_per_batch = mini_batch_size // self.num_instances

View File

@ -149,7 +149,7 @@ class DefaultPredictor:
Returns:
predictions (torch.tensor): the output features of the model
"""
inputs = {"images": image}
inputs = {"images": image.to(self.model.device)}
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
predictions = self.model(inputs)
# Normalize feature to compute cosine distance

View File

@ -38,7 +38,6 @@ class ClasEvaluator(DatasetEvaluator):
def __init__(self, cfg, output_dir=None):
self.cfg = cfg
self._output_dir = output_dir
self._cpu_device = torch.device('cpu')
self._predictions = []
@ -49,7 +48,7 @@ class ClasEvaluator(DatasetEvaluator):
def process(self, inputs, outputs):
predictions = {
"logits": outputs.to(self._cpu_device, torch.float32),
"labels": inputs["targets"],
"labels": inputs["targets"].to(self._cpu_device),
}
self._predictions.append(predictions)

View File

@ -25,12 +25,12 @@ class ClasHead(EmbeddingHead):
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
# Evaluation
if not self.training: return logits * self.cls_layer.s
if not self.training: return logits.mul_(self.cls_layer.s)
cls_outputs = self.cls_layer(logits, targets)
return {
"cls_outputs": cls_outputs,
"pred_class_logits": logits * self.cls_layer.s,
"pred_class_logits": logits.mul_(self.cls_layer.s),
"features": neck_feat,
}

View File

@ -142,6 +142,6 @@ class EmbeddingHead(nn.Module):
return {
"cls_outputs": cls_outputs,
"pred_class_logits": logits * self.cls_layer.s,
"pred_class_logits": logits.mul(self.cls_layer.s),
"features": feat,
}