fix pksampler prob list shuffle bug

This commit is contained in:
zengshao0622 2022-12-22 03:38:38 +00:00 committed by zengshao0622
parent 5ae888cf0d
commit 7cdae10bcf

View File

@ -104,6 +104,8 @@ class PKSampler(DistributedBatchSampler):
rank = dist.get_rank()
np.random.RandomState(rank * self.total_epochs +
self.epoch).shuffle(self.label_list)
np.random.RandomState(rank * self.total_epochs +
self.epoch).shuffle(self.prob_list)
self.epoch += 1
label_per_batch = self.batch_size // self.sample_per_id