parent
0cc9870eb3
commit
a38c087bcb
|
@ -39,7 +39,7 @@ Architecture:
|
|||
in_channels: 3
|
||||
Transform:
|
||||
Backbone:
|
||||
name: PPHGNetV2_B4
|
||||
name: PPHGNetV2_B4_Formula
|
||||
class_num: 1024
|
||||
|
||||
Head:
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 75
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/PP-OCRv5_server_rec
|
||||
save_epoch_step: 1
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
calc_epoch_interval: 1
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ./ppocr/utils/dict/ppocrv5_dict.txt
|
||||
max_text_length: &max_text_length 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_ppocrv5.txt
|
||||
d2s_train_image_shape: [3, 48, 320]
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
warmup_epoch: 1
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 3.0e-05
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SVTR_HGNet
|
||||
Transform:
|
||||
Backbone:
|
||||
name: PPHGNetV2_B4
|
||||
text_rec: True
|
||||
Head:
|
||||
name: MultiHead
|
||||
head_list:
|
||||
- CTCHead:
|
||||
Neck:
|
||||
name: svtr
|
||||
dims: 120
|
||||
depth: 2
|
||||
hidden_dims: 120
|
||||
kernel_size: [1, 3]
|
||||
use_guide: True
|
||||
Head:
|
||||
fc_decay: 0.00001
|
||||
- NRTRHead:
|
||||
nrtr_dim: 384
|
||||
max_text_length: *max_text_length
|
||||
|
||||
Loss:
|
||||
name: MultiLoss
|
||||
loss_config_list:
|
||||
- CTCLoss:
|
||||
- NRTRLoss:
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: MultiScaleDataSet
|
||||
ds_width: false
|
||||
data_dir: ./train_data/
|
||||
ext_op_transform_idx: 1
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
sampler:
|
||||
name: MultiScaleSampler
|
||||
scales: [[320, 32], [320, 48], [320, 64]]
|
||||
first_bs: &bs 64
|
||||
fix_bs: false
|
||||
divided_factor: [8, 16] # w, h
|
||||
is_training: True
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: *bs
|
||||
drop_last: true
|
||||
num_workers: 16
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- MultiLabelEncode:
|
||||
gtc_encode: NRTRLabelEncode
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 48, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label_ctc
|
||||
- label_gtc
|
||||
- length
|
||||
- valid_ratio
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 4
|
|
@ -71,7 +71,7 @@ def build_backbone(config, model_type):
|
|||
from .rec_repvit import RepSVTR
|
||||
from .rec_svtrv2 import SVTRv2
|
||||
from .rec_vary_vit import Vary_VIT_B, Vary_VIT_B_Formula
|
||||
from .rec_pphgnetv2 import PPHGNetV2_B4
|
||||
from .rec_pphgnetv2 import PPHGNetV2_B4, PPHGNetV2_B4_Formula
|
||||
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance",
|
||||
|
@ -101,6 +101,7 @@ def build_backbone(config, model_type):
|
|||
"DonutSwinModel",
|
||||
"Vary_VIT_B",
|
||||
"PPHGNetV2_B4",
|
||||
"PPHGNetV2_B4_Formula",
|
||||
"Vary_VIT_B_Formula",
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
|
|
|
@ -1061,7 +1061,13 @@ class StemBlock(TheseusLayer):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels, mid_channels, out_channels, use_lab=False, lr_mult=1.0
|
||||
self,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
use_lab=False,
|
||||
lr_mult=1.0,
|
||||
text_rec=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.stem1 = ConvBNAct(
|
||||
|
@ -1094,7 +1100,7 @@ class StemBlock(TheseusLayer):
|
|||
in_channels=mid_channels * 2,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
stride=1 if text_rec else 2,
|
||||
use_lab=use_lab,
|
||||
lr_mult=lr_mult,
|
||||
)
|
||||
|
@ -1230,6 +1236,7 @@ class HGV2_Stage(TheseusLayer):
|
|||
light_block=True,
|
||||
kernel_size=3,
|
||||
use_lab=False,
|
||||
stride=2,
|
||||
lr_mult=1.0,
|
||||
):
|
||||
|
||||
|
@ -1240,7 +1247,7 @@ class HGV2_Stage(TheseusLayer):
|
|||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
stride=stride,
|
||||
groups=in_channels,
|
||||
use_act=False,
|
||||
use_lab=use_lab,
|
||||
|
@ -1298,13 +1305,20 @@ class PPHGNetV2(TheseusLayer):
|
|||
dropout_prob=0.0,
|
||||
class_num=1000,
|
||||
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
|
||||
det=False,
|
||||
text_rec=False,
|
||||
out_indices=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.det = det
|
||||
self.text_rec = text_rec
|
||||
self.use_lab = use_lab
|
||||
self.use_last_conv = use_last_conv
|
||||
self.class_expand = class_expand
|
||||
self.class_num = class_num
|
||||
self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
|
||||
self.out_channels = []
|
||||
|
||||
# stem
|
||||
self.stem = StemBlock(
|
||||
|
@ -1313,6 +1327,7 @@ class PPHGNetV2(TheseusLayer):
|
|||
out_channels=stem_channels[2],
|
||||
use_lab=use_lab,
|
||||
lr_mult=lr_mult_list[0],
|
||||
text_rec=text_rec,
|
||||
)
|
||||
|
||||
# stages
|
||||
|
@ -1327,6 +1342,7 @@ class PPHGNetV2(TheseusLayer):
|
|||
light_block,
|
||||
kernel_size,
|
||||
layer_num,
|
||||
stride,
|
||||
) = stage_config[k]
|
||||
self.stages.append(
|
||||
HGV2_Stage(
|
||||
|
@ -1339,9 +1355,14 @@ class PPHGNetV2(TheseusLayer):
|
|||
light_block,
|
||||
kernel_size,
|
||||
use_lab,
|
||||
stride,
|
||||
lr_mult=lr_mult_list[i + 1],
|
||||
)
|
||||
)
|
||||
if i in self.out_indices:
|
||||
self.out_channels.append(out_channels)
|
||||
if not self.det:
|
||||
self.out_channels = stage_config["stage4"][2]
|
||||
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
|
||||
|
@ -1378,8 +1399,19 @@ class PPHGNetV2(TheseusLayer):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
for stage in self.stages:
|
||||
out = []
|
||||
for i, stage in enumerate(self.stages):
|
||||
x = stage(x)
|
||||
if self.det and i in self.out_indices:
|
||||
out.append(x)
|
||||
if self.det:
|
||||
return out
|
||||
|
||||
if self.text_rec:
|
||||
if self.training:
|
||||
x = F.adaptive_avg_pool2d(x, [1, 40])
|
||||
else:
|
||||
x = F.avg_pool2d(x, [3, 2])
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1479,6 +1511,42 @@ def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
|
||||
"""
|
||||
PPHGNetV2_B4
|
||||
Args:
|
||||
pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
|
||||
If str, means the path of the pretrained model.
|
||||
use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
|
||||
Returns:
|
||||
model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
|
||||
"""
|
||||
stage_config_rec = {
|
||||
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
|
||||
"stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
|
||||
"stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
|
||||
"stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
|
||||
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
|
||||
}
|
||||
|
||||
stage_config_det = {
|
||||
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
|
||||
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
|
||||
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
|
||||
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
|
||||
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
|
||||
}
|
||||
model = PPHGNetV2(
|
||||
stem_channels=[3, 32, 48],
|
||||
stage_config=stage_config_det if det else stage_config_rec,
|
||||
use_lab=False,
|
||||
det=det,
|
||||
text_rec=text_rec,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
|
||||
"""
|
||||
PPHGNetV2_B5
|
||||
|
@ -1527,7 +1595,7 @@ def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
class PPHGNetV2_B4(nn.Layer):
|
||||
class PPHGNetV2_B4_Formula(nn.Layer):
|
||||
"""
|
||||
PPHGNetV2_B4
|
||||
Args:
|
||||
|
@ -1543,10 +1611,10 @@ class PPHGNetV2_B4(nn.Layer):
|
|||
self.out_channels = 2048
|
||||
stage_config = {
|
||||
# in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
|
||||
"stage1": [48, 48, 128, 1, False, False, 3, 6],
|
||||
"stage2": [128, 96, 512, 1, True, False, 3, 6],
|
||||
"stage3": [512, 192, 1024, 3, True, True, 5, 6],
|
||||
"stage4": [1024, 384, 2048, 1, True, True, 5, 6],
|
||||
"stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
|
||||
"stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
|
||||
"stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
|
||||
"stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
|
||||
}
|
||||
|
||||
self.pphgnet_b4 = PPHGNetV2(
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,7 +9,7 @@ import pytest
|
|||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
|
||||
from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModel, DonutSwinModelOutput
|
||||
from ppocr.modeling.backbones.rec_pphgnetv2 import PPHGNetV2_B4
|
||||
from ppocr.modeling.backbones.rec_pphgnetv2 import PPHGNetV2_B4_Formula
|
||||
from ppocr.modeling.backbones.rec_vary_vit import Vary_VIT_B_Formula
|
||||
from ppocr.modeling.heads.rec_unimernet_head import UniMERNetHead
|
||||
from ppocr.modeling.heads.rec_ppformulanet_head import PPFormulaNet_Head
|
||||
|
@ -106,7 +106,7 @@ def test_ppformulanet_s_backbone(sample_image_ppformulanet_s):
|
|||
Args:
|
||||
sample_image_ppformulanet_s: sample image to be processed.
|
||||
"""
|
||||
backbone = PPHGNetV2_B4(
|
||||
backbone = PPHGNetV2_B4_Formula(
|
||||
class_num=1024,
|
||||
)
|
||||
backbone.eval()
|
||||
|
|
Loading…
Reference in New Issue