add text recognition algorithm rflearning
parent
6a8a0eeb6e
commit
0002349df3
|
@ -0,0 +1,113 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 6
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 50
|
||||||
|
save_model_dir: ./output/rec/rec_resnet_rfl_att/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 5000]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model: ./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words_en/word_10.png
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
save_res_path: ./output/rec/rec_resnet_rfl.txt
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: AdamW
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
weight_decay: 0.0
|
||||||
|
clip_norm_global: 5.0
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
decay_epochs : [3, 4, 5]
|
||||||
|
values : [0.001, 0.0003, 0.00009, 0.000027]
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RFL
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 1.0
|
||||||
|
model_name: large
|
||||||
|
Backbone:
|
||||||
|
name: ResNetRFL
|
||||||
|
use_cnt: True
|
||||||
|
use_seq: True
|
||||||
|
Neck:
|
||||||
|
name: RFAdaptor
|
||||||
|
use_v2s: True
|
||||||
|
use_s2v: True
|
||||||
|
Head:
|
||||||
|
name: RFLHead
|
||||||
|
in_channels: 512
|
||||||
|
hidden_size: 256
|
||||||
|
batch_max_legnth: 25
|
||||||
|
out_channels: 38
|
||||||
|
use_cnt: True
|
||||||
|
use_seq: True
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: RFLLoss
|
||||||
|
# ignore_index: 0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: RFLLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/rfl_dataset2/training
|
||||||
|
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- RFLLabelEncode: # Class handling label
|
||||||
|
- RFLRecResizeImg:
|
||||||
|
image_shape: [1, 32, 100]
|
||||||
|
padding: false
|
||||||
|
interpolation: 2
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 64
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/rfl_dataset2/evaluation
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- RFLLabelEncode: # Class handling label
|
||||||
|
- RFLRecResizeImg:
|
||||||
|
image_shape: [1, 32, 100]
|
||||||
|
padding: false
|
||||||
|
interpolation: 2
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 8
|
|
@ -0,0 +1,110 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 6
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 50
|
||||||
|
save_model_dir: ./output/rec/rec_resnet_rfl_visual/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 5000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words_en/word_10.png
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
save_res_path: ./output/rec/rec_resnet_rfl_visual.txt
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: AdamW
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
weight_decay: 0.0
|
||||||
|
clip_norm_global: 5.0
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
decay_epochs : [3, 4, 5]
|
||||||
|
values : [0.001, 0.0003, 0.00009, 0.000027]
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RFL
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 1.0
|
||||||
|
model_name: large
|
||||||
|
Backbone:
|
||||||
|
name: ResNetRFL
|
||||||
|
use_cnt: True
|
||||||
|
use_seq: False
|
||||||
|
Neck:
|
||||||
|
name: RFAdaptor
|
||||||
|
use_v2s: False
|
||||||
|
use_s2v: False
|
||||||
|
Head:
|
||||||
|
name: RFLHead
|
||||||
|
in_channels: 512
|
||||||
|
hidden_size: 256
|
||||||
|
batch_max_legnth: 25
|
||||||
|
out_channels: 38
|
||||||
|
use_cnt: True
|
||||||
|
use_seq: False
|
||||||
|
Loss:
|
||||||
|
name: RFLLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: RFLLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: CNTMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/rfl_dataset2/training
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- RFLLabelEncode: # Class handling label
|
||||||
|
- RFLRecResizeImg:
|
||||||
|
image_shape: [1, 32, 100]
|
||||||
|
padding: false
|
||||||
|
interpolation: 2
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 64
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/rfl_dataset2/evaluation
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- RFLLabelEncode: # Class handling label
|
||||||
|
- RFLRecResizeImg:
|
||||||
|
image_shape: [1, 32, 100]
|
||||||
|
padding: false
|
||||||
|
interpolation: 2
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 8
|
|
@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt
|
||||||
|
|
||||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
|
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
|
||||||
|
RFLRecResizeImg
|
||||||
from .ssl_img_aug import SSLRotateResize
|
from .ssl_img_aug import SSLRotateResize
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .copy_paste import CopyPaste
|
from .copy_paste import CopyPaste
|
||||||
|
|
|
@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class RFLLabelEncode(BaseRecLabelEncode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_text_length,
|
||||||
|
character_dict_path=None,
|
||||||
|
use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(RFLLabelEncode, self).__init__(
|
||||||
|
max_text_length, character_dict_path, use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def encode_cnt(self, text):
|
||||||
|
cnt_label = [0.0] * len(self.character)
|
||||||
|
for char_ in text:
|
||||||
|
cnt_label[char_] += 1
|
||||||
|
return np.array(cnt_label)
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
text = data['label']
|
||||||
|
text = self.encode(text)
|
||||||
|
if text is None:
|
||||||
|
return None
|
||||||
|
if len(text) >= self.max_text_len:
|
||||||
|
return None
|
||||||
|
cnt_label = self.encode_cnt(text)
|
||||||
|
data['length'] = np.array(len(text))
|
||||||
|
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||||
|
- len(text) - 2)
|
||||||
|
if len(text) != self.max_text_len:
|
||||||
|
return None
|
||||||
|
data['label'] = np.array(text)
|
||||||
|
data['cnt_label'] = cnt_label
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
class SEEDLabelEncode(BaseRecLabelEncode):
|
class SEEDLabelEncode(BaseRecLabelEncode):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
|
|
@ -237,6 +237,33 @@ class VLRecResizeImg(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class RFLRecResizeImg(object):
|
||||||
|
def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
self.interpolation = interpolation
|
||||||
|
if self.interpolation == 0:
|
||||||
|
self.interpolation = cv2.INTER_NEAREST
|
||||||
|
elif self.interpolation == 1:
|
||||||
|
self.interpolation = cv2.INTER_LINEAR
|
||||||
|
elif self.interpolation == 2:
|
||||||
|
self.interpolation = cv2.INTER_CUBIC
|
||||||
|
elif self.interpolation == 3:
|
||||||
|
self.interpolation = cv2.INTER_AREA
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported interpolation type !!!")
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
norm_img, valid_ratio = resize_norm_img(
|
||||||
|
img, self.image_shape, self.padding, self.interpolation)
|
||||||
|
data['image'] = norm_img
|
||||||
|
data['valid_ratio'] = valid_ratio
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class SRNRecResizeImg(object):
|
class SRNRecResizeImg(object):
|
||||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||||
self.image_shape = image_shape
|
self.image_shape = image_shape
|
||||||
|
@ -414,8 +441,13 @@ class SVTRRecResizeImg(object):
|
||||||
data['valid_ratio'] = valid_ratio
|
data['valid_ratio'] = valid_ratio
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class RobustScannerRecResizeImg(object):
|
class RobustScannerRecResizeImg(object):
|
||||||
def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
|
def __init__(self,
|
||||||
|
image_shape,
|
||||||
|
max_text_length,
|
||||||
|
width_downsample_ratio=0.25,
|
||||||
|
**kwargs):
|
||||||
self.image_shape = image_shape
|
self.image_shape = image_shape
|
||||||
self.width_downsample_ratio = width_downsample_ratio
|
self.width_downsample_ratio = width_downsample_ratio
|
||||||
self.max_text_length = max_text_length
|
self.max_text_length = max_text_length
|
||||||
|
@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object):
|
||||||
data['word_positons'] = word_positons
|
data['word_positons'] = word_positons
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
|
@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||||
|
|
||||||
|
|
||||||
def resize_norm_img(img, image_shape, padding=True):
|
def resize_norm_img(img,
|
||||||
|
image_shape,
|
||||||
|
padding=True,
|
||||||
|
interpolation=cv2.INTER_LINEAR):
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
w = img.shape[1]
|
w = img.shape[1]
|
||||||
if not padding:
|
if not padding:
|
||||||
resized_image = cv2.resize(
|
resized_image = cv2.resize(
|
||||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
img, (imgW, imgH), interpolation=interpolation)
|
||||||
resized_w = imgW
|
resized_w = imgW
|
||||||
else:
|
else:
|
||||||
ratio = w / float(h)
|
ratio = w / float(h)
|
||||||
|
|
|
@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss
|
||||||
from .rec_multi_loss import MultiLoss
|
from .rec_multi_loss import MultiLoss
|
||||||
from .rec_vl_loss import VLLoss
|
from .rec_vl_loss import VLLoss
|
||||||
from .rec_spin_att_loss import SPINAttentionLoss
|
from .rec_spin_att_loss import SPINAttentionLoss
|
||||||
|
from .rec_rfl_loss import RFLLoss
|
||||||
|
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
@ -69,7 +70,7 @@ def build_loss(config):
|
||||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
|
||||||
'SLALoss', 'CTLoss'
|
'SLALoss', 'CTLoss', 'RFLLoss'
|
||||||
]
|
]
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from .basic_loss import CELoss, DistanceLoss
|
||||||
|
|
||||||
|
|
||||||
|
class RFLLoss(nn.Layer):
|
||||||
|
def __init__(self, ignore_index=-100, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cnt_loss = nn.MSELoss(**kwargs)
|
||||||
|
self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
|
||||||
|
self.total_loss = {}
|
||||||
|
total_loss = 0.0
|
||||||
|
# batch [image, label, length, cnt_label]
|
||||||
|
if predicts[0] is not None:
|
||||||
|
cnt_loss = self.cnt_loss(predicts[0],
|
||||||
|
paddle.cast(batch[3], paddle.float32))
|
||||||
|
self.total_loss['cnt_loss'] = cnt_loss
|
||||||
|
total_loss += cnt_loss
|
||||||
|
|
||||||
|
if predicts[1] is not None:
|
||||||
|
targets = batch[1].astype("int64")
|
||||||
|
label_lengths = batch[2].astype('int64')
|
||||||
|
batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
|
||||||
|
1].shape[1], predicts[1].shape[2]
|
||||||
|
assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
|
||||||
|
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||||
|
|
||||||
|
inputs = predicts[1][:, :-1, :]
|
||||||
|
targets = targets[:, 1:]
|
||||||
|
|
||||||
|
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
|
||||||
|
targets = paddle.reshape(targets, [-1])
|
||||||
|
seq_loss = self.seq_loss(inputs, targets)
|
||||||
|
self.total_loss['seq_loss'] = seq_loss
|
||||||
|
total_loss += seq_loss
|
||||||
|
|
||||||
|
self.total_loss['loss'] = total_loss
|
||||||
|
return self.total_loss
|
|
@ -22,7 +22,7 @@ import copy
|
||||||
__all__ = ["build_metric"]
|
__all__ = ["build_metric"]
|
||||||
|
|
||||||
from .det_metric import DetMetric, DetFCEMetric
|
from .det_metric import DetMetric, DetFCEMetric
|
||||||
from .rec_metric import RecMetric
|
from .rec_metric import RecMetric, CNTMetric
|
||||||
from .cls_metric import ClsMetric
|
from .cls_metric import ClsMetric
|
||||||
from .e2e_metric import E2EMetric
|
from .e2e_metric import E2EMetric
|
||||||
from .distillation_metric import DistillationMetric
|
from .distillation_metric import DistillationMetric
|
||||||
|
@ -38,7 +38,7 @@ def build_metric(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
|
||||||
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
|
||||||
'VQAReTokenMetric', 'SRMetric', 'CTMetric'
|
'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein
|
||||||
import string
|
import string
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RecMetric(object):
|
class RecMetric(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
main_indicator='acc',
|
main_indicator='acc',
|
||||||
|
@ -74,3 +73,42 @@ class RecMetric(object):
|
||||||
self.correct_num = 0
|
self.correct_num = 0
|
||||||
self.all_num = 0
|
self.all_num = 0
|
||||||
self.norm_edit_dis = 0
|
self.norm_edit_dis = 0
|
||||||
|
|
||||||
|
|
||||||
|
class CNTMetric(object):
|
||||||
|
def __init__(self, main_indicator='acc', **kwargs):
|
||||||
|
self.main_indicator = main_indicator
|
||||||
|
self.eps = 1e-5
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def _normalize_text(self, text):
|
||||||
|
text = ''.join(
|
||||||
|
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
def __call__(self, pred_label, *args, **kwargs):
|
||||||
|
preds, labels = pred_label
|
||||||
|
correct_num = 0
|
||||||
|
all_num = 0
|
||||||
|
for pred, target in zip(preds, labels):
|
||||||
|
if pred == target:
|
||||||
|
correct_num += 1
|
||||||
|
all_num += 1
|
||||||
|
self.correct_num += correct_num
|
||||||
|
self.all_num += all_num
|
||||||
|
return {'acc': correct_num / (all_num + self.eps), }
|
||||||
|
|
||||||
|
def get_metric(self):
|
||||||
|
"""
|
||||||
|
return metrics {
|
||||||
|
'acc': 0,
|
||||||
|
'norm_edit_dis': 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
|
||||||
|
self.reset()
|
||||||
|
return {'acc': acc}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.correct_num = 0
|
||||||
|
self.all_num = 0
|
||||||
|
|
|
@ -42,10 +42,11 @@ def build_backbone(config, model_type):
|
||||||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||||
from .rec_svtrnet import SVTRNet
|
from .rec_svtrnet import SVTRNet
|
||||||
from .rec_vitstr import ViTSTR
|
from .rec_vitstr import ViTSTR
|
||||||
|
from .rec_resnet_rfl import ResNetRFL
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||||
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
|
||||||
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
|
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
|
||||||
]
|
]
|
||||||
elif model_type == 'e2e':
|
elif model_type == 'e2e':
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
|
|
|
@ -0,0 +1,348 @@
|
||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
|
||||||
|
|
||||||
|
kaiming_init_ = KaimingNormal()
|
||||||
|
zeros_ = Constant(value=0.)
|
||||||
|
ones_ = Constant(value=1.)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Layer):
|
||||||
|
"""Res-net Basic Block"""
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
stride=1,
|
||||||
|
downsample=None,
|
||||||
|
norm_type='BN',
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inplanes (int): input channel
|
||||||
|
planes (int): channels of the middle feature
|
||||||
|
stride (int): stride of the convolution
|
||||||
|
downsample (int): type of the down_sample
|
||||||
|
norm_type (str): type of the normalization
|
||||||
|
**kwargs (None): backup parameter
|
||||||
|
"""
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv1 = self._conv3x3(inplanes, planes)
|
||||||
|
self.bn1 = nn.BatchNorm(planes)
|
||||||
|
self.conv2 = self._conv3x3(planes, planes)
|
||||||
|
self.bn2 = nn.BatchNorm(planes)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def _conv3x3(self, in_planes, out_planes, stride=1):
|
||||||
|
|
||||||
|
return nn.Conv2D(
|
||||||
|
in_planes,
|
||||||
|
out_planes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
out += residual
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetRFL(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels=512,
|
||||||
|
use_cnt=True,
|
||||||
|
use_seq=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): input channel
|
||||||
|
out_channels (int): output channel
|
||||||
|
"""
|
||||||
|
super(ResNetRFL, self).__init__()
|
||||||
|
assert use_cnt or use_seq
|
||||||
|
self.use_cnt, self.use_seq = use_cnt, use_seq
|
||||||
|
self.backbone = RFLBase(in_channels)
|
||||||
|
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.out_channels_block = [
|
||||||
|
int(self.out_channels / 4), int(self.out_channels / 2),
|
||||||
|
self.out_channels, self.out_channels
|
||||||
|
]
|
||||||
|
block = BasicBlock
|
||||||
|
layers = [1, 2, 5, 3]
|
||||||
|
self.inplanes = int(self.out_channels // 2)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
if self.use_seq:
|
||||||
|
self.maxpool3 = nn.MaxPool2D(
|
||||||
|
kernel_size=2, stride=(2, 1), padding=(0, 1))
|
||||||
|
self.layer3 = self._make_layer(
|
||||||
|
block, self.out_channels_block[2], layers[2], stride=1)
|
||||||
|
self.conv3 = nn.Conv2D(
|
||||||
|
self.out_channels_block[2],
|
||||||
|
self.out_channels_block[2],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn3 = nn.BatchNorm(self.out_channels_block[2])
|
||||||
|
|
||||||
|
self.layer4 = self._make_layer(
|
||||||
|
block, self.out_channels_block[3], layers[3], stride=1)
|
||||||
|
self.conv4_1 = nn.Conv2D(
|
||||||
|
self.out_channels_block[3],
|
||||||
|
self.out_channels_block[3],
|
||||||
|
kernel_size=2,
|
||||||
|
stride=(2, 1),
|
||||||
|
padding=(0, 1),
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
|
||||||
|
self.conv4_2 = nn.Conv2D(
|
||||||
|
self.out_channels_block[3],
|
||||||
|
self.out_channels_block[3],
|
||||||
|
kernel_size=2,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
|
||||||
|
|
||||||
|
if self.use_cnt:
|
||||||
|
self.inplanes = int(self.out_channels // 2)
|
||||||
|
self.v_maxpool3 = nn.MaxPool2D(
|
||||||
|
kernel_size=2, stride=(2, 1), padding=(0, 1))
|
||||||
|
self.v_layer3 = self._make_layer(
|
||||||
|
block, self.out_channels_block[2], layers[2], stride=1)
|
||||||
|
self.v_conv3 = nn.Conv2D(
|
||||||
|
self.out_channels_block[2],
|
||||||
|
self.out_channels_block[2],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
|
||||||
|
|
||||||
|
self.v_layer4 = self._make_layer(
|
||||||
|
block, self.out_channels_block[3], layers[3], stride=1)
|
||||||
|
self.v_conv4_1 = nn.Conv2D(
|
||||||
|
self.out_channels_block[3],
|
||||||
|
self.out_channels_block[3],
|
||||||
|
kernel_size=2,
|
||||||
|
stride=(2, 1),
|
||||||
|
padding=(0, 1),
|
||||||
|
bias_attr=False)
|
||||||
|
self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
|
||||||
|
self.v_conv4_2 = nn.Conv2D(
|
||||||
|
self.out_channels_block[3],
|
||||||
|
self.out_channels_block[3],
|
||||||
|
kernel_size=2,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
bias_attr=False)
|
||||||
|
self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2D(
|
||||||
|
self.inplanes,
|
||||||
|
planes * block.expansion,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias_attr=False),
|
||||||
|
nn.BatchNorm(planes * block.expansion), )
|
||||||
|
|
||||||
|
layers = list()
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x_1 = self.backbone(inputs)
|
||||||
|
|
||||||
|
if self.use_cnt:
|
||||||
|
v_x = self.v_maxpool3(x_1)
|
||||||
|
v_x = self.v_layer3(v_x)
|
||||||
|
v_x = self.v_conv3(v_x)
|
||||||
|
v_x = self.v_bn3(v_x)
|
||||||
|
visual_feature_2 = self.relu(v_x)
|
||||||
|
|
||||||
|
v_x = self.v_layer4(visual_feature_2)
|
||||||
|
v_x = self.v_conv4_1(v_x)
|
||||||
|
v_x = self.v_bn4_1(v_x)
|
||||||
|
v_x = self.relu(v_x)
|
||||||
|
v_x = self.v_conv4_2(v_x)
|
||||||
|
v_x = self.v_bn4_2(v_x)
|
||||||
|
visual_feature_3 = self.relu(v_x)
|
||||||
|
else:
|
||||||
|
visual_feature_3 = None
|
||||||
|
if self.use_seq:
|
||||||
|
x = self.maxpool3(x_1)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
x_2 = self.relu(x)
|
||||||
|
|
||||||
|
x = self.layer4(x_2)
|
||||||
|
x = self.conv4_1(x)
|
||||||
|
x = self.bn4_1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv4_2(x)
|
||||||
|
x = self.bn4_2(x)
|
||||||
|
x_3 = self.relu(x)
|
||||||
|
else:
|
||||||
|
x_3 = None
|
||||||
|
|
||||||
|
return [visual_feature_3, x_3]
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetBase(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, block, layers):
|
||||||
|
super(ResNetBase, self).__init__()
|
||||||
|
|
||||||
|
self.out_channels_block = [
|
||||||
|
int(out_channels / 4), int(out_channels / 2), out_channels,
|
||||||
|
out_channels
|
||||||
|
]
|
||||||
|
|
||||||
|
self.inplanes = int(out_channels / 8)
|
||||||
|
self.conv0_1 = nn.Conv2D(
|
||||||
|
in_channels,
|
||||||
|
int(out_channels / 16),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
|
||||||
|
self.conv0_2 = nn.Conv2D(
|
||||||
|
int(out_channels / 16),
|
||||||
|
self.inplanes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn0_2 = nn.BatchNorm(self.inplanes)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||||
|
self.layer1 = self._make_layer(block, self.out_channels_block[0],
|
||||||
|
layers[0])
|
||||||
|
self.conv1 = nn.Conv2D(
|
||||||
|
self.out_channels_block[0],
|
||||||
|
self.out_channels_block[0],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn1 = nn.BatchNorm(self.out_channels_block[0])
|
||||||
|
|
||||||
|
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||||
|
self.layer2 = self._make_layer(
|
||||||
|
block, self.out_channels_block[1], layers[1], stride=1)
|
||||||
|
self.conv2 = nn.Conv2D(
|
||||||
|
self.out_channels_block[1],
|
||||||
|
self.out_channels_block[1],
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias_attr=False)
|
||||||
|
self.bn2 = nn.BatchNorm(self.out_channels_block[1])
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.Sequential(
|
||||||
|
nn.Conv2D(
|
||||||
|
self.inplanes,
|
||||||
|
planes * block.expansion,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=stride,
|
||||||
|
bias_attr=False),
|
||||||
|
nn.BatchNorm(planes * block.expansion), )
|
||||||
|
|
||||||
|
layers = list()
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv0_1(x)
|
||||||
|
x = self.bn0_1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv0_2(x)
|
||||||
|
x = self.bn0_2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.maxpool1(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.maxpool2(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RFLBase(nn.Layer):
|
||||||
|
""" Reciprocal feature learning share backbone network"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels=512):
|
||||||
|
super(RFLBase, self).__init__()
|
||||||
|
self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
|
||||||
|
[1, 2, 5, 3])
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
return self.ConvNet(inputs)
|
|
@ -38,6 +38,7 @@ def build_head(config):
|
||||||
from .rec_abinet_head import ABINetHead
|
from .rec_abinet_head import ABINetHead
|
||||||
from .rec_robustscanner_head import RobustScannerHead
|
from .rec_robustscanner_head import RobustScannerHead
|
||||||
from .rec_visionlan_head import VLHead
|
from .rec_visionlan_head import VLHead
|
||||||
|
from .rec_rfl_head import RFLHead
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
|
@ -53,7 +54,7 @@ def build_head(config):
|
||||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||||
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
|
||||||
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
|
'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
|
||||||
]
|
]
|
||||||
|
|
||||||
#table head
|
#table head
|
||||||
|
|
|
@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer):
|
||||||
else:
|
else:
|
||||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
probs = None
|
probs = None
|
||||||
|
char_onehots = None
|
||||||
|
alpha = None
|
||||||
|
|
||||||
for i in range(num_steps):
|
for i in range(num_steps):
|
||||||
char_onehots = self._char_to_onehot(
|
char_onehots = self._char_to_onehot(
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py
|
||||||
|
"""
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
|
||||||
|
|
||||||
|
from .rec_att_head import AttentionLSTM
|
||||||
|
|
||||||
|
kaiming_init_ = KaimingNormal()
|
||||||
|
zeros_ = Constant(value=0.)
|
||||||
|
ones_ = Constant(value=1.)
|
||||||
|
|
||||||
|
|
||||||
|
class CNTHead(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
embed_size=512,
|
||||||
|
encode_length=26,
|
||||||
|
out_channels=38,
|
||||||
|
**kwargs):
|
||||||
|
super(CNTHead, self).__init__()
|
||||||
|
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False)
|
||||||
|
self.Prediction_visual = nn.Linear(encode_length * embed_size,
|
||||||
|
self.out_channels)
|
||||||
|
|
||||||
|
def forward(self, visual_feature):
|
||||||
|
|
||||||
|
b, c, h, w = visual_feature.shape
|
||||||
|
visual_feature = visual_feature.reshape([b, c, h * w]).transpose(
|
||||||
|
[0, 2, 1])
|
||||||
|
visual_feature_num = self.Wv_fusion(visual_feature) # batch * 26 * 512
|
||||||
|
b, n, c = visual_feature_num.shape
|
||||||
|
# using visual feature directly calculate the text length
|
||||||
|
visual_feature_num = visual_feature_num.reshape([b, n * c])
|
||||||
|
prediction_visual = self.Prediction_visual(visual_feature_num)
|
||||||
|
|
||||||
|
return prediction_visual
|
||||||
|
|
||||||
|
|
||||||
|
class RFLHead(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=512,
|
||||||
|
hidden_size=256,
|
||||||
|
batch_max_legnth=25,
|
||||||
|
out_channels=38,
|
||||||
|
use_cnt=True,
|
||||||
|
use_seq=True,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super(RFLHead, self).__init__()
|
||||||
|
assert use_cnt or use_seq
|
||||||
|
self.use_cnt = use_cnt
|
||||||
|
self.use_seq = use_seq
|
||||||
|
if self.use_cnt:
|
||||||
|
self.cnt_head = CNTHead(
|
||||||
|
embed_size=in_channels,
|
||||||
|
encode_length=batch_max_legnth + 1,
|
||||||
|
out_channels=out_channels,
|
||||||
|
**kwargs)
|
||||||
|
if self.use_seq:
|
||||||
|
self.seq_head = AttentionLSTM(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
**kwargs)
|
||||||
|
self.batch_max_legnth = batch_max_legnth
|
||||||
|
self.num_class = out_channels
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
kaiming_init_(m.weight)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward(self, x, targets=None):
|
||||||
|
cnt_inputs, seq_inputs = x
|
||||||
|
if self.use_cnt:
|
||||||
|
cnt_outputs = self.cnt_head(cnt_inputs)
|
||||||
|
else:
|
||||||
|
cnt_outputs = None
|
||||||
|
if self.use_seq:
|
||||||
|
if self.training:
|
||||||
|
seq_outputs = self.seq_head(seq_inputs, targets[0],
|
||||||
|
self.batch_max_legnth)
|
||||||
|
else:
|
||||||
|
seq_outputs = self.seq_head(seq_inputs, None,
|
||||||
|
self.batch_max_legnth)
|
||||||
|
else:
|
||||||
|
seq_outputs = None
|
||||||
|
|
||||||
|
return cnt_outputs, seq_outputs
|
|
@ -27,9 +27,11 @@ def build_neck(config):
|
||||||
from .pren_fpn import PRENFPN
|
from .pren_fpn import PRENFPN
|
||||||
from .csp_pan import CSPPAN
|
from .csp_pan import CSPPAN
|
||||||
from .ct_fpn import CTFPN
|
from .ct_fpn import CTFPN
|
||||||
|
from .rf_adaptor import RFAdaptor
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
|
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
|
||||||
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
|
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
|
||||||
|
'RFAdaptor'
|
||||||
]
|
]
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,137 @@
|
||||||
|
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This code is refer from:
|
||||||
|
https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
|
||||||
|
|
||||||
|
kaiming_init_ = KaimingNormal()
|
||||||
|
zeros_ = Constant(value=0.)
|
||||||
|
ones_ = Constant(value=1.)
|
||||||
|
|
||||||
|
|
||||||
|
class S2VAdaptor(nn.Layer):
|
||||||
|
""" Semantic to Visual adaptation module"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels=512):
|
||||||
|
super(S2VAdaptor, self).__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels # 512
|
||||||
|
|
||||||
|
# feature strengthen module, channel attention
|
||||||
|
self.channel_inter = nn.Linear(
|
||||||
|
self.in_channels, self.in_channels, bias_attr=False)
|
||||||
|
self.channel_bn = nn.BatchNorm1D(self.in_channels)
|
||||||
|
self.channel_act = nn.ReLU()
|
||||||
|
self.apply(self.init_weights)
|
||||||
|
|
||||||
|
def init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Conv2D):
|
||||||
|
kaiming_init_(m.weight)
|
||||||
|
if isinstance(m, nn.Conv2D) and m.bias is not None:
|
||||||
|
zeros_(m.bias)
|
||||||
|
elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm1D)):
|
||||||
|
zeros_(m.bias)
|
||||||
|
ones_(m.weight)
|
||||||
|
|
||||||
|
def forward(self, semantic):
|
||||||
|
semantic_source = semantic # batch, channel, height, width
|
||||||
|
|
||||||
|
# feature transformation
|
||||||
|
semantic = semantic.squeeze(2).transpose(
|
||||||
|
[0, 2, 1]) # batch, width, channel
|
||||||
|
channel_att = self.channel_inter(semantic) # batch, width, channel
|
||||||
|
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
|
||||||
|
channel_bn = self.channel_bn(channel_att) # batch, channel, width
|
||||||
|
channel_att = self.channel_act(channel_bn) # batch, channel, width
|
||||||
|
|
||||||
|
# Feature enhancement
|
||||||
|
channel_output = semantic_source * channel_att.unsqueeze(
|
||||||
|
-2) # batch, channel, 1, width
|
||||||
|
|
||||||
|
return channel_output
|
||||||
|
|
||||||
|
|
||||||
|
class V2SAdaptor(nn.Layer):
|
||||||
|
""" Visual to Semantic adaptation module"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels=512, return_mask=False):
|
||||||
|
super(V2SAdaptor, self).__init__()
|
||||||
|
|
||||||
|
# parameter initialization
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.return_mask = return_mask
|
||||||
|
|
||||||
|
# output transformation
|
||||||
|
self.channel_inter = nn.Linear(
|
||||||
|
self.in_channels, self.in_channels, bias_attr=False)
|
||||||
|
self.channel_bn = nn.BatchNorm1D(self.in_channels)
|
||||||
|
self.channel_act = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, visual):
|
||||||
|
# Feature enhancement
|
||||||
|
visual = visual.squeeze(2).transpose([0, 2, 1]) # batch, width, channel
|
||||||
|
channel_att = self.channel_inter(visual) # batch, width, channel
|
||||||
|
channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
|
||||||
|
channel_bn = self.channel_bn(channel_att) # batch, channel, width
|
||||||
|
channel_att = self.channel_act(channel_bn) # batch, channel, width
|
||||||
|
|
||||||
|
# size alignment
|
||||||
|
channel_output = channel_att.unsqueeze(-2) # batch, width, channel
|
||||||
|
|
||||||
|
if self.return_mask:
|
||||||
|
return channel_output, channel_att
|
||||||
|
return channel_output
|
||||||
|
|
||||||
|
|
||||||
|
class RFAdaptor(nn.Layer):
|
||||||
|
def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs):
|
||||||
|
super(RFAdaptor, self).__init__()
|
||||||
|
if use_v2s is True:
|
||||||
|
self.neck_v2s = V2SAdaptor(in_channels=in_channels, **kwargs)
|
||||||
|
else:
|
||||||
|
self.neck_v2s = None
|
||||||
|
if use_s2v is True:
|
||||||
|
self.neck_s2v = S2VAdaptor(in_channels=in_channels, **kwargs)
|
||||||
|
else:
|
||||||
|
self.neck_s2v = None
|
||||||
|
self.out_channels = in_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
visual_feature, rcg_feature = x
|
||||||
|
if visual_feature is not None:
|
||||||
|
batch, source_channels, v_source_height, v_source_width = visual_feature.shape
|
||||||
|
visual_feature = visual_feature.reshape(
|
||||||
|
[batch, source_channels, 1, v_source_height * v_source_width])
|
||||||
|
|
||||||
|
if self.neck_v2s is not None:
|
||||||
|
v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature)
|
||||||
|
else:
|
||||||
|
v_rcg_feature = rcg_feature
|
||||||
|
|
||||||
|
if self.neck_s2v is not None:
|
||||||
|
v_visual_feature = visual_feature + self.neck_s2v(rcg_feature)
|
||||||
|
else:
|
||||||
|
v_visual_feature = visual_feature
|
||||||
|
if v_rcg_feature is not None:
|
||||||
|
batch, source_channels, source_height, source_width = v_rcg_feature.shape
|
||||||
|
v_rcg_feature = v_rcg_feature.reshape(
|
||||||
|
[batch, source_channels, 1, source_height * source_width])
|
||||||
|
|
||||||
|
v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1])
|
||||||
|
return v_visual_feature, v_rcg_feature
|
|
@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model):
|
||||||
if 'clip_norm' in config:
|
if 'clip_norm' in config:
|
||||||
clip_norm = config.pop('clip_norm')
|
clip_norm = config.pop('clip_norm')
|
||||||
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
|
||||||
|
elif 'clip_norm_global' in config:
|
||||||
|
clip_norm = config.pop('clip_norm_global')
|
||||||
|
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
|
||||||
else:
|
else:
|
||||||
grad_clip = None
|
grad_clip = None
|
||||||
optim = getattr(optimizer, optim_name)(learning_rate=lr,
|
optim = getattr(optimizer, optim_name)(learning_rate=lr,
|
||||||
|
|
|
@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||||
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||||
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
|
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
|
||||||
SPINLabelDecode, VLLabelDecode
|
SPINLabelDecode, VLLabelDecode, RFLLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
from .pg_postprocess import PGPostProcess
|
from .pg_postprocess import PGPostProcess
|
||||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
|
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
|
||||||
|
@ -49,7 +49,7 @@ def build_post_process(config, global_config=None):
|
||||||
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
|
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
|
||||||
'TableMasterLabelDecode', 'SPINLabelDecode',
|
'TableMasterLabelDecode', 'SPINLabelDecode',
|
||||||
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
'DistillationSerPostProcess', 'DistillationRePostProcess',
|
||||||
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
|
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
|
||||||
]
|
]
|
||||||
|
|
||||||
if config['name'] == 'PSEPostProcess':
|
if config['name'] == 'PSEPostProcess':
|
||||||
|
|
|
@ -242,6 +242,92 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class RFLLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(RFLLabelDecode, self).__init__(character_dict_path,
|
||||||
|
use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = dict_character
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list).tolist()))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
cnt_pred, preds = preds
|
||||||
|
if preds is not None:
|
||||||
|
|
||||||
|
if isinstance(preds, paddle.Tensor):
|
||||||
|
preds = preds.numpy()
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
else:
|
||||||
|
cnt_length = []
|
||||||
|
for lens in cnt_pred:
|
||||||
|
length = round(paddle.sum(lens).item())
|
||||||
|
cnt_length.append(length)
|
||||||
|
if label is None:
|
||||||
|
return cnt_length
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
length = [len(res[0]) for res in label]
|
||||||
|
return cnt_length, length
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
class SEEDLabelDecode(BaseRecLabelDecode):
|
class SEEDLabelDecode(BaseRecLabelDecode):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 6
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 50
|
||||||
|
save_model_dir: ./output/rec/rec_resnet_rfl/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 5000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words_en/word_10.png
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
save_res_path: ./output/rec/rec_resnet_rfl.txt
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: AdamW
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
weight_decay: 0.0
|
||||||
|
clip_norm_global: 5.0
|
||||||
|
lr:
|
||||||
|
name: Piecewise
|
||||||
|
decay_epochs : [3, 4, 5]
|
||||||
|
values : [0.001, 0.0003, 0.00009, 0.000027]
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RFL
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 1.0
|
||||||
|
model_name: large
|
||||||
|
Backbone:
|
||||||
|
name: ResNetRFL
|
||||||
|
use_cnt: True
|
||||||
|
use_seq: True
|
||||||
|
Neck:
|
||||||
|
name: RFAdaptor
|
||||||
|
use_v2s: True
|
||||||
|
use_s2v: True
|
||||||
|
Head:
|
||||||
|
name: RFLHead
|
||||||
|
in_channels: 512
|
||||||
|
hidden_size: 256
|
||||||
|
batch_max_legnth: 25
|
||||||
|
out_channels: 38
|
||||||
|
use_cnt: True
|
||||||
|
use_seq: True
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: RFLLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: RFLLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/ic15_data/
|
||||||
|
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- RFLLabelEncode: # Class handling label
|
||||||
|
- RFLRecResizeImg:
|
||||||
|
image_shape: [1, 32, 100]
|
||||||
|
interpolation: 2
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 64
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: ./train_data/ic15_data
|
||||||
|
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- RFLLabelEncode: # Class handling label
|
||||||
|
- RFLRecResizeImg:
|
||||||
|
image_shape: [1, 32, 100]
|
||||||
|
interpolation: 2
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 8
|
|
@ -0,0 +1,53 @@
|
||||||
|
===========================train_params===========================
|
||||||
|
model_name:rec_resnet_rfl
|
||||||
|
python:python3.7
|
||||||
|
gpu_list:0|0,1
|
||||||
|
Global.use_gpu:True|True
|
||||||
|
Global.auto_cast:null
|
||||||
|
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
|
||||||
|
Global.save_model_dir:./output/
|
||||||
|
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
|
||||||
|
Global.pretrained_model:null
|
||||||
|
train_model_name:latest
|
||||||
|
train_infer_img_dir:./inference/rec_inference
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
trainer:norm_train
|
||||||
|
norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||||
|
pact_train:null
|
||||||
|
fpgm_train:null
|
||||||
|
distill_train:null
|
||||||
|
null:null
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================eval_params===========================
|
||||||
|
eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||||
|
null:null
|
||||||
|
##
|
||||||
|
===========================infer_params===========================
|
||||||
|
Global.save_inference_dir:./output/
|
||||||
|
Global.checkpoints:
|
||||||
|
norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||||
|
quant_export:null
|
||||||
|
fpgm_export:null
|
||||||
|
distill_export:null
|
||||||
|
export1:null
|
||||||
|
export2:null
|
||||||
|
##
|
||||||
|
train_model:./inference/rec_resnet_rfl_train/best_accuracy
|
||||||
|
infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
|
||||||
|
infer_quant:False
|
||||||
|
inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5
|
||||||
|
--use_gpu:True|False
|
||||||
|
--enable_mkldnn:False
|
||||||
|
--cpu_threads:6
|
||||||
|
--rec_batch_num:1
|
||||||
|
--use_tensorrt:False
|
||||||
|
--precision:fp32
|
||||||
|
--rec_model_dir:
|
||||||
|
--image_dir:./inference/rec_inference
|
||||||
|
--save_log_path:./test/output/
|
||||||
|
--benchmark:True
|
||||||
|
null:null
|
||||||
|
===========================infer_benchmark_params==========================
|
||||||
|
random_infer_input:[{float32,[1,32,100]}]
|
|
@ -99,7 +99,7 @@ def export_single_model(model,
|
||||||
]
|
]
|
||||||
# print([None, 3, 32, 128])
|
# print([None, 3, 32, 128])
|
||||||
model = to_static(model, input_spec=other_shape)
|
model = to_static(model, input_spec=other_shape)
|
||||||
elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
|
elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
|
||||||
other_shape = [
|
other_shape = [
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
shape=[None, 1, 32, 100], dtype="float32"),
|
shape=[None, 1, 32, 100], dtype="float32"),
|
||||||
|
|
|
@ -100,6 +100,12 @@ class TextRecognizer(object):
|
||||||
"use_space_char": args.use_space_char,
|
"use_space_char": args.use_space_char,
|
||||||
"rm_symbol": True
|
"rm_symbol": True
|
||||||
}
|
}
|
||||||
|
elif self.rec_algorithm == 'RFL':
|
||||||
|
postprocess_params = {
|
||||||
|
'name': 'RFLLabelDecode',
|
||||||
|
"character_dict_path": None,
|
||||||
|
"use_space_char": args.use_space_char
|
||||||
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
|
@ -143,6 +149,16 @@ class TextRecognizer(object):
|
||||||
else:
|
else:
|
||||||
norm_img = norm_img.astype(np.float32) / 128. - 1.
|
norm_img = norm_img.astype(np.float32) / 128. - 1.
|
||||||
return norm_img
|
return norm_img
|
||||||
|
elif self.rec_algorithm == 'RFL':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
resized_image = cv2.resize(
|
||||||
|
img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
resized_image = resized_image / 255
|
||||||
|
resized_image = resized_image[np.newaxis, :]
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
return resized_image
|
||||||
|
|
||||||
assert imgC == img.shape[2]
|
assert imgC == img.shape[2]
|
||||||
imgW = int((imgH * max_wh_ratio))
|
imgW = int((imgH * max_wh_ratio))
|
||||||
|
|
|
@ -217,7 +217,7 @@ def train(config,
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
extra_input_models = [
|
extra_input_models = [
|
||||||
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
|
||||||
"RobustScanner"
|
"RobustScanner", "RFL"
|
||||||
]
|
]
|
||||||
extra_input = False
|
extra_input = False
|
||||||
if config['Architecture']['algorithm'] == 'Distillation':
|
if config['Architecture']['algorithm'] == 'Distillation':
|
||||||
|
@ -625,7 +625,7 @@ def preprocess(is_train=False):
|
||||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
|
||||||
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
|
||||||
'Gestalt', 'SLANet', 'RobustScanner', 'CT'
|
'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
|
||||||
]
|
]
|
||||||
|
|
||||||
if use_xpu:
|
if use_xpu:
|
||||||
|
|
Loading…
Reference in New Issue