diff --git a/fastreid/data/build.py b/fastreid/data/build.py index 34ad31a..096551c 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -56,6 +56,8 @@ def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=N elif sampler_name == "SetReWeightSampler": set_weight = cfg.DATALOADER.SET_WEIGHT sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight) + elif sampler_name == "ImbalancedDatasetSampler": + sampler = samplers.ImbalancedDatasetSampler(train_set.img_items) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) diff --git a/fastreid/data/samplers/__init__.py b/fastreid/data/samplers/__init__.py index 0f5f268..37ffa15 100644 --- a/fastreid/data/samplers/__init__.py +++ b/fastreid/data/samplers/__init__.py @@ -6,11 +6,13 @@ from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler from .data_sampler import TrainingSampler, InferenceSampler +from .imbalance_sampler import ImbalancedDatasetSampler __all__ = [ "BalancedIdentitySampler", "NaiveIdentitySampler", "SetReWeightSampler", "TrainingSampler", - "InferenceSampler" + "InferenceSampler", + "ImbalancedDatasetSampler", ] diff --git a/fastreid/data/samplers/imbalance_sampler.py b/fastreid/data/samplers/imbalance_sampler.py new file mode 100644 index 0000000..90da772 --- /dev/null +++ b/fastreid/data/samplers/imbalance_sampler.py @@ -0,0 +1,67 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +# based on: +# https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py + + +import itertools +from typing import Optional, List, Callable + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + +from fastreid.utils import comm + + +class ImbalancedDatasetSampler(Sampler): + """Samples elements randomly from a given list of indices for imbalanced dataset + Arguments: + data_source: a list of data items + size: number of samples to draw + """ + + def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None, + callback_get_label: Callable = None): + self.data_source = data_source + # consider all elements in the dataset + self.indices = list(range(len(data_source))) + # if num_samples is not provided, draw `len(indices)` samples in each iteration + self._size = len(self.indices) if size is None else size + self.callback_get_label = callback_get_label + + # distribution of classes in the dataset + label_to_count = {} + for idx in self.indices: + label = self._get_label(data_source, idx) + label_to_count[label] = label_to_count.get(label, 0) + 1 + + # weight for each sample + weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices] + self.weights = torch.DoubleTensor(weights) + + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + def _get_label(self, dataset, idx): + if self.callback_get_label: + return self.callback_get_label(dataset, idx) + else: + return dataset[idx][1] + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + np.random.seed(self._seed) + while True: + for i in torch.multinomial(self.weights, self._size, replacement=True): + yield self.indices[i] diff --git a/fastreid/data/samplers/triplet_sampler.py b/fastreid/data/samplers/triplet_sampler.py index 1b5a0ad..72b17d0 100644 --- a/fastreid/data/samplers/triplet_sampler.py +++ b/fastreid/data/samplers/triplet_sampler.py @@ -7,7 +7,7 @@ import copy import itertools from collections import defaultdict -from typing import Optional +from typing import Optional, List import numpy as np from torch.utils.data.sampler import Sampler @@ -39,7 +39,7 @@ def reorder_index(batch_indices, world_size): class BalancedIdentitySampler(Sampler): - def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, seed: Optional[int] = None): + def __init__(self, data_source: List, mini_batch_size: int, num_instances: int, seed: Optional[int] = None): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py index e11e8b0..2fcf900 100644 --- a/fastreid/engine/defaults.py +++ b/fastreid/engine/defaults.py @@ -149,7 +149,7 @@ class DefaultPredictor: Returns: predictions (torch.tensor): the output features of the model """ - inputs = {"images": image} + inputs = {"images": image.to(self.model.device)} with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 predictions = self.model(inputs) # Normalize feature to compute cosine distance diff --git a/fastreid/evaluation/clas_evaluator.py b/fastreid/evaluation/clas_evaluator.py index d433131..f40cc8e 100644 --- a/fastreid/evaluation/clas_evaluator.py +++ b/fastreid/evaluation/clas_evaluator.py @@ -38,7 +38,6 @@ class ClasEvaluator(DatasetEvaluator): def __init__(self, cfg, output_dir=None): self.cfg = cfg self._output_dir = output_dir - self._cpu_device = torch.device('cpu') self._predictions = [] @@ -49,7 +48,7 @@ class ClasEvaluator(DatasetEvaluator): def process(self, inputs, outputs): predictions = { "logits": outputs.to(self._cpu_device, torch.float32), - "labels": inputs["targets"], + "labels": inputs["targets"].to(self._cpu_device), } self._predictions.append(predictions) diff --git a/fastreid/modeling/heads/clas_head.py b/fastreid/modeling/heads/clas_head.py index 9b6aa6b..9660a38 100644 --- a/fastreid/modeling/heads/clas_head.py +++ b/fastreid/modeling/heads/clas_head.py @@ -25,12 +25,12 @@ class ClasHead(EmbeddingHead): logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight)) # Evaluation - if not self.training: return logits * self.cls_layer.s + if not self.training: return logits.mul_(self.cls_layer.s) cls_outputs = self.cls_layer(logits, targets) return { "cls_outputs": cls_outputs, - "pred_class_logits": logits * self.cls_layer.s, + "pred_class_logits": logits.mul_(self.cls_layer.s), "features": neck_feat, } diff --git a/fastreid/modeling/heads/embedding_head.py b/fastreid/modeling/heads/embedding_head.py index 5fce900..1189b60 100644 --- a/fastreid/modeling/heads/embedding_head.py +++ b/fastreid/modeling/heads/embedding_head.py @@ -142,6 +142,6 @@ class EmbeddingHead(nn.Module): return { "cls_outputs": cls_outputs, - "pred_class_logits": logits * self.cls_layer.s, + "pred_class_logits": logits.mul(self.cls_layer.s), "features": feat, }