rm deprecated code
parent
ec44c2e2f6
commit
90ca844a9f
27
samplers.py
27
samplers.py
|
@ -10,33 +10,6 @@ import torch
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
|
|
||||||
""" Deprecated
|
|
||||||
class RandomIdentitySampler(Sampler):
|
|
||||||
def __init__(self, data_source, num_instances=4):
|
|
||||||
self.data_source = data_source
|
|
||||||
self.num_instances = num_instances
|
|
||||||
self.index_dic = defaultdict(list)
|
|
||||||
for index, (_, pid, _) in enumerate(data_source):
|
|
||||||
self.index_dic[pid].append(index)
|
|
||||||
self.pids = list(self.index_dic.keys())
|
|
||||||
self.num_identities = len(self.pids)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
indices = torch.randperm(self.num_identities)
|
|
||||||
ret = []
|
|
||||||
for i in indices:
|
|
||||||
pid = self.pids[i]
|
|
||||||
t = self.index_dic[pid]
|
|
||||||
replace = False if len(t) >= self.num_instances else True
|
|
||||||
t = np.random.choice(t, size=self.num_instances, replace=replace)
|
|
||||||
ret.extend(t)
|
|
||||||
return iter(ret)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.num_identities * self.num_instances
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class RandomIdentitySampler(Sampler):
|
class RandomIdentitySampler(Sampler):
|
||||||
"""
|
"""
|
||||||
Randomly sample N identities, then for each identity,
|
Randomly sample N identities, then for each identity,
|
||||||
|
|
Loading…
Reference in New Issue