[Feature] Add loss
parent
cb5cb6da05
commit
44196ae73b
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cwd import ChannelWiseDivergence
|
||||
from .kl_divergence import KLDivergence
|
||||
from .weighted_soft_label_distillation import WSLD
|
||||
|
||||
__all__ = ['ChannelWiseDivergence', 'KLDivergence', 'WSLD']
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class ChannelWiseDivergence(nn.Module):
|
||||
"""PyTorch version of `Channel-wise Distillation for Semantic Segmentation.
|
||||
|
||||
<https://arxiv.org/abs/2011.13256>`_.
|
||||
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
weight (float): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tau=1.0,
|
||||
loss_weight=1.0,
|
||||
):
|
||||
super(ChannelWiseDivergence, self).__init__()
|
||||
self.tau = tau
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, preds_S, preds_T):
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
preds_S (torch.Tensor): The student model prediction with
|
||||
shape (N, C, W, H).
|
||||
preds_T (torch.Tensor): The teacher model prediction with
|
||||
shape (N, C, W, H).
|
||||
|
||||
Return:
|
||||
torch.Tensor: The calculated loss value.
|
||||
"""
|
||||
assert preds_S.shape[-2:] == preds_T.shape[-2:]
|
||||
N, C, W, H = preds_S.shape
|
||||
|
||||
softmax_pred_T = F.softmax(preds_T.view(-1, W * H) / self.tau, dim=1)
|
||||
|
||||
logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
loss = torch.sum(softmax_pred_T *
|
||||
logsoftmax(preds_T.view(-1, W * H) / self.tau) -
|
||||
softmax_pred_T *
|
||||
logsoftmax(preds_S.view(-1, W * H) / self.tau)) * (
|
||||
self.tau**2)
|
||||
|
||||
loss = self.loss_weight * loss / (C * N)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class KLDivergence(nn.Module):
|
||||
"""A measure of how one probability distribution Q is different from a
|
||||
second, reference probability distribution P.
|
||||
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tau=1.0,
|
||||
loss_weight=1.0,
|
||||
):
|
||||
super(KLDivergence, self).__init__()
|
||||
self.tau = tau
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self, preds_S, preds_T):
|
||||
"""Forward computation.
|
||||
|
||||
Args:
|
||||
preds_S (torch.Tensor): The student model prediction with
|
||||
shape (N, C, W, H).
|
||||
preds_T (torch.Tensor): The teacher model prediction with
|
||||
shape (N, C, W, H).
|
||||
|
||||
Return:
|
||||
torch.Tensor: The calculated loss value.
|
||||
"""
|
||||
N, C = preds_S.shape
|
||||
preds_T = preds_T.detach()
|
||||
softmax_pred_T = F.softmax(preds_T / self.tau, dim=1)
|
||||
|
||||
logsoftmax = torch.nn.LogSoftmax(dim=1)
|
||||
loss = torch.sum(-softmax_pred_T * logsoftmax(preds_S / self.tau)) * (
|
||||
self.tau**2)
|
||||
return self.loss_weight * loss / N
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class WSLD(nn.Module):
|
||||
|
||||
def __init__(self, tau=1.0, loss_weight=1.0, num_classes=1000):
|
||||
"""PyTorch version of `Rethinking Soft Labels for Knowledge
|
||||
Distillation: A Bias-Variance Tradeoff Perspective
|
||||
<https://arxiv.org/abs/2102.00650>`_.
|
||||
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
weight (float): Weight of loss. Defaults to 1.0.
|
||||
num_classes (int): Defaults to 1000.
|
||||
"""
|
||||
super(WSLD, self).__init__()
|
||||
|
||||
self.tau = tau
|
||||
self.loss_weight = loss_weight
|
||||
self.num_classes = num_classes
|
||||
self.softmax = nn.Softmax(dim=1).cuda()
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1).cuda()
|
||||
|
||||
def forward(self, student, teacher):
|
||||
|
||||
gt_labels = self.current_data['gt_label']
|
||||
|
||||
student_logits = student / self.tau
|
||||
teacher_logits = teacher / self.tau
|
||||
|
||||
teacher_probs = self.softmax(teacher_logits)
|
||||
|
||||
ce_loss = -torch.sum(
|
||||
teacher_probs * self.logsoftmax(student_logits), 1, keepdim=True)
|
||||
|
||||
student_detach = student.detach()
|
||||
teacher_detach = teacher.detach()
|
||||
log_softmax_s = self.logsoftmax(student_detach)
|
||||
log_softmax_t = self.logsoftmax(teacher_detach)
|
||||
one_hot_labels = F.one_hot(
|
||||
gt_labels, num_classes=self.num_classes).float()
|
||||
ce_loss_s = -torch.sum(one_hot_labels * log_softmax_s, 1, keepdim=True)
|
||||
ce_loss_t = -torch.sum(one_hot_labels * log_softmax_t, 1, keepdim=True)
|
||||
|
||||
focal_weight = ce_loss_s / (ce_loss_t + 1e-7)
|
||||
ratio_lower = torch.zeros(1).cuda()
|
||||
focal_weight = torch.max(focal_weight, ratio_lower)
|
||||
focal_weight = 1 - torch.exp(-focal_weight)
|
||||
ce_loss = focal_weight * ce_loss
|
||||
|
||||
loss = (self.tau**2) * torch.mean(ce_loss)
|
||||
|
||||
loss = self.loss_weight * loss
|
||||
|
||||
return loss
|
Loading…
Reference in New Issue