[Feature] Add PREN Scene Text Recognition Model(Accepted in CVPR2021) (#5563)
* [Feature] add PREN scene text recognition model * [Patch] Optimize yml File * [Patch] Save Label/Pred Preprocess Time Cost * [BugFix] Modify Shape Conversion to Fit for Inference Model Exportion * [Patch] ? * [Patch] ? * 啥情况...pull/5368/head^2
parent
3df64d502b
commit
6e607a0fa1
|
@ -0,0 +1,92 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 8
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/rec/pren_new
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 2000 iterations after the 4000th iteration
|
||||
eval_batch_step: [4000, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
max_text_length: &max_text_length 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_pren.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adadelta
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [2, 5, 7]
|
||||
values: [0.5, 0.1, 0.01, 0.001]
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: PREN
|
||||
in_channels: 3
|
||||
Backbone:
|
||||
name: EfficientNetb3_PREN
|
||||
Neck:
|
||||
name: PRENFPN
|
||||
n_r: 5
|
||||
d_model: 384
|
||||
max_len: *max_text_length
|
||||
dropout: 0.1
|
||||
Head:
|
||||
name: PRENHead
|
||||
|
||||
Loss:
|
||||
name: PRENLoss
|
||||
|
||||
PostProcess:
|
||||
name: PRENLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- PRENLabelEncode:
|
||||
- RecAug:
|
||||
- PRENResizeImg:
|
||||
image_shape: [64, 256] # h,w
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label']
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 128
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- PRENLabelEncode:
|
||||
- PRENResizeImg:
|
||||
image_shape: [64, 256] # h,w
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 64
|
||||
num_workers: 8
|
|
@ -22,7 +22,8 @@ from .make_shrink_map import MakeShrinkMap
|
|||
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
||||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .ColorJitter import ColorJitter
|
||||
|
|
|
@ -785,6 +785,53 @@ class SARLabelEncode(BaseRecLabelEncode):
|
|||
return [self.padding_idx]
|
||||
|
||||
|
||||
class PRENLabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(PRENLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
padding_str = '<PAD>' # 0
|
||||
end_str = '<EOS>' # 1
|
||||
unknown_str = '<UNK>' # 2
|
||||
|
||||
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
||||
self.padding_idx = 0
|
||||
self.end_idx = 1
|
||||
self.unknown_idx = 2
|
||||
|
||||
return dict_character
|
||||
|
||||
def encode(self, text):
|
||||
if len(text) == 0 or len(text) >= self.max_text_len:
|
||||
return None
|
||||
if self.lower:
|
||||
text = text.lower()
|
||||
text_list = []
|
||||
for char in text:
|
||||
if char not in self.dict:
|
||||
text_list.append(self.unknown_idx)
|
||||
else:
|
||||
text_list.append(self.dict[char])
|
||||
text_list.append(self.end_idx)
|
||||
if len(text_list) < self.max_text_len:
|
||||
text_list += [self.padding_idx] * (
|
||||
self.max_text_len - len(text_list))
|
||||
return text_list
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
encoded_text = self.encode(text)
|
||||
if encoded_text is None:
|
||||
return None
|
||||
data['label'] = np.array(encoded_text)
|
||||
return data
|
||||
|
||||
|
||||
class VQATokenLabelEncode(object):
|
||||
"""
|
||||
Label encode for NLP VQA methods
|
||||
|
|
|
@ -141,6 +141,25 @@ class SARRecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class PRENResizeImg(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
"""
|
||||
Accroding to original paper's realization, it's a hard resize method here.
|
||||
So maybe you should optimize it to fit for your task better.
|
||||
"""
|
||||
self.dst_h, self.dst_w = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
resized_img = cv2.resize(
|
||||
img, (self.dst_w, self.dst_h), interpolation=cv2.INTER_LINEAR)
|
||||
resized_img = resized_img.transpose((2, 0, 1)) / 255
|
||||
resized_img -= 0.5
|
||||
resized_img /= 0.5
|
||||
data['image'] = resized_img.astype(np.float32)
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
|
|
|
@ -32,6 +32,7 @@ from .rec_srn_loss import SRNLoss
|
|||
from .rec_nrtr_loss import NRTRLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
from .rec_aster_loss import AsterLoss
|
||||
from .rec_pren_loss import PRENLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -58,7 +59,7 @@ def build_loss(config):
|
|||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
||||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class PRENLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(PRENLoss, self).__init__()
|
||||
# note: 0 is padding idx
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss = self.loss_func(predicts, batch[1].astype('int64'))
|
||||
return {'loss': loss}
|
|
@ -30,9 +30,10 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_31 import ResNet31
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
from .rec_micronet import MicroNet
|
||||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31", "ResNet_ASTER", 'MicroNet'
|
||||
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN'
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Code is refer from:
|
||||
https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
from collections import namedtuple
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ['EfficientNetb3']
|
||||
|
||||
|
||||
class EffB3Params:
|
||||
@staticmethod
|
||||
def get_global_params():
|
||||
"""
|
||||
The fllowing are efficientnetb3's arch superparams, but to fit for scene
|
||||
text recognition task, the resolution(image_size) here is changed
|
||||
from 300 to 64.
|
||||
"""
|
||||
GlobalParams = namedtuple('GlobalParams', [
|
||||
'drop_connect_rate', 'width_coefficient', 'depth_coefficient',
|
||||
'depth_divisor', 'image_size'
|
||||
])
|
||||
global_params = GlobalParams(
|
||||
drop_connect_rate=0.3,
|
||||
width_coefficient=1.2,
|
||||
depth_coefficient=1.4,
|
||||
depth_divisor=8,
|
||||
image_size=64)
|
||||
return global_params
|
||||
|
||||
@staticmethod
|
||||
def get_block_params():
|
||||
BlockParams = namedtuple('BlockParams', [
|
||||
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
|
||||
'expand_ratio', 'id_skip', 'se_ratio', 'stride'
|
||||
])
|
||||
block_params = [
|
||||
BlockParams(3, 1, 32, 16, 1, True, 0.25, 1),
|
||||
BlockParams(3, 2, 16, 24, 6, True, 0.25, 2),
|
||||
BlockParams(5, 2, 24, 40, 6, True, 0.25, 2),
|
||||
BlockParams(3, 3, 40, 80, 6, True, 0.25, 2),
|
||||
BlockParams(5, 3, 80, 112, 6, True, 0.25, 1),
|
||||
BlockParams(5, 4, 112, 192, 6, True, 0.25, 2),
|
||||
BlockParams(3, 1, 192, 320, 6, True, 0.25, 1)
|
||||
]
|
||||
return block_params
|
||||
|
||||
|
||||
class EffUtils:
|
||||
@staticmethod
|
||||
def round_filters(filters, global_params):
|
||||
"""Calculate and round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.width_coefficient
|
||||
if not multiplier:
|
||||
return filters
|
||||
divisor = global_params.depth_divisor
|
||||
filters *= multiplier
|
||||
new_filters = int(filters + divisor / 2) // divisor * divisor
|
||||
if new_filters < 0.9 * filters:
|
||||
new_filters += divisor
|
||||
return int(new_filters)
|
||||
|
||||
@staticmethod
|
||||
def round_repeats(repeats, global_params):
|
||||
"""Round number of filters based on depth multiplier."""
|
||||
multiplier = global_params.depth_coefficient
|
||||
if not multiplier:
|
||||
return repeats
|
||||
return int(math.ceil(multiplier * repeats))
|
||||
|
||||
|
||||
class ConvBlock(nn.Layer):
|
||||
def __init__(self, block_params):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.block_args = block_params
|
||||
self.has_se = (self.block_args.se_ratio is not None) and \
|
||||
(0 < self.block_args.se_ratio <= 1)
|
||||
self.id_skip = block_params.id_skip
|
||||
|
||||
# expansion phase
|
||||
self.input_filters = self.block_args.input_filters
|
||||
output_filters = \
|
||||
self.block_args.input_filters * self.block_args.expand_ratio
|
||||
if self.block_args.expand_ratio != 1:
|
||||
self.expand_conv = nn.Conv2D(
|
||||
self.input_filters, output_filters, 1, bias_attr=False)
|
||||
self.bn0 = nn.BatchNorm(output_filters)
|
||||
|
||||
# depthwise conv phase
|
||||
k = self.block_args.kernel_size
|
||||
s = self.block_args.stride
|
||||
self.depthwise_conv = nn.Conv2D(
|
||||
output_filters,
|
||||
output_filters,
|
||||
groups=output_filters,
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
padding='same',
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm(output_filters)
|
||||
|
||||
# squeeze and excitation layer, if desired
|
||||
if self.has_se:
|
||||
num_squeezed_channels = max(1,
|
||||
int(self.block_args.input_filters *
|
||||
self.block_args.se_ratio))
|
||||
self.se_reduce = nn.Conv2D(output_filters, num_squeezed_channels, 1)
|
||||
self.se_expand = nn.Conv2D(num_squeezed_channels, output_filters, 1)
|
||||
|
||||
# output phase
|
||||
self.final_oup = self.block_args.output_filters
|
||||
self.project_conv = nn.Conv2D(
|
||||
output_filters, self.final_oup, 1, bias_attr=False)
|
||||
self.bn2 = nn.BatchNorm(self.final_oup)
|
||||
self.swish = nn.Swish()
|
||||
|
||||
def drop_connect(self, inputs, p, training):
|
||||
if not training:
|
||||
return inputs
|
||||
|
||||
batch_size = inputs.shape[0]
|
||||
keep_prob = 1 - p
|
||||
random_tensor = keep_prob
|
||||
random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
|
||||
random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
|
||||
binary_tensor = paddle.floor(random_tensor)
|
||||
output = inputs / keep_prob * binary_tensor
|
||||
return output
|
||||
|
||||
def forward(self, inputs, drop_connect_rate=None):
|
||||
# expansion and depthwise conv
|
||||
x = inputs
|
||||
if self.block_args.expand_ratio != 1:
|
||||
x = self.swish(self.bn0(self.expand_conv(inputs)))
|
||||
x = self.swish(self.bn1(self.depthwise_conv(x)))
|
||||
|
||||
# squeeze and excitation
|
||||
if self.has_se:
|
||||
x_squeezed = F.adaptive_avg_pool2d(x, 1)
|
||||
x_squeezed = self.se_expand(self.swish(self.se_reduce(x_squeezed)))
|
||||
x = F.sigmoid(x_squeezed) * x
|
||||
x = self.bn2(self.project_conv(x))
|
||||
|
||||
# skip conntection and drop connect
|
||||
if self.id_skip and self.block_args.stride == 1 and \
|
||||
self.input_filters == self.final_oup:
|
||||
if drop_connect_rate:
|
||||
x = self.drop_connect(
|
||||
x, p=drop_connect_rate, training=self.training)
|
||||
x = x + inputs
|
||||
return x
|
||||
|
||||
|
||||
class EfficientNetb3_PREN(nn.Layer):
|
||||
def __init__(self, in_channels):
|
||||
super(EfficientNetb3_PREN, self).__init__()
|
||||
self.blocks_params = EffB3Params.get_block_params()
|
||||
self.global_params = EffB3Params.get_global_params()
|
||||
self.out_channels = []
|
||||
# stem
|
||||
stem_channels = EffUtils.round_filters(32, self.global_params)
|
||||
self.conv_stem = nn.Conv2D(
|
||||
in_channels, stem_channels, 3, 2, padding='same', bias_attr=False)
|
||||
self.bn0 = nn.BatchNorm(stem_channels)
|
||||
|
||||
self.blocks = []
|
||||
# to extract three feature maps for fpn based on efficientnetb3 backbone
|
||||
self.concerned_block_idxes = [7, 17, 25]
|
||||
concerned_idx = 0
|
||||
for i, block_params in enumerate(self.blocks_params):
|
||||
block_params = block_params._replace(
|
||||
input_filters=EffUtils.round_filters(block_params.input_filters,
|
||||
self.global_params),
|
||||
output_filters=EffUtils.round_filters(
|
||||
block_params.output_filters, self.global_params),
|
||||
num_repeat=EffUtils.round_repeats(block_params.num_repeat,
|
||||
self.global_params))
|
||||
self.blocks.append(
|
||||
self.add_sublayer("{}-0".format(i), ConvBlock(block_params)))
|
||||
concerned_idx += 1
|
||||
if concerned_idx in self.concerned_block_idxes:
|
||||
self.out_channels.append(block_params.output_filters)
|
||||
if block_params.num_repeat > 1:
|
||||
block_params = block_params._replace(
|
||||
input_filters=block_params.output_filters, stride=1)
|
||||
for j in range(block_params.num_repeat - 1):
|
||||
self.blocks.append(
|
||||
self.add_sublayer('{}-{}'.format(i, j + 1),
|
||||
ConvBlock(block_params)))
|
||||
concerned_idx += 1
|
||||
if concerned_idx in self.concerned_block_idxes:
|
||||
self.out_channels.append(block_params.output_filters)
|
||||
|
||||
self.swish = nn.Swish()
|
||||
|
||||
def forward(self, inputs):
|
||||
outs = []
|
||||
|
||||
x = self.swish(self.bn0(self.conv_stem(inputs)))
|
||||
for idx, block in enumerate(self.blocks):
|
||||
drop_connect_rate = self.global_params.drop_connect_rate
|
||||
if drop_connect_rate:
|
||||
drop_connect_rate *= float(idx) / len(self.blocks)
|
||||
x = block(x, drop_connect_rate=drop_connect_rate)
|
||||
if idx in self.concerned_block_idxes:
|
||||
outs.append(x)
|
||||
return outs
|
|
@ -30,6 +30,7 @@ def build_head(config):
|
|||
from .rec_nrtr_head import Transformer
|
||||
from .rec_sar_head import SARHead
|
||||
from .rec_aster_head import AsterHead
|
||||
from .rec_pren_head import PRENHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -42,7 +43,7 @@ def build_head(config):
|
|||
support_dict = [
|
||||
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
|
||||
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead'
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
class PRENHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(PRENHead, self).__init__()
|
||||
self.linear = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
predicts = self.linear(x)
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
|
||||
return predicts
|
|
@ -23,7 +23,11 @@ def build_neck(config):
|
|||
from .pg_fpn import PGFPN
|
||||
from .table_fpn import TableFPN
|
||||
from .fpn import FPN
|
||||
support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
from .pren_fpn import PRENFPN
|
||||
support_dict = [
|
||||
'FPN', 'DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN',
|
||||
'TableFPN', 'PRENFPN'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Code is refer from:
|
||||
https://github.com/RuijieJ/pren/blob/main/Nets/Aggregation.py
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class PoolAggregate(nn.Layer):
|
||||
def __init__(self, n_r, d_in, d_middle=None, d_out=None):
|
||||
super(PoolAggregate, self).__init__()
|
||||
if not d_middle:
|
||||
d_middle = d_in
|
||||
if not d_out:
|
||||
d_out = d_in
|
||||
|
||||
self.d_in = d_in
|
||||
self.d_middle = d_middle
|
||||
self.d_out = d_out
|
||||
self.act = nn.Swish()
|
||||
|
||||
self.n_r = n_r
|
||||
self.aggs = self._build_aggs()
|
||||
|
||||
def _build_aggs(self):
|
||||
aggs = []
|
||||
for i in range(self.n_r):
|
||||
aggs.append(
|
||||
self.add_sublayer(
|
||||
'{}'.format(i),
|
||||
nn.Sequential(
|
||||
('conv1', nn.Conv2D(
|
||||
self.d_in, self.d_middle, 3, 2, 1, bias_attr=False)
|
||||
), ('bn1', nn.BatchNorm(self.d_middle)),
|
||||
('act', self.act), ('conv2', nn.Conv2D(
|
||||
self.d_middle, self.d_out, 3, 2, 1, bias_attr=False
|
||||
)), ('bn2', nn.BatchNorm(self.d_out)))))
|
||||
return aggs
|
||||
|
||||
def forward(self, x):
|
||||
b = x.shape[0]
|
||||
outs = []
|
||||
for agg in self.aggs:
|
||||
y = agg(x)
|
||||
p = F.adaptive_avg_pool2d(y, 1)
|
||||
outs.append(p.reshape((b, 1, self.d_out)))
|
||||
out = paddle.concat(outs, 1)
|
||||
return out
|
||||
|
||||
|
||||
class WeightAggregate(nn.Layer):
|
||||
def __init__(self, n_r, d_in, d_middle=None, d_out=None):
|
||||
super(WeightAggregate, self).__init__()
|
||||
if not d_middle:
|
||||
d_middle = d_in
|
||||
if not d_out:
|
||||
d_out = d_in
|
||||
|
||||
self.n_r = n_r
|
||||
self.d_out = d_out
|
||||
self.act = nn.Swish()
|
||||
|
||||
self.conv_n = nn.Sequential(
|
||||
('conv1', nn.Conv2D(
|
||||
d_in, d_in, 3, 1, 1,
|
||||
bias_attr=False)), ('bn1', nn.BatchNorm(d_in)),
|
||||
('act1', self.act), ('conv2', nn.Conv2D(
|
||||
d_in, n_r, 1, bias_attr=False)), ('bn2', nn.BatchNorm(n_r)),
|
||||
('act2', nn.Sigmoid()))
|
||||
self.conv_d = nn.Sequential(
|
||||
('conv1', nn.Conv2D(
|
||||
d_in, d_middle, 3, 1, 1,
|
||||
bias_attr=False)), ('bn1', nn.BatchNorm(d_middle)),
|
||||
('act1', self.act), ('conv2', nn.Conv2D(
|
||||
d_middle, d_out, 1,
|
||||
bias_attr=False)), ('bn2', nn.BatchNorm(d_out)))
|
||||
|
||||
def forward(self, x):
|
||||
b, _, h, w = x.shape
|
||||
|
||||
hmaps = self.conv_n(x)
|
||||
fmaps = self.conv_d(x)
|
||||
r = paddle.bmm(
|
||||
hmaps.reshape((b, self.n_r, h * w)),
|
||||
fmaps.reshape((b, self.d_out, h * w)).transpose((0, 2, 1)))
|
||||
return r
|
||||
|
||||
|
||||
class GCN(nn.Layer):
|
||||
def __init__(self, d_in, n_in, d_out=None, n_out=None, dropout=0.1):
|
||||
super(GCN, self).__init__()
|
||||
if not d_out:
|
||||
d_out = d_in
|
||||
if not n_out:
|
||||
n_out = d_in
|
||||
|
||||
self.conv_n = nn.Conv1D(n_in, n_out, 1)
|
||||
self.linear = nn.Linear(d_in, d_out)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.act = nn.Swish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_n(x)
|
||||
x = self.dropout(self.linear(x))
|
||||
return self.act(x)
|
||||
|
||||
|
||||
class PRENFPN(nn.Layer):
|
||||
def __init__(self, in_channels, n_r, d_model, max_len, dropout):
|
||||
super(PRENFPN, self).__init__()
|
||||
assert len(in_channels) == 3, "in_channels' length must be 3."
|
||||
c1, c2, c3 = in_channels # the depths are from big to small
|
||||
# build fpn
|
||||
assert d_model % 3 == 0, "{} can't be divided by 3.".format(d_model)
|
||||
self.agg_p1 = PoolAggregate(n_r, c1, d_out=d_model // 3)
|
||||
self.agg_p2 = PoolAggregate(n_r, c2, d_out=d_model // 3)
|
||||
self.agg_p3 = PoolAggregate(n_r, c3, d_out=d_model // 3)
|
||||
|
||||
self.agg_w1 = WeightAggregate(n_r, c1, 4 * c1, d_model // 3)
|
||||
self.agg_w2 = WeightAggregate(n_r, c2, 4 * c2, d_model // 3)
|
||||
self.agg_w3 = WeightAggregate(n_r, c3, 4 * c3, d_model // 3)
|
||||
|
||||
self.gcn_pool = GCN(d_model, n_r, d_model, max_len, dropout)
|
||||
self.gcn_weight = GCN(d_model, n_r, d_model, max_len, dropout)
|
||||
|
||||
self.out_channels = d_model
|
||||
|
||||
def forward(self, inputs):
|
||||
f3, f5, f7 = inputs
|
||||
|
||||
rp1 = self.agg_p1(f3)
|
||||
rp2 = self.agg_p2(f5)
|
||||
rp3 = self.agg_p3(f7)
|
||||
rp = paddle.concat([rp1, rp2, rp3], 2) # [b,nr,d]
|
||||
|
||||
rw1 = self.agg_w1(f3)
|
||||
rw2 = self.agg_w2(f5)
|
||||
rw3 = self.agg_w3(f7)
|
||||
rw = paddle.concat([rw1, rw2, rw3], 2) # [b,nr,d]
|
||||
|
||||
y1 = self.gcn_pool(rp)
|
||||
y2 = self.gcn_weight(rw)
|
||||
y = 0.5 * (y1 + y2)
|
||||
return y # [b,max_len,d]
|
|
@ -24,8 +24,9 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
||||
|
@ -39,7 +40,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess'
|
||||
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import string
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
import re
|
||||
|
@ -652,3 +652,63 @@ class SARLabelDecode(BaseRecLabelDecode):
|
|||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
||||
|
||||
class PRENLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(PRENLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
padding_str = '<PAD>' # 0
|
||||
end_str = '<EOS>' # 1
|
||||
unknown_str = '<UNK>' # 2
|
||||
|
||||
dict_character = [padding_str, end_str, unknown_str] + dict_character
|
||||
self.padding_idx = 0
|
||||
self.end_idx = 1
|
||||
self.unknown_idx = 2
|
||||
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
batch_size = len(text_index)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] == self.end_idx:
|
||||
break
|
||||
if text_index[batch_idx][idx] in \
|
||||
[self.padding_idx, self.unknown_idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
|
||||
text = ''.join(char_list)
|
||||
if len(text) > 0:
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
else:
|
||||
# here confidence of empty recog result is 1
|
||||
result_list.append(('', 1))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
|
|
@ -28,7 +28,6 @@ from ppocr.modeling.architectures import build_model
|
|||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import load_model
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
||||
|
|
|
@ -55,6 +55,12 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "PREN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 64, 512], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
|
|
|
@ -541,7 +541,7 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN'
|
||||
]
|
||||
|
||||
device = 'cpu'
|
||||
|
|
Loading…
Reference in New Issue