mirror of https://github.com/JDAI-CV/fast-reid.git
fix sampler problem #217
Fix problem about periodic long waiting time when number of images is large. In the old version, it will prepare the whole epoch indices when finishing one epoch. Now it changes to prepare the current batch indices.pull/228/head
parent
ae7c9288cf
commit
e0cf8ac56e
|
@ -48,47 +48,6 @@ class BalancedIdentitySampler(Sampler):
|
|||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def _get_epoch_indices(self):
|
||||
# Shuffle identity list
|
||||
identities = np.random.permutation(self.num_identities)
|
||||
|
||||
# If remaining identities cannot be enough for a batch,
|
||||
# just drop the remaining parts
|
||||
drop_indices = self.num_identities % self.num_pids_per_batch
|
||||
if drop_indices: identities = identities[:-drop_indices]
|
||||
|
||||
ret = []
|
||||
for kid in identities:
|
||||
i = np.random.choice(self.pid_index[self.pids[kid]])
|
||||
_, i_pid, i_cam = self.data_source[i]
|
||||
ret.append(i)
|
||||
pid_i = self.index_pid[i]
|
||||
cams = self.pid_cam[pid_i]
|
||||
index = self.pid_index[pid_i]
|
||||
select_cams = no_index(cams, i_cam)
|
||||
|
||||
if select_cams:
|
||||
if len(select_cams) >= self.num_instances:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
|
||||
for kk in cam_indexes:
|
||||
ret.append(index[kk])
|
||||
else:
|
||||
select_indexes = no_index(index, i)
|
||||
if not select_indexes:
|
||||
# only one image for this identity
|
||||
ind_indexes = [0] * (self.num_instances - 1)
|
||||
elif len(select_indexes) >= self.num_instances:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
|
||||
|
||||
for kk in ind_indexes:
|
||||
ret.append(index[kk])
|
||||
|
||||
return ret
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
@ -96,8 +55,47 @@ class BalancedIdentitySampler(Sampler):
|
|||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
indices = self._get_epoch_indices()
|
||||
yield from indices
|
||||
# Shuffle identity list
|
||||
identities = np.random.permutation(self.num_identities)
|
||||
|
||||
# If remaining identities cannot be enough for a batch,
|
||||
# just drop the remaining parts
|
||||
drop_indices = self.num_identities % self.num_pids_per_batch
|
||||
if drop_indices: identities = identities[:-drop_indices]
|
||||
|
||||
ret = []
|
||||
for kid in identities:
|
||||
i = np.random.choice(self.pid_index[self.pids[kid]])
|
||||
_, i_pid, i_cam = self.data_source[i]
|
||||
ret.append(i)
|
||||
pid_i = self.index_pid[i]
|
||||
cams = self.pid_cam[pid_i]
|
||||
index = self.pid_index[pid_i]
|
||||
select_cams = no_index(cams, i_cam)
|
||||
|
||||
if select_cams:
|
||||
if len(select_cams) >= self.num_instances:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
|
||||
for kk in cam_indexes:
|
||||
ret.append(index[kk])
|
||||
else:
|
||||
select_indexes = no_index(index, i)
|
||||
if not select_indexes:
|
||||
# only one image for this identity
|
||||
ind_indexes = [0] * (self.num_instances - 1)
|
||||
elif len(select_indexes) >= self.num_instances:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
else:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
|
||||
|
||||
for kk in ind_indexes:
|
||||
ret.append(index[kk])
|
||||
|
||||
if ret == self.batch_size:
|
||||
yield from ret
|
||||
ret = []
|
||||
|
||||
|
||||
class NaiveIdentitySampler(Sampler):
|
||||
|
@ -137,33 +135,6 @@ class NaiveIdentitySampler(Sampler):
|
|||
self._rank = comm.get_rank()
|
||||
self._world_size = comm.get_world_size()
|
||||
|
||||
def _get_epoch_indices(self):
|
||||
batch_idxs_dict = defaultdict(list)
|
||||
|
||||
for pid in self.pids:
|
||||
idxs = copy.deepcopy(self.pid_index[pid])
|
||||
if len(idxs) < self.num_instances:
|
||||
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
|
||||
np.random.shuffle(idxs)
|
||||
batch_idxs = []
|
||||
for idx in idxs:
|
||||
batch_idxs.append(idx)
|
||||
if len(batch_idxs) == self.num_instances:
|
||||
batch_idxs_dict[pid].append(batch_idxs)
|
||||
batch_idxs = []
|
||||
|
||||
avai_pids = copy.deepcopy(self.pids)
|
||||
final_idxs = []
|
||||
|
||||
while len(avai_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
|
||||
for pid in selected_pids:
|
||||
batch_idxs = batch_idxs_dict[pid].pop(0)
|
||||
final_idxs.extend(batch_idxs)
|
||||
if len(batch_idxs_dict[pid]) == 0: avai_pids.remove(pid)
|
||||
|
||||
return final_idxs
|
||||
|
||||
def __iter__(self):
|
||||
start = self._rank
|
||||
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
||||
|
@ -171,5 +142,26 @@ class NaiveIdentitySampler(Sampler):
|
|||
def _infinite_indices(self):
|
||||
np.random.seed(self._seed)
|
||||
while True:
|
||||
indices = self._get_epoch_indices()
|
||||
yield from indices
|
||||
avai_pids = copy.deepcopy(self.pids)
|
||||
batch_idxs_dict = {}
|
||||
|
||||
batch_indices = []
|
||||
while len(avai_pids) >= self.num_pids_per_batch:
|
||||
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
|
||||
for pid in selected_pids:
|
||||
# Register pid in batch_idxs_dict if not
|
||||
if pid not in batch_idxs_dict:
|
||||
idxs = copy.deepcopy(self.pid_index[pid])
|
||||
if len(idxs) < self.num_instances:
|
||||
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
|
||||
np.random.shuffle(idxs)
|
||||
batch_idxs_dict[pid] = idxs
|
||||
|
||||
avai_idxs = batch_idxs_dict[pid]
|
||||
for _ in range(self.num_instances):
|
||||
batch_indices.append(avai_idxs.pop(0))
|
||||
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
|
||||
|
||||
assert len(batch_indices) == self.batch_size, "batch indices have wrong batch size"
|
||||
yield from batch_indices
|
||||
batch_indices = []
|
||||
|
|
Loading…
Reference in New Issue