mirror of https://github.com/JDAI-CV/fast-reid.git
update imbalanced sampler
Summary: add a new sampler, which is useful for imbalanced or long-tail dataset. This refers to ufoym/imbalanced-dataset-sampler.pull/470/head
parent
bb6ddbf8b1
commit
0c8e3d9805
|
@ -56,6 +56,8 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N
|
||||||
elif sampler_name == "SetReWeightSampler":
|
elif sampler_name == "SetReWeightSampler":
|
||||||
set_weight = cfg.DATALOADER.SET_WEIGHT
|
set_weight = cfg.DATALOADER.SET_WEIGHT
|
||||||
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
|
sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight)
|
||||||
|
elif sampler_name == "ImbalancedDatasetSampler":
|
||||||
|
sampler = samplers.ImbalancedDatasetSampler(train_set.img_items)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
raise ValueError("Unknown training sampler: {}".format(sampler_name))
|
||||||
|
|
||||||
|
|
|
@ -6,11 +6,13 @@
|
||||||
|
|
||||||
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
|
from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler
|
||||||
from .data_sampler import TrainingSampler, InferenceSampler
|
from .data_sampler import TrainingSampler, InferenceSampler
|
||||||
|
from .imbalance_sampler import ImbalancedDatasetSampler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BalancedIdentitySampler",
|
"BalancedIdentitySampler",
|
||||||
"NaiveIdentitySampler",
|
"NaiveIdentitySampler",
|
||||||
"SetReWeightSampler",
|
"SetReWeightSampler",
|
||||||
"TrainingSampler",
|
"TrainingSampler",
|
||||||
"InferenceSampler"
|
"InferenceSampler",
|
||||||
|
"ImbalancedDatasetSampler",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# encoding: utf-8
|
||||||
|
"""
|
||||||
|
@author: xingyu liao
|
||||||
|
@contact: sherlockliao01@gmail.com
|
||||||
|
"""
|
||||||
|
|
||||||
|
# based on:
|
||||||
|
# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py
|
||||||
|
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from typing import Optional, List, Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
|
from fastreid.utils import comm
|
||||||
|
|
||||||
|
|
||||||
|
class ImbalancedDatasetSampler(Sampler):
|
||||||
|
"""Samples elements randomly from a given list of indices for imbalanced dataset
|
||||||
|
Arguments:
|
||||||
|
data_source: a list of data items
|
||||||
|
size: number of samples to draw
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None,
|
||||||
|
callback_get_label: Callable = None):
|
||||||
|
self.data_source = data_source
|
||||||
|
# consider all elements in the dataset
|
||||||
|
self.indices = list(range(len(data_source)))
|
||||||
|
# if num_samples is not provided, draw `len(indices)` samples in each iteration
|
||||||
|
self._size = len(self.indices) if size is None else size
|
||||||
|
self.callback_get_label = callback_get_label
|
||||||
|
|
||||||
|
# distribution of classes in the dataset
|
||||||
|
label_to_count = {}
|
||||||
|
for idx in self.indices:
|
||||||
|
label = self._get_label(data_source, idx)
|
||||||
|
label_to_count[label] = label_to_count.get(label, 0) + 1
|
||||||
|
|
||||||
|
# weight for each sample
|
||||||
|
weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices]
|
||||||
|
self.weights = torch.DoubleTensor(weights)
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = comm.shared_random_seed()
|
||||||
|
self._seed = int(seed)
|
||||||
|
self._rank = comm.get_rank()
|
||||||
|
self._world_size = comm.get_world_size()
|
||||||
|
|
||||||
|
def _get_label(self, dataset, idx):
|
||||||
|
if self.callback_get_label:
|
||||||
|
return self.callback_get_label(dataset, idx)
|
||||||
|
else:
|
||||||
|
return dataset[idx][1]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
start = self._rank
|
||||||
|
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||||
|
|
||||||
|
def _infinite_indices(self):
|
||||||
|
np.random.seed(self._seed)
|
||||||
|
while True:
|
||||||
|
for i in torch.multinomial(self.weights, self._size, replacement=True):
|
||||||
|
yield self.indices[i]
|
|
@ -7,7 +7,7 @@
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
@ -39,7 +39,7 @@ def reorder_index(batch_indices, world_size):
|
||||||
|
|
||||||
|
|
||||||
class BalancedIdentitySampler(Sampler):
|
class BalancedIdentitySampler(Sampler):
|
||||||
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
def __init__(self, data_source: List, mini_batch_size: int, num_instances: int, seed: Optional[int] = None):
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
self.num_instances = num_instances
|
self.num_instances = num_instances
|
||||||
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
self.num_pids_per_batch = mini_batch_size // self.num_instances
|
||||||
|
|
|
@ -149,7 +149,7 @@ class DefaultPredictor:
|
||||||
Returns:
|
Returns:
|
||||||
predictions (torch.tensor): the output features of the model
|
predictions (torch.tensor): the output features of the model
|
||||||
"""
|
"""
|
||||||
inputs = {"images": image}
|
inputs = {"images": image.to(self.model.device)}
|
||||||
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
|
||||||
predictions = self.model(inputs)
|
predictions = self.model(inputs)
|
||||||
# Normalize feature to compute cosine distance
|
# Normalize feature to compute cosine distance
|
||||||
|
|
|
@ -38,7 +38,6 @@ class ClasEvaluator(DatasetEvaluator):
|
||||||
def __init__(self, cfg, output_dir=None):
|
def __init__(self, cfg, output_dir=None):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._output_dir = output_dir
|
self._output_dir = output_dir
|
||||||
|
|
||||||
self._cpu_device = torch.device('cpu')
|
self._cpu_device = torch.device('cpu')
|
||||||
|
|
||||||
self._predictions = []
|
self._predictions = []
|
||||||
|
@ -49,7 +48,7 @@ class ClasEvaluator(DatasetEvaluator):
|
||||||
def process(self, inputs, outputs):
|
def process(self, inputs, outputs):
|
||||||
predictions = {
|
predictions = {
|
||||||
"logits": outputs.to(self._cpu_device, torch.float32),
|
"logits": outputs.to(self._cpu_device, torch.float32),
|
||||||
"labels": inputs["targets"],
|
"labels": inputs["targets"].to(self._cpu_device),
|
||||||
}
|
}
|
||||||
self._predictions.append(predictions)
|
self._predictions.append(predictions)
|
||||||
|
|
||||||
|
|
|
@ -25,12 +25,12 @@ class ClasHead(EmbeddingHead):
|
||||||
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight))
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
if not self.training: return logits * self.cls_layer.s
|
if not self.training: return logits.mul_(self.cls_layer.s)
|
||||||
|
|
||||||
cls_outputs = self.cls_layer(logits, targets)
|
cls_outputs = self.cls_layer(logits, targets)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cls_outputs": cls_outputs,
|
"cls_outputs": cls_outputs,
|
||||||
"pred_class_logits": logits * self.cls_layer.s,
|
"pred_class_logits": logits.mul_(self.cls_layer.s),
|
||||||
"features": neck_feat,
|
"features": neck_feat,
|
||||||
}
|
}
|
||||||
|
|
|
@ -142,6 +142,6 @@ class EmbeddingHead(nn.Module):
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cls_outputs": cls_outputs,
|
"cls_outputs": cls_outputs,
|
||||||
"pred_class_logits": logits * self.cls_layer.s,
|
"pred_class_logits": logits.mul(self.cls_layer.s),
|
||||||
"features": feat,
|
"features": feat,
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue