mirror of https://github.com/JDAI-CV/fast-reid.git
update imbalanced sampler
Summary: add a new sampler, which is useful for imbalanced or long-tail dataset. This refers to ufoym/imbalanced-dataset-sampler.pull/470/head
parent
bb6ddbf8b1
commit
0c8e3d9805
fastreid
engine
evaluation
modeling/heads
|
@ -56,6 +56,8 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
|
|||
elif sampler_name == "SetReWeightSampler":
|
||||
set_weight = cfg.DATALOADER.SET_WEIGHT
|
||||
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
|
||||
elif sampler_name == "ImbalancedDatasetSampler":
|
||||
sampler = samplers.ImbalancedDatasetSampler(train_set.img_items)
|
||||
else:
|
||||
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
||||
|
||||
|
|
|
@ -6,11 +6,13 @@
|
|||
|
||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
|
||||
from .data_sampler import TrainingSampler, InferenceSampler
|
||||
from .imbalance_sampler import ImbalancedDatasetSampler
|
||||
|
||||
__all__ = [
|
||||
"BalancedIdentitySampler",
|
||||
"NaiveIdentitySampler",
|
||||
"SetReWeightSampler",
|
||||
"TrainingSampler",
|
||||
"InferenceSampler"
|
||||
"InferenceSampler",
|
||||
"ImbalancedDatasetSampler",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
# based on:
|
||||
# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
|
||||
|
||||
|
||||
import itertools
|
||||
from typing import Optional, List, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from fastreid.utils import comm
|
||||
|
||||
|
||||
class ImbalancedDatasetSampler(Sampler):
|
||||
"""Samples elements randomly from a given list of indices for imbalanced dataset
|
||||
Arguments:
|
||||
data_source: a list of data items
|
||||
size: number of samples to draw
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
|
||||
callback_get_label: Callable = None):
|
||||
self.data_source = data_source
|
||||
# consider all elements in the dataset
|
||||
self.indices = list(range(len(data_source)))
|
||||
# if num_samples is not provided, draw `len(indices)` samples in each iteration
|
||||
self._size = len(self.indices) if size is None else size
|
||||
self.callback_get_label = callback_get_label
|
||||
|
||||
# distribution of classes in the dataset
|
||||
label_to_count = {}
|
||||
for idx in self.indices:
|
||||
label = self._get_label(data_source, idx)
|
||||
label_to_count[label] = label_to_count.get(label, 0) + 1
|
||||
|
||||
# weight for each sample
|
||||
weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
|
||||
self.weights = torch.DoubleTensor(weights)
|
||||
|
||||
if seed is None:
|
||||
seed = comm.shared_random_seed()
|
||||
self._seed = int(seed)
|
||||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def _get_label(self, dataset, idx):
|
||||
if self.callback_get_label:
|
||||
return self.callback_get_label(dataset, idx)
|
||||
else:
|
||||
return dataset[idx][1]
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
||||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
for i in torch.multinomial(self.weights, self._size, replacement=True):
|
||||
yield self.indices[i]
|
|
@ -7,7 +7,7 @@
|
|||
import copy
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
@ -39,7 +39,7 @@ def reorder_index(batch_indices, world_size):
|
|||
|
||||
|
||||
class BalancedIdentitySampler(Sampler):
|
||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
def __init__(self, data_source: List, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||
|
|
|
@ -149,7 +149,7 @@ class DefaultPredictor:
|
|||
Returns:
|
||||
predictions (torch.tensor): the output features of the model
|
||||
"""
|
||||
inputs = {"images": image}
|
||||
inputs = {"images": image.to(self.model.device)}
|
||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||
predictions = self.model(inputs)
|
||||
# Normalize feature to compute cosine distance
|
||||
|
|
|
@ -38,7 +38,6 @@ class ClasEvaluator(DatasetEvaluator):
|
|||
def __init__(self, cfg, output_dir=None):
|
||||
self.cfg = cfg
|
||||
self._output_dir = output_dir
|
||||
|
||||
self._cpu_device = torch.device('cpu')
|
||||
|
||||
self._predictions = []
|
||||
|
@ -49,7 +48,7 @@ class ClasEvaluator(DatasetEvaluator):
|
|||
def process(self, inputs, outputs):
|
||||
predictions = {
|
||||
"logits": outputs.to(self._cpu_device, torch.float32),
|
||||
"labels": inputs["targets"],
|
||||
"labels": inputs["targets"].to(self._cpu_device),
|
||||
}
|
||||
self._predictions.append(predictions)
|
||||
|
||||
|
|
|
@ -25,12 +25,12 @@ class ClasHead(EmbeddingHead):
|
|||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||
|
||||
# Evaluation
|
||||
if not self.training: return logits * self.cls_layer.s
|
||||
if not self.training: return logits.mul_(self.cls_layer.s)
|
||||
|
||||
cls_outputs = self.cls_layer(logits, targets)
|
||||
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"pred_class_logits": logits * self.cls_layer.s,
|
||||
"pred_class_logits": logits.mul_(self.cls_layer.s),
|
||||
"features": neck_feat,
|
||||
}
|
||||
|
|
|
@ -142,6 +142,6 @@ class EmbeddingHead(nn.Module):
|
|||
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"pred_class_logits": logits * self.cls_layer.s,
|
||||
"pred_class_logits": logits.mul(self.cls_layer.s),
|
||||
"features": feat,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue