[Feature] Add loss

pull/255/head
humu789 2021-12-23 04:00:59 +08:00
parent cb5cb6da05
commit 44196ae73b
4 changed files with 169 additions and 0 deletions

View File

@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cwd import ChannelWiseDivergence
from .kl_divergence import KLDivergence
from .weighted_soft_label_distillation import WSLD
__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD']

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
@LOSSES.register_module()
class ChannelWiseDivergence(nn.Module):
"""PyTorch version of `Channel-wise Distillation for Semantic Segmentation.
<https://arxiv.org/abs/2011.13256>`_.
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(
self,
tau=1.0,
loss_weight=1.0,
):
super(ChannelWiseDivergence, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
def forward(self, preds_S, preds_T):
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, W, H).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, W, H).
Return:
torch.Tensor: The calculated loss value.
"""
assert preds_S.shape[-2:] == preds_T.shape[-2:]
N, C, W, H = preds_S.shape
softmax_pred_T = F.softmax(preds_T.view(-1, W * H) / self.tau, dim=1)
logsoftmax = torch.nn.LogSoftmax(dim=1)
loss = torch.sum(softmax_pred_T *
logsoftmax(preds_T.view(-1, W * H) / self.tau) -
softmax_pred_T *
logsoftmax(preds_S.view(-1, W * H) / self.tau)) * (
self.tau**2)
loss = self.loss_weight * loss / (C * N)
return loss

View File

@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
@LOSSES.register_module()
class KLDivergence(nn.Module):
"""A measure of how one probability distribution Q is different from a
second, reference probability distribution P.
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(
self,
tau=1.0,
loss_weight=1.0,
):
super(KLDivergence, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
def forward(self, preds_S, preds_T):
"""Forward computation.
Args:
preds_S (torch.Tensor): The student model prediction with
shape (N, C, W, H).
preds_T (torch.Tensor): The teacher model prediction with
shape (N, C, W, H).
Return:
torch.Tensor: The calculated loss value.
"""
N, C = preds_S.shape
preds_T = preds_T.detach()
softmax_pred_T = F.softmax(preds_T / self.tau, dim=1)
logsoftmax = torch.nn.LogSoftmax(dim=1)
loss = torch.sum(-softmax_pred_T * logsoftmax(preds_S / self.tau)) * (
self.tau**2)
return self.loss_weight * loss / N

View File

@ -0,0 +1,61 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
@LOSSES.register_module()
class WSLD(nn.Module):
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
"""PyTorch version of `Rethinking Soft Labels for Knowledge
Distillation: A Bias-Variance Tradeoff Perspective
<https://arxiv.org/abs/2102.00650>`_.
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
weight (float): Weight of loss. Defaults to 1.0.
num_classes (int): Defaults to 1000.
"""
super(WSLD, self).__init__()
self.tau = tau
self.loss_weight = loss_weight
self.num_classes = num_classes
self.softmax = nn.Softmax(dim=1).cuda()
self.logsoftmax = nn.LogSoftmax(dim=1).cuda()
def forward(self, student, teacher):
gt_labels = self.current_data['gt_label']
student_logits = student / self.tau
teacher_logits = teacher / self.tau
teacher_probs = self.softmax(teacher_logits)
ce_loss = -torch.sum(
teacher_probs * self.logsoftmax(student_logits), 1, keepdim=True)
student_detach = student.detach()
teacher_detach = teacher.detach()
log_softmax_s = self.logsoftmax(student_detach)
log_softmax_t = self.logsoftmax(teacher_detach)
one_hot_labels = F.one_hot(
gt_labels, num_classes=self.num_classes).float()
ce_loss_s = -torch.sum(one_hot_labels * log_softmax_s, 1, keepdim=True)
ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True)
focal_weight = ce_loss_s / (ce_loss_t + 1e-7)
ratio_lower = torch.zeros(1).cuda()
focal_weight = torch.max(focal_weight, ratio_lower)
focal_weight = 1 - torch.exp(-focal_weight)
ce_loss = focal_weight * ce_loss
loss = (self.tau**2) * torch.mean(ce_loss)
loss = self.loss_weight * loss
return loss