From 01f671c72d3064160f9f298166799cfb991e60c8 Mon Sep 17 00:00:00 2001 From: LKJacky <108643365+LKJacky@users.noreply.github.com> Date: Thu, 9 Mar 2023 16:33:37 +0800 Subject: [PATCH] add dist distillation loss (#466) * add dist * update * update * update readme * update config --------- Co-authored-by: liukai --- .gitignore | 1 - configs/distill/mmcls/dist/README.md | 45 ++++++++++++++++ ...ist_logits_resnet34_resnet18_8xb32_in1k.py | 45 ++++++++++++++++ mmrazor/models/losses/__init__.py | 4 +- mmrazor/models/losses/dist_loss.py | 52 +++++++++++++++++++ 5 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 configs/distill/mmcls/dist/README.md create mode 100644 configs/distill/mmcls/dist/dist_logits_resnet34_resnet18_8xb32_in1k.py create mode 100644 mmrazor/models/losses/dist_loss.py diff --git a/.gitignore b/.gitignore index 92e7c792..99c09247 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,6 @@ __pycache__/ .Python build/ develop-eggs/ -dist/ downloads/ eggs/ .eggs/ diff --git a/configs/distill/mmcls/dist/README.md b/configs/distill/mmcls/dist/README.md new file mode 100644 index 00000000..c24e42b8 --- /dev/null +++ b/configs/distill/mmcls/dist/README.md @@ -0,0 +1,45 @@ +# KD + +> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/abs/2205.10536) + + + +## Abstract + +Unlike existing knowledge distillation methods focus on the baseline settings, where the teacher models and training strategies are not that strong and competing as state-of-the-art approaches, this paper presents a method dubbed DIST to distill better from a stronger teacher. We empirically find that the discrepancy of predictions between the student and a stronger teacher may tend to be fairly severer. As a result, the exact match of predictions in KL divergence would disturb the training and make existing methods perform poorly. In this paper, we show that simply preserving the relations between the predictions of teacher and student would suffice, and propose a correlation-based loss to capture the intrinsic inter-class relations from the teacher explicitly. Besides, considering that different instances have different semantic similarities to each class, we also extend this relational match to the intra-class level. Our method is simple yet practical, and extensive experiments demonstrate that it adapts well to various architectures, model sizes and training strategies, and can achieve state-of-the-art performance consistently on image classification, object detection, and semantic segmentation tasks. Code is available at: this https URL . + +## Results and models + +### Classification + +| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download | +| :------: | :------: | :---------------: | :---------------: | :---: | :----: | :----: | :-----------------: | :--------------------------------------------------------------- | +| logits | ImageNet | [resnet34][r34_c] | [resnet18][r18_c] | 71.61 | 73.62 | 69.90 | [config][distill_c] | [teacher][r34_pth] \| [model][distill_pth] \| [log][distill_log] | + +**Note** + +There are fluctuations in the results of the experiments of DIST loss. For example, we run three times of the official code of DIST and get three different results. + +| Time | Top-1 | +| ---- | ----- | +| 1th | 71.69 | +| 2nd | 71.82 | +| 3rd | 71.90 | + +## Citation + +```latex +@article{huang2022knowledge, + title={Knowledge Distillation from A Stronger Teacher}, + author={Huang, Tao and You, Shan and Wang, Fei and Qian, Chen and Xu, Chang}, + journal={arXiv preprint arXiv:2205.10536}, + year={2022} +} +``` + +[distill_c]: ./dist_logits_resnet34_resnet18_8xb32_in1k.py +[distill_log]: https://download.openmmlab.com/mmrazor/v1/distillation/dist_logits_resnet34_resnet18_8xb32_in1k.json +[distill_pth]: https://download.openmmlab.com/mmrazor/v1/distillation/dist_logits_resnet34_resnet18_8xb32_in1k.pth +[r18_c]: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet18_8xb32_in1k.py +[r34_c]: https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet34_8xb32_in1k.py +[r34_pth]: https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth diff --git a/configs/distill/mmcls/dist/dist_logits_resnet34_resnet18_8xb32_in1k.py b/configs/distill/mmcls/dist/dist_logits_resnet34_resnet18_8xb32_in1k.py new file mode 100644 index 00000000..5d2c7f02 --- /dev/null +++ b/configs/distill/mmcls/dist/dist_logits_resnet34_resnet18_8xb32_in1k.py @@ -0,0 +1,45 @@ +_base_ = [ + 'mmcls::_base_/datasets/imagenet_bs32.py', + 'mmcls::_base_/schedules/imagenet_bs256.py', + 'mmcls::_base_/default_runtime.py' +] + +teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501 + +model = dict( + _scope_='mmrazor', + type='SingleTeacherDistill', + data_preprocessor=dict( + type='ImgDataPreprocessor', + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + bgr_to_rgb=True), + architecture=dict( + cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False), + teacher=dict( + cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=False), + teacher_ckpt=teacher_ckpt, + distiller=dict( + type='ConfigurableDistiller', + student_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc')), + teacher_recorders=dict( + fc=dict(type='ModuleOutputs', source='head.fc')), + distill_losses=dict( + loss_kl=dict( + type='DISTLoss', + inter_loss_weight=1.0, + intra_loss_weight=1.0, + tau=1, + loss_weight=2, + )), + loss_forward_mappings=dict( + loss_kl=dict( + logits_S=dict(from_student=True, recorder='fc'), + logits_T=dict(from_student=False, recorder='fc'))))) + +val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') + +optim_wrapper = dict(optimizer=dict(nesterov=True)) diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 8c779d15..65e2108f 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -6,6 +6,7 @@ from .cross_entropy_loss import CrossEntropyLoss from .cwd import ChannelWiseDivergence from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss from .decoupled_kd import DKDLoss +from .dist_loss import DISTLoss from .factor_transfer_loss import FTLoss from .fbkd_loss import FBKDLoss from .kd_soft_ce_loss import KDSoftCELoss @@ -22,5 +23,6 @@ __all__ = [ 'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD', 'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss', 'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss', 'OFDLoss', - 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss' + 'L1Loss', 'FBKDLoss', 'CRDLoss', 'CrossEntropyLoss', 'PKDLoss', 'MGDLoss', + 'DISTLoss' ] diff --git a/mmrazor/models/losses/dist_loss.py b/mmrazor/models/losses/dist_loss.py new file mode 100644 index 00000000..4d7ac63a --- /dev/null +++ b/mmrazor/models/losses/dist_loss.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS + + +def cosine_similarity(a, b, eps=1e-8): + return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps) + + +def pearson_correlation(a, b, eps=1e-8): + return cosine_similarity(a - a.mean(1, keepdim=True), + b - b.mean(1, keepdim=True), eps) + + +def inter_class_relation(y_s, y_t): + return 1 - pearson_correlation(y_s, y_t).mean() + + +def intra_class_relation(y_s, y_t): + return inter_class_relation(y_s.transpose(0, 1), y_t.transpose(0, 1)) + + +@MODELS.register_module() +class DISTLoss(nn.Module): + + def __init__( + self, + inter_loss_weight=1.0, + intra_loss_weight=1.0, + tau=1.0, + loss_weight: float = 1.0, + teacher_detach: bool = True, + ): + super(DISTLoss, self).__init__() + self.inter_loss_weight = inter_loss_weight + self.intra_loss_weight = intra_loss_weight + self.tau = tau + + self.loss_weight = loss_weight + self.teacher_detach = teacher_detach + + def forward(self, logits_S, logits_T: torch.Tensor): + if self.teacher_detach: + logits_T = logits_T.detach() + y_s = (logits_S / self.tau).softmax(dim=1) + y_t = (logits_T / self.tau).softmax(dim=1) + inter_loss = self.tau**2 * inter_class_relation(y_s, y_t) + intra_loss = self.tau**2 * intra_class_relation(y_s, y_t) + kd_loss = self.inter_loss_weight * inter_loss + self.intra_loss_weight * intra_loss # noqa + return kd_loss * self.loss_weight