commit
85d7d50ebe
|
@ -65,7 +65,7 @@ Loss:
|
||||||
- ["Student", "Teacher"]
|
- ["Student", "Teacher"]
|
||||||
maps_name: "thrink_maps"
|
maps_name: "thrink_maps"
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
act: "softmax"
|
# act: None
|
||||||
model_name_pairs: ["Student", "Teacher"]
|
model_name_pairs: ["Student", "Teacher"]
|
||||||
key: maps
|
key: maps
|
||||||
- DistillationDBLoss:
|
- DistillationDBLoss:
|
||||||
|
|
|
@ -60,7 +60,7 @@ Loss:
|
||||||
- ["Student", "Student2"]
|
- ["Student", "Student2"]
|
||||||
maps_name: "thrink_maps"
|
maps_name: "thrink_maps"
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
act: "softmax"
|
# act: None
|
||||||
model_name_pairs: ["Student", "Student2"]
|
model_name_pairs: ["Student", "Student2"]
|
||||||
key: maps
|
key: maps
|
||||||
- DistillationDBLoss:
|
- DistillationDBLoss:
|
||||||
|
|
|
@ -57,17 +57,27 @@ class CELoss(nn.Layer):
|
||||||
class KLJSLoss(object):
|
class KLJSLoss(object):
|
||||||
def __init__(self, mode='kl'):
|
def __init__(self, mode='kl'):
|
||||||
assert mode in ['kl', 'js', 'KL', 'JS'
|
assert mode in ['kl', 'js', 'KL', 'JS'
|
||||||
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def __call__(self, p1, p2, reduction="mean"):
|
def __call__(self, p1, p2, reduction="mean"):
|
||||||
|
|
||||||
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
if self.mode.lower() == 'kl':
|
||||||
|
loss = paddle.multiply(p2,
|
||||||
if self.mode.lower() == "js":
|
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||||
loss += paddle.multiply(
|
loss += paddle.multiply(
|
||||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||||
loss *= 0.5
|
loss *= 0.5
|
||||||
|
elif self.mode.lower() == "js":
|
||||||
|
loss = paddle.multiply(
|
||||||
|
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
||||||
|
loss += paddle.multiply(
|
||||||
|
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
|
||||||
|
loss *= 0.5
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
|
||||||
|
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
loss = paddle.mean(loss, axis=[1, 2])
|
loss = paddle.mean(loss, axis=[1, 2])
|
||||||
elif reduction == "none" or reduction is None:
|
elif reduction == "none" or reduction is None:
|
||||||
|
@ -95,7 +105,7 @@ class DMLLoss(nn.Layer):
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
self.use_log = use_log
|
self.use_log = use_log
|
||||||
self.jskl_loss = KLJSLoss(mode="js")
|
self.jskl_loss = KLJSLoss(mode="kl")
|
||||||
|
|
||||||
def _kldiv(self, x, target):
|
def _kldiv(self, x, target):
|
||||||
eps = 1.0e-10
|
eps = 1.0e-10
|
||||||
|
|
Loading…
Reference in New Issue