parent
d4e286c3f3
commit
14d6b7efa4
|
@ -14,6 +14,7 @@
|
|||
- [1.2.3 UDML](#1.2.3)
|
||||
- [1.2.4 AFD](#1.2.4)
|
||||
- [1.2.5 DKD](#1.2.5)
|
||||
- [1.2.6 DIST](#1.2.6)
|
||||
- [2. 使用方法](#2)
|
||||
- [2.1 环境配置](#2.1)
|
||||
- [2.2 数据准备](#2.2)
|
||||
|
@ -444,6 +445,74 @@ Loss:
|
|||
- CELoss:
|
||||
weight: 1.0
|
||||
```
|
||||
|
||||
<a name='1.2.6'></a>
|
||||
|
||||
#### 1.2.6 DIST
|
||||
|
||||
##### 1.2.6.1 DIST 算法介绍
|
||||
|
||||
论文信息:
|
||||
|
||||
|
||||
> [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
|
||||
>
|
||||
> Tao Huang, Shan You, Fei Wang, Chen Qian, Chang Xu
|
||||
>
|
||||
> 2022, under review
|
||||
|
||||
使用KD方法进行模型蒸馏时,教师模型精度提升时,蒸馏的效果往往难以同步提升。本文提出DIST方法,使用皮尔逊相关系数(Pearson correlation coefficient)去表征学生模型与教师模型之间的差异,替代蒸馏过程中默认的KL散度,从而保证模型可以学到更加准确的相关性信息。
|
||||
|
||||
在ImageNet1k公开数据集上,效果如下所示。
|
||||
|
||||
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
|
||||
| DIST | ResNet18 | [resnet34_distill_resnet18_dist.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dist.yaml) | 71.99%(**+1.19%**) | - |
|
||||
|
||||
|
||||
##### 1.2.6.2 DIST 配置
|
||||
|
||||
DIST 配置如下所示。在模型构建Arch字段中,需要同时定义学生模型与教师模型,教师模型固定参数,且需要加载预训练模型。在损失函数Loss字段中,需要定义`DistillationDISTLoss`(学生与教师之间的DIST loss)以及`DistillationGTCELoss`(学生与教师关于真值标签的CE loss),作为训练的损失函数。
|
||||
|
||||
|
||||
```yaml
|
||||
Arch:
|
||||
name: "DistillationModel"
|
||||
# if not null, its lengths should be same as models
|
||||
pretrained_list:
|
||||
# if not null, its lengths should be same as models
|
||||
freeze_params_list:
|
||||
- True
|
||||
- False
|
||||
models:
|
||||
- Teacher:
|
||||
name: ResNet34
|
||||
pretrained: True
|
||||
|
||||
- Student:
|
||||
name: ResNet18
|
||||
pretrained: False
|
||||
|
||||
infer_model_name: "Student"
|
||||
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
- DistillationDISTLoss:
|
||||
weight: 2.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
```
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
|
||||
## 2. 模型训练、评估和预测
|
||||
|
@ -601,3 +670,5 @@ python3 tools/export_model.py \
|
|||
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
|
||||
|
||||
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.
|
||||
|
||||
[13] Huang T, You S, Wang F, et al. Knowledge Distillation from A Stronger Teacher[J]. arXiv preprint arXiv:2205.10536, 2022.
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/r34_r18_dist
|
||||
device: gpu
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 100
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
to_static: False
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: "DistillationModel"
|
||||
class_num: &class_num 1000
|
||||
# if not null, its lengths should be same as models
|
||||
pretrained_list:
|
||||
# if not null, its lengths should be same as models
|
||||
freeze_params_list:
|
||||
- True
|
||||
- False
|
||||
infer_model_name: "Student"
|
||||
models:
|
||||
- Teacher:
|
||||
name: ResNet34
|
||||
class_num: *class_num
|
||||
pretrained: True
|
||||
- Student:
|
||||
name: ResNet18
|
||||
class_num: *class_num
|
||||
pretrained: False
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- DistillationGTCELoss:
|
||||
weight: 1.0
|
||||
model_names: ["Student"]
|
||||
- DistillationDISTLoss:
|
||||
weight: 2.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
Optimizer:
|
||||
name: Momentum
|
||||
momentum: 0.9
|
||||
weight_decay: 1e-4
|
||||
lr:
|
||||
name: Piecewise
|
||||
learning_rate: 0.1
|
||||
decay_epochs: [30, 60, 90]
|
||||
values: [0.1, 0.01, 0.001, 0.0001]
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/ILSVRC2012/
|
||||
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 256
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 8
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- DistillationTopkAcc:
|
||||
model_key: "Student"
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- DistillationTopkAcc:
|
||||
model_key: "Student"
|
||||
topk: [1, 5]
|
||||
|
|
@ -25,6 +25,8 @@ from .distillationloss import DistillationRKDLoss
|
|||
from .distillationloss import DistillationKLDivLoss
|
||||
from .distillationloss import DistillationDKDLoss
|
||||
from .distillationloss import DistillationMultiLabelLoss
|
||||
from .distillationloss import DistillationDISTLoss
|
||||
|
||||
from .multilabelloss import MultiLabelLoss
|
||||
from .afdloss import AFDLoss
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def cosine_similarity(a, b, eps=1e-8):
|
||||
return (a * b).sum(1) / (a.norm(axis=1) * b.norm(axis=1) + eps)
|
||||
|
||||
|
||||
def pearson_correlation(a, b, eps=1e-8):
|
||||
return cosine_similarity(a - a.mean(1).unsqueeze(1),
|
||||
b - b.mean(1).unsqueeze(1), eps)
|
||||
|
||||
|
||||
def inter_class_relation(y_s, y_t):
|
||||
return 1 - pearson_correlation(y_s, y_t).mean()
|
||||
|
||||
|
||||
def intra_class_relation(y_s, y_t):
|
||||
return inter_class_relation(y_s.transpose([1, 0]), y_t.transpose([1, 0]))
|
||||
|
||||
|
||||
class DISTLoss(nn.Layer):
|
||||
# DISTLoss
|
||||
# paper [Knowledge Distillation from A Stronger Teacher](https://arxiv.org/pdf/2205.10536v1.pdf)
|
||||
# code reference: https://github.com/hunto/image_classification_sota/blob/d4f15a0494/lib/models/losses/dist_kd.py
|
||||
def __init__(self, beta=1.0, gamma=1.0):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, z_s, z_t):
|
||||
y_s = F.softmax(z_s, axis=-1)
|
||||
y_t = F.softmax(z_t, axis=-1)
|
||||
inter_loss = inter_class_relation(y_s, y_t)
|
||||
intra_loss = intra_class_relation(y_s, y_t)
|
||||
kd_loss = self.beta * inter_loss + self.gamma * intra_loss
|
||||
return kd_loss
|
|
@ -22,6 +22,7 @@ from .distanceloss import DistanceLoss
|
|||
from .rkdloss import RKdAngle, RkdDistance
|
||||
from .kldivloss import KLDivLoss
|
||||
from .dkdloss import DKDLoss
|
||||
from .dist_loss import DISTLoss
|
||||
from .multilabelloss import MultiLabelLoss
|
||||
|
||||
|
||||
|
@ -289,3 +290,32 @@ class DistillationMultiLabelLoss(MultiLabelLoss):
|
|||
for key in loss:
|
||||
loss_dict["{}_{}".format(key, name)] = loss[key]
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDISTLoss(DISTLoss):
|
||||
"""
|
||||
DistillationDISTLoss
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
beta=1.0,
|
||||
gamma=1.0,
|
||||
name="loss_dist"):
|
||||
super().__init__(beta=beta, gamma=gamma)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
loss_dict[f"{self.name}_{pair[0]}_{pair[1]}"] = loss
|
||||
return loss_dict
|
||||
|
|
Loading…
Reference in New Issue