diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py new file mode 100644 index 00000000..c161c568 --- /dev/null +++ b/mmrazor/models/losses/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cwd import ChannelWiseDivergence +from .kl_divergence import KLDivergence +from .weighted_soft_label_distillation import WSLD + +__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD'] diff --git a/mmrazor/models/losses/cwd.py b/mmrazor/models/losses/cwd.py new file mode 100644 index 00000000..a71bdc3d --- /dev/null +++ b/mmrazor/models/losses/cwd.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES + + +@LOSSES.register_module() +class ChannelWiseDivergence(nn.Module): + """PyTorch version of `Channel-wise Distillation for Semantic Segmentation. + + `_. + + Args: + tau (float): Temperature coefficient. Defaults to 1.0. + weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__( + self, + tau=1.0, + loss_weight=1.0, + ): + super(ChannelWiseDivergence, self).__init__() + self.tau = tau + self.loss_weight = loss_weight + + def forward(self, preds_S, preds_T): + """Forward computation. + + Args: + preds_S (torch.Tensor): The student model prediction with + shape (N, C, W, H). + preds_T (torch.Tensor): The teacher model prediction with + shape (N, C, W, H). + + Return: + torch.Tensor: The calculated loss value. + """ + assert preds_S.shape[-2:] == preds_T.shape[-2:] + N, C, W, H = preds_S.shape + + softmax_pred_T = F.softmax(preds_T.view(-1, W * H) / self.tau, dim=1) + + logsoftmax = torch.nn.LogSoftmax(dim=1) + loss = torch.sum(softmax_pred_T * + logsoftmax(preds_T.view(-1, W * H) / self.tau) - + softmax_pred_T * + logsoftmax(preds_S.view(-1, W * H) / self.tau)) * ( + self.tau**2) + + loss = self.loss_weight * loss / (C * N) + + return loss diff --git a/mmrazor/models/losses/kl_divergence.py b/mmrazor/models/losses/kl_divergence.py new file mode 100644 index 00000000..26b8712e --- /dev/null +++ b/mmrazor/models/losses/kl_divergence.py @@ -0,0 +1,47 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES + + +@LOSSES.register_module() +class KLDivergence(nn.Module): + """A measure of how one probability distribution Q is different from a + second, reference probability distribution P. + + Args: + tau (float): Temperature coefficient. Defaults to 1.0. + loss_weight (float): Weight of loss. Defaults to 1.0. + """ + + def __init__( + self, + tau=1.0, + loss_weight=1.0, + ): + super(KLDivergence, self).__init__() + self.tau = tau + self.loss_weight = loss_weight + + def forward(self, preds_S, preds_T): + """Forward computation. + + Args: + preds_S (torch.Tensor): The student model prediction with + shape (N, C, W, H). + preds_T (torch.Tensor): The teacher model prediction with + shape (N, C, W, H). + + Return: + torch.Tensor: The calculated loss value. + """ + N, C = preds_S.shape + preds_T = preds_T.detach() + softmax_pred_T = F.softmax(preds_T / self.tau, dim=1) + + logsoftmax = torch.nn.LogSoftmax(dim=1) + loss = torch.sum(-softmax_pred_T * logsoftmax(preds_S / self.tau)) * ( + self.tau**2) + return self.loss_weight * loss / N diff --git a/mmrazor/models/losses/weighted_soft_label_distillation.py b/mmrazor/models/losses/weighted_soft_label_distillation.py new file mode 100644 index 00000000..2ddb0ebf --- /dev/null +++ b/mmrazor/models/losses/weighted_soft_label_distillation.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES + + +@LOSSES.register_module() +class WSLD(nn.Module): + + def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000): + """PyTorch version of `Rethinking Soft Labels for Knowledge + Distillation: A Bias-Variance Tradeoff Perspective + `_. + + Args: + tau (float): Temperature coefficient. Defaults to 1.0. + weight (float): Weight of loss. Defaults to 1.0. + num_classes (int): Defaults to 1000. + """ + super(WSLD, self).__init__() + + self.tau = tau + self.loss_weight = loss_weight + self.num_classes = num_classes + self.softmax = nn.Softmax(dim=1).cuda() + self.logsoftmax = nn.LogSoftmax(dim=1).cuda() + + def forward(self, student, teacher): + + gt_labels = self.current_data['gt_label'] + + student_logits = student / self.tau + teacher_logits = teacher / self.tau + + teacher_probs = self.softmax(teacher_logits) + + ce_loss = -torch.sum( + teacher_probs * self.logsoftmax(student_logits), 1, keepdim=True) + + student_detach = student.detach() + teacher_detach = teacher.detach() + log_softmax_s = self.logsoftmax(student_detach) + log_softmax_t = self.logsoftmax(teacher_detach) + one_hot_labels = F.one_hot( + gt_labels, num_classes=self.num_classes).float() + ce_loss_s = -torch.sum(one_hot_labels * log_softmax_s, 1, keepdim=True) + ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True) + + focal_weight = ce_loss_s / (ce_loss_t + 1e-7) + ratio_lower = torch.zeros(1).cuda() + focal_weight = torch.max(focal_weight, ratio_lower) + focal_weight = 1 - torch.exp(-focal_weight) + ce_loss = focal_weight * ce_loss + + loss = (self.tau**2) * torch.mean(ce_loss) + + loss = self.loss_weight * loss + + return loss