37 lines
1.3 KiB
Python
37 lines
1.3 KiB
Python
from __future__ import absolute_import
|
|
|
|
import torch
|
|
|
|
class RandomIdentitySampler(object):
|
|
"""
|
|
Randomly sample N identities, then for each identity,
|
|
randomly sample K instances, therefore batch size is N*K.
|
|
|
|
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py.
|
|
|
|
Args:
|
|
data_source (Dataset): dataset to sample from.
|
|
num_instances (int): number of instances per identity.
|
|
"""
|
|
def __init__(self, data_source, num_instances=4):
|
|
self.data_source = data_source
|
|
self.num_instances = num_instances
|
|
self.index_dic = defaultdict(list)
|
|
for index, (_, pid, _) in enumerate(data_source):
|
|
self.index_dic[pid].append(index)
|
|
self.pids = list(self.index_dic.keys())
|
|
self.num_identities = len(self.pids)
|
|
|
|
def __iter__(self):
|
|
indices = torch.randperm(self.num_identities)
|
|
ret = []
|
|
for i in indices:
|
|
pid = self.pids[i]
|
|
t = self.index_dic[pid]
|
|
replace = False if len(t) >= self.num_instances else True
|
|
t = np.random.choice(t, size=self.num_instances, replace=replace)
|
|
ret.extend(t)
|
|
return iter(ret)
|
|
|
|
def __len__(self):
|
|
return self.num_identities * self.num_instances |