add kd doc (#1997)

* add kd doc

* fix

* add ssld doc

* fix ssld

* fix ssld

* Update knowledge_distillation.md

* fix doc

* fix dist export

* fix

* add dist doc

* fix speed info

* Update ssld.md

* Update ssld.md
pull/2009/head
littletomatodonkey 2022-06-09 14:52:50 +08:00 committed by GitHub
parent fed4ea6920
commit 794af8c06f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 925 additions and 187 deletions

View File

@ -44,14 +44,14 @@
| 模型 | Top-1 Acc% | 延时ms | 存储M | 策略 |
|-------|-----------|----------|---------------|---------------|
| SwinTranformer_tiny | 98.11 | 87.19 | 111 | 使用ImageNet预训练模型 |
| MobileNetV3_large_x1_0 | 97.79 | 5.59 | 23 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 97.78 | 2.67 | 8.2 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 97.84 | 2.67 | 8.2 | 使用SSLD预训练模型 |
| PPLCNet_x1_0 | 98.14 | 2.67 | 8.2 | 使用SSLD预训练模型+EDA策略|
| <b>PPLCNet_x1_0<b> | <b>98.35<b> | <b>2.67<b> | <b>8.2<b> | 使用SSLD预训练模型+EDA策略+SKL-UGI知识蒸馏策略|
| SwinTranformer_tiny | 98.11 | 89.45 | 111 | 使用ImageNet预训练模型 |
| MobileNetV3_large_x1_0 | 97.79 | 4.81 | 23 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 97.78 | 2.10 | 8.2 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 97.84 | 2.10 | 8.2 | 使用SSLD预训练模型 |
| PPLCNet_x1_0 | 98.14 | 2.10 | 8.2 | 使用SSLD预训练模型+EDA策略|
| <b>PPLCNet_x1_0<b> | <b>98.35<b> | <b>2.10<b> | <b>8.2<b> | 使用SSLD预训练模型+EDA策略+SKL-UGI知识蒸馏策略|
从表中可以看出backbone 为 SwinTranformer_tiny 时精度较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_large_x1_0 后,速度可以大幅提升,但是精度下降明显。将 backbone 替换为 PPLCNet_x1_0 时精度低0.01%,但是速度提升 2 倍左右。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升约 0.06%进一步地当融合EDA策略后精度可以再提升 0.3%,最后,在使用 SKL-UGI 知识蒸馏后,精度可以继续提升 0.21%。此时PPLCNet_x1_0 的精度超越了SwinTranformer_tiny速度快32倍。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。
从表中可以看出backbone 为 SwinTranformer_tiny 时精度较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_large_x1_0 后,速度可以大幅提升,但是精度下降明显。将 backbone 替换为 PPLCNet_x1_0 时精度低0.01%,但是速度提升 1 倍左右。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升约 0.06%进一步地当融合EDA策略后精度可以再提升 0.3%,最后,在使用 SKL-UGI 知识蒸馏后,精度可以继续提升 0.21%。此时PPLCNet_x1_0 的精度超越了SwinTranformer_tiny速度快 41 倍。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。
**备注:**

View File

@ -44,15 +44,15 @@
| 模型 | ma% | 延时ms | 存储M | 策略 |
|-------|-----------|----------|---------------|---------------|
| Res2Net200_vd_26w_4s | 91.36 | 66.58 | 293 | 使用ImageNet预训练模型 |
| ResNet50 | 89.98 | 12.74 | 92 | 使用ImageNet预训练模型 |
| MobileNetV3_large_x1_0 | 89.77 | 5.59 | 23 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 89.57 | 2.56 | 8.2 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 90.07 | 2.56 | 8.2 | 使用SSLD预训练模型 |
| PPLCNet_x1_0 | 90.59 | 2.56 | 8.2 | 使用SSLD预训练模型+EDA策略|
| <b>PPLCNet_x1_0<b> | <b>90.81<b> | <b>2.56<b> | <b>8.2<b> | 使用SSLD预训练模型+EDA策略+SKL-UGI知识蒸馏策略|
| Res2Net200_vd_26w_4s | 91.36 | 79.46 | 293 | 使用ImageNet预训练模型 |
| ResNet50 | 89.98 | 12.83 | 92 | 使用ImageNet预训练模型 |
| MobileNetV3_large_x1_0 | 89.77 | 5.09 | 23 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 89.57 | 2.36 | 8.2 | 使用ImageNet预训练模型 |
| PPLCNet_x1_0 | 90.07 | 2.36 | 8.2 | 使用SSLD预训练模型 |
| PPLCNet_x1_0 | 90.59 | 2.36 | 8.2 | 使用SSLD预训练模型+EDA策略|
| <b>PPLCNet_x1_0<b> | <b>90.81<b> | <b>2.36<b> | <b>8.2<b> | 使用SSLD预训练模型+EDA策略+SKL-UGI知识蒸馏策略|
从表中可以看出backbone 为 Res2Net200_vd_26w_4s 时精度较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_large_x1_0 后,速度可以大幅提升,但是精度下降明显。将 backbone 替换为 PPLCNet_x1_0 时精度低0.2%,但是速度提升 2 倍左右。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升约 0.5%进一步地当融合EDA策略后精度可以再提升 0.52%,最后,在使用 SKL-UGI 知识蒸馏后,精度可以继续提升 0.23%。此时PPLCNet_x1_0 的精度与 Res2Net200_vd_26w_4s 仅相差0.55%但是速度快26倍。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。
从表中可以看出backbone 为 Res2Net200_vd_26w_4s 时精度较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_large_x1_0 后,速度可以大幅提升,但是精度下降明显。将 backbone 替换为 PPLCNet_x1_0 时精度低0.2%,但是速度提升 1 倍左右。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升约 0.5%进一步地当融合EDA策略后精度可以再提升 0.52%,最后,在使用 SKL-UGI 知识蒸馏后,精度可以继续提升 0.23%。此时PPLCNet_x1_0 的精度与 Res2Net200_vd_26w_4s 仅相差0.55%,但是速度快32倍。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。
**备注:**

View File

@ -1,209 +1,339 @@
# 知识蒸馏
# 知识蒸馏实战
## 目录
- [1. 模型压缩与知识蒸馏方法简介](#1)
- [2. SSLD 蒸馏策略](#2)
- [2.1 简介](#2.1)
- [2.2 数据选择](#2.2)
- [3. 实验](#3)
- [3.1 教师模型的选择](#3.1)
- [3.2 大数据蒸馏](#3.2)
- [3.3 ImageNet1k 训练集 finetune](#3.3)
- [3.4 数据增广以及基于 Fix 策略的微调](#3.4)
- [3.5 实验过程中的一些问题](#3.5)
- [4. 蒸馏模型的应用](#4)
- [4.1 使用方法](#4.1)
- [4.2 迁移学习 finetune](#4.2)
- [4.3 目标检测](#4.3)
- [5. SSLD 实战](#5)
- [5.1 参数配置](#5.1)
- [5.2 启动命令](#5.2)
- [5.3 注意事项](#5.3)
- [6. 参考文献](#6)
- [1. 算法介绍](#1)
- [1.1 知识蒸馏简介](#1.1)
- [1.1.1 Response based distillation](#1.1.1)
- [1.1.2 Feature based distillation](#1.1.2)
- [1.1.3 Relation based distillation](#1.1.3)
- [1.2 PaddleClas支持的知识蒸馏算法](#1.2)
- [1.2.1 SSLD](#1.2.1)
- [1.2.2 DML](#1.2.2)
- [1.2.3 AFD](#1.2.3)
- [1.2.4 DKD](#1.2.4)
- [2. 使用方法](#2)
- [2.1 环境配置](#2.1)
- [2.2 数据准备](#2.2)
- [2.3 模型训练](#2.3)
- [2.4 模型评估](#2.4)
- [2.5 模型预测](#2.5)
- [2.6 模型导出与推理](#2.6)
- [3. 参考文献](#3)
<a name="1"></a>
## 1. 模型压缩与知识蒸馏方法简介
## 1. 算法介绍
<a name="1.1"></a>
### 1.1 知识蒸馏简介
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升,甚至获得与大模型相似的精度指标 [1]。 PaddleClas 融合已有的蒸馏方法 [2,3],提供了一种简单的半监督标签知识蒸馏方案(SSLDSimple Semi-supervised Label Distillation),基于 ImageNet1k 分类数据集,在 ResNet_vd 以及 MobileNet 系列上的精度均有超过 3% 的绝对精度提升,具体指标如下图所示。
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升,甚至获得与大模型相似的精度指标 [1]。
![](../../images/distillation/distillation_perform_s.jpg)
<a name="2"></a>
## 2. SSLD 蒸馏策略
根据蒸馏方式的不同可以将知识蒸馏方法分为3个不同的类别Response based distillation、Feature based distillation、Relation based distillation。下面进行详细介绍。
<a name="2.1"></a>
### 2.1 简介
<a name='1.1.1'></a>
SSLD 的流程图如下图所示。
#### 1.1.1 Response based distillation
![](../../images/distillation/ppcls_distillation.png)
首先,我们从 ImageNet22k 中挖掘出了近 400 万张图片,同时与 ImageNet-1k 训练集整合在一起,得到了一个新的包含 500 万张图片的数据集。然后,我们将学生模型与教师模型组合成一个新的网络,该网络分别输出学生模型和教师模型的预测分布,与此同时,固定教师模型整个网络的梯度,而学生模型可以做正常的反向传播。最后,我们将两个模型的 logits 经过 softmax 激活函数转换为 soft label并将二者的 soft label 做 JS 散度作为损失函数,用于蒸馏模型训练。下面以 MobileNetV3该模型直接训练精度为 75.3%的知识蒸馏为例介绍该方案的核心关键点baseline 为 79.12% 的 ResNet50_vd 模型蒸馏 MobileNetV3训练集为 ImageNet1k 训练集loss 为 cross entropy loss迭代轮数为 120epoch精度指标为 75.6%)。
最早的知识蒸馏算法 KD由 Hinton 提出,训练的损失函数中除了 gt loss 之外,还引入了学生模型与教师模型输出的 KL 散度,最终精度超过单纯使用 gt loss 训练的精度。这里需要注意的是,在训练的时候,需要首先训练得到一个更大的教师模型,来指导学生模型的训练过程。
* 教师模型的选择。在进行知识蒸馏时,如果教师模型与学生模型的结构差异太大,蒸馏得到的结果反而不会有太大收益。相同结构下,精度更高的教师模型对结果也有很大影响。相比于 79.12% 的 ResNet50_vd 教师模型,使用 82.4% 的 ResNet50_vd 教师模型可以带来 0.4% 的绝对精度收益(`75.6%->76.0%`)
PaddleClas 中提出了一种简单使用的 SSLD 知识蒸馏算法 [6],在训练的时候去除了对 gt label 的依赖,结合大量无标注数据,最终蒸馏训练得到的预训练模型在 15 个模型上的精度提升平均高达 3%
* 改进 loss 计算方法。分类 loss 计算最常用的方法就是 cross entropy loss我们经过实验发现在使用 soft label 进行训练时,相对于 cross entropy lossKL div loss 对模型性能提升几乎无帮助,但是使用具有对称特性的 JS div loss 时,在多个蒸馏任务上相比 cross entropy loss 均有 0.2% 左右的收益(`76.0%->76.2%`)SSLD 中也基于 JS div loss 展开实验
上述标准的蒸馏方法是通过一个大模型作为教师模型来指导学生模型提升效果,而后来又发展出 DML(Deep Mutual Learning)互学习蒸馏方法 [7],即通过两个结构相同的模型互相学习。具体的。相比于 KD 等依赖于大的教师模型的知识蒸馏算法DML 脱离了对大的教师模型的依赖,蒸馏训练的流程更加简单,模型产出效率也要更高一些
* 更多的迭代轮数。蒸馏的 baseline 实验只迭代了 120 个 epoch 。实验发现,迭代轮数越多,蒸馏效果越好,最终我们迭代了 360 epoch精度指标可以达到 77.1%(`76.2%->77.1%`)。
<a name='1.1.2'></a>
* 无需数据集的真值标签,很容易扩展训练集。 SSLD 的 loss 在计算过程中,仅涉及到教师和学生模型对于相同图片的处理结果(经过 softmax 激活函数处理之后的 soft label因此即使图片数据不包含真值标签也可以用来进行训练并提升模型性能。该蒸馏方案的无标签蒸馏策略也大大提升了学生模型的性能上限(`77.1%->78.5%`)。
#### 1.1.2 Feature based distillation
* ImageNet1k 蒸馏 finetune 。 我们仅使用 ImageNet1k 数据,使用蒸馏方法对上述模型进行 finetune最终仍然可以获得 0.4% 的性能提升(`78.5%->78.9%`)。
Heo 等人提出了 OverHaul [8], 计算学生模型与教师模型的 feature map distance作为蒸馏的 loss在这里使用了学生模型、教师模型的转移来保证二者的 feature map 可以正常地进行 distance 的计算。
基于 feature map distance 的知识蒸馏方法也能够和 `3.1 章节` 中的基于 response 的知识蒸馏算法融合在一起,同时对学生模型的输出结果和中间层 feature map 进行监督。而对于 DML 方法来说,这种融合过程更为简单,因为不需要对学生和教师模型的 feature map 进行转换,便可以完成对齐(alignment)过程。PP-OCRv2 系统中便使用了这种方法,最终大幅提升了 OCR 文字识别模型的精度。
<a name='1.1.3'></a>
#### 1.1.3 Relation based distillation
[1.1.1](#1.1.1) 和 [1.1.2](#1.1.2) 章节中的论文中主要是考虑到学生模型与教师模型的输出或者中间层 feature map这些知识蒸馏算法只关注个体的输出结果没有考虑到个体之间的输出关系。
Park 等人提出了 RKD [10]基于关系的知识蒸馏算法RKD 中进一步考虑个体输出之间的关系,使用 2 种损失函数二阶的距离损失distance-wise和三阶的角度损失angle-wise
<a name="2.2"></a>
### 2.2 数据选择
本论文提出的算法关系知识蒸馏RKD迁移教师模型得到的输出结果间的结构化关系给学生模型不同于之前的只关注个体输出结果RKD 算法使用两种损失函数:二阶的距离损失(distance-wise)和三阶的角度损失(angle-wise)。在最终计算蒸馏损失函数的时候,同时考虑 KD loss 和 RKD loss。最终精度优于单独使用 KD loss 蒸馏得到的模型精度。
* SSLD 蒸馏方案的一大特色就是无需使用图像的真值标签,因此可以任意扩展数据集的大小,考虑到计算资源的限制,我们在这里仅基于 ImageNet22k 数据集对蒸馏任务的训练集进行扩充。在 SSLD 蒸馏任务中,我们使用了 `Top-k per class` 的数据采样方案 [3] 。具体步骤如下。
* 训练集去重。我们首先基于 SIFT 特征相似度匹配的方式对 ImageNet22k 数据集与 ImageNet1k 验证集进行去重,防止添加的 ImageNet22k 训练集中包含 ImageNet1k 验证集图像,最终去除了 4511 张相似图片。部分过滤的相似图片如下所示。
<a name='1.2'></a>
![](../../images/distillation/22k_1k_val_compare_w_sift.png)
### 1.2 PaddleClas支持的知识蒸馏算法
* 大数据集 soft label 获取,对于去重后的 ImageNet22k 数据集,我们使用 `ResNeXt101_32x16d_wsl` 模型进行预测,得到每张图片的 soft label 。
* Top-k 数据选择ImageNet1k 数据共有 1000 类,对于每一类,找出属于该类并且得分最高的 `k` 张图片,最终得到一个数据量不超过 `1000*k` 的数据集(某些类上得到的图片数量可能少于 `k` 张)。
* 将该数据集与 ImageNet1k 的训练集融合组成最终蒸馏模型所使用的数据集,数据量为 500 万。
<a name='1.2.1'></a>
<a name="3"></a>
## 3. 实验
#### 1.2.1 SSLD
* PaddleClas 的蒸馏策略为`大数据集训练 + ImageNet1k 蒸馏 finetune` 的策略。选择合适的教师模型,首先在挑选得到的 500 万数据集上进行训练,然后在 ImageNet1k 训练集上进行 finetune最终得到蒸馏后的学生模型。
##### 1.2.1.1 SSLD 算法介绍
<a name="3.1"></a>
### 3.1 教师模型的选择
论文信息:
为了验证教师模型和学生模型的模型大小差异和教师模型的模型精度对蒸馏结果的影响,我们做了几组实验验证。训练策略统一为:`cosine_decay_warmuplr=1.3, epoch=120, bs=2048`,学生模型均为从头训练。
|Teacher Model | Teacher Top1 | Student Model | Student Top1|
|- |:-: |:-: | :-: |
| ResNeXt101_32x16d_wsl | 84.2% | MobileNetV3_large_x1_0 | 75.78% |
| ResNet50_vd | 79.12% | MobileNetV3_large_x1_0 | 75.60% |
| ResNet50_vd | 82.35% | MobileNetV3_large_x1_0 | 76.00% |
从表中可以看出
> 教师模型结构相同时,其精度越高,最终的蒸馏效果也会更好一些。
> [Beyond Self-Supervision: A Simple Yet Effective Network Distillation Alternative to Improve Backbones
](https://arxiv.org/abs/2103.05959)
>
> 教师模型与学生模型的模型大小差异不宜过大,否则反而会影响蒸馏结果的精度。
> Cheng Cui, Ruoyu Guo, Yuning Du, Dongliang He, Fu Li, Zewu Wu, Qiwen Liu, Shilei Wen, Jizhou Huang, Xiaoguang Hu, Dianhai Yu, Errui Ding, Yanjun Ma
>
> arxiv, 2021
SSLD是百度于2021年提出的一种简单的半监督知识蒸馏方案通过设计一种改进的JS散度作为损失函数结合基于ImageNet22k数据集的数据挖掘策略最终帮助15个骨干网络模型的精度平均提升超过3%。
更多关于SSLD的原理、模型库与使用介绍请参考[SSLD知识蒸馏算法介绍](./ssld.md)。
因此最终在蒸馏实验中,对于 ResNet 系列学生模型,我们使用 `ResNeXt101_32x16d_wsl` 作为教师模型;对于 MobileNet 系列学生模型,我们使用蒸馏得到的 `ResNet50_vd` 作为教师模型。
##### 1.2.1.2 SSLD 配置
<a name="3.2"></a>
### 3.2 大数据蒸馏
SSLD配置如下所示。在模型构建Arch字段中需要同时定义学生模型与教师模型教师模型固定梯度并且加载预训练参数。在损失函数Loss字段中需要定义`DistillationDMLLoss`,作为训练的损失函数。
基于 PaddleClas 的蒸馏策略为`大数据集训练 + imagenet1k finetune` 的策略。
```yaml
# model architecture
Arch:
name: "DistillationModel" # 模型名称,这里使用的是蒸馏模型,
class_num: &class_num 1000 # 类别数量对于ImageNet1k数据集来说类别数为1000
pretrained_list: # 预训练模型列表,因为在下面的子网络中指定了预训练模型,这里无需指定
freeze_params_list: # 固定网络参数列表为True时表示固定该index对应的网络
- True
- False
infer_model_name: "Student" # 在模型导出的时候会导出Student子网络
models: # 子网络列表
- Teacher: # 教师模型
name: ResNet50_vd # 模型名称
class_num: *class_num # 类别数
pretrained: True # 预训练模型路径如果为True则会从官网下载默认的预训练模型
use_ssld: True # 是否使用SSLD蒸馏得到的预训练模型精度会更高一些
- Student: # 学生模型
name: PPLCNet_x2_5 # 模型名称
class_num: *class_num # 类别数
pretrained: False # 预训练模型路径可以指定为bool值或者字符串这里为False表示学生模型默认不加载预训练模型
针对从 ImageNet22k 挑选出的 400 万数据,融合 imagenet1k 训练集,组成共 500 万的训练集进行训练,具体地,在不同模型上的训练超参及效果如下。
# loss function config for traing/eval process
Loss: # 定义损失函数
Train: # 定义训练的损失函数,为列表形式
- DistillationDMLLoss: # 蒸馏的DMLLoss对DMLLoss进行封装支持蒸馏结果(dict形式)的损失函数计算
weight: 1.0 # loss权重
model_name_pairs: # 用于计算的模型对这里表示计算Student和Teacher输出的损失函数
- ["Student", "Teacher"]
Eval: # 定义评估时的损失函数
- CELoss:
weight: 1.0
```
<a name='1.2.2'></a>
#### 1.2.2 DML
##### 1.2.2.1 DML 算法介绍
论文信息:
> [Deep Mutual Learning](https://openaccess.thecvf.com/content_cvpr_2018/html/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.html)
>
> Ying Zhang, Tao Xiang, Timothy M. Hospedales, Huchuan Lu
>
> CVPR, 2018
DML论文中在蒸馏的过程中不依赖于教师模型两个结构相同的模型互相学习计算彼此输出logits的KL散度最终完成训练过程。
|Student Model | num_epoch | l2_ecay | batch size/gpu cards | base lr | learning rate decay | top1 acc |
| - |:-: |:-: | :-: |:-: |:-: |:-: |
| MobileNetV1 | 360 | 3e-5 | 4096/8 | 1.6 | cosine_decay_warmup | 77.65% |
| MobileNetV2 | 360 | 1e-5 | 3072/8 | 0.54 | cosine_decay_warmup | 76.34% |
| MobileNetV3_large_x1_0 | 360 | 1e-5 | 5760/24 | 3.65625 | cosine_decay_warmup | 78.54% |
| MobileNetV3_small_x1_0 | 360 | 1e-5 | 5760/24 | 3.65625 | cosine_decay_warmup | 70.11% |
| ResNet50_vd | 360 | 7e-5 | 1024/32 | 0.4 | cosine_decay_warmup | 82.07% |
| ResNet101_vd | 360 | 7e-5 | 1024/32 | 0.4 | cosine_decay_warmup | 83.41% |
| Res2Net200_vd_26w_4s | 360 | 4e-5 | 1024/32 | 0.4 | cosine_decay_warmup | 84.82% |
在ImageNet1k公开数据集上效果如下所示。
<a name="3.3"></a>
### 3.3 ImageNet1k 训练集 finetune
对于在大数据集上训练的模型,其学习到的特征可能与 ImageNet1k 数据特征有偏,因此在这里使用 ImageNet1k 数据集对模型进行 finetune。 finetune 的超参和 finetune 的精度收益如下。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | PPLCNet_x2_5 | [PPLCNet_x2_5.yaml](../../../ppcls/configs/ImageNet/PPLCNet/PPLCNet_x2_5.yaml) | 74.93% | - |
| DML | PPLCNet_x2_5 | [PPLCNet_x2_5_dml.yaml](../../../ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_dml.yaml) | 76.68%(**+1.75%**) | - |
|Student Model | num_epoch | l2_ecay | batch size/gpu cards | base lr | learning rate decay | top1 acc |
| - |:-: |:-: | :-: |:-: |:-: |:-: |
| MobileNetV1 | 30 | 3e-5 | 4096/8 | 0.016 | cosine_decay_warmup | 77.89% |
| MobileNetV2 | 30 | 1e-5 | 3072/8 | 0.0054 | cosine_decay_warmup | 76.73% |
| MobileNetV3_large_x1_0 | 30 | 1e-5 | 2048/8 | 0.008 | cosine_decay_warmup | 78.96% |
| MobileNetV3_small_x1_0 | 30 | 1e-5 | 6400/32 | 0.025 | cosine_decay_warmup | 71.28% |
| ResNet50_vd | 60 | 7e-5 | 1024/32 | 0.004 | cosine_decay_warmup | 82.39% |
| ResNet101_vd | 30 | 7e-5 | 1024/32 | 0.004 | cosine_decay_warmup | 83.73% |
| Res2Net200_vd_26w_4s | 360 | 4e-5 | 1024/32 | 0.004 | cosine_decay_warmup | 85.13% |
<a name="3.4"></a>
### 3.4 数据增广以及基于 Fix 策略的微调
* 基于前文所述的实验结论,我们在训练的过程中加入自动增广(AutoAugment)[4],同时进一步减小了 l2_decay(4e-5->2e-5),最终 ResNet50_vd 经过 SSLD 蒸馏策略,在 ImageNet1k 上的精度可以达到 82.99%,相比之前不加数据增广的蒸馏策略再次增加了 0.6% 。
* 注完整的PPLCNet_x2_5模型训练了360epoch这里为了方便对比baseline和DML均训练了100epoch因此指标比官网最终开源出来的模型精度76.60%)低一些。
* 对于图像分类任务,在测试的时候,测试尺度为训练尺度的 1.15 倍左右时,往往在不需要重新训练模型的情况下,模型的精度指标就可以进一步提升 [5],对于 82.99% 的 ResNet50_vd 在 320x320 的尺度下测试,精度可达 83.7%,我们进一步使用 Fix 策略,即在 320x320 的尺度下进行训练,使用与预测时相同的数据预处理方法,同时固定除 FC 层以外的所有参数,最终在 320x320 的预测尺度下,精度可以达到 **84.0%**
##### 1.2.2.2 DML 配置
<a name="3.5"></a>
### 3.5 实验过程中的一些问题
DML配置如下所示。在模型构建Arch字段中需要同时定义学生模型与教师模型教师模型与学生模型均保持梯度更新状态。在损失函数Loss字段中需要定义`DistillationDMLLoss`学生与教师之间的JS-Div loss以及`DistillationGTCELoss`学生与教师关于真值标签的CE loss作为训练的损失函数。
* 在预测过程中batch norm 的平均值与方差是通过加载预训练模型得到(设其模式为 test mode。在训练过程中batch norm 是通过统计当前 batch 的信息(设其模式为 train mode与历史保存信息进行滑动平均计算得到在蒸馏任务中我们发现通过 train mode即教师模型的均值与方差实时变化的模式去指导学生模型比通过 test mode 蒸馏,得到的学生模型性能更好一些,下面是一组实验结果。因此我们在该蒸馏方案中,均使用 train mode 去得到教师模型的 soft label 。
```yaml
Arch:
name: "DistillationModel"
class_num: &class_num 1000
pretrained_list:
freeze_params_list: # 两个模型互相学习,因此这里两个子网络的参数均不能固定
- False
- False
models:
- Teacher:
name: PPLCNet_x2_5 # 两个模型互学习,因此均没有加载预训练模型
class_num: *class_num
pretrained: False
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
|Teacher Model | Teacher Top1 | Student Model | Student Top1|
|- |:-: |:-: | :-: |
| ResNet50_vd | 82.35% | MobileNetV3_large_x1_0 | 76.00% |
| ResNet50_vd | 82.35% | MobileNetV3_large_x1_0 | 75.84% |
Loss:
Train:
- DistillationGTCELoss: # 因为2个子网络均没有加载预训练模型这里需要同时计算不同子网络的输出与真值标签之间的CE loss
weight: 1.0
model_names: ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
```
<a name="4"></a>
## 4. 蒸馏模型的应用
<a name='1.2.3'></a>
<a name="4.1"></a>
### 4.1 使用方法
#### 1.2.3 AFD
* 中间层学习率调整。蒸馏得到的模型的中间层特征图更加精细化,因此将蒸馏模型预训练应用到其他任务中时,如果采取和之前相同的学习率,容易破坏中间层特征。而如果降低整体模型训练的学习率,则会带来训练收敛速度慢的问题。因此我们使用了中间层学习率调整的策略。具体地:
* 针对 ResNet50_vd我们设置一个学习率倍数列表res block 之前的 3 个 conv2d 卷积参数具有统一的学习率倍数4 个 res block 的 conv2d 分别有一个学习率参数,共需设置 5 个学习率倍数的超参。在实验中发现。用于迁移学习 finetune 分类模型时,`[0.1,0.1,0.2,0.2,0.3]` 的中间层学习率倍数设置在绝大多数的任务中都性能更好;而在目标检测任务中,`[0.05,0.05,0.05,0.1,0.15]` 的中间层学习率倍数设置能够带来更大的精度收益。
* 对于 MoblileNetV3_large_x1_0由于其包含 15 个 block我们设置每 3 个 block 共享一个学习率倍数参数,因此需要共 5 个学习率倍数的参数,最终发现在分类和检测任务中,`[0.25,0.25,0.5,0.5,0.75]` 的中间层学习率倍数能够带来更大的精度收益。
##### 1.2.3.1 AFD 算法介绍
论文信息:
* 适当的 l2 decay 。不同分类模型在训练的时候一般都会根据模型设置不同的 l2 decay大模型为了防止过拟合往往会设置更大的 l2 decay如 ResNet50 等模型,一般设置为 `1e-4` ;而如 MobileNet 系列模型,在训练时往往都会设置为 `1e-5~4e-5`,防止模型过度欠拟合,在蒸馏时亦是如此。在将蒸馏模型应用到目标检测任务中时,我们发现也需要调节 backbone 甚至特定任务模型模型的 l2 decay和预训练蒸馏时的 l2 decay 尽可能保持一致。以 Faster RCNN MobiletNetV3 FPN 为例,我们发现仅修改该参数,在 COCO2017 数据集上就可以带来最多 0.5% 左右的精度(mAP)提升(默认 Faster RCNN l2 decay 为 1e-4我们修改为 1e-5~4e-5 均有 0.3%~0.5% 的提升)。
> [Show, attend and distill: Knowledge distillation via attention-based feature matching](https://arxiv.org/abs/2102.02973)
>
> Mingi Ji, Byeongho Heo, Sungrae Park
>
> AAAI, 2018
<a name="4.2"></a>
### 4.2 迁移学习 finetune
* 为验证迁移学习的效果,我们在 10 个小的数据集上验证其效果。在这里为了保证实验的可对比性,我们均使用 ImageNet1k 数据集训练的标准预处理过程,对于蒸馏模型我们也添加了蒸馏模型中间层学习率的搜索。
* 对于 ResNet50_vd, baseline 为 Top1 Acc 79.12% 的预训练模型基于 grid search 搜索得到的最佳精度,对比实验则为基于该精度对预训练和中间层学习率进一步搜索得到的最佳精度。下面给出 10 个数据集上所有 baseline 和蒸馏模型的精度对比。
AFD提出在蒸馏的过程中利用基于注意力的元网络学习特征之间的相对相似性并应用识别的相似关系来控制所有可能的特征图pair的蒸馏强度。
在ImageNet1k公开数据集上效果如下所示。
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| AFD | ResNet18 | [resnet34_distill_resnet18_afd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_afd.yaml) | 71.68%(**+0.88%**) | - |
注意这里为了与论文的训练配置保持对齐设置训练的迭代轮数为100epoch因此baseline精度低于PaddleClas中开源出的模型精度71.0%
##### 1.2.3.2 AFD 配置
AFD配置如下所示。在模型构建Arch字段中需要同时定义学生模型与教师模型固定教师模型的权重。这里需要对从教师模型获取的特征进行变换进而与学生模型进行损失函数的计算。在损失函数Loss字段中需要定义`DistillationKLDivLoss`学生与教师之间的KL-Div loss、`AFDLoss`学生与教师之间的AFD loss以及`DistillationGTCELoss`学生与教师关于真值标签的CE loss作为训练的损失函数。
```yaml
Arch:
name: "DistillationModel"
pretrained_list:
freeze_params_list:
models:
- Teacher:
name: AttentionModel # 包含若干个串行的网络,后面的网络会将前面的网络输出作为输入并进行处理
pretrained_list:
freeze_params_list:
- True
- False
models:
# AttentionModel 的基础网络
- ResNet34:
name: ResNet34
pretrained: True
# return_patterns表示除了返回输出的logits也会返回对应名称的中间层feature map
return_patterns: &t_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]",
"blocks[8]", "blocks[9]", "blocks[10]", "blocks[11]",
"blocks[12]", "blocks[13]", "blocks[14]", "blocks[15]"]
# AttentionModel的变换网络会对基础子网络的特征进行变换
- LinearTransformTeacher:
name: LinearTransformTeacher
qk_dim: 128
keys: *t_keys
t_shapes: &t_shapes [[64, 56, 56], [64, 56, 56], [64, 56, 56], [128, 28, 28],
[128, 28, 28], [128, 28, 28], [128, 28, 28], [256, 14, 14],
[256, 14, 14], [256, 14, 14], [256, 14, 14], [256, 14, 14],
[256, 14, 14], [512, 7, 7], [512, 7, 7], [512, 7, 7]]
- Student:
name: AttentionModel
pretrained_list:
freeze_params_list:
- False
- False
models:
- ResNet18:
name: ResNet18
pretrained: False
return_patterns: &s_keys ["blocks[0]", "blocks[1]", "blocks[2]", "blocks[3]",
"blocks[4]", "blocks[5]", "blocks[6]", "blocks[7]"]
- LinearTransformStudent:
name: LinearTransformStudent
qk_dim: 128
keys: *s_keys
s_shapes: &s_shapes [[64, 56, 56], [64, 56, 56], [128, 28, 28], [128, 28, 28],
[256, 14, 14], [256, 14, 14], [512, 7, 7], [512, 7, 7]]
t_shapes: *t_shapes
infer_model_name: "Student"
| Dataset | Model | Baseline Top1 Acc | Distillation Model Finetune |
|- |:-: |:-: | :-: |
| Oxford102 flowers | ResNete50_vd | 97.18% | 97.41% |
| caltech-101 | ResNete50_vd | 92.57% | 93.21% |
| Oxford-IIIT-Pets | ResNete50_vd | 94.30% | 94.76% |
| DTD | ResNete50_vd | 76.48% | 77.71% |
| fgvc-aircraft-2013b | ResNete50_vd | 88.98% | 90.00% |
| Stanford-Cars | ResNete50_vd | 92.65% | 92.76% |
| SUN397 | ResNete50_vd | 64.02% | 68.36% |
| cifar100 | ResNete50_vd | 86.50% | 87.58% |
| cifar10 | ResNete50_vd | 97.72% | 97.94% |
| Food-101 | ResNete50_vd | 89.58% | 89.99% |
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
key: logits
- DistillationKLDivLoss: # 蒸馏的KL-Div loss会根据model_name_pairs中的模型名称去提取对应模型的输出特征计算loss
weight: 0.9 # 该loss的权重
model_name_pairs: [["Student", "Teacher"]]
temperature: 4
key: logits
- AFDLoss: # AFD loss
weight: 50.0
model_name_pair: ["Student", "Teacher"]
student_keys: ["bilinear_key", "value"]
teacher_keys: ["query", "value"]
s_shapes: *s_shapes
t_shapes: *t_shapes
Eval:
- CELoss:
weight: 1.0
```
* 可以看出在上面 10 个数据集上,结合适当的中间层学习率倍数设置,蒸馏模型平均能够带来 1% 以上的精度提升。
**注意(** 上述在网络中指定`return_patterns`返回中间层特征的功能是基于TheseusLayer更多关于TheseusLayer的使用说明请参考[TheseusLayer 使用说明](./theseus_layer.md)
<a name="4.3"></a>
### 4.3 目标检测
<a name='1.2.4'></a>
我们基于两阶段目标检测 Faster/Cascade RCNN 模型验证蒸馏得到的预训练模型的效果。
#### 1.2.4 DKD
* ResNet50_vd
##### 1.2.4.1 DKD 算法介绍
设置训练与评测的尺度均为 640x640最终 COCO 上检测指标如下。
论文信息:
| Model | train/test scale | pretrain top1 acc | feature map lr | coco mAP |
|- |:-: |:-: | :-: | :-: |
| Faster RCNN R50_vd FPN | 640/640 | 79.12% | [1.0,1.0,1.0,1.0,1.0] | 34.8% |
| Faster RCNN R50_vd FPN | 640/640 | 79.12% | [0.05,0.05,0.1,0.1,0.15] | 34.3% |
| Faster RCNN R50_vd FPN | 640/640 | 82.18% | [0.05,0.05,0.1,0.1,0.15] | 36.3% |
在这里可以看出,对于未蒸馏模型,过度调整中间层学习率反而降低最终检测模型的性能指标。基于该蒸馏模型,我们也提供了领先的服务端实用目标检测方案,详细的配置与训练代码均已开源,可以参考 [PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/rcnn_enhance)。
> [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679)
>
> Borui Zhao, Quan Cui, Renjie Song, Yiyu Qiu, Jiajun Liang
>
> CVPR, 2022
<a name="5"></a>
## 5. SSLD 实战
DKD将蒸馏中常用的 KD Loss 进行了解耦成为Target Class Knowledge Distillation(TCKD目标类知识蒸馏)以及Non-target Class Knowledge Distillation(NCKD非目标类知识蒸馏)两个部分,对两个部分的作用分别研究,并使它们各自的权重可以独立调节,提升了蒸馏的精度和灵活性。
本节将基于 ImageNet-1K 的数据集详细介绍 SSLD 蒸馏实验,如果想快速体验此方法,可以参考 [**30 分钟玩转 PaddleClas进阶版**](../quick_start/quick_start_classification_professional.md)中基于 CIFAR100 的 SSLD 蒸馏实验。
在ImageNet1k公开数据集上效果如下所示
<a name="5.1"></a>
### 5.1 参数配置
| 策略 | 骨干网络 | 配置文件 | Top-1 acc | 下载链接 |
| --- | --- | --- | --- | --- |
| baseline | ResNet18 | [ResNet18.yaml](../../../ppcls/configs/ImageNet/ResNet/ResNet18.yaml) | 70.8% | - |
| AFD | ResNet18 | [resnet34_distill_resnet18_dkd.yaml](../../../ppcls/configs/ImageNet/Distillation/resnet34_distill_resnet18_dkd.yaml) | 72.59%(**+1.79%**) | - |
##### 1.2.4.2 DKD 配置
DKD 配置如下所示。在模型构建Arch字段中需要同时定义学生模型与教师模型教师模型固定参数且需要加载预训练模型。在损失函数Loss字段中需要定义`DistillationDKDLoss`学生与教师之间的DKD loss以及`DistillationGTCELoss`学生与教师关于真值标签的CE loss作为训练的损失函数。
实战部分提供了 SSLD 蒸馏的示例,在 `ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml` 中提供了 `MobileNetV3_large_x1_0` 蒸馏 `MobileNetV3_small_x1_0` 的配置文件,用户可以在 `tools/train.sh` 里直接替换配置文件的路径即可使用。
```yaml
Arch:
@ -216,53 +346,165 @@ Arch:
- False
models:
- Teacher:
name: MobileNetV3_large_x1_0
name: ResNet34
pretrained: True
use_ssld: True
- Student:
name: MobileNetV3_small_x1_0
name: ResNet18
pretrained: False
infer_model_name: "Student"
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student"]
- DistillationDKDLoss:
weight: 1.0
model_name_pairs: [["Student", "Teacher"]]
temperature: 1
alpha: 1.0
beta: 1.0
Eval:
- CELoss:
weight: 1.0
```
<a name="2"></a>
## 2. 模型训练、评估和预测
<a name="2.1"></a>
### 2.1 环境配置
* 安装:请先参考 [Paddle 安装教程](../installation/install_paddle.md) 以及 [PaddleClas 安装教程](../installation/install_paddleclas.md) 配置 PaddleClas 运行环境。
<a name="2.2"></a>
### 2.2 数据准备
请在[ImageNet 官网](https://www.image-net.org/)准备 ImageNet-1k 相关的数据。
进入 PaddleClas 目录。
```
cd path_to_PaddleClas
```
在参数配置中,`freeze_params_list` 中需要指定模型是否需要冻结参数,`models` 中需要指定 Teacher 模型和 Student 模型,其中 Teacher 模型需要加载预训练模型。用户可以直接在此处更改模型。
进入 `dataset/` 目录,将下载好的数据命名为 `ILSVRC2012` ,存放于此。 `ILSVRC2012` 目录中具有以下数据:
<a name="5.2"></a>
### 5.2 启动命令
```
├── train
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
├── train_list.txt
...
├── val
│ ├── ILSVRC2012_val_00000001.JPEG
│ ├── ILSVRC2012_val_00000002.JPEG
├── val_list.txt
```
当用户配置完训练环境后,类似于训练其他分类任务,只需要将 `tools/train.sh` 中的配置文件替换成为相应的蒸馏配置文件即可。
其中 `train/``val/` 分别为训练集和验证集。`train_list.txt` 和 `val_list.txt` 分别为训练集和验证集的标签文件
其中 `train.sh` 中的内容如下:
```bash
如果包含与训练集场景相似的无标注数据则也可以按照与训练集标注完全相同的方式进行整理将文件与当前有标注的数据集放在相同目录下将其标签值记为0假设整理的标签文件名为`train_list_unlabel.txt`则可以通过下面的命令生成用于SSLD训练的标签文件。
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
--log_dir=mv3_large_x1_0_distill_mv3_small_x1_0 \
```shell
cat train_list.txt train_list_unlabel.txt > train_list_all.txt
```
**备注:**
* 关于 `train_list.txt`、`val_list.txt`的格式说明,可以参考[PaddleClas分类数据集格式说明](../data_preparation/classification_dataset.md#1-数据集格式说明) 。
<a name="2.3"></a>
### 2.3 模型训练
以SSLD知识蒸馏算法为例介绍知识蒸馏算法的模型训练、评估、预测等过程。配置文件为 [PPLCNet_x2_5_ssld.yaml](../../../ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml) ,使用下面的命令可以完成模型训练。
```shell
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/ImageNet/Distillation/mv3_large_x1_0_distill_mv3_small_x1_0.yaml
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml
```
运行 `train.sh`
<a name="2.4"></a>
### 2.4 模型评估
训练好模型之后,可以通过以下命令实现对模型指标的评估。
```bash
sh tools/train.sh
python3 tools/eval.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml \
-o Global.pretrained_model=output/DistillationModel/best_model
```
<a name="5.3"></a>
### 5.3 注意事项
其中 `-o Global.pretrained_model="output/DistillationModel/best_model"` 指定了当前最佳权重所在的路径,如果指定其他权重,只需替换对应的路径即可。
* 用户在使用 SSLD 蒸馏之前,首先需要在目标数据集上训练一个教师模型,该教师模型用于指导学生模型在该数据集上的训练。
<a name="2.5"></a>
* 如果学生模型没有加载预训练模型,训练的其他超参数可以参考该学生模型在 ImageNet-1k 上训练的超参数,如果学生模型加载了预训练模型,学习率可以调整到原来的 1/10 或者 1/100 。
### 2.5 模型预测
* 在 SSLD 蒸馏的过程中,学生模型只学习 soft-label 导致训练目标变的更加复杂,建议可以适当的调小 `l2_decay` 的值来获得更高的验证集准确率。
模型训练完成之后,可以加载训练得到的预训练模型,进行模型预测。在模型库的 `tools/infer.py` 中提供了完整的示例,只需执行下述命令即可完成模型预测:
* 若用户准备添加无标签的训练数据,只需要将新的训练数据放置在原本训练数据的路径下,生成新的数据 list 即可,另外,新生成的数据 list 需要将无标签的数据添加伪标签(只是为了统一读数据)。
```python
python3 tools/infer.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml \
-o Global.pretrained_model=output/DistillationModel/best_model
```
<a name="6"></a>
## 6. 参考文献
输出结果如下:
```
[{'class_ids': [8, 7, 86, 82, 21], 'scores': [0.87908, 0.12091, 0.0, 0.0, 0.0], 'file_name': 'docs/images/inference_deployment/whl_demo.jpg', 'label_names': ['hen', 'cock', 'partridge', 'ruffed grouse, partridge, Bonasa umbellus', 'kite']}]
```
**备注:**
* 这里`-o Global.pretrained_model="output/ResNet50/best_model"` 指定了当前最佳权重所在的路径,如果指定其他权重,只需替换对应的路径即可。
* 默认是对 `docs/images/inference_deployment/whl_demo.jpg` 进行预测,此处也可以通过增加字段 `-o Infer.infer_imgs=xxx` 对其他图片预测。
<a name="2.6"></a>
### 2.6 模型导出与推理
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端提供高性能的推理能力。相比于直接基于预训练模型进行预测Paddle Inference可使用MKLDNN、CUDNN、TensorRT 进行预测加速从而实现更优的推理性能。更多关于Paddle Inference推理引擎的介绍可以参考[Paddle Inference官网教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html)。
在模型推理之前需要先导出模型。对于知识蒸馏训练得到的模型,在导出时需要指定`-o Global.infer_model_name=Student`,来表示导出的模型为学生模型。具体命令如下所示。
```shell
python3 tools/export_model.py \
-c ppcls/configs/ImageNet/Distillation/PPLCNet_x2_5_ssld.yaml \
-o Global.pretrained_model=./output/DistillationModel/best_model \
-o Arch.infer_model_name=Student
```
最终在`inference`目录下会产生`inference.pdiparams`、`inference.pdiparams.info`、`inference.pdmodel` 3个文件。
关于更多模型推理相关的教程,请参考:[Python 预测推理](../inference_deployment/python_deploy.md)。
<a name="3"></a>
## 3. 参考文献
[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.
@ -273,3 +515,17 @@ sh tools/train.sh
[4] Cubuk E D, Zoph B, Mane D, et al. Autoaugment: Learning augmentation strategies from data[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2019: 113-123.
[5] Touvron H, Vedaldi A, Douze M, et al. Fixing the train-test resolution discrepancy[C]//Advances in Neural Information Processing Systems. 2019: 8250-8260.
[6] Cui C, Guo R, Du Y, et al. Beyond Self-Supervision: A Simple Yet Effective Network Distillation Alternative to Improve Backbones[J]. arXiv preprint arXiv:2103.05959, 2021.
[7] Zhang Y, Xiang T, Hospedales T M, et al. Deep mutual learning[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018: 4320-4328.
[8] Heo B, Kim J, Yun S, et al. A comprehensive overhaul of feature distillation[C]//Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019: 1921-1930.
[9] Du Y, Li C, Guo R, et al. PP-OCRv2: Bag of Tricks for Ultra Lightweight OCR System[J]. arXiv preprint arXiv:2109.03144, 2021.
[10] Park W, Kim D, Lu Y, et al. Relational knowledge distillation[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019: 3967-3976.
[11] Zhao B, Cui Q, Song R, et al. Decoupled Knowledge Distillation[J]. arXiv preprint arXiv:2203.08679, 2022.
[12] Ji M, Heo B, Park S. Show, attend and distill: Knowledge distillation via attention-based feature matching[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2021, 35(9): 7945-7952.

View File

@ -0,0 +1,157 @@
# SSLD 知识蒸馏实战
## 目录
- [1. 算法介绍](#1)
- [1.1 知识蒸馏简介](#1.1)
- [1.2 SSLD蒸馏策略](#1.2)
- [2. SSLD预训练模型库](#2)
- [3. SSLD使用](#3)
- [3.1 加载SSLD模型进行微调](#3.1)
- [3.2 使用SSLD方案进行知识蒸馏](#3.2)
- [4. 参考文献](#4)
<a name="1"></a>
## 1. 算法介绍
### 1.1 简介
PaddleClas 融合已有的知识蒸馏方法 [2,3],提供了一种简单的半监督标签知识蒸馏方案(SSLDSimple Semi-supervised Label Distillation),基于 ImageNet1k 分类数据集,在 ResNet_vd 以及 MobileNet 系列上的精度均有超过 3% 的绝对精度提升,具体指标如下图所示。
<div align="center">
<img src="../../images/distillation/distillation_perform_s.jpg" width = "800" />
</div>
### 1.2 SSLD蒸馏策略
SSLD 的流程图如下图所示。
<div align="center">
<img src="../../images/distillation/ppcls_distillation.png" width = "800" />
</div>
首先,我们从 ImageNet22k 中挖掘出了近 400 万张图片,同时与 ImageNet-1k 训练集整合在一起,得到了一个新的包含 500 万张图片的数据集。然后,我们将学生模型与教师模型组合成一个新的网络,该网络分别输出学生模型和教师模型的预测分布,与此同时,固定教师模型整个网络的梯度,而学生模型可以做正常的反向传播。最后,我们将两个模型的 logits 经过 softmax 激活函数转换为 soft label并将二者的 soft label 做 JS 散度作为损失函数,用于蒸馏模型训练。
以 MobileNetV3该模型直接训练精度为 75.3%)的知识蒸馏为例,该方案的核心策略优化点如下所示。
| 实验ID | 策略 | Top-1 acc |
|:------:|:---------:|:--------:|
| 1 | baseline | 75.60% |
| 2 | 更换教师模型精度为82.4%的权重 | 76.00% |
| 3 | 使用改进的JS散度损失函数 | 76.20% |
| 4 | 迭代轮数增加至360epoch | 77.10% |
| 5 | 添加400W挖掘得到的无标注数据 | 78.50% |
| 6 | 基于ImageNet1k数据微调 | 78.90% |
* 注其中baseline的训练条件为
* 训练数据ImageNet1k数据集
* 损失函数Cross Entropy Loss
* 迭代轮数120epoch
SSLD 蒸馏方案的一大特色就是无需使用图像的真值标签,因此可以任意扩展数据集的大小,考虑到计算资源的限制,我们在这里仅基于 ImageNet22k 数据集对蒸馏任务的训练集进行扩充。在 SSLD 蒸馏任务中,我们使用了 `Top-k per class` 的数据采样方案 [3] 。具体步骤如下。
1训练集去重。我们首先基于 SIFT 特征相似度匹配的方式对 ImageNet22k 数据集与 ImageNet1k 验证集进行去重,防止添加的 ImageNet22k 训练集中包含 ImageNet1k 验证集图像,最终去除了 4511 张相似图片。部分过滤的相似图片如下所示。
<div align="center">
<img src="../../images/distillation/22k_1k_val_compare_w_sift.png" width = "600" />
</div>
2大数据集 soft label 获取,对于去重后的 ImageNet22k 数据集,我们使用 `ResNeXt101_32x16d_wsl` 模型进行预测,得到每张图片的 soft label 。
3Top-k 数据选择ImageNet1k 数据共有 1000 类,对于每一类,找出属于该类并且得分最高的 `k` 张图片,最终得到一个数据量不超过 `1000*k` 的数据集(某些类上得到的图片数量可能少于 `k` 张)。
4将该数据集与 ImageNet1k 的训练集融合组成最终蒸馏模型所使用的数据集,数据量为 500 万。
<a name="2"></a>
## 2. 预训练模型库
移动端预训练模型库列表如下所示。
| 模型 | FLOPs(M) | Params(M) | top-1 acc | SSLD top-1 acc | 精度收益 | 下载链接 |
|-------------------|----------|-----------|----------|---------------|--------|------|
| PPLCNetV2_base | 604.16 | 6.54 | 77.04% | 80.10% | +3.06% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_ssld_pretrained.pdparams) |
| PPLCNet_x2_5 | 906.49 | 9.04 | 76.60% | 80.82% | +4.22% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_ssld_pretrained.pdparams) |
| PPLCNet_x1_0 | 160.81 | 2.96 | 71.32% | 74.39% | +3.07% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_ssld_pretrained.pdparams) |
| PPLCNet_x0_5 | 47.28 | 1.89 | 63.14% | 66.10% | +2.96% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_ssld_pretrained.pdparams) |
| PPLCNet_x0_25 | 18.43 | 1.52 | 51.86% | 53.43% | +1.57% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_ssld_pretrained.pdparams) |
| MobileNetV1 | 578.88 | 4.19 | 71.00% | 77.90% | +6.90% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV1_ssld_pretrained.pdparams) |
| MobileNetV2 | 327.84 | 3.44 | 72.20% | 76.74% | +4.54% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV2_ssld_pretrained.pdparams) |
| MobileNetV3_large_x1_0 | 229.66 | 5.47 | 75.30% | 79.00% | +3.70% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_large_x1_0_ssld_pretrained.pdparams) |
| MobileNetV3_small_x1_0 | 63.67 | 2.94 | 68.20% | 71.30% | +3.10% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x1_0_ssld_pretrained.pdparams) |
| MobileNetV3_small_x0_35 | 14.56 | 1.66 | 53.00% | 55.60% | +2.60% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/MobileNetV3_small_x0_35_ssld_pretrained.pdparams) |
| GhostNet_x1_3_ssld | 236.89 | 7.30 | 75.70% | 79.40% | +3.70% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/GhostNet_x1_3_ssld_pretrained.pdparams) |
* 注:其中的`top-1 acc`表示使用普通训练方式得到的模型精度,`SSLD top-1 acc`表示使用SSLD知识蒸馏训练策略得到的模型精度。
服务端预训练模型库列表如下所示。
| 模型 | FLOPs(G) | Params(M) | top-1 acc | SSLD top-1 acc | 精度收益 | 下载链接 |
|----------------------|----------|-----------|----------|---------------|--------|-------------------------------------------------------------------------------------------|
| PPHGNet_base | 25.14 | 71.62 | - | 85.00% | - | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_base_ssld_pretrained.pdparams) |
| PPHGNet_small | 8.53 | 24.38 | 81.50% | 83.80% | +2.30% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_ssld_pretrained.pdparams) |
| PPHGNet_tiny | 4.54 | 14.75 | 79.83% | 81.95% | +2.12% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_ssld_pretrained.pdparams) |
| ResNet50_vd | 8.67 | 25.58 | 79.10% | 83.00% | +3.90% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams) |
| ResNet101_vd | 16.1 | 44.57 | 80.20% | 83.70% | +3.50% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_ssld_pretrained.pdparams) |
| ResNet34_vd | 7.39 | 21.82 | 76.00% | 79.70% | +3.70% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_vd_ssld_pretrained.pdparams) |
| Res2Net50_vd_26w_4s | 8.37 | 25.06 | 79.80% | 83.10% | +3.30% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/Res2Net50_vd_26w_4s_ssld_pretrained.pdparams) |
| Res2Net101_vd_26w_4s | 16.67 | 45.22 | 80.60% | 83.90% | +3.30% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/Res2Net101_vd_26w_4s_ssld_pretrained.pdparams) |
| Res2Net200_vd_26w_4s | 31.49 | 76.21 | 81.20% | 85.10% | +3.90% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/Res2Net200_vd_26w_4s_ssld_pretrained.pdparams) |
| HRNet_W18_C | 4.14 | 21.29 | 76.90% | 81.60% | +4.70% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/_ssld_pretrained.pdparams) |
| HRNet_W48_C | 34.58 | 77.47 | 79.00% | 83.60% | +4.60% | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/HRNet_W48_C_ssld_pretrained.pdparams) |
| SE_HRNet_W64_C | 57.83 | 128.97 | - | 84.70% | - | [链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/SE_HRNet_W64_C_ssld_pretrained.pdparams) |
<a name="3"></a>
## 3. SSLD使用方法
<a name="3.1"></a>
### 3.1 加载SSLD模型进行微调
如果希望直接使用预训练模型,可以在训练的时候,加入参数`-o Arch.pretrained=True -o Arch.use_ssld=True`表示使用基于SSLD的预训练模型示例如下所示。
```shell
# 单机单卡训练
python3 tools/train.py -c ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml -o Arch.pretrained=True -o Arch.use_ssld=True
# 单机多卡训练
python3 -m paddle.distributed.launch --gpus="0,1,2,3" tools/train.py -c ppcls/configs/ImageNet/ResNet/ResNet50_vd.yaml -o Arch.pretrained=True -o Arch.use_ssld=True
```
<a name="3.2"></a>
### 3.2 使用SSLD方案进行知识蒸馏
相比于其他大多数知识蒸馏算法SSLD摆脱对数据标注的依赖通过引入无标注数据可以进一步提升模型精度。
对于无标注数据,需要按照与有标注数据完全相同的整理方式,将文件与当前有标注的数据集放在相同目录下,将其标签值记为`0`,假设整理的标签文件名为`train_list_unlabel.txt`则可以通过下面的命令生成用于SSLD训练的标签文件。
```shell
cat train_list.txt train_list_unlabel.txt > train_list_all.txt
```
更多关于图像分类任务的数据标签说明,请参考:[PaddleClas图像分类数据集格式说明](../data_preparation/classification_dataset.md#1-数据集格式说明)
PaddleClas中集成了PULC超轻量图像分类实用方案里面包含SSLD ImageNet预训练模型的使用以及更加通用的无标签数据的知识蒸馏方案更多详细信息请参考[PULC超轻量图像分类实用方案使用教程](../PULC/PULC_train.md)。
<a name="4"></a>
## 4. 参考文献
[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.
[2] Bagherinezhad H, Horton M, Rastegari M, et al. Label refinery: Improving imagenet classification through label progression[J]. arXiv preprint arXiv:1805.02641, 2018.
[3] Yalniz I Z, Jégou H, Chen K, et al. Billion-scale semi-supervised learning for image classification[J]. arXiv preprint arXiv:1905.00546, 2019.
[4] Touvron H, Vedaldi A, Douze M, et al. Fixing the train-test resolution discrepancy[C]//Advances in Neural Information Processing Systems. 2019: 8250-8260.

View File

@ -42,7 +42,7 @@
PaddleClas 中提出了一种简单使用的 SSLD 知识蒸馏算法 [6],在训练的时候去除了对 gt label 的依赖,结合大量无标注数据,最终蒸馏训练得到的预训练模型在 15 个模型上的精度提升平均高达 3%。
上述标准的蒸馏方法是通过一个大模型作为教师模型来指导学生模型提升效果,而后来又发展出 DML(Deep Mutual Learning)互学习蒸馏方法 [7],即通过两个结构相同的模型互相学习。具体的。相比于 KD 等依赖于大的教师模型的知识蒸馏算法DML 脱离了对大的教师模型的依赖,蒸馏训练的流程更加简单,模型产出效率也要更高一些。
上述标准的蒸馏方法是通过一个大模型作为教师模型来指导学生模型提升效果,而后来又发展出 DML(Deep Mutual Learning)互学习蒸馏方法 [7],即通过两个结构相同的模型互相学习。具体的。相比于 KD 等依赖于大的教师模型的知识蒸馏算法DML 脱离了对大的教师模型的依赖,蒸馏训练的流程更加简单,模型产出效率也要更高一些。
<a name='3.2'></a>
### 3.2 Feature based distillation

View File

@ -42,11 +42,21 @@ python3 -m paddle.distributed.launch \
## 3. 性能效果测试
* 在4机8卡V100的机器上基于[SSLD知识蒸馏训练策略](../advanced_tutorials/knowledge_distillation.md)数据量500W进行模型训练不同模型的训练耗时以及多机加速比情况如下所示。
* 在单机8卡V100的机器上基于[SSLD知识蒸馏训练策略](../advanced_tutorials/ssld.md)数据量500W进行模型训练不同模型的训练耗时以及单机8卡加速比情况如下所示。
| 模型 | 精度 | 单机单卡耗时 | 单机8卡耗时 | 加速比 |
|:---------:|:--------:|:--------:|:--------:|:------:|
| PPHGNet-base_ssld | 85.00% | 133.2d | 18.96d | **7.04** |
| PPLCNetv2-base_ssld | 80.10% | 31.6d | 6.4d | **4.93** |
| PPLCNet_x0_25_ssld | 53.43% | 21.8d | 6.2d | **3.99** |
* 在4机8卡V100的机器上基于[SSLD知识蒸馏训练策略](../advanced_tutorials/ssld.md)数据量500W进行模型训练不同模型的训练耗时以及多机加速比情况如下所示。
| 模型 | 精度 | 单机8卡耗时 | 4机8卡耗时 | 加速比 |
|:---------:|:--------:|:--------:|:--------:|:------:|
| PPHGNet-base_ssld | 85.00% | 15.74d | 4.86d | **3.23** |
| PPHGNet-base_ssld | 85.00% | 18.96d | 4.86d | **3.90** |
| PPLCNetv2-base_ssld | 80.10% | 6.4d | 1.67d | **3.83** |
| PPLCNet_x0_25_ssld | 53.43% | 6.2d | 1.78d | **3.48** |

View File

@ -0,0 +1,158 @@
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output_lcnet_x2_5_dml
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- False
- False
infer_model_name: "Student"
models:
- Teacher:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
# loss function config for traing/eval process
Loss:
Train:
- DistillationGTCELoss:
weight: 1.0
model_names: ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.4
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]

View File

@ -0,0 +1,157 @@
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output_r50_vd_distill
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
to_static: True
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: "DistillationModel"
class_num: &class_num 1000
# if not null, its lengths should be same as models
pretrained_list:
# if not null, its lengths should be same as models
freeze_params_list:
- True
- False
infer_model_name: "Student"
models:
- Teacher:
name: ResNet50_vd
class_num: *class_num
pretrained: True
use_ssld: True
- Student:
name: PPLCNet_x2_5
class_num: *class_num
pretrained: False
# loss function config for traing/eval process
Loss:
Train:
- DistillationDMLLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.2
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 256
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]
Eval:
- DistillationTopkAcc:
model_key: "Student"
topk: [1, 5]

View File

@ -441,6 +441,8 @@ class Engine(object):
if isinstance(out, list):
out = out[0]
if isinstance(out, dict) and "Student" in out:
out = out["Student"]
if isinstance(out, dict) and "logits" in out:
out = out["logits"]
if isinstance(out, dict) and "output" in out:

View File

@ -97,8 +97,6 @@ class Attention(nn.Layer):
super().__init__()
self.qk_dim = qk_dim
self.n_t = n_t
# self.linear_trans_s = LinearTransformStudent(qk_dim, t_shapes, s_shapes, unique_t_shapes)
# self.linear_trans_t = LinearTransformTeacher(qk_dim, t_shapes)
self.p_t = self.create_parameter(
shape=[len(t_shapes), qk_dim],

View File

@ -59,7 +59,7 @@ def search_strategy():
configs = config.get_config(
args.config, overrides=args.override, show=False)
base_config_file = configs["base_config_file"]
distill_config_file = configs["distill_config_file"]
distill_config_file = configs.get("distill_config_file", None)
model_name = config.get_config(base_config_file)["Arch"]["name"]
gpus = configs["gpus"]
gpus = ",".join([str(i) for i in gpus])