Updated Recognition Competition Model Link ()

* Updated Recognition Competition Model Link

* Updated Recognition Competition Model Link

* Updated Recognition Competition Model Link
pull/13263/head
topduke 2024-07-04 13:48:48 +08:00 committed by GitHub
parent 8f64b2ed4d
commit 661f41d484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 495 additions and 3 deletions

View File

@ -0,0 +1,134 @@
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_repsvtr_ch
save_epoch_step: 10
eval_batch_step: [0, 1000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_repsvtr.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1.e-8
weight_decay: 0.025
no_weight_decay_name: norm
one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 0.001 # 8gpus 192bs
warmup_epoch: 5
Architecture:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: RepSVTR
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
num_decoder_layers: 2
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 192
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4

View File

@ -0,0 +1,143 @@
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_svtrv2_ch
save_epoch_step: 10
eval_batch_step: [0, 1000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_svrtv2.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1.e-8
weight_decay: 0.05
no_weight_decay_name: norm
one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 0.001 # 8gpus 192bs
warmup_epoch: 5
Architecture:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: SVTRv2
use_pos_embed: False
dims: [128, 256, 384]
depths: [6, 6, 6]
num_heads: [4, 8, 12]
mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
local_k: [[5, 5], [5, 5], [-1, -1]]
sub_k: [[2, 1], [2, 1], [-1, -1]]
last_stage: False
use_pool: True
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
num_decoder_layers: 2
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 192
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4

View File

@ -0,0 +1,208 @@
Global:
debug: false
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_svtrv2_ch_distill_lr00002/
save_epoch_step: 5
eval_batch_step:
- 0
- 1000
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_svtrv2_ch_distill.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 1.e-8
weight_decay: 0.05
no_weight_decay_name: norm pos_embed patch_embed downsample
one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 0.0002 # 8gpus 192bs
warmup_epoch: 5
Architecture:
model_type: rec
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained: ./output/rec_svtrv2_ch/best_accuracy
freeze_params: true
return_all_feats: true
model_type: rec
algorithm: SVTR_LCNet
Transform: null
Backbone:
name: SVTRv2
use_pos_embed: False
dims: [128, 256, 384]
depths: [6, 6, 6]
num_heads: [4, 8, 12]
mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
local_k: [[5, 5], [5, 5], [-1, -1]]
sub_k: [[2, 1], [2, 1], [-1, -1]]
last_stage: False
use_pool: True
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
num_decoder_layers: 2
max_text_length: *max_text_length
Student:
pretrained: ./output/rec_repsvtr_ch/best_accuracy
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: SVTR_LCNet
Transform: null
Backbone:
name: RepSVTR
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
num_decoder_layers: 2
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDKDLoss:
weight: 0.1
model_name_pairs:
- - Student
- Teacher
key: head_out
multi_head: true
alpha: 1.0
beta: 2.0
dis_head: gtc
name: dkd
- DistillationCTCLoss:
weight: 1.0
model_name_list:
- Student
key: head_out
multi_head: true
- DistillationNRTRLoss:
weight: 1.0
smoothing: false
model_name_list:
- Student
key: head_out
multi_head: true
- DistillCTCLogits:
weight: 1.0
reduction: mean
model_name_pairs:
- - Student
- Teacher
key: head_out
PostProcess:
name: DistillationCTCLabelDecode
model_name:
- Student
key: head_out
multi_head: true
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: Student
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 192
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4

View File

@ -19,8 +19,15 @@
### SVTRv2算法简介
<a name="1"></a>
[PaddleOCR 算法模型挑战赛 - 赛题一OCR 端到端识别任务](https://aistudio.baidu.com/competition/detail/1131/0/introduction)排行榜第一算法。主要思路1、检测和识别模型的Backbone升级为RepSVTR2、识别教师模型升级为SVTRv2可识别长文本。
🔥 该算法由来自复旦大学视觉与学习实验室([FVL](https://fvl.fudan.edu.cn))的[OpenOCR](https://github.com/Topdu/OpenOCR)团队研发,其在[PaddleOCR算法模型挑战赛 - 赛题一OCR端到端识别任务](https://aistudio.baidu.com/competition/detail/1131/0/introduction)中荣获一等奖B榜端到端识别精度相比PP-OCRv4提升2.5%,推理速度持平。主要思路1、检测和识别模型的Backbone升级为RepSVTR2、识别教师模型升级为SVTRv2可识别长文本。
|模型|配置文件|端到端|下载链接|
| --- | --- | --- | --- |
|PP-OCRv4| |A榜 62.77% <br> B榜 62.51%| [Model List](../../doc/doc_ch/models_list.md) |
|SVTRv2(Rec Sever)|[configs/rec/SVTRv2/rec_svtrv2_ch.yml](../../configs/rec/SVTRv2/rec_svtrv2_ch.yml)|A榜 68.81% (使用PP-OCRv4检测模型)| [训练模型](https://paddleocr.bj.bcebos.com/openatom/openatom_rec_svtrv2_ch_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/openatom/openatom_rec_svtrv2_ch_infer.tar) |
|RepSVTR(Mobile)|[识别](../../configs/rec/SVTRv2/rec_repsvtr_ch.yml) <br> [识别蒸馏](../../configs/rec/SVTRv2/rec_svtrv2_ch_distillation.yml) <br> [检测](../../configs/det/det_repsvtr_db.yml)|B榜 65.07%| 识别: [训练模型](https://paddleocr.bj.bcebos.com/openatom/openatom_rec_repsvtr_ch_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/openatom/openatom_rec_repsvtr_ch_infer.tar) <br> 识别蒸馏: [训练模型](https://paddleocr.bj.bcebos.com/openatom/openatom_rec_svtrv2_distill_ch_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/openatom/openatom_rec_svtrv2_distill_ch_infer.tar) <br> 检测: [训练模型](https://paddleocr.bj.bcebos.com/openatom/openatom_det_repsvtr_ch_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/openatom/openatom_det_repsvtr_ch_infer.tar) |
🚀 快速使用参考PP-OCR推理[说明文档](../../doc/doc_ch/inference_ppocr.md)将检测和识别模型替换为上表中对应的RepSVTR或SVTRv2推理模型即可使用。
<a name="2"></a>
## 2. 环境配置
@ -115,7 +122,7 @@ Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持SVTRv2
准备好推理模型后,参考[cpp infer](../../deploy/cpp_infer/)教程进行操作即可。
<a name="4-3"></a>
### 4.3 Serving服务化部署
@ -125,7 +132,7 @@ Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
- Paddle2ONNX推理准备好推理模型后参考[paddle2onnx](../../deploy/paddle2onnx/)教程操作。
<a name="5"></a>
## 5. FAQ