add distill det doc
parent
04c44974b1
commit
7af3065789
|
@ -279,6 +279,276 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
|
||||||
|
|
||||||
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
|
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
|
||||||
|
|
||||||
|
|
||||||
### 2.2 检测配置文件解析
|
### 2.2 检测配置文件解析
|
||||||
|
|
||||||
* coming soon!
|
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
|
||||||
|
- ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
|
||||||
|
- ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
|
||||||
|
- ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法
|
||||||
|
|
||||||
|
|
||||||
|
#### 2.2.1 模型结构
|
||||||
|
|
||||||
|
知识蒸馏任务中,模型结构配置如下所示:
|
||||||
|
|
||||||
|
```
|
||||||
|
Architecture:
|
||||||
|
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
|
||||||
|
algorithm: Distillation # 算法名称
|
||||||
|
Models: # 模型,包含子网络的配置信息
|
||||||
|
Student: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false # 是否需要固定参数
|
||||||
|
return_all_feats: false # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Teacher: # 另外一个子网络,这里给的是普通大模型蒸小模型的蒸馏示例,
|
||||||
|
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||||
|
freeze_params: true # Teacher模型是训练好的,不需要参与训练,freeze_params设置为True
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 18
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 256
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
如果是采用DML,即两个小模型互相学习的方法,上述配置文件里的Teacher网络结构需要设置为Student模型一样的配置,具体参考配置文件[ch_PP-OCRv2_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml)。
|
||||||
|
|
||||||
|
下面介绍[ch_PP-OCRv2_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml)的配置文件参数:
|
||||||
|
|
||||||
|
```
|
||||||
|
Architecture:
|
||||||
|
name: DistillationModel
|
||||||
|
algorithm: Distillation
|
||||||
|
model_type: det
|
||||||
|
Models:
|
||||||
|
Teacher: # CML蒸馏的Teacher模型配置
|
||||||
|
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||||
|
freeze_params: true # Teacher 不训练
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 18
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 256
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Student: # CML蒸馏的Student模型配置
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
Student2: # CML蒸馏的Student2模型配置
|
||||||
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
|
freeze_params: false
|
||||||
|
return_all_feats: false
|
||||||
|
model_type: det
|
||||||
|
algorithm: DB
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
disable_se: True
|
||||||
|
Neck:
|
||||||
|
name: DBFPN
|
||||||
|
out_channels: 96
|
||||||
|
Head:
|
||||||
|
name: DBHead
|
||||||
|
k: 50
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)。
|
||||||
|
|
||||||
|
最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student`与`Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
|
||||||
|
|
||||||
|
在蒸馏任务中,为了方便添加蒸馏损失函数,每个网络的输出保存为`dict`,其中包含子模块输出。每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out`,`value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Teacher": {
|
||||||
|
"backbone_out": tensor,
|
||||||
|
"neck_out": tensor,
|
||||||
|
"head_out": tensor,
|
||||||
|
},
|
||||||
|
"Student": {
|
||||||
|
"backbone_out": tensor,
|
||||||
|
"neck_out": tensor,
|
||||||
|
"head_out": tensor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2.1.2 损失函数
|
||||||
|
|
||||||
|
知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
|
||||||
|
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
|
||||||
|
- DistillationDilaDBLoss: # 基于蒸馏的DB损失函数,继承自标准的DBloss
|
||||||
|
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||||
|
model_name_pairs: # 对于蒸馏模型的预测结果,提取这两个子网络的输出,计算Teacher模型和Student模型输出的loss
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
key: maps # 取子网络输出dict中,该key对应的tensor
|
||||||
|
balance_loss: true # 以下几个参数为标准DBloss的配置参数
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
- DistillationDBLoss: # 基于蒸馏的DB损失函数,继承自标准的DBloss,用于计算Student和GT之间的loss
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student"] # 模型名字只有Student,表示计算Student和GT之间的loss
|
||||||
|
name: DBLoss
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
```
|
||||||
|
|
||||||
|
同理,检测ch_PP-OCRv2_det_cml.yml蒸馏损失函数配置如下所示。相比较于ch_PP-OCRv2_det_distill.yml的损失函数配置,cml蒸馏的损失函数配置做了3个改动:
|
||||||
|
```yaml
|
||||||
|
Loss:
|
||||||
|
name: CombinedLoss
|
||||||
|
loss_config_list:
|
||||||
|
- DistillationDilaDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
- ["Student2", "Teacher"] # 改动1,计算两个Student和Teacher的损失
|
||||||
|
key: maps
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
- DistillationDMLLoss: # 改动2,增加计算两个Student之间的损失
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Student2"]
|
||||||
|
maps_name: "thrink_maps"
|
||||||
|
weight: 1.0
|
||||||
|
# act: None
|
||||||
|
key: maps
|
||||||
|
- DistillationDBLoss:
|
||||||
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Student2"] # 改动3,计算两个Student和GT之间的损失
|
||||||
|
balance_loss: true
|
||||||
|
main_loss_type: DiceLoss
|
||||||
|
alpha: 5
|
||||||
|
beta: 10
|
||||||
|
ohem_ratio: 3
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
关于`DistillationDilaDBLoss`更加具体的实现可以参考: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185)。关于`DistillationDBLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148)。
|
||||||
|
|
||||||
|
|
||||||
|
#### 2.1.3 后处理
|
||||||
|
|
||||||
|
知识蒸馏任务中,检测蒸馏后处理配置如下所示。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
PostProcess:
|
||||||
|
name: DistillationDBPostProcess # DB检测蒸馏任务的CTC解码后处理,继承自标准的DBPostProcess类
|
||||||
|
model_name: ["Student", "Student2", "Teacher"] # 对于蒸馏模型的预测结果,提取多个子网络的输出,进行解码,不需要后处理的网络可以不在model_name中设置
|
||||||
|
thresh: 0.3
|
||||||
|
box_thresh: 0.6
|
||||||
|
max_candidates: 1000
|
||||||
|
unclip_ratio: 1.5
|
||||||
|
```
|
||||||
|
|
||||||
|
以上述配置为例,最终会同时计算`Student`,`Student2`和`Teacher` 3个子网络的输出做后处理计算。同时,由于有多个输入,后处理返回的输出也有多个,
|
||||||
|
|
||||||
|
关于`DistillationDBPostProcess`更加具体的实现可以参考: [db_postprocess.py](../../ppocr/postprocess/db_postprocess.py#L195)
|
||||||
|
|
||||||
|
|
||||||
|
#### 2.1.4 蒸馏指标计算
|
||||||
|
|
||||||
|
知识蒸馏任务中,检测蒸馏指标计算配置如下所示。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
Metric:
|
||||||
|
name: DistillationMetric
|
||||||
|
base_metric_name: DetMetric
|
||||||
|
main_indicator: hmean
|
||||||
|
key: "Student"
|
||||||
|
```
|
||||||
|
|
||||||
|
由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,`key`字段设置为`Student`则表示只计算`Student`网络的精度。
|
||||||
|
|
||||||
|
|
||||||
|
#### 2.1.5 检测蒸馏模型finetune
|
||||||
|
|
||||||
|
检测蒸馏有三种方式:
|
||||||
|
- 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
|
||||||
|
- 采用ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
|
||||||
|
- 采用ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上大约有1.7%的精度提升。
|
||||||
|
|
||||||
|
在具体finetune时,需要在网络结构的`pretrained`参数中设置要加载的预训练模型。
|
||||||
|
|
||||||
|
在精度提升方面,cml的精度>dml的精度>distill蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。
|
||||||
|
|
||||||
|
|
||||||
|
另外,由于PaddleOCR提供的蒸馏预训练模型包含了多个模型的参数,如果您希望提取Student模型的参数,可以参考如下代码:
|
||||||
|
```
|
||||||
|
# 下载蒸馏训练模型的参数
|
||||||
|
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
import paddle
|
||||||
|
# 加载预训练模型
|
||||||
|
all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams")
|
||||||
|
# 查看权重参数的keys
|
||||||
|
print(all_params.keys())
|
||||||
|
# 学生模型的权重提取
|
||||||
|
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
|
||||||
|
# 查看学生模型权重参数的keys
|
||||||
|
print(s_params.keys())
|
||||||
|
# 保存
|
||||||
|
paddle.save(s_params, "ch_PP-OCRv2_det_distill_train/student.pdparams")
|
||||||
|
```
|
||||||
|
|
||||||
|
最终`Student`模型的参数将会保存在`ch_PP-OCRv2_det_distill_train/student.pdparams`中,用于模型的fine-tune。
|
||||||
|
|
Loading…
Reference in New Issue