fix kldiv when stop grad is trur (#5643)
parent
db60893201
commit
5b33340647
|
@ -95,9 +95,15 @@ class DMLLoss(nn.Layer):
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
self.use_log = use_log
|
self.use_log = use_log
|
||||||
|
|
||||||
self.jskl_loss = KLJSLoss(mode="js")
|
self.jskl_loss = KLJSLoss(mode="js")
|
||||||
|
|
||||||
|
def _kldiv(self, x, target):
|
||||||
|
eps = 1.0e-10
|
||||||
|
loss = target * (paddle.log(target + eps) - x)
|
||||||
|
# batch mean loss
|
||||||
|
loss = paddle.sum(loss) / loss.shape[0]
|
||||||
|
return loss
|
||||||
|
|
||||||
def forward(self, out1, out2):
|
def forward(self, out1, out2):
|
||||||
if self.act is not None:
|
if self.act is not None:
|
||||||
out1 = self.act(out1)
|
out1 = self.act(out1)
|
||||||
|
@ -106,9 +112,8 @@ class DMLLoss(nn.Layer):
|
||||||
# for recognition distillation, log is needed for feature map
|
# for recognition distillation, log is needed for feature map
|
||||||
log_out1 = paddle.log(out1)
|
log_out1 = paddle.log(out1)
|
||||||
log_out2 = paddle.log(out2)
|
log_out2 = paddle.log(out2)
|
||||||
loss = (F.kl_div(
|
loss = (
|
||||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
|
||||||
log_out2, out1, reduction='batchmean')) / 2.0
|
|
||||||
else:
|
else:
|
||||||
# for detection distillation log is not needed
|
# for detection distillation log is not needed
|
||||||
loss = self.jskl_loss(out1, out2)
|
loss = self.jskl_loss(out1, out2)
|
||||||
|
|
Loading…
Reference in New Issue