add ppocr v5 (#15121)

Co-authored-by: zhangyubo0722 <zangyubo0722@163.com>
pull/15130/head
zhangyubo0722 2025-05-12 21:55:26 +08:00 committed by GitHub
parent 0cc9870eb3
commit a38c087bcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18600 additions and 13 deletions

View File

@ -39,7 +39,7 @@ Architecture:
in_channels: 3
Transform:
Backbone:
name: PPHGNetV2_B4
name: PPHGNetV2_B4_Formula
class_num: 1024
Head:

View File

@ -0,0 +1,135 @@
Global:
debug: false
use_gpu: true
epoch_num: 75
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/PP-OCRv5_server_rec
save_epoch_step: 1
eval_batch_step: [0, 2000]
cal_metric_during_train: true
calc_epoch_interval: 1
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ./ppocr/utils/dict/ppocrv5_dict.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv5.txt
d2s_train_image_shape: [3, 48, 320]
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 1
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPHGNetV2_B4
text_rec: True
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 64
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 16
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4

View File

@ -71,7 +71,7 @@ def build_backbone(config, model_type):
from .rec_repvit import RepSVTR
from .rec_svtrv2 import SVTRv2
from .rec_vary_vit import Vary_VIT_B, Vary_VIT_B_Formula
from .rec_pphgnetv2 import PPHGNetV2_B4
from .rec_pphgnetv2 import PPHGNetV2_B4, PPHGNetV2_B4_Formula
support_dict = [
"MobileNetV1Enhance",
@ -101,6 +101,7 @@ def build_backbone(config, model_type):
"DonutSwinModel",
"Vary_VIT_B",
"PPHGNetV2_B4",
"PPHGNetV2_B4_Formula",
"Vary_VIT_B_Formula",
]
elif model_type == "e2e":

View File

@ -1061,7 +1061,13 @@ class StemBlock(TheseusLayer):
"""
def __init__(
self, in_channels, mid_channels, out_channels, use_lab=False, lr_mult=1.0
self,
in_channels,
mid_channels,
out_channels,
use_lab=False,
lr_mult=1.0,
text_rec=False,
):
super().__init__()
self.stem1 = ConvBNAct(
@ -1094,7 +1100,7 @@ class StemBlock(TheseusLayer):
in_channels=mid_channels * 2,
out_channels=mid_channels,
kernel_size=3,
stride=2,
stride=1 if text_rec else 2,
use_lab=use_lab,
lr_mult=lr_mult,
)
@ -1230,6 +1236,7 @@ class HGV2_Stage(TheseusLayer):
light_block=True,
kernel_size=3,
use_lab=False,
stride=2,
lr_mult=1.0,
):
@ -1240,7 +1247,7 @@ class HGV2_Stage(TheseusLayer):
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=2,
stride=stride,
groups=in_channels,
use_act=False,
use_lab=use_lab,
@ -1298,13 +1305,20 @@ class PPHGNetV2(TheseusLayer):
dropout_prob=0.0,
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
det=False,
text_rec=False,
out_indices=None,
**kwargs,
):
super().__init__()
self.det = det
self.text_rec = text_rec
self.use_lab = use_lab
self.use_last_conv = use_last_conv
self.class_expand = class_expand
self.class_num = class_num
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
self.out_channels = []
# stem
self.stem = StemBlock(
@ -1313,6 +1327,7 @@ class PPHGNetV2(TheseusLayer):
out_channels=stem_channels[2],
use_lab=use_lab,
lr_mult=lr_mult_list[0],
text_rec=text_rec,
)
# stages
@ -1327,6 +1342,7 @@ class PPHGNetV2(TheseusLayer):
light_block,
kernel_size,
layer_num,
stride,
) = stage_config[k]
self.stages.append(
HGV2_Stage(
@ -1339,9 +1355,14 @@ class PPHGNetV2(TheseusLayer):
light_block,
kernel_size,
use_lab,
stride,
lr_mult=lr_mult_list[i + 1],
)
)
if i in self.out_indices:
self.out_channels.append(out_channels)
if not self.det:
self.out_channels = stage_config["stage4"][2]
self.avg_pool = AdaptiveAvgPool2D(1)
@ -1378,8 +1399,19 @@ class PPHGNetV2(TheseusLayer):
def forward(self, x):
x = self.stem(x)
for stage in self.stages:
out = []
for i, stage in enumerate(self.stages):
x = stage(x)
if self.det and i in self.out_indices:
out.append(x)
if self.det:
return out
if self.text_rec:
if self.training:
x = F.adaptive_avg_pool2d(x, [1, 40])
else:
x = F.avg_pool2d(x, [3, 2])
return x
@ -1479,6 +1511,42 @@ def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
return model
def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
"""
PPHGNetV2_B4
Args:
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
Returns:
model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
"""
stage_config_rec = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
"stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
"stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
}
stage_config_det = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
}
model = PPHGNetV2(
stem_channels=[3, 32, 48],
stage_config=stage_config_det if det else stage_config_rec,
use_lab=False,
det=det,
text_rec=text_rec,
**kwargs,
)
return model
def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
"""
PPHGNetV2_B5
@ -1527,7 +1595,7 @@ def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
return model
class PPHGNetV2_B4(nn.Layer):
class PPHGNetV2_B4_Formula(nn.Layer):
"""
PPHGNetV2_B4
Args:
@ -1543,10 +1611,10 @@ class PPHGNetV2_B4(nn.Layer):
self.out_channels = 2048
stage_config = {
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
"stage1": [48, 48, 128, 1, False, False, 3, 6],
"stage2": [128, 96, 512, 1, True, False, 3, 6],
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
}
self.pphgnet_b4 = PPHGNetV2(

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ import pytest
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModel, DonutSwinModelOutput
from ppocr.modeling.backbones.rec_pphgnetv2 import PPHGNetV2_B4
from ppocr.modeling.backbones.rec_pphgnetv2 import PPHGNetV2_B4_Formula
from ppocr.modeling.backbones.rec_vary_vit import Vary_VIT_B_Formula
from ppocr.modeling.heads.rec_unimernet_head import UniMERNetHead
from ppocr.modeling.heads.rec_ppformulanet_head import PPFormulaNet_Head
@ -106,7 +106,7 @@ def test_ppformulanet_s_backbone(sample_image_ppformulanet_s):
Args:
sample_image_ppformulanet_s: sample image to be processed.
"""
backbone = PPHGNetV2_B4(
backbone = PPHGNetV2_B4_Formula(
class_num=1024,
)
backbone.eval()