mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
[Feature] Add Decoupled KD Loss (#222)
* add DKDLoss, config * linting * linting * reset default reduction * dkd ut * Update decoupled_kd.py * Update decoupled_kd.py * Update decoupled_kd.py * fix commit * fix readme * fix comments * linting comment * rename loss params * fix docstring * Update decoupled_kd.py * fix gt from config * merge fix * fix ut & wsld * Update README.md * Update README.md * add Acknowledgement * Update README.md * Update README.md * Update README.md * Update README.md * fix readme style * fix md Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
This commit is contained in:
parent
6e8ebfd85a
commit
a1937fd5a6
48
configs/distill/mmcls/dkd/README.md
Normal file
48
configs/distill/mmcls/dkd/README.md
Normal file
@ -0,0 +1,48 @@
|
||||
# Decoupled Knowledge Distillation
|
||||
|
||||
> [Decoupled Knowledge Distillation](https://arxiv.org/pdf/2203.08679.pdf)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
State-of-the-art distillation methods are mainly based on distilling deep features from intermediate layers, while the significance of logit distillation is greatly overlooked. To provide a novel viewpoint to study logit distillation, we reformulate the classical KD loss into two parts, i.e., target class knowledge distillation (TCKD) and non-target class knowledge distillation (NCKD). We empirically investigate and prove the effects of the two parts: TCKD transfers knowledge concerning the "difficulty" of training samples, while NCKD is the prominent reason why logit distillation works. More importantly, we reveal that the classical KD loss is a coupled formulation, which (1) suppresses the effectiveness of NCKD and (2) limits the flexibility to balance these two parts. To address these issues, we present Decoupled Knowledge Distillation (DKD), enabling TCKD and NCKD to play their roles more efficiently and flexibly. Compared with complex feature-based methods, our DKD achieves comparable or even better results and has better training efficiency on CIFAR-100, ImageNet, and MS-COCO datasets for image classification and object detection tasks. This paper proves the great potential of logit distillation, and we hope it will be helpful for future research. The code is available at https://github.com/megvii-research/mdistiller.
|
||||
|
||||

|
||||
|
||||
## Results and models
|
||||
|
||||
### Classification
|
||||
|
||||
| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
|
||||
| -------- | --------- | --------- | --------- | --------- | ------------------------------------------ | ---------------------------------------------------------------------------------------------------- |
|
||||
| ImageNet | ResNet-18 | ResNet-34 | 71.368 | 90.256 | [config](dkd_logits_r34_r18_8xb32_in1k.py) | [model & log](https://autolink.sensetime.com/pages/model/share/afc68955-e25d-4488-b044-5e801b3ff62f) |
|
||||
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
@article{zhao2022decoupled,
|
||||
title={Decoupled Knowledge Distillation},
|
||||
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
|
||||
journal={arXiv preprint arXiv:2203.08679},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Download teacher ckpt from
|
||||
|
||||
https://mmclassification.readthedocs.io/en/latest/papers/resnet.html
|
||||
|
||||
### Distillation training.
|
||||
|
||||
```bash
|
||||
sh tools/slurm_train.sh $PARTITION $JOB_NAME \
|
||||
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py \
|
||||
$DISTILLATION_WORK_DIR
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
Shout out to Davidgzx for his special contribution.
|
45
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py
Normal file
45
configs/distill/mmcls/dkd/dkd_logits_r34_r18_8xb32_in1k.py
Normal file
@ -0,0 +1,45 @@
|
||||
_base_ = [
|
||||
'mmcls::_base_/datasets/imagenet_bs32.py',
|
||||
'mmcls::_base_/schedules/imagenet_bs256.py',
|
||||
'mmcls::_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
model = dict(
|
||||
_scope_='mmrazor',
|
||||
type='SingleTeacherDistill',
|
||||
data_preprocessor=dict(
|
||||
type='ImgDataPreprocessor',
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
bgr_to_rgb=True),
|
||||
architecture=dict(
|
||||
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
|
||||
teacher=dict(
|
||||
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
|
||||
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc'),
|
||||
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
|
||||
teacher_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
distill_losses=dict(
|
||||
loss_dkd=dict(
|
||||
type='DKDLoss',
|
||||
tau=1,
|
||||
beta=0.5,
|
||||
loss_weight=1,
|
||||
reduction='mean')),
|
||||
loss_forward_mappings=dict(
|
||||
loss_dkd=dict(
|
||||
preds_S=dict(from_student=True, recorder='fc'),
|
||||
preds_T=dict(from_student=False, recorder='fc'),
|
||||
gt_labels=dict(
|
||||
recorder='gt_labels', from_student=True, data_idx=1)))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
|
@ -20,9 +20,10 @@ model = dict(
|
||||
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
|
||||
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
|
||||
distiller=dict(
|
||||
type='ConfigurableDistiller',
|
||||
student_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc'),
|
||||
data_samples=dict(type='ModuleInputs', source='')),
|
||||
gt_labels=dict(type='ModuleInputs', source='head.loss_module')),
|
||||
teacher_recorders=dict(
|
||||
fc=dict(type='ModuleOutputs', source='head.fc')),
|
||||
distill_losses=dict(
|
||||
@ -31,8 +32,8 @@ model = dict(
|
||||
loss_wsld=dict(
|
||||
student=dict(recorder='fc', from_student=True),
|
||||
teacher=dict(recorder='fc', from_student=False),
|
||||
data_samples=dict(
|
||||
recorder='data_samples', from_student=True, data_idx=1)))))
|
||||
gt_labels=dict(
|
||||
recorder='gt_labels', from_student=True, data_idx=1)))))
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
|
BIN
docs/en/imgs/model_zoo/dkd/dkd.png
Normal file
BIN
docs/en/imgs/model_zoo/dkd/dkd.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 522 KiB |
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ab_loss import ABLoss
|
||||
from .cwd import ChannelWiseDivergence
|
||||
from .decoupled_kd import DKDLoss
|
||||
from .kl_divergence import KLDivergence
|
||||
from .l2_loss import L2Loss
|
||||
from .relational_kd import AngleWiseRKD, DistanceWiseRKD
|
||||
@ -8,5 +9,5 @@ from .weighted_soft_label_distillation import WSLD
|
||||
|
||||
__all__ = [
|
||||
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
|
||||
'WSLD', 'L2Loss', 'ABLoss'
|
||||
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss'
|
||||
]
|
||||
|
157
mmrazor/models/losses/decoupled_kd.py
Normal file
157
mmrazor/models/losses/decoupled_kd.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmrazor.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DKDLoss(nn.Module):
|
||||
"""Decoupled Knowledge Distillation, CVPR2022.
|
||||
|
||||
link: https://arxiv.org/abs/2203.08679
|
||||
reformulate the classical KD loss into two parts:
|
||||
1. target class knowledge distillation (TCKD)
|
||||
2. non-target class knowledge distillation (NCKD).
|
||||
Args:
|
||||
tau (float): Temperature coefficient. Defaults to 1.0.
|
||||
alpha (float): Weight of TCKD loss. Defaults to 1.0.
|
||||
beta (float): Weight of NCKD loss. Defaults to 1.0.
|
||||
reduction (str): Specifies the reduction to apply to the loss:
|
||||
``'none'`` | ``'batchmean'`` | ``'sum'`` | ``'mean'``.
|
||||
``'none'``: no reduction will be applied,
|
||||
``'batchmean'``: the sum of the output will be divided by
|
||||
the batchsize,
|
||||
``'sum'``: the output will be summed,
|
||||
``'mean'``: the output will be divided by the number of
|
||||
elements in the output.
|
||||
Default: ``'batchmean'``
|
||||
loss_weight (float): Weight of loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tau: float = 1.0,
|
||||
alpha: float = 1.0,
|
||||
beta: float = 1.0,
|
||||
reduction: str = 'batchmean',
|
||||
loss_weight: float = 1.0,
|
||||
) -> None:
|
||||
super(DKDLoss, self).__init__()
|
||||
self.tau = tau
|
||||
accept_reduction = {'none', 'batchmean', 'sum', 'mean'}
|
||||
assert reduction in accept_reduction, \
|
||||
f'KLDivergence supports reduction {accept_reduction}, ' \
|
||||
f'but gets {reduction}.'
|
||||
self.reduction = reduction
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
preds_S: torch.Tensor,
|
||||
preds_T: torch.Tensor,
|
||||
gt_labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""DKDLoss forward function.
|
||||
|
||||
Args:
|
||||
preds_S (torch.Tensor): The student model prediction, shape (N, C).
|
||||
preds_T (torch.Tensor): The teacher model prediction, shape (N, C).
|
||||
gt_labels (torch.Tensor): The gt label tensor, shape (N, C).
|
||||
|
||||
Return:
|
||||
torch.Tensor: The calculated loss value.
|
||||
"""
|
||||
gt_mask = self._get_gt_mask(preds_S, gt_labels)
|
||||
tckd_loss = self._get_tckd_loss(preds_S, preds_T, gt_labels, gt_mask)
|
||||
nckd_loss = self._get_nckd_loss(preds_S, preds_T, gt_mask)
|
||||
loss = self.alpha * tckd_loss + self.beta * nckd_loss
|
||||
return self.loss_weight * loss
|
||||
|
||||
def _get_nckd_loss(
|
||||
self,
|
||||
preds_S: torch.Tensor,
|
||||
preds_T: torch.Tensor,
|
||||
gt_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate non-target class knowledge distillation."""
|
||||
# implementation to mask out gt_mask, faster than index
|
||||
s_nckd = F.log_softmax(preds_S / self.tau - 1000.0 * gt_mask, dim=1)
|
||||
t_nckd = F.softmax(preds_T / self.tau - 1000.0 * gt_mask, dim=1)
|
||||
return self._kl_loss(s_nckd, t_nckd)
|
||||
|
||||
def _get_tckd_loss(
|
||||
self,
|
||||
preds_S: torch.Tensor,
|
||||
preds_T: torch.Tensor,
|
||||
gt_labels: torch.Tensor,
|
||||
gt_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate target class knowledge distillation."""
|
||||
non_gt_mask = self._get_non_gt_mask(preds_S, gt_labels)
|
||||
s_tckd = F.softmax(preds_S / self.tau, dim=1)
|
||||
t_tckd = F.softmax(preds_T / self.tau, dim=1)
|
||||
mask_student = torch.log(self._cat_mask(s_tckd, gt_mask, non_gt_mask))
|
||||
mask_teacher = self._cat_mask(t_tckd, gt_mask, non_gt_mask)
|
||||
return self._kl_loss(mask_student, mask_teacher)
|
||||
|
||||
def _kl_loss(
|
||||
self,
|
||||
preds_S: torch.Tensor,
|
||||
preds_T: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate the KL Divergence."""
|
||||
kl_loss = F.kl_div(
|
||||
preds_S, preds_T, size_average=False,
|
||||
reduction=self.reduction) * self.tau**2
|
||||
return kl_loss
|
||||
|
||||
def _cat_mask(
|
||||
self,
|
||||
tckd: torch.Tensor,
|
||||
gt_mask: torch.Tensor,
|
||||
non_gt_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate preds of target (pt) & preds of non-target (pnt)."""
|
||||
t1 = (tckd * gt_mask).sum(dim=1, keepdims=True)
|
||||
t2 = (tckd * non_gt_mask).sum(dim=1, keepdims=True)
|
||||
return torch.cat([t1, t2], dim=1)
|
||||
|
||||
def _get_gt_mask(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate groundtruth mask on logits with target class tensor.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The prediction logits with shape (N, C).
|
||||
target (torch.Tensor): The gt_label target with shape (N, C).
|
||||
|
||||
Return:
|
||||
torch.Tensor: The masked logits.
|
||||
"""
|
||||
target = target.reshape(-1)
|
||||
return torch.zeros_like(logits).scatter_(1, target.unsqueeze(1),
|
||||
1).bool()
|
||||
|
||||
def _get_non_gt_mask(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Calculate non-groundtruth mask on logits with target class tensor.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The prediction logits with shape (N, C).
|
||||
target (torch.Tensor): The gt_label target with shape (N, C).
|
||||
|
||||
Return:
|
||||
torch.Tensor: The masked logits.
|
||||
"""
|
||||
target = target.reshape(-1)
|
||||
return torch.ones_like(logits).scatter_(1, target.unsqueeze(1),
|
||||
0).bool()
|
@ -27,17 +27,7 @@ class WSLD(nn.Module):
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
self.logsoftmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, student, teacher, data_samples):
|
||||
|
||||
# Unpack data samples and pack targets
|
||||
if 'score' in data_samples[0].gt_label:
|
||||
# Batch augmentation may convert labels to one-hot format scores.
|
||||
gt_labels = torch.stack([i.gt_label.score for i in data_samples])
|
||||
one_hot_labels = gt_labels.float()
|
||||
else:
|
||||
gt_labels = torch.hstack([i.gt_label.label for i in data_samples])
|
||||
one_hot_labels = F.one_hot(
|
||||
gt_labels, num_classes=self.num_classes).float()
|
||||
def forward(self, student, teacher, gt_labels):
|
||||
|
||||
student_logits = student / self.tau
|
||||
teacher_logits = teacher / self.tau
|
||||
|
@ -3,7 +3,7 @@ from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmrazor.models import ABLoss
|
||||
from mmrazor.models import ABLoss, DKDLoss
|
||||
|
||||
|
||||
class TestLosses(TestCase):
|
||||
@ -14,16 +14,28 @@ class TestLosses(TestCase):
|
||||
cls.feats_2d = torch.randn(5, 2, 3)
|
||||
cls.feats_3d = torch.randn(5, 2, 3, 3)
|
||||
|
||||
def normal_test_1d(self, loss_instance):
|
||||
loss_1d = loss_instance.forward(self.feats_1d, self.feats_1d)
|
||||
num_classes = 6
|
||||
cls.labels = torch.randint(0, num_classes, [5])
|
||||
|
||||
def normal_test_1d(self, loss_instance, labels=False):
|
||||
args = tuple([self.feats_1d, self.feats_1d])
|
||||
if labels:
|
||||
args += (self.labels, )
|
||||
loss_1d = loss_instance.forward(*args)
|
||||
self.assertTrue(loss_1d.numel() == 1)
|
||||
|
||||
def normal_test_2d(self, loss_instance):
|
||||
loss_2d = loss_instance.forward(self.feats_2d, self.feats_2d)
|
||||
def normal_test_2d(self, loss_instance, labels=False):
|
||||
args = tuple([self.feats_2d, self.feats_2d])
|
||||
if labels:
|
||||
args += (self.labels, )
|
||||
loss_2d = loss_instance.forward(*args)
|
||||
self.assertTrue(loss_2d.numel() == 1)
|
||||
|
||||
def normal_test_3d(self, loss_instance):
|
||||
loss_3d = loss_instance.forward(self.feats_3d, self.feats_3d)
|
||||
def normal_test_3d(self, loss_instance, labels=False):
|
||||
args = tuple([self.feats_3d, self.feats_3d])
|
||||
if labels:
|
||||
args += (self.labels, )
|
||||
loss_3d = loss_instance.forward(*args)
|
||||
self.assertTrue(loss_3d.numel() == 1)
|
||||
|
||||
def test_ab_loss(self):
|
||||
@ -32,3 +44,9 @@ class TestLosses(TestCase):
|
||||
self.normal_test_1d(ab_loss)
|
||||
self.normal_test_2d(ab_loss)
|
||||
self.normal_test_3d(ab_loss)
|
||||
|
||||
def test_dkd_loss(self):
|
||||
dkd_loss_cfg = dict(loss_weight=1.0)
|
||||
dkd_loss = DKDLoss(**dkd_loss_cfg)
|
||||
# dkd requires label logits
|
||||
self.normal_test_1d(dkd_loss, labels=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user