dbg
parent
7c6e76e50d
commit
ee1bc18f3a
|
@ -68,13 +68,14 @@ class PKSampler(DistributedBatchSampler):
|
||||||
logger.error(
|
logger.error(
|
||||||
"PKSampler only support id_avg_prob and sample_avg_prob sample method, "
|
"PKSampler only support id_avg_prob and sample_avg_prob sample method, "
|
||||||
"but receive {}.".format(self.sample_method))
|
"but receive {}.".format(self.sample_method))
|
||||||
if sum(np.abs(self.prob_list - 1) > 0.00000001):
|
diff = np.abs(sum(self.prob_list) - 1)
|
||||||
|
if diff > 0.00000001:
|
||||||
self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
|
self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
|
||||||
if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
|
if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
|
||||||
logger.error("PKSampler prob list error")
|
logger.error("PKSampler prob list error")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"PKSampler: sum of prob list not equal to 1, change the last prob"
|
"PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff)
|
||||||
)
|
)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
|
Loading…
Reference in New Issue