mirror of https://github.com/YifanXu74/MQ-Det.git
132 lines
5.4 KiB
Python
132 lines
5.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
import itertools
|
|
|
|
import torch
|
|
from torch.utils.data.sampler import BatchSampler
|
|
from torch.utils.data.sampler import Sampler
|
|
|
|
|
|
class GroupedBatchSampler(BatchSampler):
|
|
"""
|
|
Wraps another sampler to yield a mini-batch of indices.
|
|
It enforces that elements from the same group should appear in groups of batch_size.
|
|
It also tries to provide mini-batches which follows an ordering which is
|
|
as close as possible to the ordering from the original sampler.
|
|
|
|
Arguments:
|
|
sampler (Sampler): Base sampler.
|
|
batch_size (int): Size of mini-batch.
|
|
drop_uneven (bool): If ``True``, the sampler will drop the batches whose
|
|
size is less than ``batch_size``
|
|
|
|
"""
|
|
|
|
def __init__(self, sampler, group_ids, batch_size, drop_uneven=False):
|
|
if not isinstance(sampler, Sampler):
|
|
raise ValueError(
|
|
"sampler should be an instance of "
|
|
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
|
|
)
|
|
self.sampler = sampler
|
|
self.group_ids = torch.as_tensor(group_ids)
|
|
assert self.group_ids.dim() == 1
|
|
self.batch_size = batch_size
|
|
self.drop_uneven = drop_uneven
|
|
|
|
self.groups = torch.unique(self.group_ids).sort(0)[0]
|
|
|
|
self._can_reuse_batches = False
|
|
|
|
def _prepare_batches(self):
|
|
dataset_size = len(self.group_ids)
|
|
# get the sampled indices from the sampler
|
|
sampled_ids = torch.as_tensor(list(self.sampler))
|
|
# potentially not all elements of the dataset were sampled
|
|
# by the sampler (e.g., DistributedSampler).
|
|
# construct a tensor which contains -1 if the element was
|
|
# not sampled, and a non-negative number indicating the
|
|
# order where the element was sampled.
|
|
# for example. if sampled_ids = [3, 1] and dataset_size = 5,
|
|
# the order is [-1, 1, -1, 0, -1]
|
|
order = torch.full((dataset_size,), -1, dtype=torch.int64)
|
|
order[sampled_ids] = torch.arange(len(sampled_ids))
|
|
|
|
# get a mask with the elements that were sampled
|
|
mask = order >= 0
|
|
|
|
# find the elements that belong to each individual cluster
|
|
clusters = [(self.group_ids == i) & mask for i in self.groups]
|
|
# get relative order of the elements inside each cluster
|
|
# that follows the order from the sampler
|
|
relative_order = [order[cluster] for cluster in clusters]
|
|
# with the relative order, find the absolute order in the
|
|
# sampled space
|
|
permutation_ids = [s[s.sort()[1]] for s in relative_order]
|
|
# permute each cluster so that they follow the order from
|
|
# the sampler
|
|
permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]
|
|
|
|
# pop out empty elements
|
|
new_permuted_clusters=[]
|
|
for c in permuted_clusters:
|
|
if len(c)!=0:
|
|
new_permuted_clusters.append(c)
|
|
permuted_clusters=new_permuted_clusters
|
|
|
|
# splits each cluster in batch_size, and merge as a list of tensors
|
|
splits = [c.split(self.batch_size) for c in permuted_clusters]
|
|
merged = tuple(itertools.chain.from_iterable(splits))
|
|
|
|
# now each batch internally has the right order, but
|
|
# they are grouped by clusters. Find the permutation between
|
|
# different batches that brings them as close as possible to
|
|
# the order that we have in the sampler. For that, we will consider the
|
|
# ordering as coming from the first element of each batch, and sort
|
|
# correspondingly
|
|
|
|
# print('self.groups',self.groups)
|
|
# print('self.group_ids',self.group_ids)
|
|
# print('self.batch_size',self.batch_size)
|
|
# print('dataset_size,',dataset_size)
|
|
# print('sampled_ids', sampled_ids)
|
|
# print('order', order)
|
|
# print('clusters', clusters)
|
|
|
|
first_element_of_batch = [t[0].item() for t in merged]
|
|
# get and inverse mapping from sampled indices and the position where
|
|
# they occur (as returned by the sampler)
|
|
inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
|
|
# from the first element in each batch, get a relative ordering
|
|
first_index_of_batch = torch.as_tensor(
|
|
[inv_sampled_ids_map[s] for s in first_element_of_batch]
|
|
)
|
|
|
|
# permute the batches so that they approximately follow the order
|
|
# from the sampler
|
|
permutation_order = first_index_of_batch.sort(0)[1].tolist()
|
|
# finally, permute the batches
|
|
batches = [merged[i].tolist() for i in permutation_order]
|
|
|
|
if self.drop_uneven:
|
|
kept = []
|
|
for batch in batches:
|
|
if len(batch) == self.batch_size:
|
|
kept.append(batch)
|
|
batches = kept
|
|
return batches
|
|
|
|
def __iter__(self):
|
|
if self._can_reuse_batches:
|
|
batches = self._batches
|
|
self._can_reuse_batches = False
|
|
else:
|
|
batches = self._prepare_batches()
|
|
self._batches = batches
|
|
return iter(batches)
|
|
|
|
def __len__(self):
|
|
if not hasattr(self, "_batches"):
|
|
self._batches = self._prepare_batches()
|
|
self._can_reuse_batches = True
|
|
return len(self._batches)
|