diff --git a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md
index 6224e82a7..18bb25f2e 100644
--- a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md
+++ b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md
@@ -14,6 +14,7 @@
- [1.2.3 UDML](#1.2.3)
- [1.2.4 AFD](#1.2.4)
- [1.2.5 DKD](#1.2.5)
+ - [1.2.6 DIST](#1.2.6)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
@@ -444,6 +445,74 @@ Loss:
- CELoss:
weight: 1.0
```
+
+
+
+#### 1.2.6 DIST
+
+##### 1.2.6.1 DIST 算法介绍
+
+论文信息:
+
+
+> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
+>
+> Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu
+>
+> 2022, under review
+
+使用KD方法进行模型蒸馏时,教师模型精度提升时,蒸馏的效果往往难以同步提升。本文提出DIST方法,使用皮尔逊相关系数(Pearson correlation coefficient)去表征学生模型与教师模型之间的差异,替代蒸馏过程中默认的KL散度,从而保证模型可以学到更加准确的相关性信息。
+
+在ImageNet1k公开数据集上,效果如下所示。
+
+| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
+| --- | --- | --- | --- | --- |
+| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
+| DIST | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml) | 71.99%(**+1.19%**) | - |
+
+
+##### 1.2.6.2 DIST 配置
+
+DIST 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationDISTLoss`(学生与教师之间的DIST loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
+
+
+```yaml
+Arch:
+ name: "DistillationModel"
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ models:
+ - Teacher:
+ name: ResNet34
+ pretrained: True
+
+ - Student:
+ name: ResNet18
+ pretrained: False
+
+ infer_model_name: "Student"
+
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationGTCELoss:
+ weight: 1.0
+ model_names: ["Student"]
+ - DistillationDISTLoss:
+ weight: 2.0
+ model_name_pairs:
+ - ["Student", "Teacher"]
+ Eval:
+ - CELoss:
+ weight: 1.0
+```
+
+
## 2. 模型训练、评估和预测
@@ -601,3 +670,5 @@ python3 tools/export_model.py \
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
+
+[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022.
diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml
new file mode 100644
index 000000000..9bae5a3c1
--- /dev/null
+++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml
@@ -0,0 +1,152 @@
+# global configs
+Global:
+ checkpoints: null
+ pretrained_model: null
+ output_dir: ./output/r34_r18_dist
+ device: gpu
+ save_interval: 1
+ eval_during_train: True
+ eval_interval: 1
+ epochs: 100
+ print_batch_step: 10
+ use_visualdl: False
+ # used for static mode and model export
+ image_shape: [3, 224, 224]
+ save_inference_dir: ./inference
+ to_static: False
+
+# model architecture
+Arch:
+ name: "DistillationModel"
+ class_num: &class_num 1000
+ # if not null, its lengths should be same as models
+ pretrained_list:
+ # if not null, its lengths should be same as models
+ freeze_params_list:
+ - True
+ - False
+ infer_model_name: "Student"
+ models:
+ - Teacher:
+ name: ResNet34
+ class_num: *class_num
+ pretrained: True
+ - Student:
+ name: ResNet18
+ class_num: *class_num
+ pretrained: False
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - DistillationGTCELoss:
+ weight: 1.0
+ model_names: ["Student"]
+ - DistillationDISTLoss:
+ weight: 2.0
+ model_name_pairs:
+ - ["Student", "Teacher"]
+ Eval:
+ - CELoss:
+ weight: 1.0
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ weight_decay: 1e-4
+ lr:
+ name: Piecewise
+ learning_rate: 0.1
+ decay_epochs: [30, 60, 90]
+ values: [0.1, 0.01, 0.001, 0.0001]
+
+
+# data loader for train and eval
+DataLoader:
+ Train:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/ILSVRC2012/
+ cls_label_path: ./dataset/ILSVRC2012/train_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: True
+ loader:
+ num_workers: 8
+ use_shared_memory: True
+
+ Eval:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/ILSVRC2012/
+ cls_label_path: ./dataset/ILSVRC2012/val_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 256
+ drop_last: False
+ shuffle: False
+ loader:
+ num_workers: 8
+ use_shared_memory: True
+
+Infer:
+ infer_imgs: docs/images/inference_deployment/whl_demo.jpg
+ batch_size: 10
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ PostProcess:
+ name: Topk
+ topk: 5
+ class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
+
+Metric:
+ Train:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 5]
+ Eval:
+ - DistillationTopkAcc:
+ model_key: "Student"
+ topk: [1, 5]
+
diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py
index 741eb3b61..5a62e0156 100644
--- a/ppcls/loss/__init__.py
+++ b/ppcls/loss/__init__.py
@@ -25,6 +25,8 @@ from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationMultiLabelLoss
+from .distillationloss import DistillationDISTLoss
+
from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss
diff --git a/ppcls/loss/dist_loss.py b/ppcls/loss/dist_loss.py
new file mode 100644
index 000000000..78c8e12ff
--- /dev/null
+++ b/ppcls/loss/dist_loss.py
@@ -0,0 +1,52 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+def cosine_similarity(a, b, eps=1e-8):
+ return (a * b).sum(1) / (a.norm(axis=1) * b.norm(axis=1) + eps)
+
+
+def pearson_correlation(a, b, eps=1e-8):
+ return cosine_similarity(a - a.mean(1).unsqueeze(1),
+ b - b.mean(1).unsqueeze(1), eps)
+
+
+def inter_class_relation(y_s, y_t):
+ return 1 - pearson_correlation(y_s, y_t).mean()
+
+
+def intra_class_relation(y_s, y_t):
+ return inter_class_relation(y_s.transpose([1, 0]), y_t.transpose([1, 0]))
+
+
+class DISTLoss(nn.Layer):
+ # DISTLoss
+ # paper [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
+ # code reference: https://github.com/hunto/image_classification_sota/blob/d4f15a0494/lib/models/losses/dist_kd.py
+ def __init__(self, beta=1.0, gamma=1.0):
+ super().__init__()
+ self.beta = beta
+ self.gamma = gamma
+
+ def forward(self, z_s, z_t):
+ y_s = F.softmax(z_s, axis=-1)
+ y_t = F.softmax(z_t, axis=-1)
+ inter_loss = inter_class_relation(y_s, y_t)
+ intra_loss = intra_class_relation(y_s, y_t)
+ kd_loss = self.beta * inter_loss + self.gamma * intra_loss
+ return kd_loss
diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py
index 4f72777f4..8537fc548 100644
--- a/ppcls/loss/distillationloss.py
+++ b/ppcls/loss/distillationloss.py
@@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss
from .dkdloss import DKDLoss
+from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss
@@ -289,3 +290,32 @@ class DistillationMultiLabelLoss(MultiLabelLoss):
for key in loss:
loss_dict["{}_{}".format(key, name)] = loss[key]
return loss_dict
+
+
+class DistillationDISTLoss(DISTLoss):
+ """
+ DistillationDISTLoss
+ """
+
+ def __init__(self,
+ model_name_pairs=[],
+ key=None,
+ beta=1.0,
+ gamma=1.0,
+ name="loss_dist"):
+ super().__init__(beta=beta, gamma=gamma)
+ self.key = key
+ self.model_name_pairs = model_name_pairs
+ self.name = name
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ loss = super().forward(out1, out2)
+ loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
+ return loss_dict