mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
fix balanced sampler bug
Summary: fix typing error in balanced sampler with missing `len` in ret.
This commit is contained in:
parent
fd06a2819a
commit
a5550d7725
@ -83,7 +83,7 @@ class BalancedIdentitySampler(Sampler):
|
||||
else:
|
||||
select_indexes = no_index(index, i)
|
||||
if not select_indexes:
|
||||
# only one image for this identity
|
||||
# Only one image for this identity
|
||||
ind_indexes = [0] * (self.num_instances - 1)
|
||||
elif len(select_indexes) >= self.num_instances:
|
||||
ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
|
||||
@ -93,7 +93,7 @@ class BalancedIdentitySampler(Sampler):
|
||||
for kk in ind_indexes:
|
||||
ret.append(index[kk])
|
||||
|
||||
if ret == self.batch_size:
|
||||
if len(ret) == self.batch_size:
|
||||
yield from ret
|
||||
ret = []
|
||||
|
||||
@ -162,6 +162,7 @@ class NaiveIdentitySampler(Sampler):
|
||||
batch_indices.append(avai_idxs.pop(0))
|
||||
if len(avai_idxs) < self.num_instances: avai_pids.remove(pid)
|
||||
|
||||
assert len(batch_indices) == self.batch_size, "batch indices have wrong batch size"
|
||||
assert len(batch_indices) == self.batch_size, f"batch indices have wrong " \
|
||||
f"length with {len(batch_indices)}!"
|
||||
yield from batch_indices
|
||||
batch_indices = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user