rm deprecated code
parent
ec44c2e2f6
commit
90ca844a9f
27
samplers.py
27
samplers.py
|
@ -10,33 +10,6 @@ import torch
|
|||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
""" Deprecated
|
||||
class RandomIdentitySampler(Sampler):
|
||||
def __init__(self, data_source, num_instances=4):
|
||||
self.data_source = data_source
|
||||
self.num_instances = num_instances
|
||||
self.index_dic = defaultdict(list)
|
||||
for index, (_, pid, _) in enumerate(data_source):
|
||||
self.index_dic[pid].append(index)
|
||||
self.pids = list(self.index_dic.keys())
|
||||
self.num_identities = len(self.pids)
|
||||
|
||||
def __iter__(self):
|
||||
indices = torch.randperm(self.num_identities)
|
||||
ret = []
|
||||
for i in indices:
|
||||
pid = self.pids[i]
|
||||
t = self.index_dic[pid]
|
||||
replace = False if len(t) >= self.num_instances else True
|
||||
t = np.random.choice(t, size=self.num_instances, replace=replace)
|
||||
ret.extend(t)
|
||||
return iter(ret)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_identities * self.num_instances
|
||||
"""
|
||||
|
||||
|
||||
class RandomIdentitySampler(Sampler):
|
||||
"""
|
||||
Randomly sample N identities, then for each identity,
|
||||
|
|
Loading…
Reference in New Issue