From 0002349df3baed66b5149698cbb8cb5d4c1df3ee Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Tue, 27 Sep 2022 10:54:31 +0800 Subject: [PATCH 01/10] add text recognition algorithm rflearning --- configs/rec/rec_resnet_rfl_att.yml | 113 ++++++ configs/rec/rec_resnet_rfl_visual.yml | 110 ++++++ ppocr/data/imaug/__init__.py | 3 +- ppocr/data/imaug/label_ops.py | 56 +++ ppocr/data/imaug/rec_img_aug.py | 42 ++- ppocr/losses/__init__.py | 3 +- ppocr/losses/rec_rfl_loss.py | 61 +++ ppocr/metrics/__init__.py | 4 +- ppocr/metrics/rec_metric.py | 40 +- ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/backbones/rec_resnet_rfl.py | 348 ++++++++++++++++++ ppocr/modeling/heads/__init__.py | 3 +- ppocr/modeling/heads/rec_att_head.py | 2 + ppocr/modeling/heads/rec_rfl_head.py | 109 ++++++ ppocr/modeling/necks/__init__.py | 4 +- ppocr/modeling/necks/rf_adaptor.py | 137 +++++++ ppocr/optimizer/__init__.py | 3 + ppocr/postprocess/__init__.py | 4 +- ppocr/postprocess/rec_postprocess.py | 86 +++++ .../configs/rec_resnet_rfl/rec_resnet_rfl.yml | 111 ++++++ .../rec_resnet_rfl/train_infer_python.txt | 53 +++ tools/export_model.py | 2 +- tools/infer/predict_rec.py | 16 + tools/program.py | 4 +- 24 files changed, 1301 insertions(+), 16 deletions(-) create mode 100644 configs/rec/rec_resnet_rfl_att.yml create mode 100644 configs/rec/rec_resnet_rfl_visual.yml create mode 100644 ppocr/losses/rec_rfl_loss.py create mode 100644 ppocr/modeling/backbones/rec_resnet_rfl.py create mode 100644 ppocr/modeling/heads/rec_rfl_head.py create mode 100644 ppocr/modeling/necks/rf_adaptor.py create mode 100644 test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml create mode 100644 test_tipc/configs/rec_resnet_rfl/train_infer_python.txt diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml new file mode 100644 index 000000000..f8332e082 --- /dev/null +++ b/configs/rec/rec_resnet_rfl_att.yml @@ -0,0 +1,113 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 20 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_resnet_rfl_att/ + save_epoch_step: 1 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 5000] + cal_metric_during_train: True + pretrained_model: ./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/rec_resnet_rfl.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.0 + clip_norm_global: 5.0 + lr: + name: Piecewise + decay_epochs : [3, 4, 5] + values : [0.001, 0.0003, 0.00009, 0.000027] + +Architecture: + model_type: rec + algorithm: RFL + in_channels: 1 + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 1.0 + model_name: large + Backbone: + name: ResNetRFL + use_cnt: True + use_seq: True + Neck: + name: RFAdaptor + use_v2s: True + use_s2v: True + Head: + name: RFLHead + in_channels: 512 + hidden_size: 256 + batch_max_legnth: 25 + out_channels: 38 + use_cnt: True + use_seq: True + +Loss: + name: RFLLoss + # ignore_index: 0 + +PostProcess: + name: RFLLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/rfl_dataset2/training + + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RFLLabelEncode: # Class handling label + - RFLRecResizeImg: + image_shape: [1, 32, 100] + padding: false + interpolation: 2 + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 64 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/rfl_dataset2/evaluation + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RFLLabelEncode: # Class handling label + - RFLRecResizeImg: + image_shape: [1, 32, 100] + padding: false + interpolation: 2 + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 8 diff --git a/configs/rec/rec_resnet_rfl_visual.yml b/configs/rec/rec_resnet_rfl_visual.yml new file mode 100644 index 000000000..438d2ef0c --- /dev/null +++ b/configs/rec/rec_resnet_rfl_visual.yml @@ -0,0 +1,110 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 20 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_resnet_rfl_visual/ + save_epoch_step: 1 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 5000] + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/rec_resnet_rfl_visual.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.0 + clip_norm_global: 5.0 + lr: + name: Piecewise + decay_epochs : [3, 4, 5] + values : [0.001, 0.0003, 0.00009, 0.000027] + +Architecture: + model_type: rec + algorithm: RFL + in_channels: 1 + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 1.0 + model_name: large + Backbone: + name: ResNetRFL + use_cnt: True + use_seq: False + Neck: + name: RFAdaptor + use_v2s: False + use_s2v: False + Head: + name: RFLHead + in_channels: 512 + hidden_size: 256 + batch_max_legnth: 25 + out_channels: 38 + use_cnt: True + use_seq: False +Loss: + name: RFLLoss + +PostProcess: + name: RFLLabelDecode + +Metric: + name: CNTMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/rfl_dataset2/training + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RFLLabelEncode: # Class handling label + - RFLRecResizeImg: + image_shape: [1, 32, 100] + padding: false + interpolation: 2 + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 64 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/rfl_dataset2/evaluation + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RFLLabelEncode: # Class handling label + - RFLRecResizeImg: + image_shape: [1, 32, 100] + padding: false + interpolation: 2 + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 8 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 863988ccc..db0a489d5 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ - ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg + ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \ + RFLRecResizeImg from .ssl_img_aug import SSLRotateResize from .randaugment import RandAugment from .copy_paste import CopyPaste diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index dbfb93176..590caf1c4 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode): return idx +class RFLLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length, + character_dict_path=None, + use_space_char=False, + **kwargs): + super(RFLLabelEncode, self).__init__( + max_text_length, character_dict_path, use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def encode_cnt(self, text): + cnt_label = [0.0] * len(self.character) + for char_ in text: + cnt_label[char_] += 1 + return np.array(cnt_label) + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + if text is None: + return None + if len(text) >= self.max_text_len: + return None + cnt_label = self.encode_cnt(text) + data['length'] = np.array(len(text)) + text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len + - len(text) - 2) + if len(text) != self.max_text_len: + return None + data['label'] = np.array(text) + data['cnt_label'] = cnt_label + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx + + class SEEDLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 89022d85a..e22153bde 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -237,6 +237,33 @@ class VLRecResizeImg(object): return data +class RFLRecResizeImg(object): + def __init__(self, image_shape, padding=True, interpolation=1, **kwargs): + self.image_shape = image_shape + self.padding = padding + + self.interpolation = interpolation + if self.interpolation == 0: + self.interpolation = cv2.INTER_NEAREST + elif self.interpolation == 1: + self.interpolation = cv2.INTER_LINEAR + elif self.interpolation == 2: + self.interpolation = cv2.INTER_CUBIC + elif self.interpolation == 3: + self.interpolation = cv2.INTER_AREA + else: + raise Exception("Unsupported interpolation type !!!") + + def __call__(self, data): + img = data['image'] + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + norm_img, valid_ratio = resize_norm_img( + img, self.image_shape, self.padding, self.interpolation) + data['image'] = norm_img + data['valid_ratio'] = valid_ratio + return data + + class SRNRecResizeImg(object): def __init__(self, image_shape, num_heads, max_text_length, **kwargs): self.image_shape = image_shape @@ -414,8 +441,13 @@ class SVTRRecResizeImg(object): data['valid_ratio'] = valid_ratio return data + class RobustScannerRecResizeImg(object): - def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs): + def __init__(self, + image_shape, + max_text_length, + width_downsample_ratio=0.25, + **kwargs): self.image_shape = image_shape self.width_downsample_ratio = width_downsample_ratio self.max_text_length = max_text_length @@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object): data['word_positons'] = word_positons return data + def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] @@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): return padding_im, resize_shape, pad_shape, valid_ratio -def resize_norm_img(img, image_shape, padding=True): +def resize_norm_img(img, + image_shape, + padding=True, + interpolation=cv2.INTER_LINEAR): imgC, imgH, imgW = image_shape h = img.shape[0] w = img.shape[1] if not padding: resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + img, (imgW, imgH), interpolation=interpolation) resized_w = imgW else: ratio = w / float(h) diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 02525b3d5..ffee0a93e 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss from .rec_multi_loss import MultiLoss from .rec_vl_loss import VLLoss from .rec_spin_att_loss import SPINAttentionLoss +from .rec_rfl_loss import RFLLoss # cls loss from .cls_loss import ClsLoss @@ -69,7 +70,7 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss', 'CTLoss' + 'SLALoss', 'CTLoss', 'RFLLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py new file mode 100644 index 000000000..8e9d7d039 --- /dev/null +++ b/ppocr/losses/rec_rfl_loss.py @@ -0,0 +1,61 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + +from .basic_loss import CELoss, DistanceLoss + + +class RFLLoss(nn.Layer): + def __init__(self, ignore_index=-100, **kwargs): + super().__init__() + + self.cnt_loss = nn.MSELoss(**kwargs) + self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index) + + def forward(self, predicts, batch): + + self.total_loss = {} + total_loss = 0.0 + # batch [image, label, length, cnt_label] + if predicts[0] is not None: + cnt_loss = self.cnt_loss(predicts[0], + paddle.cast(batch[3], paddle.float32)) + self.total_loss['cnt_loss'] = cnt_loss + total_loss += cnt_loss + + if predicts[1] is not None: + targets = batch[1].astype("int64") + label_lengths = batch[2].astype('int64') + batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[ + 1].shape[1], predicts[1].shape[2] + assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \ + "The target's shape and inputs's shape is [N, d] and [N, num_steps]" + + inputs = predicts[1][:, :-1, :] + targets = targets[:, 1:] + + inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]]) + targets = paddle.reshape(targets, [-1]) + seq_loss = self.seq_loss(inputs, targets) + self.total_loss['seq_loss'] = seq_loss + total_loss += seq_loss + + self.total_loss['loss'] = total_loss + return self.total_loss diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index a39d0a464..20aea8b59 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -22,7 +22,7 @@ import copy __all__ = ["build_metric"] from .det_metric import DetMetric, DetFCEMetric -from .rec_metric import RecMetric +from .rec_metric import RecMetric, CNTMetric from .cls_metric import ClsMetric from .e2e_metric import E2EMetric from .distillation_metric import DistillationMetric @@ -38,7 +38,7 @@ def build_metric(config): support_dict = [ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', - 'VQAReTokenMetric', 'SRMetric', 'CTMetric' + 'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 986397811..34d6ff3a3 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein import string - class RecMetric(object): def __init__(self, main_indicator='acc', @@ -74,3 +73,42 @@ class RecMetric(object): self.correct_num = 0 self.all_num = 0 self.norm_edit_dis = 0 + + +class CNTMetric(object): + def __init__(self, main_indicator='acc', **kwargs): + self.main_indicator = main_indicator + self.eps = 1e-5 + self.reset() + + def _normalize_text(self, text): + text = ''.join( + filter(lambda x: x in (string.digits + string.ascii_letters), text)) + return text.lower() + + def __call__(self, pred_label, *args, **kwargs): + preds, labels = pred_label + correct_num = 0 + all_num = 0 + for pred, target in zip(preds, labels): + if pred == target: + correct_num += 1 + all_num += 1 + self.correct_num += correct_num + self.all_num += all_num + return {'acc': correct_num / (all_num + self.eps), } + + def get_metric(self): + """ + return metrics { + 'acc': 0, + 'norm_edit_dis': 0, + } + """ + acc = 1.0 * self.correct_num / (self.all_num + self.eps) + self.reset() + return {'acc': acc} + + def reset(self): + self.correct_num = 0 + self.all_num = 0 diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 6fdcc4a75..84892fa9c 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -42,10 +42,11 @@ def build_backbone(config, model_type): from .rec_efficientb3_pren import EfficientNetb3_PREN from .rec_svtrnet import SVTRNet from .rec_vitstr import ViTSTR + from .rec_resnet_rfl import ResNetRFL support_dict = [ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', - 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32' + 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_resnet_rfl.py b/ppocr/modeling/backbones/rec_resnet_rfl.py new file mode 100644 index 000000000..fd317c6ea --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_rfl.py @@ -0,0 +1,348 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal + +kaiming_init_ = KaimingNormal() +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class BasicBlock(nn.Layer): + """Res-net Basic Block""" + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + norm_type='BN', + **kwargs): + """ + Args: + inplanes (int): input channel + planes (int): channels of the middle feature + stride (int): stride of the convolution + downsample (int): type of the down_sample + norm_type (str): type of the normalization + **kwargs (None): backup parameter + """ + super(BasicBlock, self).__init__() + self.conv1 = self._conv3x3(inplanes, planes) + self.bn1 = nn.BatchNorm(planes) + self.conv2 = self._conv3x3(planes, planes) + self.bn2 = nn.BatchNorm(planes) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def _conv3x3(self, in_planes, out_planes, stride=1): + + return nn.Conv2D( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias_attr=False) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + + return out + + +class ResNetRFL(nn.Layer): + def __init__(self, + in_channels, + out_channels=512, + use_cnt=True, + use_seq=True): + """ + + Args: + in_channels (int): input channel + out_channels (int): output channel + """ + super(ResNetRFL, self).__init__() + assert use_cnt or use_seq + self.use_cnt, self.use_seq = use_cnt, use_seq + self.backbone = RFLBase(in_channels) + + self.out_channels = out_channels + self.out_channels_block = [ + int(self.out_channels / 4), int(self.out_channels / 2), + self.out_channels, self.out_channels + ] + block = BasicBlock + layers = [1, 2, 5, 3] + self.inplanes = int(self.out_channels // 2) + + self.relu = nn.ReLU() + if self.use_seq: + self.maxpool3 = nn.MaxPool2D( + kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.layer3 = self._make_layer( + block, self.out_channels_block[2], layers[2], stride=1) + self.conv3 = nn.Conv2D( + self.out_channels_block[2], + self.out_channels_block[2], + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn3 = nn.BatchNorm(self.out_channels_block[2]) + + self.layer4 = self._make_layer( + block, self.out_channels_block[3], layers[3], stride=1) + self.conv4_1 = nn.Conv2D( + self.out_channels_block[3], + self.out_channels_block[3], + kernel_size=2, + stride=(2, 1), + padding=(0, 1), + bias_attr=False) + self.bn4_1 = nn.BatchNorm(self.out_channels_block[3]) + self.conv4_2 = nn.Conv2D( + self.out_channels_block[3], + self.out_channels_block[3], + kernel_size=2, + stride=1, + padding=0, + bias_attr=False) + self.bn4_2 = nn.BatchNorm(self.out_channels_block[3]) + + if self.use_cnt: + self.inplanes = int(self.out_channels // 2) + self.v_maxpool3 = nn.MaxPool2D( + kernel_size=2, stride=(2, 1), padding=(0, 1)) + self.v_layer3 = self._make_layer( + block, self.out_channels_block[2], layers[2], stride=1) + self.v_conv3 = nn.Conv2D( + self.out_channels_block[2], + self.out_channels_block[2], + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.v_bn3 = nn.BatchNorm(self.out_channels_block[2]) + + self.v_layer4 = self._make_layer( + block, self.out_channels_block[3], layers[3], stride=1) + self.v_conv4_1 = nn.Conv2D( + self.out_channels_block[3], + self.out_channels_block[3], + kernel_size=2, + stride=(2, 1), + padding=(0, 1), + bias_attr=False) + self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3]) + self.v_conv4_2 = nn.Conv2D( + self.out_channels_block[3], + self.out_channels_block[3], + kernel_size=2, + stride=1, + padding=0, + bias_attr=False) + self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias_attr=False), + nn.BatchNorm(planes * block.expansion), ) + + layers = list() + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, inputs): + x_1 = self.backbone(inputs) + + if self.use_cnt: + v_x = self.v_maxpool3(x_1) + v_x = self.v_layer3(v_x) + v_x = self.v_conv3(v_x) + v_x = self.v_bn3(v_x) + visual_feature_2 = self.relu(v_x) + + v_x = self.v_layer4(visual_feature_2) + v_x = self.v_conv4_1(v_x) + v_x = self.v_bn4_1(v_x) + v_x = self.relu(v_x) + v_x = self.v_conv4_2(v_x) + v_x = self.v_bn4_2(v_x) + visual_feature_3 = self.relu(v_x) + else: + visual_feature_3 = None + if self.use_seq: + x = self.maxpool3(x_1) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x_2 = self.relu(x) + + x = self.layer4(x_2) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x_3 = self.relu(x) + else: + x_3 = None + + return [visual_feature_3, x_3] + + +class ResNetBase(nn.Layer): + def __init__(self, in_channels, out_channels, block, layers): + super(ResNetBase, self).__init__() + + self.out_channels_block = [ + int(out_channels / 4), int(out_channels / 2), out_channels, + out_channels + ] + + self.inplanes = int(out_channels / 8) + self.conv0_1 = nn.Conv2D( + in_channels, + int(out_channels / 16), + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn0_1 = nn.BatchNorm(int(out_channels / 16)) + self.conv0_2 = nn.Conv2D( + int(out_channels / 16), + self.inplanes, + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn0_2 = nn.BatchNorm(self.inplanes) + self.relu = nn.ReLU() + + self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.layer1 = self._make_layer(block, self.out_channels_block[0], + layers[0]) + self.conv1 = nn.Conv2D( + self.out_channels_block[0], + self.out_channels_block[0], + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn1 = nn.BatchNorm(self.out_channels_block[0]) + + self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.layer2 = self._make_layer( + block, self.out_channels_block[1], layers[1], stride=1) + self.conv2 = nn.Conv2D( + self.out_channels_block[1], + self.out_channels_block[1], + kernel_size=3, + stride=1, + padding=1, + bias_attr=False) + self.bn2 = nn.BatchNorm(self.out_channels_block[1]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2D( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias_attr=False), + nn.BatchNorm(planes * block.expansion), ) + + layers = list() + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + return x + + +class RFLBase(nn.Layer): + """ Reciprocal feature learning share backbone network""" + + def __init__(self, in_channels, out_channels=512): + super(RFLBase, self).__init__() + self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock, + [1, 2, 5, 3]) + + def forward(self, inputs): + return self.ConvNet(inputs) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 751757e5f..ba180566c 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -38,6 +38,7 @@ def build_head(config): from .rec_abinet_head import ABINetHead from .rec_robustscanner_head import RobustScannerHead from .rec_visionlan_head import VLHead + from .rec_rfl_head import RFLHead # cls head from .cls_head import ClsHead @@ -53,7 +54,7 @@ def build_head(config): 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', - 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head' + 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead' ] #table head diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index ab8b119fe..d5cf1cd16 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer): else: targets = paddle.zeros(shape=[batch_size], dtype="int32") probs = None + char_onehots = None + alpha = None for i in range(num_steps): char_onehots = self._char_to_onehot( diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py new file mode 100644 index 000000000..b5452ec1c --- /dev/null +++ b/ppocr/modeling/heads/rec_rfl_head.py @@ -0,0 +1,109 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py +""" +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal + +from .rec_att_head import AttentionLSTM + +kaiming_init_ = KaimingNormal() +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class CNTHead(nn.Layer): + def __init__(self, + embed_size=512, + encode_length=26, + out_channels=38, + **kwargs): + super(CNTHead, self).__init__() + + self.out_channels = out_channels + + self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False) + self.Prediction_visual = nn.Linear(encode_length * embed_size, + self.out_channels) + + def forward(self, visual_feature): + + b, c, h, w = visual_feature.shape + visual_feature = visual_feature.reshape([b, c, h * w]).transpose( + [0, 2, 1]) + visual_feature_num = self.Wv_fusion(visual_feature) # batch * 26 * 512 + b, n, c = visual_feature_num.shape + # using visual feature directly calculate the text length + visual_feature_num = visual_feature_num.reshape([b, n * c]) + prediction_visual = self.Prediction_visual(visual_feature_num) + + return prediction_visual + + +class RFLHead(nn.Layer): + def __init__(self, + in_channels=512, + hidden_size=256, + batch_max_legnth=25, + out_channels=38, + use_cnt=True, + use_seq=True, + **kwargs): + + super(RFLHead, self).__init__() + assert use_cnt or use_seq + self.use_cnt = use_cnt + self.use_seq = use_seq + if self.use_cnt: + self.cnt_head = CNTHead( + embed_size=in_channels, + encode_length=batch_max_legnth + 1, + out_channels=out_channels, + **kwargs) + if self.use_seq: + self.seq_head = AttentionLSTM( + in_channels=in_channels, + out_channels=out_channels, + hidden_size=hidden_size, + **kwargs) + self.batch_max_legnth = batch_max_legnth + self.num_class = out_channels + self.apply(self.init_weights) + + def init_weights(self, m): + if isinstance(m, nn.Linear): + kaiming_init_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + + def forward(self, x, targets=None): + cnt_inputs, seq_inputs = x + if self.use_cnt: + cnt_outputs = self.cnt_head(cnt_inputs) + else: + cnt_outputs = None + if self.use_seq: + if self.training: + seq_outputs = self.seq_head(seq_inputs, targets[0], + self.batch_max_legnth) + else: + seq_outputs = self.seq_head(seq_inputs, None, + self.batch_max_legnth) + else: + seq_outputs = None + + return cnt_outputs, seq_outputs diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index c7e8dd068..a94d223a1 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -27,9 +27,11 @@ def build_neck(config): from .pren_fpn import PRENFPN from .csp_pan import CSPPAN from .ct_fpn import CTFPN + from .rf_adaptor import RFAdaptor support_dict = [ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', - 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN' + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN', + 'RFAdaptor' ] module_name = config.pop('name') diff --git a/ppocr/modeling/necks/rf_adaptor.py b/ppocr/modeling/necks/rf_adaptor.py new file mode 100644 index 000000000..94590127b --- /dev/null +++ b/ppocr/modeling/necks/rf_adaptor.py @@ -0,0 +1,137 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py +""" + +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal + +kaiming_init_ = KaimingNormal() +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +class S2VAdaptor(nn.Layer): + """ Semantic to Visual adaptation module""" + + def __init__(self, in_channels=512): + super(S2VAdaptor, self).__init__() + + self.in_channels = in_channels # 512 + + # feature strengthen module, channel attention + self.channel_inter = nn.Linear( + self.in_channels, self.in_channels, bias_attr=False) + self.channel_bn = nn.BatchNorm1D(self.in_channels) + self.channel_act = nn.ReLU() + self.apply(self.init_weights) + + def init_weights(self, m): + if isinstance(m, nn.Conv2D): + kaiming_init_(m.weight) + if isinstance(m, nn.Conv2D) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm1D)): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, semantic): + semantic_source = semantic # batch, channel, height, width + + # feature transformation + semantic = semantic.squeeze(2).transpose( + [0, 2, 1]) # batch, width, channel + channel_att = self.channel_inter(semantic) # batch, width, channel + channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width + channel_bn = self.channel_bn(channel_att) # batch, channel, width + channel_att = self.channel_act(channel_bn) # batch, channel, width + + # Feature enhancement + channel_output = semantic_source * channel_att.unsqueeze( + -2) # batch, channel, 1, width + + return channel_output + + +class V2SAdaptor(nn.Layer): + """ Visual to Semantic adaptation module""" + + def __init__(self, in_channels=512, return_mask=False): + super(V2SAdaptor, self).__init__() + + # parameter initialization + self.in_channels = in_channels + self.return_mask = return_mask + + # output transformation + self.channel_inter = nn.Linear( + self.in_channels, self.in_channels, bias_attr=False) + self.channel_bn = nn.BatchNorm1D(self.in_channels) + self.channel_act = nn.ReLU() + + def forward(self, visual): + # Feature enhancement + visual = visual.squeeze(2).transpose([0, 2, 1]) # batch, width, channel + channel_att = self.channel_inter(visual) # batch, width, channel + channel_att = channel_att.transpose([0, 2, 1]) # batch, channel, width + channel_bn = self.channel_bn(channel_att) # batch, channel, width + channel_att = self.channel_act(channel_bn) # batch, channel, width + + # size alignment + channel_output = channel_att.unsqueeze(-2) # batch, width, channel + + if self.return_mask: + return channel_output, channel_att + return channel_output + + +class RFAdaptor(nn.Layer): + def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs): + super(RFAdaptor, self).__init__() + if use_v2s is True: + self.neck_v2s = V2SAdaptor(in_channels=in_channels, **kwargs) + else: + self.neck_v2s = None + if use_s2v is True: + self.neck_s2v = S2VAdaptor(in_channels=in_channels, **kwargs) + else: + self.neck_s2v = None + self.out_channels = in_channels + + def forward(self, x): + visual_feature, rcg_feature = x + if visual_feature is not None: + batch, source_channels, v_source_height, v_source_width = visual_feature.shape + visual_feature = visual_feature.reshape( + [batch, source_channels, 1, v_source_height * v_source_width]) + + if self.neck_v2s is not None: + v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature) + else: + v_rcg_feature = rcg_feature + + if self.neck_s2v is not None: + v_visual_feature = visual_feature + self.neck_s2v(rcg_feature) + else: + v_visual_feature = visual_feature + if v_rcg_feature is not None: + batch, source_channels, source_height, source_width = v_rcg_feature.shape + v_rcg_feature = v_rcg_feature.reshape( + [batch, source_channels, 1, source_height * source_width]) + + v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1]) + return v_visual_feature, v_rcg_feature diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py index a6bd2ebb4..b92954c9c 100644 --- a/ppocr/optimizer/__init__.py +++ b/ppocr/optimizer/__init__.py @@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model): if 'clip_norm' in config: clip_norm = config.pop('clip_norm') grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm) + elif 'clip_norm_global' in config: + clip_norm = config.pop('clip_norm_global') + grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm) else: grad_clip = None optim = getattr(optimizer, optim_name)(learning_rate=lr, diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 35b7a6800..b5715967b 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ - SPINLabelDecode, VLLabelDecode + SPINLabelDecode, VLLabelDecode, RFLLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess @@ -49,7 +49,7 @@ def build_post_process(config, global_config=None): 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', - 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess' + 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 749060a05..e754c950b 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -242,6 +242,92 @@ class AttnLabelDecode(BaseRecLabelDecode): return idx +class RFLLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(RFLLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def add_special_char(self, dict_character): + self.beg_str = "sos" + self.end_str = "eos" + dict_character = dict_character + dict_character = [self.beg_str] + dict_character + [self.end_str] + return dict_character + + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + [beg_idx, end_idx] = self.get_ignored_tokens() + batch_size = len(text_index) + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if int(text_index[batch_idx][idx]) == int(end_idx): + break + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list).tolist())) + return result_list + + def __call__(self, preds, label=None, *args, **kwargs): + cnt_pred, preds = preds + if preds is not None: + + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=2) + preds_prob = preds.max(axis=2) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + else: + cnt_length = [] + for lens in cnt_pred: + length = round(paddle.sum(lens).item()) + cnt_length.append(length) + if label is None: + return cnt_length + label = self.decode(label, is_remove_duplicate=False) + length = [len(res[0]) for res in label] + return cnt_length, length + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx + + class SEEDLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml b/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml new file mode 100644 index 000000000..b4f18f5c0 --- /dev/null +++ b/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml @@ -0,0 +1,111 @@ +Global: + use_gpu: True + epoch_num: 6 + log_smooth_window: 20 + print_batch_step: 50 + save_model_dir: ./output/rec/rec_resnet_rfl/ + save_epoch_step: 1 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 5000] + cal_metric_during_train: False + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/rec_resnet_rfl.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + weight_decay: 0.0 + clip_norm_global: 5.0 + lr: + name: Piecewise + decay_epochs : [3, 4, 5] + values : [0.001, 0.0003, 0.00009, 0.000027] + +Architecture: + model_type: rec + algorithm: RFL + in_channels: 1 + Transform: + name: TPS + num_fiducial: 20 + loc_lr: 1.0 + model_name: large + Backbone: + name: ResNetRFL + use_cnt: True + use_seq: True + Neck: + name: RFAdaptor + use_v2s: True + use_s2v: True + Head: + name: RFLHead + in_channels: 512 + hidden_size: 256 + batch_max_legnth: 25 + out_channels: 38 + use_cnt: True + use_seq: True + +Loss: + name: RFLLoss + +PostProcess: + name: RFLLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data/ + label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RFLLabelEncode: # Class handling label + - RFLRecResizeImg: + image_shape: [1, 32, 100] + interpolation: 2 + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 64 + drop_last: True + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ic15_data + label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RFLLabelEncode: # Class handling label + - RFLRecResizeImg: + image_shape: [1, 32, 100] + interpolation: 2 + - KeepKeys: + keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 8 diff --git a/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt b/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt new file mode 100644 index 000000000..091e962b2 --- /dev/null +++ b/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:rec_resnet_rfl +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./inference/rec_inference +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/rec_resnet_rfl_train/best_accuracy +infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o +infer_quant:False +inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5 +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--rec_model_dir: +--image_dir:./inference/rec_inference +--save_log_path:./test/output/ +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[1,32,100]}] diff --git a/tools/export_model.py b/tools/export_model.py index 8610df83e..9c23060ee 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -99,7 +99,7 @@ def export_single_model(model, ] # print([None, 3, 32, 128]) model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] in ["NRTR", "SPIN"]: + elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']: other_shape = [ paddle.static.InputSpec( shape=[None, 1, 32, 100], dtype="float32"), diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 176e2c68e..697e9da43 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -100,6 +100,12 @@ class TextRecognizer(object): "use_space_char": args.use_space_char, "rm_symbol": True } + elif self.rec_algorithm == 'RFL': + postprocess_params = { + 'name': 'RFLLabelDecode', + "character_dict_path": None, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) @@ -143,6 +149,16 @@ class TextRecognizer(object): else: norm_img = norm_img.astype(np.float32) / 128. - 1. return norm_img + elif self.rec_algorithm == 'RFL': + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_CUBIC) + resized_image = resized_image.astype('float32') + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image assert imgC == img.shape[2] imgW = int((imgH * max_wh_ratio)) diff --git a/tools/program.py b/tools/program.py index 9117d51b9..129e926be 100755 --- a/tools/program.py +++ b/tools/program.py @@ -217,7 +217,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", - "RobustScanner" + "RobustScanner", "RFL" ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': @@ -625,7 +625,7 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner', 'CT' + 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL' ] if use_xpu: From 394ee0117767a28e22964a66ace799f14e3e75b0 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Wed, 28 Sep 2022 17:40:43 +0800 Subject: [PATCH 02/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/rec/rec_resnet_rfl_att.yml | 4 ++-- configs/rec/rec_resnet_rfl_visual.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml index f8332e082..3c9c1f886 100644 --- a/configs/rec/rec_resnet_rfl_att.yml +++ b/configs/rec/rec_resnet_rfl_att.yml @@ -72,7 +72,7 @@ Metric: Train: dataset: name: LMDBDataSet - data_dir: ./train_data/rfl_dataset2/training + data_dir: ./train_data/data_lmdb_release/training transforms: - DecodeImage: # load image @@ -94,7 +94,7 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: ./train_data/rfl_dataset2/evaluation + data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image img_mode: BGR diff --git a/configs/rec/rec_resnet_rfl_visual.yml b/configs/rec/rec_resnet_rfl_visual.yml index 438d2ef0c..5eaea08ce 100644 --- a/configs/rec/rec_resnet_rfl_visual.yml +++ b/configs/rec/rec_resnet_rfl_visual.yml @@ -70,7 +70,7 @@ Metric: Train: dataset: name: LMDBDataSet - data_dir: ./train_data/rfl_dataset2/training + data_dir: ./train_data/data_lmdb_release/training transforms: - DecodeImage: # load image img_mode: BGR @@ -91,7 +91,7 @@ Train: Eval: dataset: name: LMDBDataSet - data_dir: ./train_data/rfl_dataset2/evaluation + data_dir: ./train_data/data_lmdb_release/evaluation transforms: - DecodeImage: # load image img_mode: BGR From 5a380afb1ecb49a44fbfbe011dc1e1827d43affa Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Wed, 28 Sep 2022 17:42:13 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=BC=95=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/losses/rec_rfl_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py index 8e9d7d039..0921406c1 100644 --- a/ppocr/losses/rec_rfl_loss.py +++ b/ppocr/losses/rec_rfl_loss.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +This code is refer from: +https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_common/models/loss/cross_entropy_loss.py +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function From b580fa0517bec79b6fceab2262d125376da564d1 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Wed, 28 Sep 2022 17:43:00 +0800 Subject: [PATCH 04/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8?= =?UTF-8?q?=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/metrics/rec_metric.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index 34d6ff3a3..4758e71d0 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -81,11 +81,6 @@ class CNTMetric(object): self.eps = 1e-5 self.reset() - def _normalize_text(self, text): - text = ''.join( - filter(lambda x: x in (string.digits + string.ascii_letters), text)) - return text.lower() - def __call__(self, pred_label, *args, **kwargs): preds, labels = pred_label correct_num = 0 @@ -102,7 +97,6 @@ class CNTMetric(object): """ return metrics { 'acc': 0, - 'norm_edit_dis': 0, } """ acc = 1.0 * self.correct_num / (self.all_num + self.eps) From 154f42f1b07f273fd2ebb550c7074934c81d16e8 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Wed, 28 Sep 2022 17:44:13 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E6=8E=A8=E7=90=86=E6=97=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0softmax?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/modeling/heads/rec_att_head.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py index d5cf1cd16..6349ee0c2 100644 --- a/ppocr/modeling/heads/rec_att_head.py +++ b/ppocr/modeling/heads/rec_att_head.py @@ -169,7 +169,8 @@ class AttentionLSTM(nn.Layer): next_input = probs_step.argmax(axis=1) targets = next_input - + if not self.training: + probs = paddle.nn.functional.softmax(probs, axis=2) return probs From 035d7e39069e9f18109812401a5cc39c6a90b016 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Wed, 28 Sep 2022 17:50:28 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=A9=BA=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/rec/rec_resnet_rfl_att.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml index 3c9c1f886..b9fb74176 100644 --- a/configs/rec/rec_resnet_rfl_att.yml +++ b/configs/rec/rec_resnet_rfl_att.yml @@ -73,7 +73,6 @@ Train: dataset: name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training - transforms: - DecodeImage: # load image img_mode: BGR From 6735cced20bf905b88bb9fe9d63f811cca896481 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Thu, 29 Sep 2022 14:21:47 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0RFL=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/doc_ch/algorithm_overview.md | 2 + doc/doc_ch/algorithm_rec_rfl.md | 161 ++++++++++++++++++++++++++++ doc/doc_en/algorithm_overview_en.md | 3 +- doc/doc_en/algorithm_rec_rfl_en.md | 143 ++++++++++++++++++++++++ 4 files changed, 308 insertions(+), 1 deletion(-) create mode 100644 doc/doc_ch/algorithm_rec_rfl.md create mode 100644 doc/doc_en/algorithm_rec_rfl_en.md diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index ecb0e9dfe..4351fdbfc 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -79,6 +79,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 - [x] [VisionLAN](./algorithm_rec_visionlan.md) - [x] [SPIN](./algorithm_rec_spin.md) - [x] [RobustScanner](./algorithm_rec_robustscanner.md) +- [x] [RFL](./algorithm_rec_rfl.md) 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -102,6 +103,7 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | |RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | +|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) | diff --git a/doc/doc_ch/algorithm_rec_rfl.md b/doc/doc_ch/algorithm_rec_rfl.md new file mode 100644 index 000000000..5135e77a1 --- /dev/null +++ b/doc/doc_ch/algorithm_rec_rfl.md @@ -0,0 +1,161 @@ +# 场景文本识别算法-RFL + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition](https://arxiv.org/abs/2105.06229.pdf) +> Hui Jiang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Wenqi Ren, Fei Wu, and Wenming Tan +> ICDAR, 2021 + + + +`RFL`使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下: + +|模型|骨干网络|配置文件|Acc|下载链接| +| --- | --- | --- | --- | --- | +|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)| +|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)| + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + + +### 3.1 模型训练 + +请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。 + +#### 启动训练 + + +具体地,在完成数据准备后,便可以启动训练,训练命令如下: +```shell +#step1:训练CNT分支 +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml + +#step2:联合训练CNT和Att分支,注意将pretrained_model的路径设置为本地路径。 +#单卡训练(训练周期长,不建议) +python3 tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_visual/best_accuracy + +#多卡训练,通过--gpus参数指定卡号 +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_visual/best_accuracy +``` + + +### 3.2 评估 + +可下载已训练完成的[模型文件](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar),使用如下命令进行评估: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy +``` + + +### 3.3 预测 + +使用如下命令进行单张图片预测: +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/infer_rec.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy +# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。 +``` + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) ),可以使用如下命令进行转换: + +```shell +# 注意将pretrained_model的路径设置为本地路径。 +python3 tools/export_model.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model=./output/rec/rec_resnet_rfl_att/best_accuracy Global.save_inference_dir=./inference/rec_resnet_rfl_att/ +``` +**注意:** +- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。 +- 如果您修改了训练时的输入大小,请修改`tools/export_model.py`文件中的对应RFL的`infer_shape`。 + +转换成功后,在目录下有三个文件: +``` +/inference/rec_resnet_rfl_att/ + ├── inference.pdiparams # 识别inference模型的参数文件 + ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略 + └── inference.pdmodel # 识别inference模型的program文件 +``` + +执行如下命令进行模型推理: + +```shell +python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_resnet_rfl_att/' --rec_algorithm='RFL' --rec_image_shape='1,32,100' +# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='./doc/imgs_words_en/'。 +``` + +![](../imgs_words_en/word_10.png) + +执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: +结果如下: +```shell +Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999927282333374) +``` + +**注意**: + +- 训练上述模型采用的图像分辨率是[1,32,100],需要通过参数`rec_image_shape`设置为您训练时的识别图像形状。 +- 在推理时需要设置参数`rec_char_dict_path`指定字典,如果您修改了字典,请修改该参数为您的字典文件。 +- 如果您修改了预处理方法,需修改`tools/infer/predict_rec.py`中RFL的预处理为您的预处理方法。 + + + +### 4.2 C++推理部署 + +由于C++预处理后处理还未支持RFL,所以暂未支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@article{2021Reciprocal, + title = {Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition}, + author = {Jiang, H. and Xu, Y. and Cheng, Z. and Pu, S. and Niu, Y. and Ren, W. and Wu, F. and Tan, W. }, + booktitle = {ICDAR}, + year = {2021}, + url = {https://arxiv.org/abs/2105.06229} +} +``` diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index bca22f784..f7ef7ad4b 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -76,6 +76,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): - [x] [VisionLAN](./algorithm_rec_visionlan_en.md) - [x] [SPIN](./algorithm_rec_spin_en.md) - [x] [RobustScanner](./algorithm_rec_robustscanner_en.md) +- [x] [RFL](./algorithm_rec_rfl_en.md) Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -99,7 +100,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | |SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | |RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | - +|RFL|ResNetRFL| 88.63% | rec_resnet_rfl_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar) | diff --git a/doc/doc_en/algorithm_rec_rfl_en.md b/doc/doc_en/algorithm_rec_rfl_en.md new file mode 100644 index 000000000..8f0adfbe3 --- /dev/null +++ b/doc/doc_en/algorithm_rec_rfl_en.md @@ -0,0 +1,143 @@ +# RFL + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition](https://arxiv.org/abs/2105.06229.pdf) +> Hui Jiang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Wenqi Ren, Fei Wu, and Wenming Tan +> ICDAR, 2021 + +Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets, the algorithm reproduction effect is as follows: + +|Model|Backbone|config|Acc|Download link| +| --- | --- | --- | --- | --- | +|RFL-CNT|ResNetRFL|[rec_resnet_rfl_visual.yml](../../configs/rec/rec_resnet_rfl_visual.yml)|93.40%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)| +|RFL-Att|ResNetRFL|[rec_resnet_rfl_att.yml](../../configs/rec/rec_resnet_rfl_att.yml)|88.63%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)| + + +## 2. Environment +Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. + + + +## 3. Model Training / Evaluation / Prediction + +Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. + +Training: + +Specifically, after the data preparation is completed, the training can be started. The training command is as follows: + +``` +#step1:train the CNT branch +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_visual.yml + +#step2:joint training of CNT and Att branches +#Single GPU training (long training period, not recommended) +python3 tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy + +#Multi GPU training, specify the gpu number through the --gpus parameter +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy + + +``` + +Evaluation: + +``` +# GPU evaluation +python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy +``` + +Prediction: + +``` +# The configuration file used for prediction must match the training +python3 tools/infer_rec.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model={path/to/weights}/best_accuracy +``` + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, the model saved during the RFL text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/contribution/rec_resnet_rfl.tar)) ), you can use the following command to convert: + +``` +python3 tools/export_model.py -c configs/rec/rec_resnet_rfl_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_resnet_rfl_att +``` + +**Note:** +- If you are training the model on your own dataset and have modified the dictionary file, please pay attention to modify the `character_dict_path` in the configuration file to the modified dictionary file. +- If you modified the input size during training, please modify the `infer_shape` corresponding to NRTR in the `tools/export_model.py` file. + +After the conversion is successful, there are three files in the directory: +``` +/inference/rec_resnet_rfl_att/ + ├── inference.pdiparams + ├── inference.pdiparams.info + └── inference.pdmodel +``` + + +For RFL text recognition model inference, the following commands can be executed: + +``` +python3 tools/infer/predict_rec.py --image_dir='./doc/imgs_words_en/word_10.png' --rec_model_dir='./inference/rec_resnet_rfl_att/' --rec_algorithm='RFL' --rec_image_shape='1,32,100' +``` + +![](../imgs_words_en/word_10.png) + +After executing the command, the prediction result (recognized text and score) of the image above is printed to the screen, an example is as follows: +The result is as follows: +```shell +Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999927282333374) +``` + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + +## Citation + +```bibtex +@article{2021Reciprocal, + title = {Reciprocal Feature Learning via Explicit and Implicit Tasks in Scene Text Recognition}, + author = {Jiang, H. and Xu, Y. and Cheng, Z. and Pu, S. and Niu, Y. and Ren, W. and Wu, F. and Tan, W. }, + booktitle = {ICDAR}, + year = {2021}, + url = {https://arxiv.org/abs/2105.06229} +} +``` From c459b7256538b9e788c81c540b615f4fe7911b81 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Sat, 8 Oct 2022 11:20:36 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E6=B7=BB=E5=8A=A0RFL=20CNT=E5=88=86?= =?UTF-8?q?=E6=94=AFinfer=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/modeling/heads/rec_rfl_head.py | 5 ++--- ppocr/postprocess/rec_postprocess.py | 10 ++++++---- tools/infer_rec.py | 14 ++++++++++---- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py index b5452ec1c..1ded8cde9 100644 --- a/ppocr/modeling/heads/rec_rfl_head.py +++ b/ppocr/modeling/heads/rec_rfl_head.py @@ -103,7 +103,6 @@ class RFLHead(nn.Layer): else: seq_outputs = self.seq_head(seq_inputs, None, self.batch_max_legnth) + return cnt_outputs, seq_outputs else: - seq_outputs = None - - return cnt_outputs, seq_outputs + return cnt_outputs diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index e754c950b..40ba5c208 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode): return result_list def __call__(self, preds, label=None, *args, **kwargs): - cnt_pred, preds = preds - if preds is not None: - + if len(preds) == 2: + cnt_pred, preds = preds if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds_idx = preds.argmax(axis=2) @@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode): return text, label else: + cnt_pred = preds + if isinstance(cnt_pred, paddle.Tensor): + cnt_pred = cnt_pred.numpy() cnt_length = [] for lens in cnt_pred: - length = round(paddle.sum(lens).item()) + length = round(np.sum(lens)) cnt_length.append(length) if label is None: return cnt_length diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 14b14544e..cb8a6ec30 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -97,7 +97,8 @@ def main(): elif config['Architecture']['algorithm'] == "SAR": op[op_name]['keep_keys'] = ['image', 'valid_ratio'] elif config['Architecture']['algorithm'] == "RobustScanner": - op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons'] + op[op_name][ + 'keep_keys'] = ['image', 'valid_ratio', 'word_positons'] else: op[op_name]['keep_keys'] = ['image'] transforms.append(op) @@ -136,9 +137,10 @@ def main(): if config['Architecture']['algorithm'] == "RobustScanner": valid_ratio = np.expand_dims(batch[1], axis=0) word_positons = np.expand_dims(batch[2], axis=0) - img_metas = [paddle.to_tensor(valid_ratio), - paddle.to_tensor(word_positons), - ] + img_metas = [ + paddle.to_tensor(valid_ratio), + paddle.to_tensor(word_positons), + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) if config['Architecture']['algorithm'] == "SRN": @@ -160,6 +162,10 @@ def main(): "score": float(post_result[key][0][1]), } info = json.dumps(rec_info, ensure_ascii=False) + elif isinstance(post_result, list) and isinstance(post_result[0], + int): + # for RFLearning CNT branch + info = str(post_result[0]) else: if len(post_result[0]) >= 2: info = post_result[0][0] + "\t" + str(post_result[0][1]) From 06e734a51bfb922ddb8e7def62ba6b6e4f1da253 Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Sun, 9 Oct 2022 10:02:41 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=AF=B4=E6=98=8E?= =?UTF-8?q?=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- doc/doc_ch/algorithm_rec_rfl.md | 2 +- doc/doc_en/algorithm_rec_rfl_en.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/doc_ch/algorithm_rec_rfl.md b/doc/doc_ch/algorithm_rec_rfl.md index 5135e77a1..0906d4570 100644 --- a/doc/doc_ch/algorithm_rec_rfl.md +++ b/doc/doc_ch/algorithm_rec_rfl.md @@ -41,7 +41,7 @@ ### 3.1 模型训练 -请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。 +PaddleOCR对代码进行了模块化,训练`RFL`识别模型时需要**更换配置文件**为`RFL`的[配置文件](../../configs/rec/rec_resnet_rfl_att.yml)。 #### 启动训练 diff --git a/doc/doc_en/algorithm_rec_rfl_en.md b/doc/doc_en/algorithm_rec_rfl_en.md index 8f0adfbe3..273210c6c 100644 --- a/doc/doc_en/algorithm_rec_rfl_en.md +++ b/doc/doc_en/algorithm_rec_rfl_en.md @@ -36,7 +36,7 @@ Please refer to ["Environment Preparation"](./environment_en.md) to configure th ## 3. Model Training / Evaluation / Prediction -Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. +PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**. Training: From 483e50382627e807d0e1c6adad9438f00945b9cc Mon Sep 17 00:00:00 2001 From: zhiminzhang0830 <452516515@qq.com> Date: Mon, 10 Oct 2022 12:12:47 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E9=80=9A=E8=BF=87=E5=8F=98=E9=87=8F?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E5=88=A4=E6=96=AD=E6=98=AF=E5=90=A6=E6=98=AF?= =?UTF-8?q?visual?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/losses/rec_rfl_loss.py | 18 +++++++++++------- ppocr/postprocess/rec_postprocess.py | 21 +++++++++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py index 0921406c1..be0f06d90 100644 --- a/ppocr/losses/rec_rfl_loss.py +++ b/ppocr/losses/rec_rfl_loss.py @@ -36,22 +36,26 @@ class RFLLoss(nn.Layer): self.total_loss = {} total_loss = 0.0 + if isinstance(predicts, tuple) or isinstance(predicts, list): + cnt_outputs, seq_outputs = predicts + else: + cnt_outputs, seq_outputs = predicts, None # batch [image, label, length, cnt_label] - if predicts[0] is not None: - cnt_loss = self.cnt_loss(predicts[0], + if cnt_outputs is not None: + cnt_loss = self.cnt_loss(cnt_outputs, paddle.cast(batch[3], paddle.float32)) self.total_loss['cnt_loss'] = cnt_loss total_loss += cnt_loss - if predicts[1] is not None: + if seq_outputs is not None: targets = batch[1].astype("int64") label_lengths = batch[2].astype('int64') - batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[ - 1].shape[1], predicts[1].shape[2] - assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \ + batch_size, num_steps, num_classes = seq_outputs.shape[ + 0], seq_outputs.shape[1], seq_outputs.shape[2] + assert len(targets.shape) == len(list(seq_outputs.shape)) - 1, \ "The target's shape and inputs's shape is [N, d] and [N, num_steps]" - inputs = predicts[1][:, :-1, :] + inputs = seq_outputs[:, :-1, :] targets = targets[:, 1:] inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]]) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 74f4e880b..59b5254e4 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -287,12 +287,13 @@ class RFLLabelDecode(BaseRecLabelDecode): return result_list def __call__(self, preds, label=None, *args, **kwargs): - if len(preds) == 2: - cnt_pred, preds = preds - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) + # if seq_outputs is not None: + if isinstance(preds, tuple) or isinstance(preds, list): + cnt_outputs, seq_outputs = preds + if isinstance(seq_outputs, paddle.Tensor): + seq_outputs = seq_outputs.numpy() + preds_idx = seq_outputs.argmax(axis=2) + preds_prob = seq_outputs.max(axis=2) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) if label is None: @@ -301,11 +302,11 @@ class RFLLabelDecode(BaseRecLabelDecode): return text, label else: - cnt_pred = preds - if isinstance(cnt_pred, paddle.Tensor): - cnt_pred = cnt_pred.numpy() + cnt_outputs = preds + if isinstance(cnt_outputs, paddle.Tensor): + cnt_outputs = cnt_outputs.numpy() cnt_length = [] - for lens in cnt_pred: + for lens in cnt_outputs: length = round(np.sum(lens)) cnt_length.append(length) if label is None: