diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index cbbc0919f..11d1ac8e6 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -104,6 +104,8 @@ class PKSampler(DistributedBatchSampler): rank = dist.get_rank() np.random.RandomState(rank * self.total_epochs + self.epoch).shuffle(self.label_list) + np.random.RandomState(rank * self.total_epochs + + self.epoch).shuffle(self.prob_list) self.epoch += 1 label_per_batch = self.batch_size // self.sample_per_id