fix kljs loss
parent
c1eaa17a95
commit
a479ca67a4
|
@ -65,7 +65,7 @@ Loss:
|
|||
- ["Student", "Teacher"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Teacher"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
|
|
|
@ -60,7 +60,7 @@ Loss:
|
|||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
|
|
|
@ -57,17 +57,27 @@ class CELoss(nn.Layer):
|
|||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'
|
||||
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
if self.mode.lower() == 'kl':
|
||||
loss = paddle.multiply(p2,
|
||||
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
elif self.mode.lower() == "js":
|
||||
loss = paddle.multiply(
|
||||
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
else:
|
||||
raise ValueError(
|
||||
"The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
|
||||
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1, 2])
|
||||
elif reduction == "none" or reduction is None:
|
||||
|
@ -95,7 +105,7 @@ class DMLLoss(nn.Layer):
|
|||
self.act = None
|
||||
|
||||
self.use_log = use_log
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
self.jskl_loss = KLJSLoss(mode="kl")
|
||||
|
||||
def _kldiv(self, x, target):
|
||||
eps = 1.0e-10
|
||||
|
|
Loading…
Reference in New Issue