From 0002349df3baed66b5149698cbb8cb5d4c1df3ee Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Tue, 27 Sep 2022 10:54:31 +0800
Subject: [PATCH 01/10] add text recognition algorithm rflearning
---
configs/rec/rec_resnet_rfl_att.yml | 113 ++++++
configs/rec/rec_resnet_rfl_visual.yml | 110 ++++++
ppocr/data/imaug/__init__.py | 3 +-
ppocr/data/imaug/label_ops.py | 56 +++
ppocr/data/imaug/rec_img_aug.py | 42 ++-
ppocr/losses/__init__.py | 3 +-
ppocr/losses/rec_rfl_loss.py | 61 +++
ppocr/metrics/__init__.py | 4 +-
ppocr/metrics/rec_metric.py | 40 +-
ppocr/modeling/backbones/__init__.py | 3 +-
ppocr/modeling/backbones/rec_resnet_rfl.py | 348 ++++++++++++++++++
ppocr/modeling/heads/__init__.py | 3 +-
ppocr/modeling/heads/rec_att_head.py | 2 +
ppocr/modeling/heads/rec_rfl_head.py | 109 ++++++
ppocr/modeling/necks/__init__.py | 4 +-
ppocr/modeling/necks/rf_adaptor.py | 137 +++++++
ppocr/optimizer/__init__.py | 3 +
ppocr/postprocess/__init__.py | 4 +-
ppocr/postprocess/rec_postprocess.py | 86 +++++
.../configs/rec_resnet_rfl/rec_resnet_rfl.yml | 111 ++++++
.../rec_resnet_rfl/train_infer_python.txt | 53 +++
tools/export_model.py | 2 +-
tools/infer/predict_rec.py | 16 +
tools/program.py | 4 +-
24 files changed, 1301 insertions(+), 16 deletions(-)
create mode 100644 configs/rec/rec_resnet_rfl_att.yml
create mode 100644 configs/rec/rec_resnet_rfl_visual.yml
create mode 100644 ppocr/losses/rec_rfl_loss.py
create mode 100644 ppocr/modeling/backbones/rec_resnet_rfl.py
create mode 100644 ppocr/modeling/heads/rec_rfl_head.py
create mode 100644 ppocr/modeling/necks/rf_adaptor.py
create mode 100644 test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
create mode 100644 test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml
new file mode 100644
index 000000000..f8332e082
--- /dev/null
+++ b/configs/rec/rec_resnet_rfl_att.yml
@@ -0,0 +1,113 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 20
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_resnet_rfl_att/
+ save_epoch_step: 1
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 5000]
+ cal_metric_during_train: True
+ pretrained_model: ./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path:
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/rec_resnet_rfl.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.0
+ clip_norm_global: 5.0
+ lr:
+ name: Piecewise
+ decay_epochs : [3, 4, 5]
+ values : [0.001, 0.0003, 0.00009, 0.000027]
+
+Architecture:
+ model_type: rec
+ algorithm: RFL
+ in_channels: 1
+ Transform:
+ name: TPS
+ num_fiducial: 20
+ loc_lr: 1.0
+ model_name: large
+ Backbone:
+ name: ResNetRFL
+ use_cnt: True
+ use_seq: True
+ Neck:
+ name: RFAdaptor
+ use_v2s: True
+ use_s2v: True
+ Head:
+ name: RFLHead
+ in_channels: 512
+ hidden_size: 256
+ batch_max_legnth: 25
+ out_channels: 38
+ use_cnt: True
+ use_seq: True
+
+Loss:
+ name: RFLLoss
+ # ignore_index: 0
+
+PostProcess:
+ name: RFLLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/rfl_dataset2/training
+
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - RFLLabelEncode: # Class handling label
+ - RFLRecResizeImg:
+ image_shape: [1, 32, 100]
+ padding: false
+ interpolation: 2
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 64
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/rfl_dataset2/evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - RFLLabelEncode: # Class handling label
+ - RFLRecResizeImg:
+ image_shape: [1, 32, 100]
+ padding: false
+ interpolation: 2
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 8
diff --git a/configs/rec/rec_resnet_rfl_visual.yml b/configs/rec/rec_resnet_rfl_visual.yml
new file mode 100644
index 000000000..438d2ef0c
--- /dev/null
+++ b/configs/rec/rec_resnet_rfl_visual.yml
@@ -0,0 +1,110 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 20
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_resnet_rfl_visual/
+ save_epoch_step: 1
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 5000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path:
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/rec_resnet_rfl_visual.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.0
+ clip_norm_global: 5.0
+ lr:
+ name: Piecewise
+ decay_epochs : [3, 4, 5]
+ values : [0.001, 0.0003, 0.00009, 0.000027]
+
+Architecture:
+ model_type: rec
+ algorithm: RFL
+ in_channels: 1
+ Transform:
+ name: TPS
+ num_fiducial: 20
+ loc_lr: 1.0
+ model_name: large
+ Backbone:
+ name: ResNetRFL
+ use_cnt: True
+ use_seq: False
+ Neck:
+ name: RFAdaptor
+ use_v2s: False
+ use_s2v: False
+ Head:
+ name: RFLHead
+ in_channels: 512
+ hidden_size: 256
+ batch_max_legnth: 25
+ out_channels: 38
+ use_cnt: True
+ use_seq: False
+Loss:
+ name: RFLLoss
+
+PostProcess:
+ name: RFLLabelDecode
+
+Metric:
+ name: CNTMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/rfl_dataset2/training
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - RFLLabelEncode: # Class handling label
+ - RFLRecResizeImg:
+ image_shape: [1, 32, 100]
+ padding: false
+ interpolation: 2
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 64
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/rfl_dataset2/evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - RFLLabelEncode: # Class handling label
+ - RFLRecResizeImg:
+ image_shape: [1, 32, 100]
+ padding: false
+ interpolation: 2
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 8
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 863988ccc..db0a489d5 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
- ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
+ ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
+ RFLRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index dbfb93176..590caf1c4 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode):
return idx
+class RFLLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(RFLLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
+ return dict_character
+
+ def encode_cnt(self, text):
+ cnt_label = [0.0] * len(self.character)
+ for char_ in text:
+ cnt_label[char_] += 1
+ return np.array(cnt_label)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ cnt_label = self.encode_cnt(text)
+ data['length'] = np.array(len(text))
+ text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
+ - len(text) - 2)
+ if len(text) != self.max_text_len:
+ return None
+ data['label'] = np.array(text)
+ data['cnt_label'] = cnt_label
+ return data
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
+
+
class SEEDLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 89022d85a..e22153bde 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -237,6 +237,33 @@ class VLRecResizeImg(object):
return data
+class RFLRecResizeImg(object):
+ def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
+ self.image_shape = image_shape
+ self.padding = padding
+
+ self.interpolation = interpolation
+ if self.interpolation == 0:
+ self.interpolation = cv2.INTER_NEAREST
+ elif self.interpolation == 1:
+ self.interpolation = cv2.INTER_LINEAR
+ elif self.interpolation == 2:
+ self.interpolation = cv2.INTER_CUBIC
+ elif self.interpolation == 3:
+ self.interpolation = cv2.INTER_AREA
+ else:
+ raise Exception("Unsupported interpolation type !!!")
+
+ def __call__(self, data):
+ img = data['image']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ norm_img, valid_ratio = resize_norm_img(
+ img, self.image_shape, self.padding, self.interpolation)
+ data['image'] = norm_img
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
@@ -414,8 +441,13 @@ class SVTRRecResizeImg(object):
data['valid_ratio'] = valid_ratio
return data
+
class RobustScannerRecResizeImg(object):
- def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
+ def __init__(self,
+ image_shape,
+ max_text_length,
+ width_downsample_ratio=0.25,
+ **kwargs):
self.image_shape = image_shape
self.width_downsample_ratio = width_downsample_ratio
self.max_text_length = max_text_length
@@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object):
data['word_positons'] = word_positons
return data
+
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
@@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
return padding_im, resize_shape, pad_shape, valid_ratio
-def resize_norm_img(img, image_shape, padding=True):
+def resize_norm_img(img,
+ image_shape,
+ padding=True,
+ interpolation=cv2.INTER_LINEAR):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
if not padding:
resized_image = cv2.resize(
- img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+ img, (imgW, imgH), interpolation=interpolation)
resized_w = imgW
else:
ratio = w / float(h)
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 02525b3d5..ffee0a93e 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss
+from .rec_rfl_loss import RFLLoss
# cls loss
from .cls_loss import ClsLoss
@@ -69,7 +70,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss'
+ 'SLALoss', 'CTLoss', 'RFLLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py
new file mode 100644
index 000000000..8e9d7d039
--- /dev/null
+++ b/ppocr/losses/rec_rfl_loss.py
@@ -0,0 +1,61 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+from .basic_loss import CELoss, DistanceLoss
+
+
+class RFLLoss(nn.Layer):
+ def __init__(self, ignore_index=-100, **kwargs):
+ super().__init__()
+
+ self.cnt_loss = nn.MSELoss(**kwargs)
+ self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
+
+ def forward(self, predicts, batch):
+
+ self.total_loss = {}
+ total_loss = 0.0
+ # batch [image, label, length, cnt_label]
+ if predicts[0] is not None:
+ cnt_loss = self.cnt_loss(predicts[0],
+ paddle.cast(batch[3], paddle.float32))
+ self.total_loss['cnt_loss'] = cnt_loss
+ total_loss += cnt_loss
+
+ if predicts[1] is not None:
+ targets = batch[1].astype("int64")
+ label_lengths = batch[2].astype('int64')
+ batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
+ 1].shape[1], predicts[1].shape[2]
+ assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
+ "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+ inputs = predicts[1][:, :-1, :]
+ targets = targets[:, 1:]
+
+ inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
+ targets = paddle.reshape(targets, [-1])
+ seq_loss = self.seq_loss(inputs, targets)
+ self.total_loss['seq_loss'] = seq_loss
+ total_loss += seq_loss
+
+ self.total_loss['loss'] = total_loss
+ return self.total_loss
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index a39d0a464..20aea8b59 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -22,7 +22,7 @@ import copy
__all__ = ["build_metric"]
from .det_metric import DetMetric, DetFCEMetric
-from .rec_metric import RecMetric
+from .rec_metric import RecMetric, CNTMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
@@ -38,7 +38,7 @@ def build_metric(config):
support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
- 'VQAReTokenMetric', 'SRMetric', 'CTMetric'
+ 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 986397811..34d6ff3a3 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein
import string
-
class RecMetric(object):
def __init__(self,
main_indicator='acc',
@@ -74,3 +73,42 @@ class RecMetric(object):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0
+
+
+class CNTMetric(object):
+ def __init__(self, main_indicator='acc', **kwargs):
+ self.main_indicator = main_indicator
+ self.eps = 1e-5
+ self.reset()
+
+ def _normalize_text(self, text):
+ text = ''.join(
+ filter(lambda x: x in (string.digits + string.ascii_letters), text))
+ return text.lower()
+
+ def __call__(self, pred_label, *args, **kwargs):
+ preds, labels = pred_label
+ correct_num = 0
+ all_num = 0
+ for pred, target in zip(preds, labels):
+ if pred == target:
+ correct_num += 1
+ all_num += 1
+ self.correct_num += correct_num
+ self.all_num += all_num
+ return {'acc': correct_num / (all_num + self.eps), }
+
+ def get_metric(self):
+ """
+ return metrics {
+ 'acc': 0,
+ 'norm_edit_dis': 0,
+ }
+ """
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ self.reset()
+ return {'acc': acc}
+
+ def reset(self):
+ self.correct_num = 0
+ self.all_num = 0
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 6fdcc4a75..84892fa9c 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -42,10 +42,11 @@ def build_backbone(config, model_type):
from .rec_efficientb3_pren import EfficientNetb3_PREN
from .rec_svtrnet import SVTRNet
from .rec_vitstr import ViTSTR
+ from .rec_resnet_rfl import ResNetRFL
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
- 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
+ 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_resnet_rfl.py b/ppocr/modeling/backbones/rec_resnet_rfl.py
new file mode 100644
index 000000000..fd317c6ea
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_rfl.py
@@ -0,0 +1,348 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
+
+kaiming_init_ = KaimingNormal()
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+class BasicBlock(nn.Layer):
+ """Res-net Basic Block"""
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ norm_type='BN',
+ **kwargs):
+ """
+ Args:
+ inplanes (int): input channel
+ planes (int): channels of the middle feature
+ stride (int): stride of the convolution
+ downsample (int): type of the down_sample
+ norm_type (str): type of the normalization
+ **kwargs (None): backup parameter
+ """
+ super(BasicBlock, self).__init__()
+ self.conv1 = self._conv3x3(inplanes, planes)
+ self.bn1 = nn.BatchNorm(planes)
+ self.conv2 = self._conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm(planes)
+ self.relu = nn.ReLU()
+ self.downsample = downsample
+ self.stride = stride
+
+ def _conv3x3(self, in_planes, out_planes, stride=1):
+
+ return nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False)
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetRFL(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels=512,
+ use_cnt=True,
+ use_seq=True):
+ """
+
+ Args:
+ in_channels (int): input channel
+ out_channels (int): output channel
+ """
+ super(ResNetRFL, self).__init__()
+ assert use_cnt or use_seq
+ self.use_cnt, self.use_seq = use_cnt, use_seq
+ self.backbone = RFLBase(in_channels)
+
+ self.out_channels = out_channels
+ self.out_channels_block = [
+ int(self.out_channels / 4), int(self.out_channels / 2),
+ self.out_channels, self.out_channels
+ ]
+ block = BasicBlock
+ layers = [1, 2, 5, 3]
+ self.inplanes = int(self.out_channels // 2)
+
+ self.relu = nn.ReLU()
+ if self.use_seq:
+ self.maxpool3 = nn.MaxPool2D(
+ kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.layer3 = self._make_layer(
+ block, self.out_channels_block[2], layers[2], stride=1)
+ self.conv3 = nn.Conv2D(
+ self.out_channels_block[2],
+ self.out_channels_block[2],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn3 = nn.BatchNorm(self.out_channels_block[2])
+
+ self.layer4 = self._make_layer(
+ block, self.out_channels_block[3], layers[3], stride=1)
+ self.conv4_1 = nn.Conv2D(
+ self.out_channels_block[3],
+ self.out_channels_block[3],
+ kernel_size=2,
+ stride=(2, 1),
+ padding=(0, 1),
+ bias_attr=False)
+ self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
+ self.conv4_2 = nn.Conv2D(
+ self.out_channels_block[3],
+ self.out_channels_block[3],
+ kernel_size=2,
+ stride=1,
+ padding=0,
+ bias_attr=False)
+ self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
+
+ if self.use_cnt:
+ self.inplanes = int(self.out_channels // 2)
+ self.v_maxpool3 = nn.MaxPool2D(
+ kernel_size=2, stride=(2, 1), padding=(0, 1))
+ self.v_layer3 = self._make_layer(
+ block, self.out_channels_block[2], layers[2], stride=1)
+ self.v_conv3 = nn.Conv2D(
+ self.out_channels_block[2],
+ self.out_channels_block[2],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
+
+ self.v_layer4 = self._make_layer(
+ block, self.out_channels_block[3], layers[3], stride=1)
+ self.v_conv4_1 = nn.Conv2D(
+ self.out_channels_block[3],
+ self.out_channels_block[3],
+ kernel_size=2,
+ stride=(2, 1),
+ padding=(0, 1),
+ bias_attr=False)
+ self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
+ self.v_conv4_2 = nn.Conv2D(
+ self.out_channels_block[3],
+ self.out_channels_block[3],
+ kernel_size=2,
+ stride=1,
+ padding=0,
+ bias_attr=False)
+ self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias_attr=False),
+ nn.BatchNorm(planes * block.expansion), )
+
+ layers = list()
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, inputs):
+ x_1 = self.backbone(inputs)
+
+ if self.use_cnt:
+ v_x = self.v_maxpool3(x_1)
+ v_x = self.v_layer3(v_x)
+ v_x = self.v_conv3(v_x)
+ v_x = self.v_bn3(v_x)
+ visual_feature_2 = self.relu(v_x)
+
+ v_x = self.v_layer4(visual_feature_2)
+ v_x = self.v_conv4_1(v_x)
+ v_x = self.v_bn4_1(v_x)
+ v_x = self.relu(v_x)
+ v_x = self.v_conv4_2(v_x)
+ v_x = self.v_bn4_2(v_x)
+ visual_feature_3 = self.relu(v_x)
+ else:
+ visual_feature_3 = None
+ if self.use_seq:
+ x = self.maxpool3(x_1)
+ x = self.layer3(x)
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x_2 = self.relu(x)
+
+ x = self.layer4(x_2)
+ x = self.conv4_1(x)
+ x = self.bn4_1(x)
+ x = self.relu(x)
+ x = self.conv4_2(x)
+ x = self.bn4_2(x)
+ x_3 = self.relu(x)
+ else:
+ x_3 = None
+
+ return [visual_feature_3, x_3]
+
+
+class ResNetBase(nn.Layer):
+ def __init__(self, in_channels, out_channels, block, layers):
+ super(ResNetBase, self).__init__()
+
+ self.out_channels_block = [
+ int(out_channels / 4), int(out_channels / 2), out_channels,
+ out_channels
+ ]
+
+ self.inplanes = int(out_channels / 8)
+ self.conv0_1 = nn.Conv2D(
+ in_channels,
+ int(out_channels / 16),
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
+ self.conv0_2 = nn.Conv2D(
+ int(out_channels / 16),
+ self.inplanes,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn0_2 = nn.BatchNorm(self.inplanes)
+ self.relu = nn.ReLU()
+
+ self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.layer1 = self._make_layer(block, self.out_channels_block[0],
+ layers[0])
+ self.conv1 = nn.Conv2D(
+ self.out_channels_block[0],
+ self.out_channels_block[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn1 = nn.BatchNorm(self.out_channels_block[0])
+
+ self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+ self.layer2 = self._make_layer(
+ block, self.out_channels_block[1], layers[1], stride=1)
+ self.conv2 = nn.Conv2D(
+ self.out_channels_block[1],
+ self.out_channels_block[1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias_attr=False)
+ self.bn2 = nn.BatchNorm(self.out_channels_block[1])
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2D(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias_attr=False),
+ nn.BatchNorm(planes * block.expansion), )
+
+ layers = list()
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv0_1(x)
+ x = self.bn0_1(x)
+ x = self.relu(x)
+ x = self.conv0_2(x)
+ x = self.bn0_2(x)
+ x = self.relu(x)
+
+ x = self.maxpool1(x)
+ x = self.layer1(x)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.maxpool2(x)
+ x = self.layer2(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ return x
+
+
+class RFLBase(nn.Layer):
+ """ Reciprocal feature learning share backbone network"""
+
+ def __init__(self, in_channels, out_channels=512):
+ super(RFLBase, self).__init__()
+ self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
+ [1, 2, 5, 3])
+
+ def forward(self, inputs):
+ return self.ConvNet(inputs)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 751757e5f..ba180566c 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -38,6 +38,7 @@ def build_head(config):
from .rec_abinet_head import ABINetHead
from .rec_robustscanner_head import RobustScannerHead
from .rec_visionlan_head import VLHead
+ from .rec_rfl_head import RFLHead
# cls head
from .cls_head import ClsHead
@@ -53,7 +54,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
- 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
+ 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
]
#table head
diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py
index ab8b119fe..d5cf1cd16 100644
--- a/ppocr/modeling/heads/rec_att_head.py
+++ b/ppocr/modeling/heads/rec_att_head.py
@@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer):
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
+ char_onehots = None
+ alpha = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py
new file mode 100644
index 000000000..b5452ec1c
--- /dev/null
+++ b/ppocr/modeling/heads/rec_rfl_head.py
@@ -0,0 +1,109 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py
+"""
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
+
+from .rec_att_head import AttentionLSTM
+
+kaiming_init_ = KaimingNormal()
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+class CNTHead(nn.Layer):
+ def __init__(self,
+ embed_size=512,
+ encode_length=26,
+ out_channels=38,
+ **kwargs):
+ super(CNTHead, self).__init__()
+
+ self.out_channels = out_channels
+
+ self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False)
+ self.Prediction_visual = nn.Linear(encode_length * embed_size,
+ self.out_channels)
+
+ def forward(self, visual_feature):
+
+ b, c, h, w = visual_feature.shape
+ visual_feature = visual_feature.reshape([b, c, h * w]).transpose(
+ [0, 2, 1])
+ visual_feature_num = self.Wv_fusion(visual_feature) # batch * 26 * 512
+ b, n, c = visual_feature_num.shape
+ # using visual feature directly calculate the text length
+ visual_feature_num = visual_feature_num.reshape([b, n * c])
+ prediction_visual = self.Prediction_visual(visual_feature_num)
+
+ return prediction_visual
+
+
+class RFLHead(nn.Layer):
+ def __init__(self,
+ in_channels=512,
+ hidden_size=256,
+ batch_max_legnth=25,
+ out_channels=38,
+ use_cnt=True,
+ use_seq=True,
+ **kwargs):
+
+ super(RFLHead, self).__init__()
+ assert use_cnt or use_seq
+ self.use_cnt = use_cnt
+ self.use_seq = use_seq
+ if self.use_cnt:
+ self.cnt_head = CNTHead(
+ embed_size=in_channels,
+ encode_length=batch_max_legnth + 1,
+ out_channels=out_channels,
+ **kwargs)
+ if self.use_seq:
+ self.seq_head = AttentionLSTM(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ hidden_size=hidden_size,
+ **kwargs)
+ self.batch_max_legnth = batch_max_legnth
+ self.num_class = out_channels
+ self.apply(self.init_weights)
+
+ def init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ kaiming_init_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ zeros_(m.bias)
+
+ def forward(self, x, targets=None):
+ cnt_inputs, seq_inputs = x
+ if self.use_cnt:
+ cnt_outputs = self.cnt_head(cnt_inputs)
+ else:
+ cnt_outputs = None
+ if self.use_seq:
+ if self.training:
+ seq_outputs = self.seq_head(seq_inputs, targets[0],
+ self.batch_max_legnth)
+ else:
+ seq_outputs = self.seq_head(seq_inputs, None,
+ self.batch_max_legnth)
+ else:
+ seq_outputs = None
+
+ return cnt_outputs, seq_outputs
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index c7e8dd068..a94d223a1 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -27,9 +27,11 @@ def build_neck(config):
from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN
from .ct_fpn import CTFPN
+ from .rf_adaptor import RFAdaptor
support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
- 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
+ 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
+ 'RFAdaptor'
]
module_name = config.pop('name')
diff --git a/ppocr/modeling/necks/rf_adaptor.py b/ppocr/modeling/necks/rf_adaptor.py
new file mode 100644
index 000000000..94590127b
--- /dev/null
+++ b/ppocr/modeling/necks/rf_adaptor.py
@@ -0,0 +1,137 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py
+"""
+
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
+
+kaiming_init_ = KaimingNormal()
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+class S2VAdaptor(nn.Layer):
+ """ Semantic to Visual adaptation module"""
+
+ def __init__(self, in_channels=512):
+ super(S2VAdaptor, self).__init__()
+
+ self.in_channels = in_channels # 512
+
+ # feature strengthen module, channel attention
+ self.channel_inter = nn.Linear(
+ self.in_channels, self.in_channels, bias_attr=False)
+ self.channel_bn = nn.BatchNorm1D(self.in_channels)
+ self.channel_act = nn.ReLU()
+ self.apply(self.init_weights)
+
+ def init_weights(self, m):
+ if isinstance(m, nn.Conv2D):
+ kaiming_init_(m.weight)
+ if isinstance(m, nn.Conv2D) and m.bias is not None:
+ zeros_(m.bias)
+ elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm1D)):
+ zeros_(m.bias)
+ ones_(m.weight)
+
+ def forward(self, semantic):
+ semantic_source = semantic # batch, channel, height, width
+
+ # feature transformation
+ semantic = semantic.squeeze(2).transpose(
+ [0, 2, 1]) # batch, width, channel
+ channel_att = self.channel_inter(semantic) # batch, width, channel
+ channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
+ channel_bn = self.channel_bn(channel_att) # batch, channel, width
+ channel_att = self.channel_act(channel_bn) # batch, channel, width
+
+ # Feature enhancement
+ channel_output = semantic_source * channel_att.unsqueeze(
+ -2) # batch, channel, 1, width
+
+ return channel_output
+
+
+class V2SAdaptor(nn.Layer):
+ """ Visual to Semantic adaptation module"""
+
+ def __init__(self, in_channels=512, return_mask=False):
+ super(V2SAdaptor, self).__init__()
+
+ # parameter initialization
+ self.in_channels = in_channels
+ self.return_mask = return_mask
+
+ # output transformation
+ self.channel_inter = nn.Linear(
+ self.in_channels, self.in_channels, bias_attr=False)
+ self.channel_bn = nn.BatchNorm1D(self.in_channels)
+ self.channel_act = nn.ReLU()
+
+ def forward(self, visual):
+ # Feature enhancement
+ visual = visual.squeeze(2).transpose([0, 2, 1]) # batch, width, channel
+ channel_att = self.channel_inter(visual) # batch, width, channel
+ channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width
+ channel_bn = self.channel_bn(channel_att) # batch, channel, width
+ channel_att = self.channel_act(channel_bn) # batch, channel, width
+
+ # size alignment
+ channel_output = channel_att.unsqueeze(-2) # batch, width, channel
+
+ if self.return_mask:
+ return channel_output, channel_att
+ return channel_output
+
+
+class RFAdaptor(nn.Layer):
+ def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs):
+ super(RFAdaptor, self).__init__()
+ if use_v2s is True:
+ self.neck_v2s = V2SAdaptor(in_channels=in_channels, **kwargs)
+ else:
+ self.neck_v2s = None
+ if use_s2v is True:
+ self.neck_s2v = S2VAdaptor(in_channels=in_channels, **kwargs)
+ else:
+ self.neck_s2v = None
+ self.out_channels = in_channels
+
+ def forward(self, x):
+ visual_feature, rcg_feature = x
+ if visual_feature is not None:
+ batch, source_channels, v_source_height, v_source_width = visual_feature.shape
+ visual_feature = visual_feature.reshape(
+ [batch, source_channels, 1, v_source_height * v_source_width])
+
+ if self.neck_v2s is not None:
+ v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature)
+ else:
+ v_rcg_feature = rcg_feature
+
+ if self.neck_s2v is not None:
+ v_visual_feature = visual_feature + self.neck_s2v(rcg_feature)
+ else:
+ v_visual_feature = visual_feature
+ if v_rcg_feature is not None:
+ batch, source_channels, source_height, source_width = v_rcg_feature.shape
+ v_rcg_feature = v_rcg_feature.reshape(
+ [batch, source_channels, 1, source_height * source_width])
+
+ v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1])
+ return v_visual_feature, v_rcg_feature
diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py
index a6bd2ebb4..b92954c9c 100644
--- a/ppocr/optimizer/__init__.py
+++ b/ppocr/optimizer/__init__.py
@@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model):
if 'clip_norm' in config:
clip_norm = config.pop('clip_norm')
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
+ elif 'clip_norm_global' in config:
+ clip_norm = config.pop('clip_norm_global')
+ grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr,
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 35b7a6800..b5715967b 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
- SPINLabelDecode, VLLabelDecode
+ SPINLabelDecode, VLLabelDecode, RFLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
@@ -49,7 +49,7 @@ def build_post_process(config, global_config=None):
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
- 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
+ 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 749060a05..e754c950b 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -242,6 +242,92 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx
+class RFLLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(RFLLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = dict_character
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ [beg_idx, end_idx] = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if int(text_index[batch_idx][idx]) == int(end_idx):
+ break
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list).tolist()))
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ cnt_pred, preds = preds
+ if preds is not None:
+
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+
+ if label is None:
+ return text
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+
+ else:
+ cnt_length = []
+ for lens in cnt_pred:
+ length = round(paddle.sum(lens).item())
+ cnt_length.append(length)
+ if label is None:
+ return cnt_length
+ label = self.decode(label, is_remove_duplicate=False)
+ length = [len(res[0]) for res in label]
+ return cnt_length, length
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
+
+
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
diff --git a/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml b/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
new file mode 100644
index 000000000..b4f18f5c0
--- /dev/null
+++ b/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
@@ -0,0 +1,111 @@
+Global:
+ use_gpu: True
+ epoch_num: 6
+ log_smooth_window: 20
+ print_batch_step: 50
+ save_model_dir: ./output/rec/rec_resnet_rfl/
+ save_epoch_step: 1
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 5000]
+ cal_metric_during_train: False
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words_en/word_10.png
+ # for data or label process
+ character_dict_path:
+ max_text_length: 25
+ infer_mode: False
+ use_space_char: False
+ save_res_path: ./output/rec/rec_resnet_rfl.txt
+
+
+Optimizer:
+ name: AdamW
+ beta1: 0.9
+ beta2: 0.999
+ weight_decay: 0.0
+ clip_norm_global: 5.0
+ lr:
+ name: Piecewise
+ decay_epochs : [3, 4, 5]
+ values : [0.001, 0.0003, 0.00009, 0.000027]
+
+Architecture:
+ model_type: rec
+ algorithm: RFL
+ in_channels: 1
+ Transform:
+ name: TPS
+ num_fiducial: 20
+ loc_lr: 1.0
+ model_name: large
+ Backbone:
+ name: ResNetRFL
+ use_cnt: True
+ use_seq: True
+ Neck:
+ name: RFAdaptor
+ use_v2s: True
+ use_s2v: True
+ Head:
+ name: RFLHead
+ in_channels: 512
+ hidden_size: 256
+ batch_max_legnth: 25
+ out_channels: 38
+ use_cnt: True
+ use_seq: True
+
+Loss:
+ name: RFLLoss
+
+PostProcess:
+ name: RFLLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data/
+ label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - RFLLabelEncode: # Class handling label
+ - RFLRecResizeImg:
+ image_shape: [1, 32, 100]
+ interpolation: 2
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+ loader:
+ shuffle: True
+ batch_size_per_card: 64
+ drop_last: True
+ num_workers: 8
+
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/ic15_data
+ label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - RFLLabelEncode: # Class handling label
+ - RFLRecResizeImg:
+ image_shape: [1, 32, 100]
+ interpolation: 2
+ - KeepKeys:
+ keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 256
+ num_workers: 8
diff --git a/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt b/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
new file mode 100644
index 000000000..091e962b2
--- /dev/null
+++ b/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_resnet_rfl
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_resnet_rfl_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[1,32,100]}]
diff --git a/tools/export_model.py b/tools/export_model.py
index 8610df83e..9c23060ee 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -99,7 +99,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
+ elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 176e2c68e..697e9da43 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -100,6 +100,12 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char,
"rm_symbol": True
}
+ elif self.rec_algorithm == 'RFL':
+ postprocess_params = {
+ 'name': 'RFLLabelDecode',
+ "character_dict_path": None,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
@@ -143,6 +149,16 @@ class TextRecognizer(object):
else:
norm_img = norm_img.astype(np.float32) / 128. - 1.
return norm_img
+ elif self.rec_algorithm == 'RFL':
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
+ resized_image = resized_image.astype('float32')
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ resized_image -= 0.5
+ resized_image /= 0.5
+ return resized_image
assert imgC == img.shape[2]
imgW = int((imgH * max_wh_ratio))
diff --git a/tools/program.py b/tools/program.py
index 9117d51b9..129e926be 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -217,7 +217,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
- "RobustScanner"
+ "RobustScanner", "RFL"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
@@ -625,7 +625,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
- 'Gestalt', 'SLANet', 'RobustScanner', 'CT'
+ 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
]
if use_xpu:
From 394ee0117767a28e22964a66ace799f14e3e75b0 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Wed, 28 Sep 2022 17:40:43 +0800
Subject: [PATCH 02/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=95=B0=E6=8D=AE?=
=?UTF-8?q?=E9=9B=86=E8=B7=AF=E5=BE=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
configs/rec/rec_resnet_rfl_att.yml | 4 ++--
configs/rec/rec_resnet_rfl_visual.yml | 4 ++--
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml
index f8332e082..3c9c1f886 100644
--- a/configs/rec/rec_resnet_rfl_att.yml
+++ b/configs/rec/rec_resnet_rfl_att.yml
@@ -72,7 +72,7 @@ Metric:
Train:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/rfl_dataset2/training
+ data_dir: ./train_data/data_lmdb_release/training
transforms:
- DecodeImage: # load image
@@ -94,7 +94,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/rfl_dataset2/evaluation
+ data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
img_mode: BGR
diff --git a/configs/rec/rec_resnet_rfl_visual.yml b/configs/rec/rec_resnet_rfl_visual.yml
index 438d2ef0c..5eaea08ce 100644
--- a/configs/rec/rec_resnet_rfl_visual.yml
+++ b/configs/rec/rec_resnet_rfl_visual.yml
@@ -70,7 +70,7 @@ Metric:
Train:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/rfl_dataset2/training
+ data_dir: ./train_data/data_lmdb_release/training
transforms:
- DecodeImage: # load image
img_mode: BGR
@@ -91,7 +91,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
- data_dir: ./train_data/rfl_dataset2/evaluation
+ data_dir: ./train_data/data_lmdb_release/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
From 5a380afb1ecb49a44fbfbe011dc1e1827d43affa Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Wed, 28 Sep 2022 17:42:13 +0800
Subject: [PATCH 03/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=BC=95=E7=94=A8?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ppocr/losses/rec_rfl_loss.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py
index 8e9d7d039..0921406c1 100644
--- a/ppocr/losses/rec_rfl_loss.py
+++ b/ppocr/losses/rec_rfl_loss.py
@@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+"""
+This code is refer from:
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_common/models/loss/cross_entropy_loss.py
+"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
From b580fa0517bec79b6fceab2262d125376da564d1 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Wed, 28 Sep 2022 17:43:00 +0800
Subject: [PATCH 04/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8?=
=?UTF-8?q?=E5=87=BD=E6=95=B0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ppocr/metrics/rec_metric.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 34d6ff3a3..4758e71d0 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -81,11 +81,6 @@ class CNTMetric(object):
self.eps = 1e-5
self.reset()
- def _normalize_text(self, text):
- text = ''.join(
- filter(lambda x: x in (string.digits + string.ascii_letters), text))
- return text.lower()
-
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
@@ -102,7 +97,6 @@ class CNTMetric(object):
"""
return metrics {
'acc': 0,
- 'norm_edit_dis': 0,
}
"""
acc = 1.0 * self.correct_num / (self.all_num + self.eps)
From 154f42f1b07f273fd2ebb550c7074934c81d16e8 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Wed, 28 Sep 2022 17:44:13 +0800
Subject: [PATCH 05/10] =?UTF-8?q?=E6=8E=A8=E7=90=86=E6=97=B6=E6=B7=BB?=
=?UTF-8?q?=E5=8A=A0softmax?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ppocr/modeling/heads/rec_att_head.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py
index d5cf1cd16..6349ee0c2 100644
--- a/ppocr/modeling/heads/rec_att_head.py
+++ b/ppocr/modeling/heads/rec_att_head.py
@@ -169,7 +169,8 @@ class AttentionLSTM(nn.Layer):
next_input = probs_step.argmax(axis=1)
targets = next_input
-
+ if not self.training:
+ probs = paddle.nn.functional.softmax(probs, axis=2)
return probs
From 035d7e39069e9f18109812401a5cc39c6a90b016 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Wed, 28 Sep 2022 17:50:28 +0800
Subject: [PATCH 06/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99?=
=?UTF-8?q?=E7=A9=BA=E8=A1=8C?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
configs/rec/rec_resnet_rfl_att.yml | 1 -
1 file changed, 1 deletion(-)
diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml
index 3c9c1f886..b9fb74176 100644
--- a/configs/rec/rec_resnet_rfl_att.yml
+++ b/configs/rec/rec_resnet_rfl_att.yml
@@ -73,7 +73,6 @@ Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training
-
transforms:
- DecodeImage: # load image
img_mode: BGR
From 6735cced20bf905b88bb9fe9d63f811cca896481 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Thu, 29 Sep 2022 14:21:47 +0800
Subject: [PATCH 07/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0RFL=E6=96=87=E6=A1=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
doc/doc_ch/algorithm_overview.md | 2 +
doc/doc_ch/algorithm_rec_rfl.md | 161 ++++++++++++++++++++++++++++
doc/doc_en/algorithm_overview_en.md | 3 +-
doc/doc_en/algorithm_rec_rfl_en.md | 143 ++++++++++++++++++++++++
4 files changed, 308 insertions(+), 1 deletion(-)
create mode 100644 doc/doc_ch/algorithm_rec_rfl.md
create mode 100644 doc/doc_en/algorithm_rec_rfl_en.md
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index ecb0e9dfe..4351fdbfc 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -79,6 +79,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
- [x] [VisionLAN](./algorithm_rec_visionlan.md)
- [x] [SPIN](./algorithm_rec_spin.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner.md)
+- [x] [RFL](./algorithm_rec_rfl.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -102,6 +103,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
+|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
diff --git a/doc/doc_ch/algorithm_rec_rfl.md b/doc/doc_ch/algorithm_rec_rfl.md
new file mode 100644
index 000000000..5135e77a1
--- /dev/null
+++ b/doc/doc_ch/algorithm_rec_rfl.md
@@ -0,0 +1,161 @@
+# 场景文本识别算法-RFL
+
+- [1. 算法简介](#1)
+- [2. 环境配置](#2)
+- [3. 模型训练、评估、预测](#3)
+ - [3.1 训练](#3-1)
+ - [3.2 评估](#3-2)
+ - [3.3 预测](#3-3)
+- [4. 推理部署](#4)
+ - [4.1 Python推理](#4-1)
+ - [4.2 C++推理](#4-2)
+ - [4.3 Serving服务化部署](#4-3)
+ - [4.4 更多推理部署](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. 算法简介
+
+论文信息:
+> [Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition](https://arxiv.org/abs/2105.06229.pdf)
+> Hui Jiang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Wenqi Ren, Fei Wu, and Wenming Tan
+> ICDAR, 2021
+
+
+
+`RFL`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
+
+|模型|骨干网络|配置文件|Acc|下载链接|
+| --- | --- | --- | --- | --- |
+|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
+|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
+
+
+## 2. 环境配置
+请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
+
+
+
+## 3. 模型训练、评估、预测
+
+
+### 3.1 模型训练
+
+请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。
+
+#### 启动训练
+
+
+具体地,在完成数据准备后,便可以启动训练,训练命令如下:
+```shell
+#step1:训练CNT分支
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
+
+#step2:联合训练CNT和Att分支,注意将pretrained_model的路径设置为本地路径。
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_visual/best_accuracy
+
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_visual/best_accuracy
+```
+
+
+### 3.2 评估
+
+可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar),使用如下命令进行评估:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy
+```
+
+
+### 3.3 预测
+
+使用如下命令进行单张图片预测:
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/infer_rec.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy
+# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
+```
+
+
+
+## 4. 推理部署
+
+
+### 4.1 Python推理
+首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) ),可以使用如下命令进行转换:
+
+```shell
+# 注意将pretrained_model的路径设置为本地路径。
+python3 tools/export_model.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy Global.save_inference_dir=./inference/rec_resnet_rfl_att/
+```
+**注意:**
+- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应RFL的`infer_shape`。
+
+转换成功后,在目录下有三个文件:
+```
+/inference/rec_resnet_rfl_att/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+执行如下命令进行模型推理:
+
+```shell
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_resnet_rfl_att/' --rec_algorithm='RFL' --rec_image_shape='1,32,100'
+# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。
+```
+
+
+
+执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
+结果如下:
+```shell
+Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999927282333374)
+```
+
+**注意**:
+
+- 训练上述模型采用的图像分辨率是[1,32,100],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。
+- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。
+- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中RFL的预处理为您的预处理方法。
+
+
+
+### 4.2 C++推理部署
+
+由于C++预处理后处理还未支持RFL,所以暂未支持
+
+
+### 4.3 Serving服务化部署
+
+暂不支持
+
+
+### 4.4 更多推理部署
+
+暂不支持
+
+
+## 5. FAQ
+
+
+## 引用
+
+```bibtex
+@article{2021Reciprocal,
+ title = {Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition},
+ author = {Jiang, H. and Xu, Y. and Cheng, Z. and Pu, S. and Niu, Y. and Ren, W. and Wu, F. and Tan, W. },
+ booktitle = {ICDAR},
+ year = {2021},
+ url = {https://arxiv.org/abs/2105.06229}
+}
+```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index bca22f784..f7ef7ad4b 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -76,6 +76,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [VisionLAN](./algorithm_rec_visionlan_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
- [x] [RobustScanner](./algorithm_rec_robustscanner_en.md)
+- [x] [RFL](./algorithm_rec_rfl_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -99,7 +100,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon |
-
+|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) |
diff --git a/doc/doc_en/algorithm_rec_rfl_en.md b/doc/doc_en/algorithm_rec_rfl_en.md
new file mode 100644
index 000000000..8f0adfbe3
--- /dev/null
+++ b/doc/doc_en/algorithm_rec_rfl_en.md
@@ -0,0 +1,143 @@
+# RFL
+
+- [1. Introduction](#1)
+- [2. Environment](#2)
+- [3. Model Training / Evaluation / Prediction](#3)
+ - [3.1 Training](#3-1)
+ - [3.2 Evaluation](#3-2)
+ - [3.3 Prediction](#3-3)
+- [4. Inference and Deployment](#4)
+ - [4.1 Python Inference](#4-1)
+ - [4.2 C++ Inference](#4-2)
+ - [4.3 Serving](#4-3)
+ - [4.4 More](#4-4)
+- [5. FAQ](#5)
+
+
+## 1. Introduction
+
+Paper:
+> [Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition](https://arxiv.org/abs/2105.06229.pdf)
+> Hui Jiang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Wenqi Ren, Fei Wu, and Wenming Tan
+> ICDAR, 2021
+
+Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows:
+
+|Model|Backbone|config|Acc|Download link|
+| --- | --- | --- | --- | --- |
+|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
+|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)|
+
+
+## 2. Environment
+Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
+
+
+
+## 3. Model Training / Evaluation / Prediction
+
+Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+
+Training:
+
+Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
+
+```
+#step1:train the CNT branch
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml
+
+#step2:joint training of CNT and Att branches
+#Single GPU training (long training period, not recommended)
+python3 tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+
+#Multi GPU training, specify the gpu number through the --gpus parameter
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+
+
+```
+
+Evaluation:
+
+```
+# GPU evaluation
+python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+Prediction:
+
+```
+# The configuration file used for prediction must match the training
+python3 tools/infer_rec.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model={path/to/weights}/best_accuracy
+```
+
+
+## 4. Inference and Deployment
+
+
+### 4.1 Python Inference
+First, the model saved during the RFL text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)) ), you can use the following command to convert:
+
+```
+python3 tools/export_model.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_resnet_rfl_att
+```
+
+**Note:**
+- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file.
+- If you modified the input size during training, please modify the `infer_shape` corresponding to NRTR in the `tools/export_model.py` file.
+
+After the conversion is successful, there are three files in the directory:
+```
+/inference/rec_resnet_rfl_att/
+ ├── inference.pdiparams
+ ├── inference.pdiparams.info
+ └── inference.pdmodel
+```
+
+
+For RFL text recognition model inference, the following commands can be executed:
+
+```
+python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_resnet_rfl_att/' --rec_algorithm='RFL' --rec_image_shape='1,32,100'
+```
+
+
+
+After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows:
+The result is as follows:
+```shell
+Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999927282333374)
+```
+
+
+### 4.2 C++ Inference
+
+Not supported
+
+
+### 4.3 Serving
+
+Not supported
+
+
+### 4.4 More
+
+Not supported
+
+
+## 5. FAQ
+
+## Citation
+
+```bibtex
+@article{2021Reciprocal,
+ title = {Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition},
+ author = {Jiang, H. and Xu, Y. and Cheng, Z. and Pu, S. and Niu, Y. and Ren, W. and Wu, F. and Tan, W. },
+ booktitle = {ICDAR},
+ year = {2021},
+ url = {https://arxiv.org/abs/2105.06229}
+}
+```
From c459b7256538b9e788c81c540b615f4fe7911b81 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Sat, 8 Oct 2022 11:20:36 +0800
Subject: [PATCH 08/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0RFL=20CNT=E5=88=86?=
=?UTF-8?q?=E6=94=AFinfer=E6=94=AF=E6=8C=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ppocr/modeling/heads/rec_rfl_head.py | 5 ++---
ppocr/postprocess/rec_postprocess.py | 10 ++++++----
tools/infer_rec.py | 14 ++++++++++----
3 files changed, 18 insertions(+), 11 deletions(-)
diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py
index b5452ec1c..1ded8cde9 100644
--- a/ppocr/modeling/heads/rec_rfl_head.py
+++ b/ppocr/modeling/heads/rec_rfl_head.py
@@ -103,7 +103,6 @@ class RFLHead(nn.Layer):
else:
seq_outputs = self.seq_head(seq_inputs, None,
self.batch_max_legnth)
+ return cnt_outputs, seq_outputs
else:
- seq_outputs = None
-
- return cnt_outputs, seq_outputs
+ return cnt_outputs
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index e754c950b..40ba5c208 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode):
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
- cnt_pred, preds = preds
- if preds is not None:
-
+ if len(preds) == 2:
+ cnt_pred, preds = preds
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
@@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode):
return text, label
else:
+ cnt_pred = preds
+ if isinstance(cnt_pred, paddle.Tensor):
+ cnt_pred = cnt_pred.numpy()
cnt_length = []
for lens in cnt_pred:
- length = round(paddle.sum(lens).item())
+ length = round(np.sum(lens))
cnt_length.append(length)
if label is None:
return cnt_length
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 14b14544e..cb8a6ec30 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -97,7 +97,8 @@ def main():
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
elif config['Architecture']['algorithm'] == "RobustScanner":
- op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
+ op[op_name][
+ 'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
@@ -136,9 +137,10 @@ def main():
if config['Architecture']['algorithm'] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0)
- img_metas = [paddle.to_tensor(valid_ratio),
- paddle.to_tensor(word_positons),
- ]
+ img_metas = [
+ paddle.to_tensor(valid_ratio),
+ paddle.to_tensor(word_positons),
+ ]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
@@ -160,6 +162,10 @@ def main():
"score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info, ensure_ascii=False)
+ elif isinstance(post_result, list) and isinstance(post_result[0],
+ int):
+ # for RFLearning CNT branch
+ info = str(post_result[0])
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
From 06e734a51bfb922ddb8e7def62ba6b6e4f1da253 Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Sun, 9 Oct 2022 10:02:41 +0800
Subject: [PATCH 09/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AF=B4=E6=98=8E?=
=?UTF-8?q?=E6=96=87=E6=A1=A3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
doc/doc_ch/algorithm_rec_rfl.md | 2 +-
doc/doc_en/algorithm_rec_rfl_en.md | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/doc/doc_ch/algorithm_rec_rfl.md b/doc/doc_ch/algorithm_rec_rfl.md
index 5135e77a1..0906d4570 100644
--- a/doc/doc_ch/algorithm_rec_rfl.md
+++ b/doc/doc_ch/algorithm_rec_rfl.md
@@ -41,7 +41,7 @@
### 3.1 模型训练
-请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。
+PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。
#### 启动训练
diff --git a/doc/doc_en/algorithm_rec_rfl_en.md b/doc/doc_en/algorithm_rec_rfl_en.md
index 8f0adfbe3..273210c6c 100644
--- a/doc/doc_en/algorithm_rec_rfl_en.md
+++ b/doc/doc_en/algorithm_rec_rfl_en.md
@@ -36,7 +36,7 @@ Please refer to ["Environment Preparation"](./environment_en.md) to configure th
## 3. Model Training / Evaluation / Prediction
-Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
+PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
From 483e50382627e807d0e1c6adad9438f00945b9cc Mon Sep 17 00:00:00 2001
From: zhiminzhang0830 <452516515@qq.com>
Date: Mon, 10 Oct 2022 12:12:47 +0800
Subject: [PATCH 10/10] =?UTF-8?q?=E9=80=9A=E8=BF=87=E5=8F=98=E9=87=8F?=
=?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=88=A4=E6=96=AD=E6=98=AF=E5=90=A6=E6=98=AF?=
=?UTF-8?q?visual?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
ppocr/losses/rec_rfl_loss.py | 18 +++++++++++-------
ppocr/postprocess/rec_postprocess.py | 21 +++++++++++----------
2 files changed, 22 insertions(+), 17 deletions(-)
diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py
index 0921406c1..be0f06d90 100644
--- a/ppocr/losses/rec_rfl_loss.py
+++ b/ppocr/losses/rec_rfl_loss.py
@@ -36,22 +36,26 @@ class RFLLoss(nn.Layer):
self.total_loss = {}
total_loss = 0.0
+ if isinstance(predicts, tuple) or isinstance(predicts, list):
+ cnt_outputs, seq_outputs = predicts
+ else:
+ cnt_outputs, seq_outputs = predicts, None
# batch [image, label, length, cnt_label]
- if predicts[0] is not None:
- cnt_loss = self.cnt_loss(predicts[0],
+ if cnt_outputs is not None:
+ cnt_loss = self.cnt_loss(cnt_outputs,
paddle.cast(batch[3], paddle.float32))
self.total_loss['cnt_loss'] = cnt_loss
total_loss += cnt_loss
- if predicts[1] is not None:
+ if seq_outputs is not None:
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
- batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
- 1].shape[1], predicts[1].shape[2]
- assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
+ batch_size, num_steps, num_classes = seq_outputs.shape[
+ 0], seq_outputs.shape[1], seq_outputs.shape[2]
+ assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
- inputs = predicts[1][:, :-1, :]
+ inputs = seq_outputs[:, :-1, :]
targets = targets[:, 1:]
inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 74f4e880b..59b5254e4 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode):
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
- if len(preds) == 2:
- cnt_pred, preds = preds
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- preds_idx = preds.argmax(axis=2)
- preds_prob = preds.max(axis=2)
+ # if seq_outputs is not None:
+ if isinstance(preds, tuple) or isinstance(preds, list):
+ cnt_outputs, seq_outputs = preds
+ if isinstance(seq_outputs, paddle.Tensor):
+ seq_outputs = seq_outputs.numpy()
+ preds_idx = seq_outputs.argmax(axis=2)
+ preds_prob = seq_outputs.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
@@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode):
return text, label
else:
- cnt_pred = preds
- if isinstance(cnt_pred, paddle.Tensor):
- cnt_pred = cnt_pred.numpy()
+ cnt_outputs = preds
+ if isinstance(cnt_outputs, paddle.Tensor):
+ cnt_outputs = cnt_outputs.numpy()
cnt_length = []
- for lens in cnt_pred:
+ for lens in cnt_outputs:
length = round(np.sum(lens))
cnt_length.append(length)
if label is None: