add dist distillation loss (#466)

* add dist

* update

* update

* update readme

* update config

---------

Co-authored-by: liukai <your_email@abc.example>
pull/479/head
LKJacky 2023-03-09 16:33:37 +08:00 committed by GitHub
parent 5a9aa24c16
commit 01f671c72d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 145 additions and 2 deletions

1
.gitignore vendored
View File

@ -11,7 +11,6 @@ __pycache__/
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/

View File

@ -0,0 +1,45 @@
# KD
> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/abs/2205.10536)
<!-- [ALGORITHM] -->
## Abstract
Unlike existing knowledge distillation methods focus on the baseline settings, where the teacher models and training strategies are not that strong and competing as state-of-the-art approaches, this paper presents a method dubbed DIST to distill better from a stronger teacher. We empirically find that the discrepancy of predictions between the student and a stronger teacher may tend to be fairly severer. As a result, the exact match of predictions in KL divergence would disturb the training and make existing methods perform poorly. In this paper, we show that simply preserving the relations between the predictions of teacher and student would suffice, and propose a correlation-based loss to capture the intrinsic inter-class relations from the teacher explicitly. Besides, considering that different instances have different semantic similarities to each class, we also extend this relational match to the intra-class level. Our method is simple yet practical, and extensive experiments demonstrate that it adapts well to various architectures, model sizes and training strategies, and can achieve state-of-the-art performance consistently on image classification, object detection, and semantic segmentation tasks. Code is available at: this https URL .
## Results and models
### Classification
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :------: | :---------------: | :---------------: | :---: | :----: | :----: | :-----------------: | :--------------------------------------------------------------- |
| logits | ImageNet | [resnet34][r34_c] | [resnet18][r18_c] | 71.61 | 73.62 | 69.90 | [config][distill_c] | [teacher][r34_pth] \| [model][distill_pth] \| [log][distill_log] |
**Note**
There are fluctuations in the results of the experiments of DIST loss. For example, we run three times of the official code of DIST and get three different results.
| Time | Top-1 |
| ---- | ----- |
| 1th | 71.69 |
| 2nd | 71.82 |
| 3rd | 71.90 |
## Citation
```latex
@article{huang2022knowledge,
title={Knowledge Distillation from A Stronger Teacher},
author={Huang, Tao and You, Shan and Wang, Fei and Qian, Chen and Xu, Chang},
journal={arXiv preprint arXiv:2205.10536},
year={2022}
}
```
[distill_c]: ./dist_logits_resnet34_resnet18_8xb32_in1k.py
[distill_log]: https://download.openmmlab.com/mmrazor/v1/distillation/dist_logits_resnet34_resnet18_8xb32_in1k.json
[distill_pth]: https://download.openmmlab.com/mmrazor/v1/distillation/dist_logits_resnet34_resnet18_8xb32_in1k.pth
[r18_c]: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet18_8xb32_in1k.py
[r34_c]: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet34_8xb32_in1k.py
[r34_pth]: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth

View File

@ -0,0 +1,45 @@
_base_ = [
'mmcls::_base_/datasets/imagenet_bs32.py',
'mmcls::_base_/schedules/imagenet_bs256.py',
'mmcls::_base_/default_runtime.py'
]
teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(
type='DISTLoss',
inter_loss_weight=1.0,
intra_loss_weight=1.0,
tau=1,
loss_weight=2,
)),
loss_forward_mappings=dict(
loss_kl=dict(
logits_S=dict(from_student=True, recorder='fc'),
logits_T=dict(from_student=False, recorder='fc')))))
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
optim_wrapper = dict(optimizer=dict(nesterov=True))

View File

@ -6,6 +6,7 @@ from .cross_entropy_loss import CrossEntropyLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
from .dist_loss import DISTLoss
from .factor_transfer_loss import FTLoss
from .fbkd_loss import FBKDLoss
from .kd_soft_ce_loss import KDSoftCELoss
@ -22,5 +23,6 @@ __all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss',
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss'
'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss',
'DISTLoss'
]

View File

@ -0,0 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmrazor.registry import MODELS
def cosine_similarity(a, b, eps=1e-8):
return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)
def pearson_correlation(a, b, eps=1e-8):
return cosine_similarity(a - a.mean(1, keepdim=True),
b - b.mean(1, keepdim=True), eps)
def inter_class_relation(y_s, y_t):
return 1 - pearson_correlation(y_s, y_t).mean()
def intra_class_relation(y_s, y_t):
return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1))
@MODELS.register_module()
class DISTLoss(nn.Module):
def __init__(
self,
inter_loss_weight=1.0,
intra_loss_weight=1.0,
tau=1.0,
loss_weight: float = 1.0,
teacher_detach: bool = True,
):
super(DISTLoss, self).__init__()
self.inter_loss_weight = inter_loss_weight
self.intra_loss_weight = intra_loss_weight
self.tau = tau
self.loss_weight = loss_weight
self.teacher_detach = teacher_detach
def forward(self, logits_S, logits_T: torch.Tensor):
if self.teacher_detach:
logits_T = logits_T.detach()
y_s = (logits_S / self.tau).softmax(dim=1)
y_t = (logits_T / self.tau).softmax(dim=1)
inter_loss = self.tau**2 * inter_class_relation(y_s, y_t)
intra_loss = self.tau**2 * intra_class_relation(y_s, y_t)
kd_loss = self.inter_loss_weight * inter_loss + self.intra_loss_weight * intra_loss # noqa
return kd_loss * self.loss_weight