diff --git a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md index 6224e82a7..18bb25f2e 100644 --- a/docs/zh_CN/advanced_tutorials/knowledge_distillation.md +++ b/docs/zh_CN/advanced_tutorials/knowledge_distillation.md @@ -14,6 +14,7 @@ - [1.2.3 UDML](#1.2.3) - [1.2.4 AFD](#1.2.4) - [1.2.5 DKD](#1.2.5) + - [1.2.6 DIST](#1.2.6) - [2. 使用方法](#2) - [2.1 环境配置](#2.1) - [2.2 数据准备](#2.2) @@ -444,6 +445,74 @@ Loss: - CELoss: weight: 1.0 ``` + + + +#### 1.2.6 DIST + +##### 1.2.6.1 DIST 算法介绍 + +论文信息: + + +> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf) +> +> Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu +> +> 2022, under review + +使用KD方法进行模型蒸馏时,教师模型精度提升时,蒸馏的效果往往难以同步提升。本文提出DIST方法,使用皮尔逊相关系数(Pearson correlation coefficient)去表征学生模型与教师模型之间的差异,替代蒸馏过程中默认的KL散度,从而保证模型可以学到更加准确的相关性信息。 + +在ImageNet1k公开数据集上,效果如下所示。 + +| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 | +| --- | --- | --- | --- | --- | +| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - | +| DIST | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml) | 71.99%(**+1.19%**) | - | + + +##### 1.2.6.2 DIST 配置 + +DIST 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationDISTLoss`(学生与教师之间的DIST loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。 + + +```yaml +Arch: + name: "DistillationModel" + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + models: + - Teacher: + name: ResNet34 + pretrained: True + + - Student: + name: ResNet18 + pretrained: False + + infer_model_name: "Student" + + +# loss function config for traing/eval process +Loss: + Train: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + - DistillationDISTLoss: + weight: 2.0 + model_name_pairs: + - ["Student", "Teacher"] + Eval: + - CELoss: + weight: 1.0 +``` + + ## 2. 模型训练、评估和预测 @@ -601,3 +670,5 @@ python3 tools/export_model.py \ [11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022. [12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952. + +[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022. diff --git a/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml new file mode 100644 index 000000000..9bae5a3c1 --- /dev/null +++ b/ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml @@ -0,0 +1,152 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/r34_r18_dist + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + to_static: False + +# model architecture +Arch: + name: "DistillationModel" + class_num: &class_num 1000 + # if not null, its lengths should be same as models + pretrained_list: + # if not null, its lengths should be same as models + freeze_params_list: + - True + - False + infer_model_name: "Student" + models: + - Teacher: + name: ResNet34 + class_num: *class_num + pretrained: True + - Student: + name: ResNet18 + class_num: *class_num + pretrained: False + +# loss function config for traing/eval process +Loss: + Train: + - DistillationGTCELoss: + weight: 1.0 + model_names: ["Student"] + - DistillationDISTLoss: + weight: 2.0 + model_name_pairs: + - ["Student", "Teacher"] + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + weight_decay: 1e-4 + lr: + name: Piecewise + learning_rate: 0.1 + decay_epochs: [30, 60, 90] + values: [0.1, 0.01, 0.001, 0.0001] + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + Eval: + - DistillationTopkAcc: + model_key: "Student" + topk: [1, 5] + diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 741eb3b61..5a62e0156 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -25,6 +25,8 @@ from .distillationloss import DistillationRKDLoss from .distillationloss import DistillationKLDivLoss from .distillationloss import DistillationDKDLoss from .distillationloss import DistillationMultiLabelLoss +from .distillationloss import DistillationDISTLoss + from .multilabelloss import MultiLabelLoss from .afdloss import AFDLoss diff --git a/ppcls/loss/dist_loss.py b/ppcls/loss/dist_loss.py new file mode 100644 index 000000000..78c8e12ff --- /dev/null +++ b/ppcls/loss/dist_loss.py @@ -0,0 +1,52 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + + +def cosine_similarity(a, b, eps=1e-8): + return (a * b).sum(1) / (a.norm(axis=1) * b.norm(axis=1) + eps) + + +def pearson_correlation(a, b, eps=1e-8): + return cosine_similarity(a - a.mean(1).unsqueeze(1), + b - b.mean(1).unsqueeze(1), eps) + + +def inter_class_relation(y_s, y_t): + return 1 - pearson_correlation(y_s, y_t).mean() + + +def intra_class_relation(y_s, y_t): + return inter_class_relation(y_s.transpose([1, 0]), y_t.transpose([1, 0])) + + +class DISTLoss(nn.Layer): + # DISTLoss + # paper [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf) + # code reference: https://github.com/hunto/image_classification_sota/blob/d4f15a0494/lib/models/losses/dist_kd.py + def __init__(self, beta=1.0, gamma=1.0): + super().__init__() + self.beta = beta + self.gamma = gamma + + def forward(self, z_s, z_t): + y_s = F.softmax(z_s, axis=-1) + y_t = F.softmax(z_t, axis=-1) + inter_loss = inter_class_relation(y_s, y_t) + intra_loss = intra_class_relation(y_s, y_t) + kd_loss = self.beta * inter_loss + self.gamma * intra_loss + return kd_loss diff --git a/ppcls/loss/distillationloss.py b/ppcls/loss/distillationloss.py index 4f72777f4..8537fc548 100644 --- a/ppcls/loss/distillationloss.py +++ b/ppcls/loss/distillationloss.py @@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss from .rkdloss import RKdAngle, RkdDistance from .kldivloss import KLDivLoss from .dkdloss import DKDLoss +from .dist_loss import DISTLoss from .multilabelloss import MultiLabelLoss @@ -289,3 +290,32 @@ class DistillationMultiLabelLoss(MultiLabelLoss): for key in loss: loss_dict["{}_{}".format(key, name)] = loss[key] return loss_dict + + +class DistillationDISTLoss(DISTLoss): + """ + DistillationDISTLoss + """ + + def __init__(self, + model_name_pairs=[], + key=None, + beta=1.0, + gamma=1.0, + name="loss_dist"): + super().__init__(beta=beta, gamma=gamma) + self.key = key + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss + return loss_dict