[Feature] Add Decoupled KD Loss (#222)

* add DKDLoss, config

* linting

* linting

* reset default reduction

* dkd ut

* Update decoupled_kd.py

* Update decoupled_kd.py

* Update decoupled_kd.py

* fix commit

* fix readme

* fix comments

* linting comment

* rename loss params

* fix docstring

* Update decoupled_kd.py

* fix gt from config

* merge fix

* fix ut & wsld

* Update README.md

* Update README.md

* add Acknowledgement

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* fix readme style

* fix md

Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
This commit is contained in:
spynccat 2022-08-15 14:59:24 +08:00 committed by GitHub
parent 6e8ebfd85a
commit a1937fd5a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 282 additions and 22 deletions

View File

@ -0,0 +1,48 @@
# Decoupled Knowledge Distillation
> [Decoupled Knowledge Distillation](https://arxiv.org/pdf/2203.08679.pdf)
<!-- [ALGORITHM] -->
## Abstract
State-of-the-art distillation methods are mainly based on distilling deep features from intermediate layers, while the significance of logit distillation is greatly overlooked. To provide a novel viewpoint to study logit distillation, we reformulate the classical KD loss into two parts, i.e., target class knowledge distillation (TCKD) and non-target class knowledge distillation (NCKD). We empirically investigate and prove the effects of the two parts: TCKD transfers knowledge concerning the "difficulty" of training samples, while NCKD is the prominent reason why logit distillation works. More importantly, we reveal that the classical KD loss is a coupled formulation, which (1) suppresses the effectiveness of NCKD and (2) limits the flexibility to balance these two parts. To address these issues, we present Decoupled Knowledge Distillation (DKD), enabling TCKD and NCKD to play their roles more efficiently and flexibly. Compared with complex feature-based methods, our DKD achieves comparable or even better results and has better training efficiency on CIFAR-100, ImageNet, and MS-COCO datasets for image classification and object detection tasks. This paper proves the great potential of logit distillation, and we hope it will be helpful for future research. The code is available at https://github.com/megvii-research/mdistiller.
![avatar](../../../../docs/en/imgs/model_zoo/dkd/dkd.png)
## Results and models
### Classification
| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
| -------- | --------- | --------- | --------- | --------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------------- |
| ImageNet | ResNet-18 | ResNet-34 | 71.368 | 90.256 | [config](dkd_logits_r34_r18_8xb32_in1k.py) | [model & log](https://autolink.sensetime.com/pages/model/share/afc68955-e25d-4488-b044-5e801b3ff62f) |
## Citation
```latex
@article{zhao2022decoupled,
title={Decoupled Knowledge Distillation},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
```
## Getting Started
### Download teacher ckpt from
https://mmclassification.readthedocs.io/en/latest/papers/resnet.html
### Distillation training.
```bash
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py \
$DISTILLATION_WORK_DIR
```
## Acknowledgement
Shout out to Davidgzx for his special contribution.

View File

@ -0,0 +1,45 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc'),
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_dkd=dict(
type='DKDLoss',
tau=1,
beta=0.5,
loss_weight=1,
reduction='mean')),
loss_forward_mappings=dict(
loss_dkd=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc'),
gt_labels=dict(
recorder='gt_labels', from_student=True, data_idx=1)))))
find_unused_parameters = True
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

View File

@ -20,9 +20,10 @@ model = dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc'),
data_samples=dict(type='ModuleInputs', source='')),
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
@ -31,8 +32,8 @@ model = dict(
loss_wsld=dict(
student=dict(recorder='fc', from_student=True),
teacher=dict(recorder='fc', from_student=False),
data_samples=dict(
recorder='data_samples', from_student=True, data_idx=1)))))
gt_labels=dict(
recorder='gt_labels', from_student=True, data_idx=1)))))
find_unused_parameters = True

Binary file not shown.

After

