diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index 7f718a333..f02b75f4b 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -68,13 +68,14 @@ class PKSampler(DistributedBatchSampler): logger.error( "PKSampler only support id_avg_prob and sample_avg_prob sample method, " "but receive {}.".format(self.sample_method)) - if sum(np.abs(self.prob_list - 1) > 0.00000001): + diff = np.abs(sum(self.prob_list) - 1) + if diff > 0.00000001: self.prob_list[-1] = 1 - sum(self.prob_list[:-1]) if self.prob_list[-1] > 1 or self.prob_list[-1] < 0: logger.error("PKSampler prob list error") else: logger.info( - "PKSampler: sum of prob list not equal to 1, change the last prob" + "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff) ) def __iter__(self):