PaddleClas/ppcls/loss/dkdloss.py

69 lines
2.5 KiB
Python

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class DKDLoss(nn.Layer):
"""
DKDLoss
Reference: https://arxiv.org/abs/2203.08679
Code was heavily based on https://github.com/megvii-research/mdistiller
"""
def __init__(self,
temperature=1.0,
alpha=1.0,
beta=1.0,
use_target_as_gt=False):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta
self.use_target_as_gt = use_target_as_gt
def forward(self, logits_student, logits_teacher, target=None):
if target is None or self.use_target_as_gt:
target = logits_teacher.argmax(axis=-1)
gt_mask = _get_gt_mask(logits_student, target)
other_mask = 1 - gt_mask
pred_student = F.softmax(logits_student / self.temperature, axis=1)
pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1)
pred_student = cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
log_pred_student = paddle.log(pred_student)
tckd_loss = (F.kl_div(
log_pred_student, pred_teacher,
reduction='sum') * (self.temperature**2) / target.shape[0])
pred_teacher_part2 = F.softmax(
logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1)
log_pred_student_part2 = F.log_softmax(
logits_student / self.temperature - 1000.0 * gt_mask, axis=1)
nckd_loss = (F.kl_div(
log_pred_student_part2, pred_teacher_part2,
reduction='sum') * (self.temperature**2) / target.shape[0])
return self.alpha * tckd_loss + self.beta * nckd_loss
def _get_gt_mask(logits, target):
target = target.reshape([-1]).unsqueeze(1)
updates = paddle.ones_like(target)
mask = scatter(
paddle.zeros_like(logits), target, updates.astype('float32'))
return mask
def cat_mask(t, mask1, mask2):
t1 = (t * mask1).sum(axis=1, keepdim=True)
t2 = (t * mask2).sum(axis=1, keepdim=True)
rt = paddle.concat([t1, t2], axis=1)
return rt
def scatter(x, index, updates):
i, j = index.shape
grid_x, grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j))
index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1)
updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1)
updates = paddle.gather_nd(updates, index=updates_index)
return paddle.scatter_nd_add(x, index, updates)