Update distillation.md

pull/642/head
littletomatodonkey 2021-03-16 12:41:47 +08:00 committed by GitHub
parent 3b1fad3fa8
commit cbd9f341ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 23 deletions

View File

@ -1,5 +1,7 @@
# 一、模型压缩方法简介
# 知识蒸馏
## 一、模型压缩方法简介
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
@ -11,9 +13,9 @@
![](../../../images/distillation/distillation_perform_s.jpg)
# 二、SSLD 蒸馏策略
## 二、SSLD 蒸馏策略
## 2.1 简介
### 2.1 简介
SSLD的流程图如下图所示。
@ -33,7 +35,7 @@ SSLD的流程图如下图所示。
## 2.2 数据选择
### 2.2 数据选择
* SSLD蒸馏方案的一大特色就是无需使用图像的真值标签因此可以任意扩展数据集的大小考虑到计算资源的限制我们在这里仅基于ImageNet22k数据集对蒸馏任务的训练集进行扩充。在SSLD蒸馏任务中我们使用了`Top-k per class`的数据采样方案[3]。具体步骤如下。
@ -46,11 +48,11 @@ SSLD的流程图如下图所示。
* 将该数据集与ImageNet1k的训练集融合组成最终蒸馏模型所使用的数据集数据量为500万。
# 三、实验
## 三、实验
* PaddleClas的蒸馏策略为`大数据集训练+ImageNet1k蒸馏finetune`的策略。选择合适的教师模型首先在挑选得到的500万数据集上进行训练然后在ImageNet1k训练集上进行finetune最终得到蒸馏后的学生模型。
## 3.1 教师模型的选择
### 3.1 教师模型的选择
为了验证教师模型和学生模型的模型大小差异和教师模型的模型精度对蒸馏结果的影响,我们做了几组实验验证。训练策略统一为:`cosine_decay_warmuplr=1.3, epoch=120, bs=2048`,学生模型均为从头训练。
@ -70,7 +72,7 @@ SSLD的流程图如下图所示。
因此最终在蒸馏实验中对于ResNet系列学生模型我们使用`ResNeXt101_32x16d_wsl`作为教师模型对于MobileNet系列学生模型我们使用蒸馏得到的`ResNet50_vd`作为教师模型。
## 3.2 大数据蒸馏
### 3.2 大数据蒸馏
基于PaddleClas的蒸馏策略为`大数据集训练+imagenet1k finetune`的策略。
@ -87,7 +89,7 @@ SSLD的流程图如下图所示。
| ResNet101_vd | 360 | 7e-5 | 1024/32 | 0.4 | cosine_decay_warmup | 83.41% |
| Res2Net200_vd_26w_4s | 360 | 4e-5 | 1024/32 | 0.4 | cosine_decay_warmup | 84.82% |
## 3.3 ImageNet1k训练集finetune
### 3.3 ImageNet1k训练集finetune
对于在大数据集上训练的模型其学习到的特征可能与ImageNet1k数据特征有偏因此在这里使用ImageNet1k数据集对模型进行finetune。finetune的超参和finetune的精度收益如下。
@ -103,7 +105,7 @@ SSLD的流程图如下图所示。
| Res2Net200_vd_26w_4s | 360 | 4e-5 | 1024/32 | 0.004 | cosine_decay_warmup | 85.13% |
## 3.4 数据增广以及基于Fix策略的微调
### 3.4 数据增广以及基于Fix策略的微调
* 基于前文所述的实验结论,我们在训练的过程中加入自动增广(AutoAugment)[4]同时进一步减小了l2_decay(4e-5->2e-5)最终ResNet50_vd经过SSLD蒸馏策略在ImageNet1k上的精度可以达到82.99%相比之前不加数据增广的蒸馏策略再次增加了0.6%。
@ -111,9 +113,9 @@ SSLD的流程图如下图所示。
* 对于图像分类任务在测试的时候测试尺度为训练尺度的1.15倍左右时,往往在不需要重新训练模型的情况下,模型的精度指标就可以进一步提升[5]对于82.99%的ResNet50_vd在320x320的尺度下测试精度可达83.7%我们进一步使用Fix策略即在320x320的尺度下进行训练使用与预测时相同的数据预处理方法同时固定除FC层以外的所有参数最终在320x320的预测尺度下精度可以达到**84.0%**。
## 3.4 实验过程中的一些问题
### 3.4 实验过程中的一些问题
### 3.4.1 bn的计算方法
#### 3.4.1 bn的计算方法
* 在预测过程中batch norm的平均值与方差是通过加载预训练模型得到设其模式为test mode。在训练过程中batch norm是通过统计当前batch的信息设其模式为train mode与历史保存信息进行滑动平均计算得到在蒸馏任务中我们发现通过train mode即教师模型的bn实时变化的模式去指导学生模型比通过test mode蒸馏得到的学生模型性能更好一些下面是一组实验结果。因此我们在该蒸馏方案中均使用train mode去得到教师模型的soft label。
@ -122,7 +124,7 @@ SSLD的流程图如下图所示。
| ResNet50_vd | 82.35% | MobileNetV3_large_x1_0 | 76.00% |
| ResNet50_vd | 82.35% | MobileNetV3_large_x1_0 | 75.84% |
### 3.4.2 模型名字冲突问题的解决办法
#### 3.4.2 模型名字冲突问题的解决办法
* 在蒸馏过程中如果遇到命名冲突的问题如使用ResNet50_vd蒸馏ResNet34_vd此时直接训练会提示相同variable名称不匹配的问题此时可以通过给学生模型或者教师模型中的变量名添加名称的方式解决该问题如下所示。在训练之后也可以直接根据后缀区分学生模型和教师模型各自包含的参数。
```python
@ -143,9 +145,9 @@ cd model_final # enter model dir
for var in ./*_student; do cp "$var" "../student_model/${var%_student}"; done # batch copy and rename
```
# 四、蒸馏模型的应用
## 四、蒸馏模型的应用
## 4.1 使用方法
### 4.1 使用方法
* 中间层学习率调整。蒸馏得到的模型的中间层特征图更加精细化,因此将蒸馏模型预训练应用到其他任务中时,如果采取和之前相同的学习率,容易破坏中间层特征。而如果降低整体模型训练的学习率,则会带来训练收敛速度慢的问题。因此我们使用了中间层学习率调整的策略。具体地:
* 针对ResNet50_vd我们设置一个学习率倍数列表res block之前的3个conv2d卷积参数具有统一的学习率倍数4个res block的conv2d分别有一个学习率参数共需设置5个学习率倍数的超参。在实验中发现。用于迁移学习finetune分类模型时`[0.1,0.1,0.2,0.2,0.3]`的中间层学习率倍数设置在绝大多数的任务中都性能更好;而在目标检测任务中,`[0.05,0.05,0.05,0.1,0.15]`的中间层学习率倍数设置能够带来更大的精度收益。
@ -155,7 +157,7 @@ for var in ./*_student; do cp "$var" "../student_model/${var%_student}"; done #
* 适当的l2 decay。不同分类模型在训练的时候一般都会根据模型设置不同的l2 decay大模型为了防止过拟合往往会设置更大的l2 decay如ResNet50等模型一般设置为`1e-4`而如MobileNet系列模型在训练时往往都会设置为`1e-5~4e-5`防止模型过度欠拟合在蒸馏时亦是如此。在将蒸馏模型应用到目标检测任务中时我们发现也需要调节backbone甚至特定任务模型模型的l2 decay和预训练蒸馏时的l2 decay尽可能保持一致。以Faster RCNN MobiletNetV3 FPN为例我们发现仅修改该参数在COCO2017数据集上就可以带来最多0.5%左右的精度(mAP)提升默认Faster RCNN l2 decay为1e-4我们修改为1e-5~4e-5均有0.3%~0.5%的提升)。
## 4.2 迁移学习finetune
### 4.2 迁移学习finetune
* 为验证迁移学习的效果我们在10个小的数据集上验证其效果。在这里为了保证实验的可对比性我们均使用ImageNet1k数据集训练的标准预处理过程对于蒸馏模型我们也添加了蒸馏模型中间层学习率的搜索。
* 对于ResNet50_vdbaseline为Top1 Acc 79.12%的预训练模型基于grid search搜索得到的最佳精度对比实验则为基于该精度对预训练和中间层学习率进一步搜索得到的最佳精度。下面给出10个数据集上所有baseline和蒸馏模型的精度对比。
@ -176,7 +178,7 @@ for var in ./*_student; do cp "$var" "../student_model/${var%_student}"; done #
* 可以看出在上面10个数据集上结合适当的中间层学习率倍数设置蒸馏模型平均能够带来1%以上的精度提升。
## 4.3 目标检测
### 4.3 目标检测
我们基于两阶段目标检测Faster/Cascade RCNN模型验证蒸馏得到的预训练模型的效果。
@ -193,15 +195,15 @@ for var in ./*_student; do cp "$var" "../student_model/${var%_student}"; done #
在这里可以看出,对于未蒸馏模型,过度调整中间层学习率反而降低最终检测模型的性能指标。基于该蒸馏模型,我们也提供了领先的服务端实用目标检测方案,详细的配置与训练代码均已开源,可以参考[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/rcnn_enhance)。
# 五、SSLD实战
## 五、SSLD实战
本节将基于ImageNet-1K的数据集详细介绍SSLD蒸馏实验如果想快速体验此方法可以参考[**30分钟玩转PaddleClas**](../../tutorials/quick_start.md)中基于Flowers102的SSLD蒸馏实验。
## 5.1 参数配置
### 5.1 参数配置
实战部分提供了SSLD蒸馏的示例在`ppcls/modeling/architectures/distillation_models.py`中提供了`ResNeXt101_32x16d_wsl`蒸馏`ResNet50_vd`与`ResNet50_vd_ssld`蒸馏`MobileNetV3_large_x1_0`的示例,`configs/Distillation`里分别提供了二者的配置文件,用户可以在`tools/run.sh`里直接替换配置文件的路径即可使用。
### ResNeXt101_32x16d_wsl蒸馏ResNet50_vd
#### ResNeXt101_32x16d_wsl蒸馏ResNet50_vd
`ResNeXt101_32x16d_wsl`蒸馏`ResNet50_vd`的配置如下,其中`pretrained model`指定了`ResNeXt101_32x16d_wsl`(教师模型)的预训练模型的路径,该路径也可以同时指定教师模型与学生模型的预训练模型的路径,用户只需要同时传入二者预训练的路径即可(配置中的注释部分)。
@ -215,7 +217,7 @@ pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained/"
use_distillation: True
```
### ResNet50_vd_ssld蒸馏MobileNetV3_large_x1_0
#### ResNet50_vd_ssld蒸馏MobileNetV3_large_x1_0
类似于`ResNeXt101_32x16d_wsl`蒸馏`ResNet50_vd``ResNet50_vd_ssld`蒸馏`MobileNetV3_large_x1_0`的配置如下:
@ -229,7 +231,7 @@ pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained/"
use_distillation: True
```
## 5.2 启动命令
### 5.2 启动命令
当用户配置完训练环境后,类似于训练其他分类任务,只需要将`tools/run.sh`中的配置文件替换成为相应的蒸馏配置文件即可。
@ -251,7 +253,7 @@ python -m paddle.distributed.launch \
sh tools/run.sh
```
## 5.3 注意事项
### 5.3 注意事项
* 用户在使用SSLD蒸馏之前首先需要在目标数据集上训练一个教师模型该教师模型用于指导学生模型在该数据集上的训练。
@ -267,7 +269,7 @@ sh tools/run.sh
> 如果您觉得此文档对您有帮助欢迎star我们的项目[https://github.com/PaddlePaddle/PaddleClas](https://github.com/PaddlePaddle/PaddleClas)
# 参考文献
## 参考文献
[1] Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network[J]. arXiv preprint arXiv:1503.02531, 2015.