Width:  |  Height:  |  Size: 522 KiB

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ab_loss import ABLoss
from .cwd import ChannelWiseDivergence
from .decoupled_kd import DKDLoss
from .kl_divergence import KLDivergence
from .l2_loss import L2Loss
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
@ -8,5 +9,5 @@ from .weighted_soft_label_distillation import WSLD
__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss'
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss'
]

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmrazor.registry import MODELS
@MODELS.register_module()
class DKDLoss(nn.Module):
"""Decoupled Knowledge Distillation, CVPR2022.
link: https://arxiv.org/abs/2203.08679
reformulate the classical KD loss into two parts:
1. target class knowledge distillation (TCKD)
2. non-target class knowledge distillation (NCKD).
Args:
tau (float): Temperature coefficient. Defaults to 1.0.
alpha (float): Weight of TCKD loss. Defaults to 1.0.
beta (float): Weight of NCKD loss. Defaults to 1.0.
reduction (str): Specifies the reduction to apply to the loss:
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
``'none'``: no reduction will be applied,
``'batchmean'``: the sum of the output will be divided by
the batchsize,
``'sum'``: the output will be summed,
``'mean'``: the output will be divided by the number of
elements in the output.
Default: ``'batchmean'``
loss_weight (float): Weight of loss. Defaults to 1.0.
"""
def __init__(
self,
tau: float = 1.0,
alpha: float = 1.0,
beta: float = 1.0,
reduction: str = 'batchmean',
loss_weight: float = 1.0,
) -> None:
super(DKDLoss, self).__init__()
self.tau = tau
accept_reduction = {'none', 'batchmean', 'sum', 'mean'}
assert reduction in accept_reduction, \
f'KLDivergence supports reduction {accept_reduction}, ' \
f'but gets {reduction}.'
self.reduction = reduction
self.alpha = alpha
self.beta = beta
self.loss_weight = loss_weight
def forward(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
gt_labels: torch.Tensor,
) -> torch.Tensor:
"""DKDLoss forward function.
Args:
preds_S (torch.Tensor): The student model prediction, shape (N, C).
preds_T (torch.Tensor): The teacher model prediction, shape (N, C).
gt_labels (torch.Tensor): The gt label tensor, shape (N, C).
Return:
torch.Tensor: The calculated loss value.
"""
gt_mask = self._get_gt_mask(preds_S, gt_labels)
tckd_loss = self._get_tckd_loss(preds_S, preds_T, gt_labels, gt_mask)
nckd_loss = self._get_nckd_loss(preds_S, preds_T, gt_mask)
loss = self.alpha * tckd_loss + self.beta * nckd_loss
return self.loss_weight * loss
def _get_nckd_loss(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
gt_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate non-target class knowledge distillation."""
# implementation to mask out gt_mask, faster than index
s_nckd = F.log_softmax(preds_S / self.tau - 1000.0 * gt_mask, dim=1)
t_nckd = F.softmax(preds_T / self.tau - 1000.0 * gt_mask, dim=1)
return self._kl_loss(s_nckd, t_nckd)
def _get_tckd_loss(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
gt_labels: torch.Tensor,
gt_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate target class knowledge distillation."""
non_gt_mask = self._get_non_gt_mask(preds_S, gt_labels)
s_tckd = F.softmax(preds_S / self.tau, dim=1)
t_tckd = F.softmax(preds_T / self.tau, dim=1)
mask_student = torch.log(self._cat_mask(s_tckd, gt_mask, non_gt_mask))
mask_teacher = self._cat_mask(t_tckd, gt_mask, non_gt_mask)
return self._kl_loss(mask_student, mask_teacher)
def _kl_loss(
self,
preds_S: torch.Tensor,
preds_T: torch.Tensor,
) -> torch.Tensor:
"""Calculate the KL Divergence."""
kl_loss = F.kl_div(
preds_S, preds_T, size_average=False,
reduction=self.reduction) * self.tau**2
return kl_loss
def _cat_mask(
self,
tckd: torch.Tensor,
gt_mask: torch.Tensor,
non_gt_mask: torch.Tensor,
) -> torch.Tensor:
"""Calculate preds of target (pt) & preds of non-target (pnt)."""
t1 = (tckd * gt_mask).sum(dim=1, keepdims=True)
t2 = (tckd * non_gt_mask).sum(dim=1, keepdims=True)
return torch.cat([t1, t2], dim=1)
def _get_gt_mask(
self,
logits: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Calculate groundtruth mask on logits with target class tensor.
Args:
logits (torch.Tensor): The prediction logits with shape (N, C).
target (torch.Tensor): The gt_label target with shape (N, C).
Return:
torch.Tensor: The masked logits.
"""
target = target.reshape(-1)
return torch.zeros_like(logits).scatter_(1, target.unsqueeze(1),
1).bool()
def _get_non_gt_mask(
self,
logits: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Calculate non-groundtruth mask on logits with target class tensor.
Args:
logits (torch.Tensor): The prediction logits with shape (N, C).
target (torch.Tensor): The gt_label target with shape (N, C).
Return:
torch.Tensor: The masked logits.
"""
target = target.reshape(-1)
return torch.ones_like(logits).scatter_(1, target.unsqueeze(1),
0).bool()

View File

@ -27,17 +27,7 @@ class WSLD(nn.Module):
self.softmax = nn.Softmax(dim=1)
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, student, teacher, data_samples):
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
gt_labels = torch.stack([i.gt_label.score for i in data_samples])
one_hot_labels = gt_labels.float()
else:
gt_labels = torch.hstack([i.gt_label.label for i in data_samples])
one_hot_labels = F.one_hot(
gt_labels, num_classes=self.num_classes).float()
def forward(self, student, teacher, gt_labels):
student_logits = student / self.tau
teacher_logits = teacher / self.tau

View File

@ -3,7 +3,7 @@ from unittest import TestCase
import torch
from mmrazor.models import ABLoss
from mmrazor.models import ABLoss, DKDLoss
class TestLosses(TestCase):
@ -14,16 +14,28 @@ class TestLosses(TestCase):
cls.feats_2d = torch.randn(5, 2, 3)
cls.feats_3d = torch.randn(5, 2, 3, 3)
def normal_test_1d(self, loss_instance):
loss_1d = loss_instance.forward(self.feats_1d, self.feats_1d)
num_classes = 6
cls.labels = torch.randint(0, num_classes, [5])
def normal_test_1d(self, loss_instance, labels=False):
args = tuple([self.feats_1d, self.feats_1d])
if labels:
args += (self.labels, )
loss_1d = loss_instance.forward(*args)
self.assertTrue(loss_1d.numel() == 1)
def normal_test_2d(self, loss_instance):
loss_2d = loss_instance.forward(self.feats_2d, self.feats_2d)
def normal_test_2d(self, loss_instance, labels=False):
args = tuple([self.feats_2d, self.feats_2d])
if labels:
args += (self.labels, )
loss_2d = loss_instance.forward(*args)
self.assertTrue(loss_2d.numel() == 1)
def normal_test_3d(self, loss_instance):
loss_3d = loss_instance.forward(self.feats_3d, self.feats_3d)
def normal_test_3d(self, loss_instance, labels=False):
args = tuple([self.feats_3d, self.feats_3d])
if labels:
args += (self.labels, )
loss_3d = loss_instance.forward(*args)
self.assertTrue(loss_3d.numel() == 1)
def test_ab_loss(self):
@ -32,3 +44,9 @@ class TestLosses(TestCase):
self.normal_test_1d(ab_loss)
self.normal_test_2d(ab_loss)
self.normal_test_3d(ab_loss)
def test_dkd_loss(self):
dkd_loss_cfg = dict(loss_weight=1.0)
dkd_loss = DKDLoss(**dkd_loss_cfg)
# dkd requires label logits
self.normal_test_1d(dkd_loss, labels=True)