openocr compti code (#12033)
* openocr compti code * update config and repsvtr * svtrv2 docpull/11999/head^2
parent
3e5934de62
commit
38c0c9ee77
ppocr
losses
modeling
heads
|
@ -0,0 +1,169 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: &epoch_num 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 100
|
||||
save_model_dir: ./output/det_repsvtr_db
|
||||
save_epoch_step: 10
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 1000
|
||||
cal_metric_during_train: false
|
||||
checkpoints:
|
||||
pretrained_model:
|
||||
save_inference_dir: null
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./checkpoints/det_db/predicts_db.txt
|
||||
distributed: true
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: RepSVTR_det
|
||||
Neck:
|
||||
name: RSEFPN
|
||||
out_channels: 96
|
||||
shortcut: True
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 #(8*8c)
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 5.0e-05
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- CopyPaste: null
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- type: Fliplr
|
||||
args:
|
||||
p: 0.5
|
||||
- type: Affine
|
||||
args:
|
||||
rotate:
|
||||
- -10
|
||||
- 10
|
||||
- type: Resize
|
||||
args:
|
||||
size:
|
||||
- 0.5
|
||||
- 3
|
||||
- EastRandomCropData:
|
||||
size:
|
||||
- 640
|
||||
- 640
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
total_epoch: *epoch_num
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
total_epoch: *epoch_num
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- threshold_map
|
||||
- threshold_mask
|
||||
- shrink_map
|
||||
- shrink_mask
|
||||
loader:
|
||||
shuffle: true
|
||||
drop_last: false
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- DetLabelEncode: null
|
||||
- DetResizeForTest:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean:
|
||||
- 0.485
|
||||
- 0.456
|
||||
- 0.406
|
||||
std:
|
||||
- 0.229
|
||||
- 0.224
|
||||
- 0.225
|
||||
order: hwc
|
||||
- ToCHWImage: null
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- shape
|
||||
- polys
|
||||
- ignore_tags
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 1
|
||||
num_workers: 2
|
||||
profiler_options: null
|
|
@ -0,0 +1,134 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_repsvtr_gtc
|
||||
save_epoch_step: 10
|
||||
eval_batch_step: [0, 1000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
max_text_length: &max_text_length 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_repsvtr.txt
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
epsilon: 1.e-8
|
||||
weight_decay: 0.025
|
||||
no_weight_decay_name: norm
|
||||
one_dim_param_no_weight_decay: True
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 # 8gpus 192bs
|
||||
warmup_epoch: 5
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SVTR_HGNet
|
||||
Transform:
|
||||
Backbone:
|
||||
name: RepSVTR
|
||||
Head:
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 256
|
||||
depth: 2
|
||||
hidden_dims: 256
|
||||
kernel_size: [1, 3]
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- NRTRHead:
|
||||
nrtr_dim: 384
|
||||
max_text_length: *max_text_length
|
||||
num_decoder_layers: 2
|
||||
|
||||
Loss:
|
||||
name: MultiLoss
|
||||
loss_config_list:
|
||||
- CTCLoss:
|
||||
- NRTRLoss:
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: MultiScaleDataSet
|
||||
ds_width: false
|
||||
data_dir: ./train_data/
|
||||
ext_op_transform_idx: 1
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
sampler:
|
||||
name: MultiScaleSampler
|
||||
scales: [[320, 32], [320, 48], [320, 64]]
|
||||
first_bs: &bs 192
|
||||
fix_bs: false
|
||||
divided_factor: [8, 16] # w, h
|
||||
is_training: True
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: *bs
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 4
|
|
@ -0,0 +1,145 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_svtrv2_gtc
|
||||
save_epoch_step: 10
|
||||
eval_batch_step: [0, 1000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
max_text_length: &max_text_length 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_svrtv2.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
epsilon: 1.e-8
|
||||
weight_decay: 0.05
|
||||
no_weight_decay_name: norm
|
||||
one_dim_param_no_weight_decay: True
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001 # 8gpus 192bs
|
||||
warmup_epoch: 5
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SVTR_HGNet
|
||||
Transform:
|
||||
Backbone:
|
||||
name: SVTRv2
|
||||
use_pos_embed: False
|
||||
dims: [128, 256, 384]
|
||||
depths: [6, 6, 6]
|
||||
num_heads: [4, 8, 12]
|
||||
mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
|
||||
local_k: [[5, 5], [5, 5], [-1, -1]]
|
||||
sub_k: [[2, 1], [2, 1], [-1, -1]]
|
||||
last_stage: False
|
||||
use_pool: True
|
||||
Head:
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 256
|
||||
depth: 2
|
||||
hidden_dims: 256
|
||||
kernel_size: [1, 3]
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- NRTRHead:
|
||||
nrtr_dim: 384
|
||||
max_text_length: *max_text_length
|
||||
num_decoder_layers: 2
|
||||
|
||||
Loss:
|
||||
name: MultiLoss
|
||||
loss_config_list:
|
||||
- CTCLoss:
|
||||
- NRTRLoss:
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: MultiScaleDataSet
|
||||
ds_width: false
|
||||
data_dir: ./train_data/
|
||||
ext_op_transform_idx: 1
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
sampler:
|
||||
name: MultiScaleSampler
|
||||
scales: [[320, 32], [320, 48], [320, 64]]
|
||||
first_bs: &bs 192
|
||||
fix_bs: false
|
||||
divided_factor: [8, 16] # w, h
|
||||
is_training: True
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: *bs
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 4
|
|
@ -0,0 +1,208 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_svtrv2_gtc_distill_lr00002/
|
||||
save_epoch_step: 5
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 1000
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
max_text_length: &max_text_length 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_svtrv2_gtc_distill.txt
|
||||
Optimizer:
|
||||
name: AdamW
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
epsilon: 1.e-8
|
||||
weight_decay: 0.05
|
||||
no_weight_decay_name: norm pos_embed patch_embed downsample
|
||||
one_dim_param_no_weight_decay: True
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0002 # 8gpus 192bs
|
||||
warmup_epoch: 5
|
||||
Architecture:
|
||||
model_type: rec
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained: ./output/rec_svtrv2_gtc/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: SVTR_LCNet
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: SVTRv2
|
||||
use_pos_embed: False
|
||||
dims: [128, 256, 384]
|
||||
depths: [6, 6, 6]
|
||||
num_heads: [4, 8, 12]
|
||||
mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
|
||||
local_k: [[5, 5], [5, 5], [-1, -1]]
|
||||
sub_k: [[2, 1], [2, 1], [-1, -1]]
|
||||
last_stage: False
|
||||
use_pool: True
|
||||
Head:
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 256
|
||||
depth: 2
|
||||
hidden_dims: 256
|
||||
kernel_size: [1, 3]
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- NRTRHead:
|
||||
nrtr_dim: 384
|
||||
num_decoder_layers: 2
|
||||
max_text_length: *max_text_length
|
||||
Student:
|
||||
pretrained: ./output/rec_repsvtr_gtc/best_accuracy
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: SVTR_LCNet
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: repvit_svtr
|
||||
Head:
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 256
|
||||
depth: 2
|
||||
hidden_dims: 256
|
||||
kernel_size: [1, 3]
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- NRTRHead:
|
||||
nrtr_dim: 384
|
||||
num_decoder_layers: 2
|
||||
max_text_length: *max_text_length
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDKDLoss:
|
||||
weight: 0.1
|
||||
model_name_pairs:
|
||||
- - Student
|
||||
- Teacher
|
||||
key: head_out
|
||||
multi_head: true
|
||||
alpha: 1.0
|
||||
beta: 2.0
|
||||
dis_head: gtc
|
||||
name: dkd
|
||||
- DistillationCTCLoss:
|
||||
weight: 1.0
|
||||
model_name_list:
|
||||
- Student
|
||||
key: head_out
|
||||
multi_head: true
|
||||
- DistillationNRTRLoss:
|
||||
weight: 1.0
|
||||
smoothing: false
|
||||
model_name_list:
|
||||
- Student
|
||||
key: head_out
|
||||
multi_head: true
|
||||
- DistillCTCLogits:
|
||||
weight: 1.0
|
||||
reduction: mean
|
||||
model_name_pairs:
|
||||
- - Student
|
||||
- Teacher
|
||||
key: head_out
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name:
|
||||
- Student
|
||||
key: head_out
|
||||
multi_head: true
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: RecMetric
|
||||
main_indicator: acc
|
||||
key: Student
|
||||
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: MultiScaleDataSet
|
||||
ds_width: false
|
||||
data_dir: ./train_data/
|
||||
ext_op_transform_idx: 1
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
sampler:
|
||||
name: MultiScaleSampler
|
||||
scales: [[320, 32], [320, 48], [320, 64]]
|
||||
first_bs: &bs 192
|
||||
fix_bs: false
|
||||
divided_factor: [8, 16] # w, h
|
||||
is_training: True
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: *bs
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 4
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
论文信息:
|
||||
> [SVTR: Scene Text Recognition with a Single Visual Model](https://arxiv.org/abs/2205.00159)
|
||||
> Yongkun Du and Zhineng Chen and Caiyan Jia Xiaoting Yin and Tianlun Zheng and Chenxia Li and Yuning Du and Yu-Gang Jiang
|
||||
> Yongkun Du and Zhineng Chen and Caiyan Jia and Xiaoting Yin and Tianlun Zheng and Chenxia Li and Yuning Du and Yu-Gang Jiang
|
||||
> IJCAI, 2022
|
||||
|
||||
场景文本识别旨在将自然图像中的文本转录为数字字符序列,从而传达对场景理解至关重要的高级语义。这项任务由于文本变形、字体、遮挡、杂乱背景等方面的变化具有一定的挑战性。先前的方法为提高识别精度做出了许多工作。然而文本识别器除了准确度外,还因为实际需求需要考虑推理速度等因素。
|
||||
|
@ -102,7 +102,7 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_none_ctc_en_train/rec_svtr_tiny_6l
|
|||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型和配置文件下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
|
||||
首先将训练得到best模型,转换成inference model。下面以`SVTR-T`在英文数据集训练的模型为例([模型和配置文件下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
# 场景文本识别算法-SVTRv2
|
||||
|
||||
- [1. 算法简介](#1)
|
||||
- [2. 环境配置](#2)
|
||||
- [3. 模型训练、评估、预测](#3)
|
||||
- [3.1 训练](#3-1)
|
||||
- [3.2 评估](#3-2)
|
||||
- [3.3 预测](#3-3)
|
||||
- [4. 推理部署](#4)
|
||||
- [4.1 Python推理](#4-1)
|
||||
- [4.2 C++推理](#4-2)
|
||||
- [4.3 Serving服务化部署](#4-3)
|
||||
- [4.4 更多推理部署](#4-4)
|
||||
- [5. FAQ](#5)
|
||||
|
||||
<a name="1"></a>
|
||||
## 1. 算法简介
|
||||
|
||||
### SVTRv2算法简介
|
||||
|
||||
<a name="1"></a>
|
||||
[PaddleOCR 算法模型挑战赛 - 赛题一:OCR 端到端识别任务](https://aistudio.baidu.com/competition/detail/1131/0/introduction)排行榜第一算法。主要思路:1、检测和识别模型的Backbone升级为RepSVTR;2、识别教师模型升级为SVTRv2,可识别长文本。
|
||||
|
||||
|
||||
<a name="2"></a>
|
||||
## 2. 环境配置
|
||||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
|
||||
|
||||
|
||||
<a name="3"></a>
|
||||
## 3. 模型训练、评估、预测
|
||||
|
||||
<a name="3-1"></a>
|
||||
### 3.1 模型训练
|
||||
|
||||
|
||||
训练命令:
|
||||
```shell
|
||||
#单卡训练(训练周期长,不建议)
|
||||
python3 tools/train.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml
|
||||
|
||||
#多卡训练,通过--gpus参数指定卡号
|
||||
# Rec 学生模型
|
||||
python -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml
|
||||
# Rec 教师模型
|
||||
python -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/SVTRv2/rec_svtrv2_gtc.yml
|
||||
# Rec 蒸馏训练
|
||||
python -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/SVTRv2/rec_svtrv2_gtc_distill.yml
|
||||
```
|
||||
|
||||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml -o Global.pretrained_model=output/rec_repsvtr_gtc/best_accuracy
|
||||
```
|
||||
|
||||
<a name="3-3"></a>
|
||||
### 3.3 预测
|
||||
|
||||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c tools/eval.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml -o Global.pretrained_model=output/rec_repsvtr_gtc/best_accuracy Global.infer_img='./doc/imgs_words_en/word_10.png'
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
||||
<a name="4"></a>
|
||||
## 4. 推理部署
|
||||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model,以RepSVTR为例,可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/SVTRv2/rec_repsvtr_gtc.yml -o Global.pretrained_model=output/rec_repsvtr_gtc/best_accuracy Global.save_inference_dir=./inference/rec_repsvtr_infer
|
||||
```
|
||||
|
||||
**注意:**
|
||||
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
./inference/rec_repsvtr_infer/
|
||||
├── inference.pdiparams # 识别inference模型的参数文件
|
||||
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 识别inference模型的program文件
|
||||
```
|
||||
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
||||
```shell
|
||||
python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_repsvtr_infer/'
|
||||
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
|
||||
```
|
||||

|
||||
|
||||
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
|
||||
结果如下:
|
||||
```shell
|
||||
Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
|
||||
```
|
||||
|
||||
**注意**:
|
||||
|
||||
- 如果您调整了训练时的输入分辨率,需要通过参数`rec_image_shape`设置为您需要的识别图像形状。
|
||||
- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
|
||||
- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中SVTR的预处理为您的预处理方法。
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理部署
|
||||
|
||||
由于C++预处理后处理还未支持SVTRv2
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
||||
## 引用
|
||||
|
||||
```bibtex
|
||||
@article{Du2022SVTR,
|
||||
title = {SVTR: Scene Text Recognition with a Single Visual Model},
|
||||
author = {Du, Yongkun and Chen, Zhineng and Jia, Caiyan and Yin, Xiaoting and Zheng, Tianlun and Li, Chenxia and Du, Yuning and Jiang, Yu-Gang},
|
||||
booktitle = {IJCAI},
|
||||
year = {2022},
|
||||
url = {https://arxiv.org/abs/2205.00159}
|
||||
}
|
||||
```
|
|
@ -55,7 +55,7 @@ class MultiLoss(nn.Layer):
|
|||
)
|
||||
elif name == "NRTRLoss":
|
||||
loss = (
|
||||
loss_func(predicts["nrtr"], batch[:1] + batch[2:])["loss"]
|
||||
loss_func(predicts["gtc"], batch[:1] + batch[2:])["loss"]
|
||||
* self.weight_2
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -25,6 +25,7 @@ def build_backbone(config, model_type):
|
|||
from .rec_lcnetv3 import PPLCNetV3
|
||||
from .rec_hgnet import PPHGNet_small
|
||||
from .rec_vit import ViT
|
||||
from .rec_repvit import RepSVTR_det
|
||||
|
||||
support_dict = [
|
||||
"MobileNetV3",
|
||||
|
@ -34,6 +35,7 @@ def build_backbone(config, model_type):
|
|||
"PPLCNet",
|
||||
"PPLCNetV3",
|
||||
"PPHGNet_small",
|
||||
"RepSVTR_det",
|
||||
]
|
||||
if model_type == "table":
|
||||
from .table_master_resnet import TableResNetExtra
|
||||
|
@ -59,6 +61,8 @@ def build_backbone(config, model_type):
|
|||
from .rec_lcnetv3 import PPLCNetV3
|
||||
from .rec_hgnet import PPHGNet_small
|
||||
from .rec_vit_parseq import ViTParseQ
|
||||
from .rec_repvit import RepSVTR
|
||||
from .rec_svtrv2 import SVTRv2
|
||||
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance",
|
||||
|
@ -81,6 +85,8 @@ def build_backbone(config, model_type):
|
|||
"PPHGNet_small",
|
||||
"ViTParseQ",
|
||||
"ViT",
|
||||
"RepSVTR",
|
||||
"SVTRv2",
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,363 @@
|
|||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This code is refer from:
|
||||
https://github.com/THU-MIG/RepViT
|
||||
"""
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=0.02)
|
||||
normal_ = Normal
|
||||
zeros_ = Constant(value=0.0)
|
||||
ones_ = Constant(value=1.0)
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_value:
|
||||
:return:
|
||||
"""
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
# from timm.models.layers import SqueezeExcite
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
|
||||
min_value = min_value or divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < round_limit * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
"""SE Module as defined in original SE-Nets with a few additions
|
||||
Additions include:
|
||||
* divisor can be specified to keep channels % div == 0 (default: 8)
|
||||
* reduction channels can be specified directly by arg (if rd_channels is set)
|
||||
* reduction channels can be specified by float rd_ratio (default: 1/16)
|
||||
* global max pooling can be added to the squeeze aggregation
|
||||
* customizable activation, normalization, and gate layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
rd_ratio=1.0 / 16,
|
||||
rd_channels=None,
|
||||
rd_divisor=8,
|
||||
act_layer=nn.ReLU,
|
||||
):
|
||||
super(SEModule, self).__init__()
|
||||
if not rd_channels:
|
||||
rd_channels = make_divisible(
|
||||
channels * rd_ratio, rd_divisor, round_limit=0.0
|
||||
)
|
||||
self.fc1 = nn.Conv2D(channels, rd_channels, kernel_size=1, bias_attr=True)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2D(rd_channels, channels, kernel_size=1, bias_attr=True)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.act(x_se)
|
||||
x_se = self.fc2(x_se)
|
||||
return x * nn.functional.sigmoid(x_se)
|
||||
|
||||
|
||||
class Conv2D_BN(nn.Sequential):
|
||||
def __init__(
|
||||
self,
|
||||
a,
|
||||
b,
|
||||
ks=1,
|
||||
stride=1,
|
||||
pad=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bn_weight_init=1,
|
||||
resolution=-10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.add_sublayer(
|
||||
"c", nn.Conv2D(a, b, ks, stride, pad, dilation, groups, bias_attr=False)
|
||||
)
|
||||
self.add_sublayer("bn", nn.BatchNorm2D(b))
|
||||
if bn_weight_init == 1:
|
||||
ones_(self.bn.weight)
|
||||
else:
|
||||
zeros_(self.bn.weight)
|
||||
zeros_(self.bn.bias)
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self.c, self.bn
|
||||
w = bn.weight / (bn._variance + bn._epsilon) ** 0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn._mean * bn.weight / (bn._variance + bn._epsilon) ** 0.5
|
||||
m = nn.Conv2D(
|
||||
w.shape[1] * self.c._groups,
|
||||
w.shape[0],
|
||||
w.shape[2:],
|
||||
stride=self.c._stride,
|
||||
padding=self.c._padding,
|
||||
dilation=self.c._dilation,
|
||||
groups=self.c._groups,
|
||||
)
|
||||
m.weight.set_value(w)
|
||||
m.bias.set_value(b)
|
||||
return m
|
||||
|
||||
|
||||
class Residual(nn.Layer):
|
||||
def __init__(self, m, drop=0.0):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.drop = drop
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.drop > 0:
|
||||
return (
|
||||
x
|
||||
+ self.m(x)
|
||||
* paddle.rand(x.size(0), 1, 1, 1)
|
||||
.ge_(self.drop)
|
||||
.div(1 - self.drop)
|
||||
.detach()
|
||||
)
|
||||
else:
|
||||
return x + self.m(x)
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
if isinstance(self.m, Conv2D_BN):
|
||||
m = self.m.fuse()
|
||||
assert m._groups == m.in_channels
|
||||
identity = paddle.ones([m.weight.shape[0], m.weight.shape[1], 1, 1])
|
||||
identity = nn.functional.pad(identity, [1, 1, 1, 1])
|
||||
m.weight += identity
|
||||
return m
|
||||
elif isinstance(self.m, nn.Conv2D):
|
||||
m = self.m
|
||||
assert m._groups != m.in_channels
|
||||
identity = paddle.ones([m.weight.shape[0], m.weight.shape[1], 1, 1])
|
||||
identity = nn.functional.pad(identity, [1, 1, 1, 1])
|
||||
m.weight += identity
|
||||
return m
|
||||
else:
|
||||
return self
|
||||
|
||||
|
||||
class RepVGGDW(nn.Layer):
|
||||
def __init__(self, ed) -> None:
|
||||
super().__init__()
|
||||
self.conv = Conv2D_BN(ed, ed, 3, 1, 1, groups=ed)
|
||||
self.conv1 = nn.Conv2D(ed, ed, 1, 1, 0, groups=ed)
|
||||
self.dim = ed
|
||||
self.bn = nn.BatchNorm2D(ed)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bn((self.conv(x) + self.conv1(x)) + x)
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
conv = self.conv.fuse()
|
||||
conv1 = self.conv1
|
||||
|
||||
conv_w = conv.weight
|
||||
conv_b = conv.bias
|
||||
conv1_w = conv1.weight
|
||||
conv1_b = conv1.bias
|
||||
|
||||
conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
|
||||
|
||||
identity = nn.functional.pad(
|
||||
paddle.ones([conv1_w.shape[0], conv1_w.shape[1], 1, 1]), [1, 1, 1, 1]
|
||||
)
|
||||
|
||||
final_conv_w = conv_w + conv1_w + identity
|
||||
final_conv_b = conv_b + conv1_b
|
||||
|
||||
conv.weight.set_value(final_conv_w)
|
||||
conv.bias.set_value(final_conv_b)
|
||||
|
||||
bn = self.bn
|
||||
w = bn.weight / (bn._variance + bn._epsilon) ** 0.5
|
||||
w = conv.weight * w[:, None, None, None]
|
||||
b = (
|
||||
bn.bias
|
||||
+ (conv.bias - bn._mean) * bn.weight / (bn._variance + bn._epsilon) ** 0.5
|
||||
)
|
||||
conv.weight.set_value(w)
|
||||
conv.bias.set_value(b)
|
||||
return conv
|
||||
|
||||
|
||||
class RepViTBlock(nn.Layer):
|
||||
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
|
||||
super(RepViTBlock, self).__init__()
|
||||
|
||||
self.identity = stride == 1 and inp == oup
|
||||
assert hidden_dim == 2 * inp
|
||||
|
||||
if stride != 1:
|
||||
self.token_mixer = nn.Sequential(
|
||||
Conv2D_BN(
|
||||
inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp
|
||||
),
|
||||
SEModule(inp, 0.25) if use_se else nn.Identity(),
|
||||
Conv2D_BN(inp, oup, ks=1, stride=1, pad=0),
|
||||
)
|
||||
self.channel_mixer = Residual(
|
||||
nn.Sequential(
|
||||
# pw
|
||||
Conv2D_BN(oup, 2 * oup, 1, 1, 0),
|
||||
nn.GELU() if use_hs else nn.GELU(),
|
||||
# pw-linear
|
||||
Conv2D_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert self.identity
|
||||
self.token_mixer = nn.Sequential(
|
||||
RepVGGDW(inp),
|
||||
SEModule(inp, 0.25) if use_se else nn.Identity(),
|
||||
)
|
||||
self.channel_mixer = Residual(
|
||||
nn.Sequential(
|
||||
# pw
|
||||
Conv2D_BN(inp, hidden_dim, 1, 1, 0),
|
||||
nn.GELU() if use_hs else nn.GELU(),
|
||||
# pw-linear
|
||||
Conv2D_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.channel_mixer(self.token_mixer(x))
|
||||
|
||||
|
||||
class RepViT(nn.Layer):
|
||||
def __init__(self, cfgs, in_channels=3, out_indices=None):
|
||||
super(RepViT, self).__init__()
|
||||
# setting of inverted residual blocks
|
||||
self.cfgs = cfgs
|
||||
|
||||
# building first layer
|
||||
input_channel = self.cfgs[0][2]
|
||||
patch_embed = nn.Sequential(
|
||||
Conv2D_BN(in_channels, input_channel // 2, 3, 2, 1),
|
||||
nn.GELU(),
|
||||
Conv2D_BN(input_channel // 2, input_channel, 3, 2, 1),
|
||||
)
|
||||
layers = [patch_embed]
|
||||
# building inverted residual blocks
|
||||
block = RepViTBlock
|
||||
for k, t, c, use_se, use_hs, s in self.cfgs:
|
||||
output_channel = _make_divisible(c, 8)
|
||||
exp_size = _make_divisible(input_channel * t, 8)
|
||||
layers.append(
|
||||
block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)
|
||||
)
|
||||
input_channel = output_channel
|
||||
self.features = nn.LayerList(layers)
|
||||
self.out_indices = out_indices
|
||||
if out_indices is not None:
|
||||
self.out_channels = [self.cfgs[ids - 1][2] for ids in out_indices]
|
||||
else:
|
||||
self.out_channels = self.cfgs[-1][2]
|
||||
|
||||
def forward(self, x):
|
||||
if self.out_indices is not None:
|
||||
return self.forward_det(x)
|
||||
return self.forward_rec(x)
|
||||
|
||||
def forward_det(self, x):
|
||||
outs = []
|
||||
for i, f in enumerate(self.features):
|
||||
x = f(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
return outs
|
||||
|
||||
def forward_rec(self, x):
|
||||
for f in self.features:
|
||||
x = f(x)
|
||||
h = x.shape[2]
|
||||
x = nn.functional.avg_pool2d(x, [h, 2])
|
||||
return x
|
||||
|
||||
|
||||
def RepSVTR(in_channels=3):
|
||||
"""
|
||||
Constructs a MobileNetV3-Large model
|
||||
"""
|
||||
# k, t, c, SE, HS, s
|
||||
cfgs = [
|
||||
[3, 2, 96, 1, 0, 1],
|
||||
[3, 2, 96, 0, 0, 1],
|
||||
[3, 2, 96, 0, 0, 1],
|
||||
[3, 2, 192, 0, 1, (2, 1)],
|
||||
[3, 2, 192, 1, 1, 1],
|
||||
[3, 2, 192, 0, 1, 1],
|
||||
[3, 2, 192, 1, 1, 1],
|
||||
[3, 2, 192, 0, 1, 1],
|
||||
[3, 2, 192, 1, 1, 1],
|
||||
[3, 2, 192, 0, 1, 1],
|
||||
[3, 2, 384, 0, 1, (2, 1)],
|
||||
[3, 2, 384, 1, 1, 1],
|
||||
[3, 2, 384, 0, 1, 1],
|
||||
]
|
||||
return RepViT(cfgs, in_channels=in_channels)
|
||||
|
||||
|
||||
def RepSVTR_det(in_channels=3, out_indices=[2, 5, 10, 13]):
|
||||
"""
|
||||
Constructs a MobileNetV3-Large model
|
||||
"""
|
||||
# k, t, c, SE, HS, s
|
||||
cfgs = [
|
||||
[3, 2, 48, 1, 0, 1],
|
||||
[3, 2, 48, 0, 0, 1],
|
||||
[3, 2, 96, 0, 0, 2],
|
||||
[3, 2, 96, 1, 0, 1],
|
||||
[3, 2, 96, 0, 0, 1],
|
||||
[3, 2, 192, 0, 1, 2],
|
||||
[3, 2, 192, 1, 1, 1],
|
||||
[3, 2, 192, 0, 1, 1],
|
||||
[3, 2, 192, 1, 1, 1],
|
||||
[3, 2, 192, 0, 1, 1],
|
||||
[3, 2, 384, 0, 1, 2],
|
||||
[3, 2, 384, 1, 1, 1],
|
||||
[3, 2, 384, 0, 1, 1],
|
||||
]
|
||||
return RepViT(cfgs, in_channels=in_channels, out_indices=out_indices)
|
|
@ -0,0 +1,575 @@
|
|||
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import ParamAttr
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=0.02)
|
||||
normal_ = Normal
|
||||
zeros_ = Constant(value=0.0)
|
||||
ones_ = Constant(value=1.0)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob=0.0, training=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
|
||||
shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
||||
random_tensor = paddle.floor(random_tensor) # binarize
|
||||
output = x.divide(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Layer):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Identity(nn.Layer):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
class Mlp(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias_attr=False,
|
||||
groups=1,
|
||||
act=nn.GELU,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
|
||||
bias_attr=bias_attr,
|
||||
)
|
||||
self.norm = nn.BatchNorm2D(out_channels)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, inputs):
|
||||
out = self.conv(inputs)
|
||||
out = self.norm(out)
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.dim = dim
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape((0, -1, 3, self.num_heads, self.head_dim))
|
||||
.transpose((2, 0, 3, 1, 4))
|
||||
)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
|
||||
attn = nn.functional.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
epsilon=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim, epsilon=epsilon)
|
||||
self.mixer = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
||||
self.norm2 = norm_layer(dim, epsilon=epsilon)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlock(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
epsilon=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.norm1 = norm_layer(dim, epsilon=epsilon)
|
||||
self.mixer = nn.Conv2D(
|
||||
dim,
|
||||
dim,
|
||||
5,
|
||||
1,
|
||||
2,
|
||||
groups=num_heads,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
||||
self.norm2 = norm_layer(dim, epsilon=epsilon)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
C, H, W = x.shape[1:]
|
||||
x = x + self.drop_path(self.mixer(x))
|
||||
x = self.norm1(x.flatten(2).transpose([0, 2, 1]))
|
||||
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
||||
x = x.transpose([0, 2, 1]).reshape([0, C, H, W])
|
||||
return x
|
||||
|
||||
|
||||
class FlattenTranspose(nn.Layer):
|
||||
def forward(self, x):
|
||||
return x.flatten(2).transpose([0, 2, 1])
|
||||
|
||||
|
||||
class SubSample2D(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=[2, 1],
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
)
|
||||
self.norm = nn.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, sz):
|
||||
# print(x.shape)
|
||||
x = self.conv(x)
|
||||
C, H, W = x.shape[1:]
|
||||
x = self.norm(x.flatten(2).transpose([0, 2, 1]))
|
||||
x = x.transpose([0, 2, 1]).reshape([0, C, H, W])
|
||||
return x, [H, W]
|
||||
|
||||
|
||||
class SubSample1D(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=[2, 1],
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
)
|
||||
self.norm = nn.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, sz):
|
||||
C = x.shape[-1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, C, sz[0], sz[1]])
|
||||
x = self.conv(x)
|
||||
C, H, W = x.shape[1:]
|
||||
x = self.norm(x.flatten(2).transpose([0, 2, 1]))
|
||||
return x, [H, W]
|
||||
|
||||
|
||||
class IdentitySize(nn.Layer):
|
||||
def forward(self, x, sz):
|
||||
return x, sz
|
||||
|
||||
|
||||
class SVTRStage(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
dim=64,
|
||||
out_dim=256,
|
||||
depth=3,
|
||||
mixer=["Local"] * 3,
|
||||
sub_k=[2, 1],
|
||||
num_heads=2,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path=[0.1] * 3,
|
||||
norm_layer=nn.LayerNorm,
|
||||
act=nn.GELU,
|
||||
eps=1e-6,
|
||||
downsample=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
conv_block_num = sum([1 if mix == "Conv" else 0 for mix in mixer])
|
||||
blocks = []
|
||||
for i in range(depth):
|
||||
if mixer[i] == "Conv":
|
||||
blocks.append(
|
||||
ConvBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate,
|
||||
act_layer=act,
|
||||
drop_path=drop_path[i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=eps,
|
||||
)
|
||||
)
|
||||
else:
|
||||
blocks.append(
|
||||
Block(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
act_layer=act,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path[i],
|
||||
norm_layer=norm_layer,
|
||||
epsilon=eps,
|
||||
)
|
||||
)
|
||||
if i == conv_block_num - 1 and mixer[-1] != "Conv":
|
||||
blocks.append(FlattenTranspose())
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
if downsample:
|
||||
if mixer[-1] == "Conv":
|
||||
self.downsample = SubSample2D(dim, out_dim, stride=sub_k)
|
||||
elif mixer[-1] == "Global":
|
||||
self.downsample = SubSample1D(dim, out_dim, stride=sub_k)
|
||||
else:
|
||||
self.downsample = IdentitySize()
|
||||
|
||||
def forward(self, x, sz):
|
||||
x = self.blocks(x)
|
||||
x, sz = self.downsample(x, sz)
|
||||
return x, sz
|
||||
|
||||
|
||||
class ADDPosEmbed(nn.Layer):
|
||||
def __init__(self, feat_max_size=[8, 32], embed_dim=768):
|
||||
super().__init__()
|
||||
pos_embed = paddle.zeros(
|
||||
[1, feat_max_size[0] * feat_max_size[1], embed_dim], dtype=paddle.float32
|
||||
)
|
||||
trunc_normal_(pos_embed)
|
||||
pos_embed = pos_embed.transpose([0, 2, 1]).reshape(
|
||||
[1, embed_dim, feat_max_size[0], feat_max_size[1]]
|
||||
)
|
||||
self.pos_embed = self.create_parameter(
|
||||
[1, embed_dim, feat_max_size[0], feat_max_size[1]]
|
||||
)
|
||||
self.add_parameter("pos_embed", self.pos_embed)
|
||||
self.pos_embed.set_value(pos_embed)
|
||||
|
||||
def forward(self, x):
|
||||
sz = x.shape[2:]
|
||||
x = x + self.pos_embed[:, :, : sz[0], : sz[1]]
|
||||
return x
|
||||
|
||||
|
||||
class POPatchEmbed(nn.Layer):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
feat_max_size=[8, 32],
|
||||
embed_dim=768,
|
||||
use_pos_embed=False,
|
||||
flatten=False,
|
||||
):
|
||||
super().__init__()
|
||||
patch_embed = [
|
||||
ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=embed_dim // 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None,
|
||||
),
|
||||
ConvBNLayer(
|
||||
in_channels=embed_dim // 2,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=nn.GELU,
|
||||
bias_attr=None,
|
||||
),
|
||||
]
|
||||
if use_pos_embed:
|
||||
patch_embed.append(ADDPosEmbed(feat_max_size, embed_dim))
|
||||
if flatten:
|
||||
patch_embed.append(FlattenTranspose())
|
||||
self.patch_embed = nn.Sequential(*patch_embed)
|
||||
|
||||
def forward(self, x):
|
||||
sz = x.shape[2:]
|
||||
x = self.patch_embed(x)
|
||||
return x, [sz[0] // 4, sz[1] // 4]
|
||||
|
||||
|
||||
class LastStage(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, last_drop, out_char_num):
|
||||
super().__init__()
|
||||
self.last_conv = nn.Linear(in_channels, out_channels, bias_attr=False)
|
||||
self.hardswish = nn.Hardswish()
|
||||
self.dropout = nn.Dropout(p=last_drop, mode="downscale_in_infer")
|
||||
|
||||
def forward(self, x, sz):
|
||||
x = x.reshape([0, sz[0], sz[1], x.shape[-1]])
|
||||
x = x.mean(1)
|
||||
x = self.last_conv(x)
|
||||
x = self.hardswish(x)
|
||||
x = self.dropout(x)
|
||||
return x, [1, sz[1]]
|
||||
|
||||
|
||||
class OutPool(nn.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, sz):
|
||||
C = x.shape[-1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, C, sz[0], sz[1]])
|
||||
x = nn.functional.avg_pool2d(x, [sz[0], 2])
|
||||
return x, [1, sz[1] // 2]
|
||||
|
||||
|
||||
class Feat2D(nn.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, sz):
|
||||
C = x.shape[-1]
|
||||
x = x.transpose([0, 2, 1]).reshape([0, C, sz[0], sz[1]])
|
||||
return x, sz
|
||||
|
||||
|
||||
class SVTRv2(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
max_sz=[32, 128],
|
||||
in_channels=3,
|
||||
out_channels=192,
|
||||
out_char_num=25,
|
||||
depths=[3, 6, 3],
|
||||
dims=[64, 128, 256],
|
||||
mixer=[["Conv"] * 3, ["Conv"] * 3 + ["Global"] * 3, ["Global"] * 3],
|
||||
use_pos_embed=False,
|
||||
sub_k=[[1, 1], [2, 1], [1, 1]],
|
||||
num_heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
last_drop=0.1,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
act=nn.GELU,
|
||||
last_stage=False,
|
||||
eps=1e-6,
|
||||
use_pool=False,
|
||||
feat2d=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
num_stages = len(depths)
|
||||
self.num_features = dims[-1]
|
||||
|
||||
feat_max_size = [max_sz[0] // 4, max_sz[1] // 4]
|
||||
self.pope = POPatchEmbed(
|
||||
in_channels=in_channels,
|
||||
feat_max_size=feat_max_size,
|
||||
embed_dim=dims[0],
|
||||
use_pos_embed=use_pos_embed,
|
||||
flatten=mixer[0][0] != "Conv",
|
||||
)
|
||||
|
||||
dpr = np.linspace(0, drop_path_rate, sum(depths)) # stochastic depth decay rule
|
||||
|
||||
self.stages = nn.LayerList()
|
||||
for i_stage in range(num_stages):
|
||||
stage = SVTRStage(
|
||||
dim=dims[i_stage],
|
||||
out_dim=dims[i_stage + 1] if i_stage < num_stages - 1 else 0,
|
||||
depth=depths[i_stage],
|
||||
mixer=mixer[i_stage],
|
||||
sub_k=sub_k[i_stage],
|
||||
num_heads=num_heads[i_stage],
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_stage]) : sum(depths[: i_stage + 1])],
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
downsample=False if i_stage == num_stages - 1 else True,
|
||||
eps=eps,
|
||||
)
|
||||
self.stages.append(stage)
|
||||
|
||||
self.out_channels = self.num_features
|
||||
self.last_stage = last_stage
|
||||
if last_stage:
|
||||
self.out_channels = out_channels
|
||||
self.stages.append(
|
||||
LastStage(self.num_features, out_channels, last_drop, out_char_num)
|
||||
)
|
||||
if use_pool:
|
||||
self.stages.append(OutPool())
|
||||
|
||||
if feat2d:
|
||||
self.stages.append(Feat2D())
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
zeros_(m.bias)
|
||||
ones_(m.weight)
|
||||
|
||||
def forward(self, x):
|
||||
x, sz = self.pope(x)
|
||||
for stage in self.stages:
|
||||
x, sz = stage(x, sz)
|
||||
return x
|
|
@ -149,5 +149,5 @@ class MultiHead(nn.Layer):
|
|||
head_out["sar"] = sar_out
|
||||
else:
|
||||
gtc_out = self.gtc_head(self.before_gtc(x), targets[1:])
|
||||
head_out["nrtr"] = gtc_out
|
||||
head_out["gtc"] = gtc_out
|
||||
return head_out
|
||||
|
|
Loading…
Reference in New Issue