fix balanced sampler bug

Summary: fix typing error in balanced sampler with missing `len` in ret.
This commit is contained in:
liaoxingyu 2020-08-19 16:28:10 +08:00
parent fd06a2819a
commit a5550d7725

View File

@ -83,7 +83,7 @@ class BalancedIdentitySampler(Sampler):
else:
select_indexes = no_index(index, i)
if not select_indexes:
# only one image for this identity
# Only one image for this identity
ind_indexes = [0] * (self.num_instances - 1)
elif len(select_indexes) >= self.num_instances:
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
@ -93,7 +93,7 @@ class BalancedIdentitySampler(Sampler):
for kk in ind_indexes:
ret.append(index[kk])
if ret == self.batch_size:
if len(ret) == self.batch_size:
yield from ret
ret = []
@ -162,6 +162,7 @@ class NaiveIdentitySampler(Sampler):
batch_indices.append(avai_idxs.pop(0))
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
assert len(batch_indices) == self.batch_size, "batch indices have wrong batch size"
assert len(batch_indices) == self.batch_size, f"batch indices have wrong " \
f"length with {len(batch_indices)}!"
yield from batch_indices
batch_indices = []