upgrade kd doc
parent
5a08a408a8
commit
4c0cf75350
|
@ -71,7 +71,7 @@ PostProcess:
|
|||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
ignore_space: True
|
||||
ignore_space: False
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
|
|
|
@ -145,7 +145,7 @@ Metric:
|
|||
base_metric_name: RecMetric
|
||||
main_indicator: acc
|
||||
key: "Student"
|
||||
ignore_space: True
|
||||
ignore_space: False
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
|
|
|
@ -60,7 +60,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
|
|||
<a name="21"></a>
|
||||
### 2.1 识别配置文件解析
|
||||
|
||||
配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。
|
||||
配置文件在[ch_PP-OCRv3_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)。
|
||||
|
||||
<a name="211"></a>
|
||||
#### 2.1.1 模型结构
|
||||
|
@ -69,7 +69,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
|
|||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
|
||||
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型相同
|
||||
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
|
||||
algorithm: Distillation # 算法名称
|
||||
Models: # 模型,包含子网络的配置信息
|
||||
|
@ -78,37 +78,55 @@ Architecture:
|
|||
freeze_params: false # 是否需要固定参数
|
||||
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
|
||||
model_type: *model_type # 模型类别
|
||||
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
|
||||
algorithm: SVTR # 子网络的算法名称,该子网络其余参数均为构造参数,与普通的模型训练配置一致
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
last_conv_stride: [1, 2]
|
||||
last_pool_type: avg
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
|
||||
pretrained: # 下面的组网参数同上
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 64
|
||||
depth: 2
|
||||
hidden_dims: 120
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- SARHead:
|
||||
enc_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
algorithm: SVTR
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
last_conv_stride: [1, 2]
|
||||
last_pool_type: avg
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 64
|
||||
depth: 2
|
||||
hidden_dims: 120
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- SARHead:
|
||||
enc_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
```
|
||||
|
||||
当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
|
||||
|
@ -124,55 +142,82 @@ Architecture:
|
|||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
algorithm: SVTR
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
last_conv_stride: [1, 2]
|
||||
last_pool_type: avg
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 64
|
||||
depth: 2
|
||||
hidden_dims: 120
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- SARHead:
|
||||
enc_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
algorithm: SVTR
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
last_conv_stride: [1, 2]
|
||||
last_pool_type: avg
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 64
|
||||
depth: 2
|
||||
hidden_dims: 120
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- SARHead:
|
||||
enc_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
Student2:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
algorithm: SVTR
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
last_conv_stride: [1, 2]
|
||||
last_pool_type: avg
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 64
|
||||
depth: 2
|
||||
hidden_dims: 120
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- SARHead:
|
||||
enc_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
```
|
||||
|
||||
最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`。
|
||||
|
@ -205,34 +250,56 @@ Architecture:
|
|||
|
||||
```yaml
|
||||
Loss:
|
||||
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
|
||||
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
|
||||
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
|
||||
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
|
||||
weight: 1.0 # 权重
|
||||
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
|
||||
use_log: true # 对输入计算log,如果函数已经
|
||||
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
multi_head: True # 是否为多头结构,我们
|
||||
dis_head: ctc # 蒸馏
|
||||
name: dml_ctc # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
|
||||
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
|
||||
weight: 1.0 # 权重
|
||||
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
|
||||
use_log: true # 对输入计算log,如果函数已经
|
||||
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
multi_head: True # 是否为多头结构,我们
|
||||
dis_head: sar # 蒸馏
|
||||
name: dml_sar # 蒸馏loss的前缀名称,避免不同loss之间的命名冲突
|
||||
- DistillationDistanceLoss: # 蒸馏的距离损失函数
|
||||
weight: 1.0 # 权重
|
||||
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
|
||||
model_name_pairs: # 用于计算distance loss的子网络名称对
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
|
||||
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationSARLoss: # 基于蒸馏的SAR损失函数,继承自标准的SARLoss
|
||||
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
multi_head: True # 是否为多头结构,为true时,取出其中的
|
||||
```
|
||||
|
||||
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
|
||||
|
||||
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
|
||||
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
|
||||
- `Student`和`Teacher`最终输出(`head_out`)的CTC分支与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
|
||||
- `Student`和`Teacher`最终输出(`head_out`)的SAR分支与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
|
||||
- `Student`和`Teacher`最终输出(`head_out`)的CTC分支之间的DML loss,权重为1。
|
||||
- `Student`和`Teacher`最终输出(`head_out`)SARC分支之间的DML loss,权重为1。
|
||||
- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
|
||||
|
||||
|
||||
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。
|
||||
|
||||
<a name="213"></a>
|
||||
|
@ -245,6 +312,7 @@ PostProcess:
|
|||
name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
multi_head: True # 多头结构时,会取出其中的CTC分支进行计算
|
||||
```
|
||||
|
||||
以上述配置为例,最终会同时计算`Student`和`Teahcer` 2个子网络的CTC解码输出,返回一个`dict`,`key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
|
||||
|
@ -262,6 +330,7 @@ Metric:
|
|||
base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
|
||||
main_indicator: acc # 指标的名称
|
||||
key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
|
||||
ignore_space: False # 评估时是否忽略空格的影响
|
||||
```
|
||||
|
||||
以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
|
||||
|
@ -273,15 +342,15 @@ Metric:
|
|||
|
||||
对蒸馏得到的识别蒸馏进行微调有2种方式。
|
||||
|
||||
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
|
||||
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv3_rec_distillation.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
|
||||
|
||||
(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
|
||||
|
||||
* 首先下载预训练模型并解压。
|
||||
```shell
|
||||
# 下面预训练模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
|
||||
tar -xf ch_PP-OCRv2_rec_train.tar
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
|
||||
tar -xf ch_PP-OCRv3_rec_train.tar
|
||||
```
|
||||
|
||||
* 然后使用python,对其中的学生模型参数进行提取
|
||||
|
@ -289,7 +358,7 @@ tar -xf ch_PP-OCRv2_rec_train.tar
|
|||
```python
|
||||
import paddle
|
||||
# 加载预训练模型
|
||||
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
|
||||
all_params = paddle.load("ch_PP-OCRv3_rec_train/best_accuracy.pdparams")
|
||||
# 查看权重参数的keys
|
||||
print(all_params.keys())
|
||||
# 学生模型的权重提取
|
||||
|
@ -297,10 +366,10 @@ s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Stu
|
|||
# 查看学生模型权重参数的keys
|
||||
print(s_params.keys())
|
||||
# 保存
|
||||
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
|
||||
paddle.save(s_params, "ch_PP-OCRv3_rec_train/student.pdparams")
|
||||
```
|
||||
|
||||
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
|
||||
转化完成之后,使用[ch_PP-OCRv3_rec.yml](../../configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
|
||||
|
||||
<a name="22"></a>
|
||||
### 2.2 检测配置文件解析
|
||||
|
|
|
@ -49,18 +49,23 @@ def get_check_global_params(mode):
|
|||
return check_params
|
||||
|
||||
|
||||
def _check_image_file(path):
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
|
||||
return any([path.lower().endswith(e) for e in img_end])
|
||||
|
||||
|
||||
def get_image_file_list(img_file):
|
||||
imgs_lists = []
|
||||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
|
||||
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'}
|
||||
if os.path.isfile(img_file) and _check_image_file(file_path):
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
for single_file in os.listdir(img_file):
|
||||
file_path = os.path.join(img_file, single_file)
|
||||
if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
|
||||
if os.path.isfile(file_path) and _check_image_file(file_path):
|
||||
imgs_lists.append(file_path)
|
||||
if len(imgs_lists) == 0:
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
|
Loading…
Reference in New Issue