fix sampler problem #217

Fix problem about periodic long waiting time when number of images is large.

In the old version, it will prepare the whole epoch indices when finishing one epoch.
Now it changes to prepare the current batch indices.
pull/228/head
liaoxingyu 2020-08-14 13:57:22 +08:00
parent ae7c9288cf
commit e0cf8ac56e
1 changed files with 64 additions and 72 deletions

View File

@ -48,47 +48,6 @@ class BalancedIdentitySampler(Sampler):
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
def _get_epoch_indices(self):
# Shuffle identity list
identities = np.random.permutation(self.num_identities)
# If remaining identities cannot be enough for a batch,
# just drop the remaining parts
drop_indices = self.num_identities % self.num_pids_per_batch
if drop_indices: identities = identities[:-drop_indices]
ret = []
for kid in identities:
i = np.random.choice(self.pid_index[self.pids[kid]])
_, i_pid, i_cam = self.data_source[i]
ret.append(i)
pid_i = self.index_pid[i]
cams = self.pid_cam[pid_i]
index = self.pid_index[pid_i]
select_cams = no_index(cams, i_cam)
if select_cams:
if len(select_cams) >= self.num_instances:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
else:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
for kk in cam_indexes:
ret.append(index[kk])
else:
select_indexes = no_index(index, i)
if not select_indexes:
# only one image for this identity
ind_indexes = [0] * (self.num_instances - 1)
elif len(select_indexes) >= self.num_instances:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
else:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
for kk in ind_indexes:
ret.append(index[kk])
return ret
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
@ -96,8 +55,47 @@ class BalancedIdentitySampler(Sampler):
def _infinite_indices(self):
np.random.seed(self._seed)
while True:
indices = self._get_epoch_indices()
yield from indices
# Shuffle identity list
identities = np.random.permutation(self.num_identities)
# If remaining identities cannot be enough for a batch,
# just drop the remaining parts
drop_indices = self.num_identities % self.num_pids_per_batch
if drop_indices: identities = identities[:-drop_indices]
ret = []
for kid in identities:
i = np.random.choice(self.pid_index[self.pids[kid]])
_, i_pid, i_cam = self.data_source[i]
ret.append(i)
pid_i = self.index_pid[i]
cams = self.pid_cam[pid_i]
index = self.pid_index[pid_i]
select_cams = no_index(cams, i_cam)
if select_cams:
if len(select_cams) >= self.num_instances:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=False)
else:
cam_indexes = np.random.choice(select_cams, size=self.num_instances - 1, replace=True)
for kk in cam_indexes:
ret.append(index[kk])
else:
select_indexes = no_index(index, i)
if not select_indexes:
# only one image for this identity
ind_indexes = [0] * (self.num_instances - 1)
elif len(select_indexes) >= self.num_instances:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
else:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=True)
for kk in ind_indexes:
ret.append(index[kk])
if ret == self.batch_size:
yield from ret
ret = []
class NaiveIdentitySampler(Sampler):
@ -137,33 +135,6 @@ class NaiveIdentitySampler(Sampler):
self._rank = comm.get_rank()
self._world_size = comm.get_world_size()
def _get_epoch_indices(self):
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(self.pid_index[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
np.random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0: avai_pids.remove(pid)
return final_idxs
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
@ -171,5 +142,26 @@ class NaiveIdentitySampler(Sampler):
def _infinite_indices(self):
np.random.seed(self._seed)
while True:
indices = self._get_epoch_indices()
yield from indices
avai_pids = copy.deepcopy(self.pids)
batch_idxs_dict = {}
batch_indices = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
for pid in selected_pids:
# Register pid in batch_idxs_dict if not
if pid not in batch_idxs_dict:
idxs = copy.deepcopy(self.pid_index[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist()
np.random.shuffle(idxs)
batch_idxs_dict[pid] = idxs
avai_idxs = batch_idxs_dict[pid]
for _ in range(self.num_instances):
batch_indices.append(avai_idxs.pop(0))
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
assert len(batch_indices) == self.batch_size, "batch indices have wrong batch size"
yield from batch_indices
batch_indices = []