add dist algo (#2133)

* add dist_kd

* add doc

* fix some typos
pull/1765/merge
littletomatodonkey 2022-07-06 16:30:01 +08:00 committed by GitHub
parent d4e286c3f3
commit 14d6b7efa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 307 additions and 0 deletions

View File

@ -14,6 +14,7 @@
- [1.2.3 UDML](#1.2.3) - [1.2.3 UDML](#1.2.3)
- [1.2.4 AFD](#1.2.4) - [1.2.4 AFD](#1.2.4)
- [1.2.5 DKD](#1.2.5) - [1.2.5 DKD](#1.2.5)
- [1.2.6 DIST](#1.2.6)
- [2. 使用方法](#2) - [2. 使用方法](#2)
- [2.1 环境配置](#2.1) - [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2) - [2.2 数据准备](#2.2)
@ -444,6 +445,74 @@ Loss:
- CELoss: - CELoss:
weight: 1.0 weight: 1.0
``` ```
<a name='1.2.6'></a>
#### 1.2.6 DIST
##### 1.2.6.1 DIST 算法介绍
论文信息:
> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
>
> Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu
>
> 2022, under review
使用KD方法进行模型蒸馏时教师模型精度提升时蒸馏的效果往往难以同步提升。本文提出DIST方法使用皮尔逊相关系数Pearson correlation coefficient去表征学生模型与教师模型之间的差异替代蒸馏过程中默认的KL散度从而保证模型可以学到更加准确的相关性信息。
在ImageNet1k公开数据集上效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| DIST | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml) | 71.99%(**+1.19%**) | - |
##### 1.2.6.2 DIST 配置
DIST 配置如下所示。在模型构建Arch字段中需要同时定义学生模型与教师模型教师模型固定参数且需要加载预训练模型。在损失函数Loss字段中需要定义`DistillationDISTLoss`学生与教师之间的DIST loss以及`DistillationGTCELoss`学生与教师关于真值标签的CE loss作为训练的损失函数。
```yaml
Arch:
name: "DistillationModel"
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
models:
- Teacher:
name: ResNet34
pretrained: True
- Student:
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDISTLoss:
weight: 2.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a> <a name="2"></a>
## 2. 模型训练、评估和预测 ## 2. 模型训练、评估和预测
@ -601,3 +670,5 @@ python3 tools/export_model.py \
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022. [11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952. [12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022.

View File

@ -0,0 +1,152 @@
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/r34_r18_dist
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
to_static: False
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet34
class_num: *class_num
pretrained: True
- Student:
name: ResNet18
class_num: *class_num
pretrained: False
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDISTLoss:
weight: 2.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
weight_decay: 1e-4
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]

View File

@ -25,6 +25,8 @@ from .distillationloss import DistillationRKDLoss
from .distillationloss import DistillationKLDivLoss from .distillationloss import DistillationKLDivLoss
from .distillationloss import DistillationDKDLoss from .distillationloss import DistillationDKDLoss
from .distillationloss import DistillationMultiLabelLoss from .distillationloss import DistillationMultiLabelLoss
from .distillationloss import DistillationDISTLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
from .afdloss import AFDLoss from .afdloss import AFDLoss

View File

@ -0,0 +1,52 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def cosine_similarity(a, b, eps=1e-8):
return (a * b).sum(1) / (a.norm(axis=1) * b.norm(axis=1) + eps)
def pearson_correlation(a, b, eps=1e-8):
return cosine_similarity(a - a.mean(1).unsqueeze(1),
b - b.mean(1).unsqueeze(1), eps)
def inter_class_relation(y_s, y_t):
return 1 - pearson_correlation(y_s, y_t).mean()
def intra_class_relation(y_s, y_t):
return inter_class_relation(y_s.transpose([1, 0]), y_t.transpose([1, 0]))
class DISTLoss(nn.Layer):
# DISTLoss
# paper [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
# code reference: https://github.com/hunto/image_classification_sota/blob/d4f15a0494/lib/models/losses/dist_kd.py
def __init__(self, beta=1.0, gamma=1.0):
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(self, z_s, z_t):
y_s = F.softmax(z_s, axis=-1)
y_t = F.softmax(z_t, axis=-1)
inter_loss = inter_class_relation(y_s, y_t)
intra_loss = intra_class_relation(y_s, y_t)
kd_loss = self.beta * inter_loss + self.gamma * intra_loss
return kd_loss

View File

@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
from .rkdloss import RKdAngle, RkdDistance from .rkdloss import RKdAngle, RkdDistance
from .kldivloss import KLDivLoss from .kldivloss import KLDivLoss
from .dkdloss import DKDLoss from .dkdloss import DKDLoss
from .dist_loss import DISTLoss
from .multilabelloss import MultiLabelLoss from .multilabelloss import MultiLabelLoss
@ -289,3 +290,32 @@ class DistillationMultiLabelLoss(MultiLabelLoss):
for key in loss: for key in loss:
loss_dict["{}_{}".format(key, name)] = loss[key] loss_dict["{}_{}".format(key, name)] = loss[key]
return loss_dict return loss_dict
class DistillationDISTLoss(DISTLoss):
"""
DistillationDISTLoss
"""
def __init__(self,
model_name_pairs=[],
key=None,
beta=1.0,
gamma=1.0,
name="loss_dist"):
super().__init__(beta=beta, gamma=gamma)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
return loss_dict