parent
d4e286c3f3
commit
14d6b7efa4
|
@ -14,6 +14,7 @@
|
||||||
- [1.2.3 UDML](#1.2.3)
|
- [1.2.3 UDML](#1.2.3)
|
||||||
- [1.2.4 AFD](#1.2.4)
|
- [1.2.4 AFD](#1.2.4)
|
||||||
- [1.2.5 DKD](#1.2.5)
|
- [1.2.5 DKD](#1.2.5)
|
||||||
|
- [1.2.6 DIST](#1.2.6)
|
||||||
- [2. 使用方法](#2)
|
- [2. 使用方法](#2)
|
||||||
- [2.1 环境配置](#2.1)
|
- [2.1 环境配置](#2.1)
|
||||||
- [2.2 数据准备](#2.2)
|
- [2.2 数据准备](#2.2)
|
||||||
|
@ -444,6 +445,74 @@ Loss:
|
||||||
- CELoss:
|
- CELoss:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<a name='1.2.6'></a>
|
||||||
|
|
||||||
|
#### 1.2.6 DIST
|
||||||
|
|
||||||
|
##### 1.2.6.1 DIST 算法介绍
|
||||||
|
|
||||||
|
论文信息:
|
||||||
|
|
||||||
|
|
||||||
|
> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
|
||||||
|
>
|
||||||
|
> Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu
|
||||||
|
>
|
||||||
|
> 2022, under review
|
||||||
|
|
||||||
|
使用KD方法进行模型蒸馏时,教师模型精度提升时,蒸馏的效果往往难以同步提升。本文提出DIST方法,使用皮尔逊相关系数(Pearson correlation coefficient)去表征学生模型与教师模型之间的差异,替代蒸馏过程中默认的KL散度,从而保证模型可以学到更加准确的相关性信息。
|
||||||
|
|
||||||
|
在ImageNet1k公开数据集上,效果如下所示。
|
||||||
|
|
||||||
|
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
|
||||||
|
| DIST | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml) | 71.99%(**+1.19%**) | - |
|
||||||
|
|
||||||
|
|
||||||
|
##### 1.2.6.2 DIST 配置
|
||||||
|
|
||||||
|
DIST 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationDISTLoss`(学生与教师之间的DIST loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
|
||||||
|
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: ResNet34
|
||||||
|
pretrained: True
|
||||||
|
|
||||||
|
- Student:
|
||||||
|
name: ResNet18
|
||||||
|
pretrained: False
|
||||||
|
|
||||||
|
infer_model_name: "Student"
|
||||||
|
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
- DistillationDISTLoss:
|
||||||
|
weight: 2.0
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
<a name="2"></a>
|
<a name="2"></a>
|
||||||
|
|
||||||
## 2. 模型训练、评估和预测
|
## 2. 模型训练、评估和预测
|
||||||
|
@ -601,3 +670,5 @@ python3 tools/export_model.py \
|
||||||
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
|
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
|
||||||
|
|
||||||
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
|
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
|
||||||
|
|
||||||
|
[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022.
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
# global configs
|
||||||
|
Global:
|
||||||
|
checkpoints: null
|
||||||
|
pretrained_model: null
|
||||||
|
output_dir: ./output/r34_r18_dist
|
||||||
|
device: gpu
|
||||||
|
save_interval: 1
|
||||||
|
eval_during_train: True
|
||||||
|
eval_interval: 1
|
||||||
|
epochs: 100
|
||||||
|
print_batch_step: 10
|
||||||
|
use_visualdl: False
|
||||||
|
# used for static mode and model export
|
||||||
|
image_shape: [3, 224, 224]
|
||||||
|
save_inference_dir: ./inference
|
||||||
|
to_static: False
|
||||||
|
|
||||||
|
# model architecture
|
||||||
|
Arch:
|
||||||
|
name: "DistillationModel"
|
||||||
|
class_num: &class_num 1000
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
pretrained_list:
|
||||||
|
# if not null, its lengths should be same as models
|
||||||
|
freeze_params_list:
|
||||||
|
- True
|
||||||
|
- False
|
||||||
|
infer_model_name: "Student"
|
||||||
|
models:
|
||||||
|
- Teacher:
|
||||||
|
name: ResNet34
|
||||||
|
class_num: *class_num
|
||||||
|
pretrained: True
|
||||||
|
- Student:
|
||||||
|
name: ResNet18
|
||||||
|
class_num: *class_num
|
||||||
|
pretrained: False
|
||||||
|
|
||||||
|
# loss function config for traing/eval process
|
||||||
|
Loss:
|
||||||
|
Train:
|
||||||
|
- DistillationGTCELoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_names: ["Student"]
|
||||||
|
- DistillationDISTLoss:
|
||||||
|
weight: 2.0
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
Eval:
|
||||||
|
- CELoss:
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Momentum
|
||||||
|
momentum: 0.9
|
||||||
|
weight_decay: 1e-4
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
learning_rate: 0.1
|
||||||
|
decay_epochs: [30, 60, 90]
|
||||||
|
values: [0.1, 0.01, 0.001, 0.0001]
|
||||||
|
|
||||||
|
|
||||||
|
# data loader for train and eval
|
||||||
|
DataLoader:
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- RandCropImage:
|
||||||
|
size: 224
|
||||||
|
- RandFlipImage:
|
||||||
|
flip_code: 1
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 64
|
||||||
|
drop_last: False
|
||||||
|
shuffle: True
|
||||||
|
loader:
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: ImageNetDataset
|
||||||
|
image_root: ./dataset/ILSVRC2012/
|
||||||
|
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||||
|
transform_ops:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
sampler:
|
||||||
|
name: DistributedBatchSampler
|
||||||
|
batch_size: 256
|
||||||
|
drop_last: False
|
||||||
|
shuffle: False
|
||||||
|
loader:
|
||||||
|
num_workers: 8
|
||||||
|
use_shared_memory: True
|
||||||
|
|
||||||
|
Infer:
|
||||||
|
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||||
|
batch_size: 10
|
||||||
|
transforms:
|
||||||
|
- DecodeImage:
|
||||||
|
to_rgb: True
|
||||||
|
channel_first: False
|
||||||
|
- ResizeImage:
|
||||||
|
resize_short: 256
|
||||||
|
- CropImage:
|
||||||
|
size: 224
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1.0/255.0
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: ''
|
||||||
|
- ToCHWImage:
|
||||||
|
PostProcess:
|
||||||
|
name: Topk
|
||||||
|
topk: 5
|
||||||
|
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
Train:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
||||||
|
Eval:
|
||||||
|
- DistillationTopkAcc:
|
||||||
|
model_key: "Student"
|
||||||
|
topk: [1, 5]
|
||||||
|
|
|
@ -25,6 +25,8 @@ from .distillationloss import DistillationRKDLoss
|
||||||
from .distillationloss import DistillationKLDivLoss
|
from .distillationloss import DistillationKLDivLoss
|
||||||
from .distillationloss import DistillationDKDLoss
|
from .distillationloss import DistillationDKDLoss
|
||||||
from .distillationloss import DistillationMultiLabelLoss
|
from .distillationloss import DistillationMultiLabelLoss
|
||||||
|
from .distillationloss import DistillationDISTLoss
|
||||||
|
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
from .afdloss import AFDLoss
|
from .afdloss import AFDLoss
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b, eps=1e-8):
|
||||||
|
return (a * b).sum(1) / (a.norm(axis=1) * b.norm(axis=1) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
def pearson_correlation(a, b, eps=1e-8):
|
||||||
|
return cosine_similarity(a - a.mean(1).unsqueeze(1),
|
||||||
|
b - b.mean(1).unsqueeze(1), eps)
|
||||||
|
|
||||||
|
|
||||||
|
def inter_class_relation(y_s, y_t):
|
||||||
|
return 1 - pearson_correlation(y_s, y_t).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def intra_class_relation(y_s, y_t):
|
||||||
|
return inter_class_relation(y_s.transpose([1, 0]), y_t.transpose([1, 0]))
|
||||||
|
|
||||||
|
|
||||||
|
class DISTLoss(nn.Layer):
|
||||||
|
# DISTLoss
|
||||||
|
# paper [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
|
||||||
|
# code reference: https://github.com/hunto/image_classification_sota/blob/d4f15a0494/lib/models/losses/dist_kd.py
|
||||||
|
def __init__(self, beta=1.0, gamma=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.beta = beta
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def forward(self, z_s, z_t):
|
||||||
|
y_s = F.softmax(z_s, axis=-1)
|
||||||
|
y_t = F.softmax(z_t, axis=-1)
|
||||||
|
inter_loss = inter_class_relation(y_s, y_t)
|
||||||
|
intra_loss = intra_class_relation(y_s, y_t)
|
||||||
|
kd_loss = self.beta * inter_loss + self.gamma * intra_loss
|
||||||
|
return kd_loss
|
|
@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
|
||||||
from .rkdloss import RKdAngle, RkdDistance
|
from .rkdloss import RKdAngle, RkdDistance
|
||||||
from .kldivloss import KLDivLoss
|
from .kldivloss import KLDivLoss
|
||||||
from .dkdloss import DKDLoss
|
from .dkdloss import DKDLoss
|
||||||
|
from .dist_loss import DISTLoss
|
||||||
from .multilabelloss import MultiLabelLoss
|
from .multilabelloss import MultiLabelLoss
|
||||||
|
|
||||||
|
|
||||||
|
@ -289,3 +290,32 @@ class DistillationMultiLabelLoss(MultiLabelLoss):
|
||||||
for key in loss:
|
for key in loss:
|
||||||
loss_dict["{}_{}".format(key, name)] = loss[key]
|
loss_dict["{}_{}".format(key, name)] = loss[key]
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDISTLoss(DISTLoss):
|
||||||
|
"""
|
||||||
|
DistillationDISTLoss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_pairs=[],
|
||||||
|
key=None,
|
||||||
|
beta=1.0,
|
||||||
|
gamma=1.0,
|
||||||
|
name="loss_dist"):
|
||||||
|
super().__init__(beta=beta, gamma=gamma)
|
||||||
|
self.key = key
|
||||||
|
self.model_name_pairs = model_name_pairs
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx, pair in enumerate(self.model_name_pairs):
|
||||||
|
out1 = predicts[pair[0]]
|
||||||
|
out2 = predicts[pair[1]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
loss = super().forward(out1, out2)
|
||||||
|
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
|
||||||
|
return loss_dict
|
||||||
|
|
Loading…
Reference in New Issue