diff --git a/configs/rec/rec_resnet_rfl_att.yml b/configs/rec/rec_resnet_rfl_att.yml
new file mode 100644
index 000000000..f8332e082
--- /dev/null
+++ b/configs/rec/rec_resnet_rfl_att.yml
@@ -0,0 +1,113 @@
+Global:
+  use_gpu: True
+  epoch_num: 6
+  log_smooth_window: 20
+  print_batch_step: 50
+  save_model_dir: ./output/rec/rec_resnet_rfl_att/
+  save_epoch_step: 1
+  # evaluation is run every 5000 iterations after the 4000th iteration
+  eval_batch_step: [0, 5000]
+  cal_metric_during_train: True
+  pretrained_model: ./pretrain_models/rec_resnet_rfl_visual/best_accuracy.pdparams 
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/imgs_words_en/word_10.png
+  # for data or label process
+  character_dict_path:
+  max_text_length: 25
+  infer_mode: False
+  use_space_char: False
+  save_res_path: ./output/rec/rec_resnet_rfl.txt
+
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.0
+  clip_norm_global: 5.0
+  lr:
+    name: Piecewise
+    decay_epochs : [3, 4, 5]
+    values : [0.001, 0.0003, 0.00009, 0.000027]
+
+Architecture:
+  model_type: rec
+  algorithm: RFL
+  in_channels: 1
+  Transform:
+    name: TPS
+    num_fiducial: 20
+    loc_lr: 1.0
+    model_name: large
+  Backbone:
+    name: ResNetRFL
+    use_cnt: True
+    use_seq: True
+  Neck:
+    name: RFAdaptor
+    use_v2s: True
+    use_s2v: True
+  Head:
+    name: RFLHead  
+    in_channels: 512
+    hidden_size: 256
+    batch_max_legnth: 25
+    out_channels: 38
+    use_cnt: True
+    use_seq: True
+
+Loss:
+  name: RFLLoss
+  # ignore_index: 0
+
+PostProcess:
+  name: RFLLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: LMDBDataSet
+    data_dir: ./train_data/rfl_dataset2/training
+
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - RFLLabelEncode: # Class handling label
+      - RFLRecResizeImg:
+          image_shape: [1, 32, 100]
+          padding: false
+          interpolation: 2
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+  loader:
+    shuffle: True
+    batch_size_per_card: 64
+    drop_last: True
+    num_workers: 8
+
+Eval:
+  dataset:
+    name: LMDBDataSet
+    data_dir: ./train_data/rfl_dataset2/evaluation
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - RFLLabelEncode: # Class handling label
+      - RFLRecResizeImg:
+          image_shape: [1, 32, 100]
+          padding: false
+          interpolation: 2
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 256
+    num_workers: 8
diff --git a/configs/rec/rec_resnet_rfl_visual.yml b/configs/rec/rec_resnet_rfl_visual.yml
new file mode 100644
index 000000000..438d2ef0c
--- /dev/null
+++ b/configs/rec/rec_resnet_rfl_visual.yml
@@ -0,0 +1,110 @@
+Global:
+  use_gpu: True
+  epoch_num: 6
+  log_smooth_window: 20
+  print_batch_step: 50
+  save_model_dir: ./output/rec/rec_resnet_rfl_visual/
+  save_epoch_step: 1
+  # evaluation is run every 5000 iterations after the 4000th iteration
+  eval_batch_step: [0, 5000]
+  cal_metric_during_train: False
+  pretrained_model:
+  checkpoints: 
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/imgs_words_en/word_10.png
+  # for data or label process
+  character_dict_path:
+  max_text_length: 25
+  infer_mode: False
+  use_space_char: False
+  save_res_path: ./output/rec/rec_resnet_rfl_visual.txt
+
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.0
+  clip_norm_global: 5.0
+  lr:
+    name: Piecewise
+    decay_epochs : [3, 4, 5]
+    values : [0.001, 0.0003, 0.00009, 0.000027]
+
+Architecture:
+  model_type: rec
+  algorithm: RFL
+  in_channels: 1
+  Transform:
+    name: TPS
+    num_fiducial: 20
+    loc_lr: 1.0
+    model_name: large
+  Backbone:
+    name: ResNetRFL
+    use_cnt: True
+    use_seq: False
+  Neck:
+    name: RFAdaptor
+    use_v2s: False
+    use_s2v: False
+  Head:
+    name: RFLHead  
+    in_channels: 512
+    hidden_size: 256
+    batch_max_legnth: 25
+    out_channels: 38
+    use_cnt: True
+    use_seq: False
+Loss:
+  name: RFLLoss
+
+PostProcess:
+  name: RFLLabelDecode
+
+Metric:
+  name: CNTMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: LMDBDataSet
+    data_dir: ./train_data/rfl_dataset2/training
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - RFLLabelEncode: # Class handling label
+      - RFLRecResizeImg:
+          image_shape: [1, 32, 100]
+          padding: false
+          interpolation: 2
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+  loader:
+    shuffle: True
+    batch_size_per_card: 64
+    drop_last: True
+    num_workers: 8
+
+Eval:
+  dataset:
+    name: LMDBDataSet
+    data_dir: ./train_data/rfl_dataset2/evaluation
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - RFLLabelEncode: # Class handling label
+      - RFLRecResizeImg:
+          image_shape: [1, 32, 100]
+          padding: false
+          interpolation: 2
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 256
+    num_workers: 8
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 863988ccc..db0a489d5 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -26,7 +26,8 @@ from .make_pse_gt import MakePseGt
 
 from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
     SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
-    ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg
+    ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
+    RFLRecResizeImg
 from .ssl_img_aug import SSLRotateResize
 from .randaugment import RandAugment
 from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index dbfb93176..590caf1c4 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -488,6 +488,62 @@ class AttnLabelEncode(BaseRecLabelEncode):
         return idx
 
 
+class RFLLabelEncode(BaseRecLabelEncode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length,
+                 character_dict_path=None,
+                 use_space_char=False,
+                 **kwargs):
+        super(RFLLabelEncode, self).__init__(
+            max_text_length, character_dict_path, use_space_char)
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+    def encode_cnt(self, text):
+        cnt_label = [0.0] * len(self.character)
+        for char_ in text:
+            cnt_label[char_] += 1
+        return np.array(cnt_label)
+
+    def __call__(self, data):
+        text = data['label']
+        text = self.encode(text)
+        if text is None:
+            return None
+        if len(text) >= self.max_text_len:
+            return None
+        cnt_label = self.encode_cnt(text)
+        data['length'] = np.array(len(text))
+        text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
+                                                               - len(text) - 2)
+        if len(text) != self.max_text_len:
+            return None
+        data['label'] = np.array(text)
+        data['cnt_label'] = cnt_label
+        return data
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+                          % beg_or_end
+        return idx
+
+
 class SEEDLabelEncode(BaseRecLabelEncode):
     """ Convert between text-label and text-index """
 
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 89022d85a..e22153bde 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -237,6 +237,33 @@ class VLRecResizeImg(object):
         return data
 
 
+class RFLRecResizeImg(object):
+    def __init__(self, image_shape, padding=True, interpolation=1, **kwargs):
+        self.image_shape = image_shape
+        self.padding = padding
+
+        self.interpolation = interpolation
+        if self.interpolation == 0:
+            self.interpolation = cv2.INTER_NEAREST
+        elif self.interpolation == 1:
+            self.interpolation = cv2.INTER_LINEAR
+        elif self.interpolation == 2:
+            self.interpolation = cv2.INTER_CUBIC
+        elif self.interpolation == 3:
+            self.interpolation = cv2.INTER_AREA
+        else:
+            raise Exception("Unsupported interpolation type !!!")
+
+    def __call__(self, data):
+        img = data['image']
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        norm_img, valid_ratio = resize_norm_img(
+            img, self.image_shape, self.padding, self.interpolation)
+        data['image'] = norm_img
+        data['valid_ratio'] = valid_ratio
+        return data
+
+
 class SRNRecResizeImg(object):
     def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
         self.image_shape = image_shape
@@ -414,8 +441,13 @@ class SVTRRecResizeImg(object):
         data['valid_ratio'] = valid_ratio
         return data
 
+
 class RobustScannerRecResizeImg(object):
-    def __init__(self, image_shape, max_text_length, width_downsample_ratio=0.25, **kwargs):
+    def __init__(self,
+                 image_shape,
+                 max_text_length,
+                 width_downsample_ratio=0.25,
+                 **kwargs):
         self.image_shape = image_shape
         self.width_downsample_ratio = width_downsample_ratio
         self.max_text_length = max_text_length
@@ -432,6 +464,7 @@ class RobustScannerRecResizeImg(object):
         data['word_positons'] = word_positons
         return data
 
+
 def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
     imgC, imgH, imgW_min, imgW_max = image_shape
     h = img.shape[0]
@@ -467,13 +500,16 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
     return padding_im, resize_shape, pad_shape, valid_ratio
 
 
-def resize_norm_img(img, image_shape, padding=True):
+def resize_norm_img(img,
+                    image_shape,
+                    padding=True,
+                    interpolation=cv2.INTER_LINEAR):
     imgC, imgH, imgW = image_shape
     h = img.shape[0]
     w = img.shape[1]
     if not padding:
         resized_image = cv2.resize(
-            img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+            img, (imgW, imgH), interpolation=interpolation)
         resized_w = imgW
     else:
         ratio = w / float(h)
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 02525b3d5..ffee0a93e 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -38,6 +38,7 @@ from .rec_pren_loss import PRENLoss
 from .rec_multi_loss import MultiLoss
 from .rec_vl_loss import VLLoss
 from .rec_spin_att_loss import SPINAttentionLoss
+from .rec_rfl_loss import RFLLoss
 
 # cls loss
 from .cls_loss import ClsLoss
@@ -69,7 +70,7 @@ def build_loss(config):
         'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
         'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
         'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
-        'SLALoss', 'CTLoss'
+        'SLALoss', 'CTLoss', 'RFLLoss'
     ]
     config = copy.deepcopy(config)
     module_name = config.pop('name')
diff --git a/ppocr/losses/rec_rfl_loss.py b/ppocr/losses/rec_rfl_loss.py
new file mode 100644
index 000000000..8e9d7d039
--- /dev/null
+++ b/ppocr/losses/rec_rfl_loss.py
@@ -0,0 +1,61 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+from .basic_loss import CELoss, DistanceLoss
+
+
+class RFLLoss(nn.Layer):
+    def __init__(self, ignore_index=-100, **kwargs):
+        super().__init__()
+
+        self.cnt_loss = nn.MSELoss(**kwargs)
+        self.seq_loss = nn.CrossEntropyLoss(ignore_index=ignore_index)
+
+    def forward(self, predicts, batch):
+
+        self.total_loss = {}
+        total_loss = 0.0
+        # batch [image, label, length, cnt_label]
+        if predicts[0] is not None:
+            cnt_loss = self.cnt_loss(predicts[0],
+                                     paddle.cast(batch[3], paddle.float32))
+            self.total_loss['cnt_loss'] = cnt_loss
+            total_loss += cnt_loss
+
+        if predicts[1] is not None:
+            targets = batch[1].astype("int64")
+            label_lengths = batch[2].astype('int64')
+            batch_size, num_steps, num_classes = predicts[1].shape[0], predicts[
+                1].shape[1], predicts[1].shape[2]
+            assert len(targets.shape) == len(list(predicts[1].shape)) - 1, \
+                "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+            inputs = predicts[1][:, :-1, :]
+            targets = targets[:, 1:]
+
+            inputs = paddle.reshape(inputs, [-1, inputs.shape[-1]])
+            targets = paddle.reshape(targets, [-1])
+            seq_loss = self.seq_loss(inputs, targets)
+            self.total_loss['seq_loss'] = seq_loss
+            total_loss += seq_loss
+
+        self.total_loss['loss'] = total_loss
+        return self.total_loss
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index a39d0a464..20aea8b59 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -22,7 +22,7 @@ import copy
 __all__ = ["build_metric"]
 
 from .det_metric import DetMetric, DetFCEMetric
-from .rec_metric import RecMetric
+from .rec_metric import RecMetric, CNTMetric
 from .cls_metric import ClsMetric
 from .e2e_metric import E2EMetric
 from .distillation_metric import DistillationMetric
@@ -38,7 +38,7 @@ def build_metric(config):
     support_dict = [
         "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
         "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
-        'VQAReTokenMetric', 'SRMetric', 'CTMetric'
+        'VQAReTokenMetric', 'SRMetric', 'CTMetric', 'CNTMetric'
     ]
 
     config = copy.deepcopy(config)
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 986397811..34d6ff3a3 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -16,7 +16,6 @@ from rapidfuzz.distance import Levenshtein
 import string
 
 
-
 class RecMetric(object):
     def __init__(self,
                  main_indicator='acc',
@@ -74,3 +73,42 @@ class RecMetric(object):
         self.correct_num = 0
         self.all_num = 0
         self.norm_edit_dis = 0
+
+
+class CNTMetric(object):
+    def __init__(self, main_indicator='acc', **kwargs):
+        self.main_indicator = main_indicator
+        self.eps = 1e-5
+        self.reset()
+
+    def _normalize_text(self, text):
+        text = ''.join(
+            filter(lambda x: x in (string.digits + string.ascii_letters), text))
+        return text.lower()
+
+    def __call__(self, pred_label, *args, **kwargs):
+        preds, labels = pred_label
+        correct_num = 0
+        all_num = 0
+        for pred, target in zip(preds, labels):
+            if pred == target:
+                correct_num += 1
+            all_num += 1
+        self.correct_num += correct_num
+        self.all_num += all_num
+        return {'acc': correct_num / (all_num + self.eps), }
+
+    def get_metric(self):
+        """
+        return metrics {
+                 'acc': 0,
+                 'norm_edit_dis': 0,
+            }
+        """
+        acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+        self.reset()
+        return {'acc': acc}
+
+    def reset(self):
+        self.correct_num = 0
+        self.all_num = 0
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 6fdcc4a75..84892fa9c 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -42,10 +42,11 @@ def build_backbone(config, model_type):
         from .rec_efficientb3_pren import EfficientNetb3_PREN
         from .rec_svtrnet import SVTRNet
         from .rec_vitstr import ViTSTR
+        from .rec_resnet_rfl import ResNetRFL
         support_dict = [
             'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
             'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
-            'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
+            'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL'
         ]
     elif model_type == 'e2e':
         from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_resnet_rfl.py b/ppocr/modeling/backbones/rec_resnet_rfl.py
new file mode 100644
index 000000000..fd317c6ea
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_rfl.py
@@ -0,0 +1,348 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from: 
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
+
+kaiming_init_ = KaimingNormal()
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+class BasicBlock(nn.Layer):
+    """Res-net Basic Block"""
+    expansion = 1
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 downsample=None,
+                 norm_type='BN',
+                 **kwargs):
+        """
+        Args:
+            inplanes (int): input channel
+            planes (int): channels of the middle feature
+            stride (int): stride of the convolution
+            downsample (int): type of the down_sample
+            norm_type (str): type of the normalization
+            **kwargs (None): backup parameter
+        """
+        super(BasicBlock, self).__init__()
+        self.conv1 = self._conv3x3(inplanes, planes)
+        self.bn1 = nn.BatchNorm(planes)
+        self.conv2 = self._conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm(planes)
+        self.relu = nn.ReLU()
+        self.downsample = downsample
+        self.stride = stride
+
+    def _conv3x3(self, in_planes, out_planes, stride=1):
+
+        return nn.Conv2D(
+            in_planes,
+            out_planes,
+            kernel_size=3,
+            stride=stride,
+            padding=1,
+            bias_attr=False)
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ResNetRFL(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels=512,
+                 use_cnt=True,
+                 use_seq=True):
+        """
+
+        Args:
+            in_channels (int): input channel
+            out_channels (int): output channel
+        """
+        super(ResNetRFL, self).__init__()
+        assert use_cnt or use_seq
+        self.use_cnt, self.use_seq = use_cnt, use_seq
+        self.backbone = RFLBase(in_channels)
+
+        self.out_channels = out_channels
+        self.out_channels_block = [
+            int(self.out_channels / 4), int(self.out_channels / 2),
+            self.out_channels, self.out_channels
+        ]
+        block = BasicBlock
+        layers = [1, 2, 5, 3]
+        self.inplanes = int(self.out_channels // 2)
+
+        self.relu = nn.ReLU()
+        if self.use_seq:
+            self.maxpool3 = nn.MaxPool2D(
+                kernel_size=2, stride=(2, 1), padding=(0, 1))
+            self.layer3 = self._make_layer(
+                block, self.out_channels_block[2], layers[2], stride=1)
+            self.conv3 = nn.Conv2D(
+                self.out_channels_block[2],
+                self.out_channels_block[2],
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias_attr=False)
+            self.bn3 = nn.BatchNorm(self.out_channels_block[2])
+
+            self.layer4 = self._make_layer(
+                block, self.out_channels_block[3], layers[3], stride=1)
+            self.conv4_1 = nn.Conv2D(
+                self.out_channels_block[3],
+                self.out_channels_block[3],
+                kernel_size=2,
+                stride=(2, 1),
+                padding=(0, 1),
+                bias_attr=False)
+            self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
+            self.conv4_2 = nn.Conv2D(
+                self.out_channels_block[3],
+                self.out_channels_block[3],
+                kernel_size=2,
+                stride=1,
+                padding=0,
+                bias_attr=False)
+            self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
+
+        if self.use_cnt:
+            self.inplanes = int(self.out_channels // 2)
+            self.v_maxpool3 = nn.MaxPool2D(
+                kernel_size=2, stride=(2, 1), padding=(0, 1))
+            self.v_layer3 = self._make_layer(
+                block, self.out_channels_block[2], layers[2], stride=1)
+            self.v_conv3 = nn.Conv2D(
+                self.out_channels_block[2],
+                self.out_channels_block[2],
+                kernel_size=3,
+                stride=1,
+                padding=1,
+                bias_attr=False)
+            self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
+
+            self.v_layer4 = self._make_layer(
+                block, self.out_channels_block[3], layers[3], stride=1)
+            self.v_conv4_1 = nn.Conv2D(
+                self.out_channels_block[3],
+                self.out_channels_block[3],
+                kernel_size=2,
+                stride=(2, 1),
+                padding=(0, 1),
+                bias_attr=False)
+            self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
+            self.v_conv4_2 = nn.Conv2D(
+                self.out_channels_block[3],
+                self.out_channels_block[3],
+                kernel_size=2,
+                stride=1,
+                padding=0,
+                bias_attr=False)
+            self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2D(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias_attr=False),
+                nn.BatchNorm(planes * block.expansion), )
+
+        layers = list()
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, inputs):
+        x_1 = self.backbone(inputs)
+
+        if self.use_cnt:
+            v_x = self.v_maxpool3(x_1)
+            v_x = self.v_layer3(v_x)
+            v_x = self.v_conv3(v_x)
+            v_x = self.v_bn3(v_x)
+            visual_feature_2 = self.relu(v_x)
+
+            v_x = self.v_layer4(visual_feature_2)
+            v_x = self.v_conv4_1(v_x)
+            v_x = self.v_bn4_1(v_x)
+            v_x = self.relu(v_x)
+            v_x = self.v_conv4_2(v_x)
+            v_x = self.v_bn4_2(v_x)
+            visual_feature_3 = self.relu(v_x)
+        else:
+            visual_feature_3 = None
+        if self.use_seq:
+            x = self.maxpool3(x_1)
+            x = self.layer3(x)
+            x = self.conv3(x)
+            x = self.bn3(x)
+            x_2 = self.relu(x)
+
+            x = self.layer4(x_2)
+            x = self.conv4_1(x)
+            x = self.bn4_1(x)
+            x = self.relu(x)
+            x = self.conv4_2(x)
+            x = self.bn4_2(x)
+            x_3 = self.relu(x)
+        else:
+            x_3 = None
+
+        return [visual_feature_3, x_3]
+
+
+class ResNetBase(nn.Layer):
+    def __init__(self, in_channels, out_channels, block, layers):
+        super(ResNetBase, self).__init__()
+
+        self.out_channels_block = [
+            int(out_channels / 4), int(out_channels / 2), out_channels,
+            out_channels
+        ]
+
+        self.inplanes = int(out_channels / 8)
+        self.conv0_1 = nn.Conv2D(
+            in_channels,
+            int(out_channels / 16),
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias_attr=False)
+        self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
+        self.conv0_2 = nn.Conv2D(
+            int(out_channels / 16),
+            self.inplanes,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias_attr=False)
+        self.bn0_2 = nn.BatchNorm(self.inplanes)
+        self.relu = nn.ReLU()
+
+        self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+        self.layer1 = self._make_layer(block, self.out_channels_block[0],
+                                       layers[0])
+        self.conv1 = nn.Conv2D(
+            self.out_channels_block[0],
+            self.out_channels_block[0],
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias_attr=False)
+        self.bn1 = nn.BatchNorm(self.out_channels_block[0])
+
+        self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+        self.layer2 = self._make_layer(
+            block, self.out_channels_block[1], layers[1], stride=1)
+        self.conv2 = nn.Conv2D(
+            self.out_channels_block[1],
+            self.out_channels_block[1],
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias_attr=False)
+        self.bn2 = nn.BatchNorm(self.out_channels_block[1])
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2D(
+                    self.inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias_attr=False),
+                nn.BatchNorm(planes * block.expansion), )
+
+        layers = list()
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv0_1(x)
+        x = self.bn0_1(x)
+        x = self.relu(x)
+        x = self.conv0_2(x)
+        x = self.bn0_2(x)
+        x = self.relu(x)
+
+        x = self.maxpool1(x)
+        x = self.layer1(x)
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+
+        x = self.maxpool2(x)
+        x = self.layer2(x)
+        x = self.conv2(x)
+        x = self.bn2(x)
+        x = self.relu(x)
+
+        return x
+
+
+class RFLBase(nn.Layer):
+    """ Reciprocal feature learning share backbone network"""
+
+    def __init__(self, in_channels, out_channels=512):
+        super(RFLBase, self).__init__()
+        self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
+                                  [1, 2, 5, 3])
+
+    def forward(self, inputs):
+        return self.ConvNet(inputs)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 751757e5f..ba180566c 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -38,6 +38,7 @@ def build_head(config):
     from .rec_abinet_head import ABINetHead
     from .rec_robustscanner_head import RobustScannerHead
     from .rec_visionlan_head import VLHead
+    from .rec_rfl_head import RFLHead
 
     # cls head
     from .cls_head import ClsHead
@@ -53,7 +54,7 @@ def build_head(config):
         'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
         'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
         'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
-        'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
+        'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head', 'RFLHead'
     ]
 
     #table head
diff --git a/ppocr/modeling/heads/rec_att_head.py b/ppocr/modeling/heads/rec_att_head.py
index ab8b119fe..d5cf1cd16 100644
--- a/ppocr/modeling/heads/rec_att_head.py
+++ b/ppocr/modeling/heads/rec_att_head.py
@@ -149,6 +149,8 @@ class AttentionLSTM(nn.Layer):
         else:
             targets = paddle.zeros(shape=[batch_size], dtype="int32")
             probs = None
+            char_onehots = None
+            alpha = None
 
             for i in range(num_steps):
                 char_onehots = self._char_to_onehot(
diff --git a/ppocr/modeling/heads/rec_rfl_head.py b/ppocr/modeling/heads/rec_rfl_head.py
new file mode 100644
index 000000000..b5452ec1c
--- /dev/null
+++ b/ppocr/modeling/heads/rec_rfl_head.py
@@ -0,0 +1,109 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from: 
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/sequence_heads/counting_head.py
+"""
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
+
+from .rec_att_head import AttentionLSTM
+
+kaiming_init_ = KaimingNormal()
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+class CNTHead(nn.Layer):
+    def __init__(self,
+                 embed_size=512,
+                 encode_length=26,
+                 out_channels=38,
+                 **kwargs):
+        super(CNTHead, self).__init__()
+
+        self.out_channels = out_channels
+
+        self.Wv_fusion = nn.Linear(embed_size, embed_size, bias_attr=False)
+        self.Prediction_visual = nn.Linear(encode_length * embed_size,
+                                           self.out_channels)
+
+    def forward(self, visual_feature):
+
+        b, c, h, w = visual_feature.shape
+        visual_feature = visual_feature.reshape([b, c, h * w]).transpose(
+            [0, 2, 1])
+        visual_feature_num = self.Wv_fusion(visual_feature)  # batch * 26 * 512
+        b, n, c = visual_feature_num.shape
+        # using visual feature directly calculate the text length
+        visual_feature_num = visual_feature_num.reshape([b, n * c])
+        prediction_visual = self.Prediction_visual(visual_feature_num)
+
+        return prediction_visual
+
+
+class RFLHead(nn.Layer):
+    def __init__(self,
+                 in_channels=512,
+                 hidden_size=256,
+                 batch_max_legnth=25,
+                 out_channels=38,
+                 use_cnt=True,
+                 use_seq=True,
+                 **kwargs):
+
+        super(RFLHead, self).__init__()
+        assert use_cnt or use_seq
+        self.use_cnt = use_cnt
+        self.use_seq = use_seq
+        if self.use_cnt:
+            self.cnt_head = CNTHead(
+                embed_size=in_channels,
+                encode_length=batch_max_legnth + 1,
+                out_channels=out_channels,
+                **kwargs)
+        if self.use_seq:
+            self.seq_head = AttentionLSTM(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                hidden_size=hidden_size,
+                **kwargs)
+        self.batch_max_legnth = batch_max_legnth
+        self.num_class = out_channels
+        self.apply(self.init_weights)
+
+    def init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            kaiming_init_(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                zeros_(m.bias)
+
+    def forward(self, x, targets=None):
+        cnt_inputs, seq_inputs = x
+        if self.use_cnt:
+            cnt_outputs = self.cnt_head(cnt_inputs)
+        else:
+            cnt_outputs = None
+        if self.use_seq:
+            if self.training:
+                seq_outputs = self.seq_head(seq_inputs, targets[0],
+                                            self.batch_max_legnth)
+            else:
+                seq_outputs = self.seq_head(seq_inputs, None,
+                                            self.batch_max_legnth)
+        else:
+            seq_outputs = None
+
+        return cnt_outputs, seq_outputs
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index c7e8dd068..a94d223a1 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -27,9 +27,11 @@ def build_neck(config):
     from .pren_fpn import PRENFPN
     from .csp_pan import CSPPAN
     from .ct_fpn import CTFPN
+    from .rf_adaptor import RFAdaptor
     support_dict = [
         'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
-        'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
+        'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN',
+        'RFAdaptor'
     ]
 
     module_name = config.pop('name')
diff --git a/ppocr/modeling/necks/rf_adaptor.py b/ppocr/modeling/necks/rf_adaptor.py
new file mode 100644
index 000000000..94590127b
--- /dev/null
+++ b/ppocr/modeling/necks/rf_adaptor.py
@@ -0,0 +1,137 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from: 
+https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/connects/single_block/RFAdaptor.py
+"""
+
+import paddle
+import paddle.nn as nn
+from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
+
+kaiming_init_ = KaimingNormal()
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
+
+
+class S2VAdaptor(nn.Layer):
+    """ Semantic to Visual adaptation module"""
+
+    def __init__(self, in_channels=512):
+        super(S2VAdaptor, self).__init__()
+
+        self.in_channels = in_channels  # 512
+
+        # feature strengthen module, channel attention
+        self.channel_inter = nn.Linear(
+            self.in_channels, self.in_channels, bias_attr=False)
+        self.channel_bn = nn.BatchNorm1D(self.in_channels)
+        self.channel_act = nn.ReLU()
+        self.apply(self.init_weights)
+
+    def init_weights(self, m):
+        if isinstance(m, nn.Conv2D):
+            kaiming_init_(m.weight)
+            if isinstance(m, nn.Conv2D) and m.bias is not None:
+                zeros_(m.bias)
+        elif isinstance(m, (nn.BatchNorm, nn.BatchNorm2D, nn.BatchNorm1D)):
+            zeros_(m.bias)
+            ones_(m.weight)
+
+    def forward(self, semantic):
+        semantic_source = semantic  # batch, channel, height, width
+
+        # feature transformation
+        semantic = semantic.squeeze(2).transpose(
+            [0, 2, 1])  # batch, width, channel
+        channel_att = self.channel_inter(semantic)  # batch, width, channel
+        channel_att = channel_att.transpose([0, 2, 1])  # batch, channel, width
+        channel_bn = self.channel_bn(channel_att)  # batch, channel, width
+        channel_att = self.channel_act(channel_bn)  # batch, channel, width
+
+        # Feature enhancement
+        channel_output = semantic_source * channel_att.unsqueeze(
+            -2)  # batch, channel, 1, width
+
+        return channel_output
+
+
+class V2SAdaptor(nn.Layer):
+    """ Visual to Semantic adaptation module"""
+
+    def __init__(self, in_channels=512, return_mask=False):
+        super(V2SAdaptor, self).__init__()
+
+        # parameter initialization
+        self.in_channels = in_channels
+        self.return_mask = return_mask
+
+        # output transformation
+        self.channel_inter = nn.Linear(
+            self.in_channels, self.in_channels, bias_attr=False)
+        self.channel_bn = nn.BatchNorm1D(self.in_channels)
+        self.channel_act = nn.ReLU()
+
+    def forward(self, visual):
+        # Feature enhancement
+        visual = visual.squeeze(2).transpose([0, 2, 1])  # batch, width, channel
+        channel_att = self.channel_inter(visual)  # batch, width, channel
+        channel_att = channel_att.transpose([0, 2, 1])  # batch, channel, width
+        channel_bn = self.channel_bn(channel_att)  # batch, channel, width
+        channel_att = self.channel_act(channel_bn)  # batch, channel, width
+
+        # size alignment
+        channel_output = channel_att.unsqueeze(-2)  # batch, width, channel
+
+        if self.return_mask:
+            return channel_output, channel_att
+        return channel_output
+
+
+class RFAdaptor(nn.Layer):
+    def __init__(self, in_channels=512, use_v2s=True, use_s2v=True, **kwargs):
+        super(RFAdaptor, self).__init__()
+        if use_v2s is True:
+            self.neck_v2s = V2SAdaptor(in_channels=in_channels, **kwargs)
+        else:
+            self.neck_v2s = None
+        if use_s2v is True:
+            self.neck_s2v = S2VAdaptor(in_channels=in_channels, **kwargs)
+        else:
+            self.neck_s2v = None
+        self.out_channels = in_channels
+
+    def forward(self, x):
+        visual_feature, rcg_feature = x
+        if visual_feature is not None:
+            batch, source_channels, v_source_height, v_source_width = visual_feature.shape
+            visual_feature = visual_feature.reshape(
+                [batch, source_channels, 1, v_source_height * v_source_width])
+
+        if self.neck_v2s is not None:
+            v_rcg_feature = rcg_feature * self.neck_v2s(visual_feature)
+        else:
+            v_rcg_feature = rcg_feature
+
+        if self.neck_s2v is not None:
+            v_visual_feature = visual_feature + self.neck_s2v(rcg_feature)
+        else:
+            v_visual_feature = visual_feature
+        if v_rcg_feature is not None:
+            batch, source_channels, source_height, source_width = v_rcg_feature.shape
+            v_rcg_feature = v_rcg_feature.reshape(
+                [batch, source_channels, 1, source_height * source_width])
+
+            v_rcg_feature = v_rcg_feature.squeeze(2).transpose([0, 2, 1])
+        return v_visual_feature, v_rcg_feature
diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py
index a6bd2ebb4..b92954c9c 100644
--- a/ppocr/optimizer/__init__.py
+++ b/ppocr/optimizer/__init__.py
@@ -53,6 +53,9 @@ def build_optimizer(config, epochs, step_each_epoch, model):
     if 'clip_norm' in config:
         clip_norm = config.pop('clip_norm')
         grad_clip = paddle.nn.ClipGradByNorm(clip_norm=clip_norm)
+    elif 'clip_norm_global' in config:
+        clip_norm = config.pop('clip_norm_global')
+        grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=clip_norm)
     else:
         grad_clip = None
     optim = getattr(optimizer, optim_name)(learning_rate=lr,
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 35b7a6800..b5715967b 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
 from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
     DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
     SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
-    SPINLabelDecode, VLLabelDecode
+    SPINLabelDecode, VLLabelDecode, RFLLabelDecode
 from .cls_postprocess import ClsPostProcess
 from .pg_postprocess import PGPostProcess
 from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
@@ -49,7 +49,7 @@ def build_post_process(config, global_config=None):
         'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
         'TableMasterLabelDecode', 'SPINLabelDecode',
         'DistillationSerPostProcess', 'DistillationRePostProcess',
-        'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
+        'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess', 'RFLLabelDecode'
     ]
 
     if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 749060a05..e754c950b 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -242,6 +242,92 @@ class AttnLabelDecode(BaseRecLabelDecode):
         return idx
 
 
+class RFLLabelDecode(BaseRecLabelDecode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self, character_dict_path=None, use_space_char=False,
+                 **kwargs):
+        super(RFLLabelDecode, self).__init__(character_dict_path,
+                                             use_space_char)
+
+    def add_special_char(self, dict_character):
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = dict_character
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+        """ convert text-index into text-label. """
+        result_list = []
+        ignored_tokens = self.get_ignored_tokens()
+        [beg_idx, end_idx] = self.get_ignored_tokens()
+        batch_size = len(text_index)
+        for batch_idx in range(batch_size):
+            char_list = []
+            conf_list = []
+            for idx in range(len(text_index[batch_idx])):
+                if text_index[batch_idx][idx] in ignored_tokens:
+                    continue
+                if int(text_index[batch_idx][idx]) == int(end_idx):
+                    break
+                if is_remove_duplicate:
+                    # only for predict
+                    if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+                            batch_idx][idx]:
+                        continue
+                char_list.append(self.character[int(text_index[batch_idx][
+                    idx])])
+                if text_prob is not None:
+                    conf_list.append(text_prob[batch_idx][idx])
+                else:
+                    conf_list.append(1)
+            text = ''.join(char_list)
+            result_list.append((text, np.mean(conf_list).tolist()))
+        return result_list
+
+    def __call__(self, preds, label=None, *args, **kwargs):
+        cnt_pred, preds = preds
+        if preds is not None:
+
+            if isinstance(preds, paddle.Tensor):
+                preds = preds.numpy()
+            preds_idx = preds.argmax(axis=2)
+            preds_prob = preds.max(axis=2)
+            text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+
+            if label is None:
+                return text
+            label = self.decode(label, is_remove_duplicate=False)
+            return text, label
+
+        else:
+            cnt_length = []
+            for lens in cnt_pred:
+                length = round(paddle.sum(lens).item())
+                cnt_length.append(length)
+            if label is None:
+                return cnt_length
+            label = self.decode(label, is_remove_duplicate=False)
+            length = [len(res[0]) for res in label]
+            return cnt_length, length
+
+    def get_ignored_tokens(self):
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "unsupport type %s in get_beg_end_flag_idx" \
+                          % beg_or_end
+        return idx
+
+
 class SEEDLabelDecode(BaseRecLabelDecode):
     """ Convert between text-label and text-index """
 
diff --git a/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml b/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
new file mode 100644
index 000000000..b4f18f5c0
--- /dev/null
+++ b/test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml
@@ -0,0 +1,111 @@
+Global:
+  use_gpu: True
+  epoch_num: 6
+  log_smooth_window: 20
+  print_batch_step: 50
+  save_model_dir: ./output/rec/rec_resnet_rfl/
+  save_epoch_step: 1
+  # evaluation is run every 5000 iterations after the 4000th iteration
+  eval_batch_step: [0, 5000]
+  cal_metric_during_train: False
+  pretrained_model:
+  checkpoints: 
+  save_inference_dir:
+  use_visualdl: False
+  infer_img: doc/imgs_words_en/word_10.png
+  # for data or label process
+  character_dict_path:
+  max_text_length: 25
+  infer_mode: False
+  use_space_char: False
+  save_res_path: ./output/rec/rec_resnet_rfl.txt
+
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  weight_decay: 0.0
+  clip_norm_global: 5.0
+  lr:
+    name: Piecewise
+    decay_epochs : [3, 4, 5]
+    values : [0.001, 0.0003, 0.00009, 0.000027]
+
+Architecture:
+  model_type: rec
+  algorithm: RFL
+  in_channels: 1
+  Transform:
+    name: TPS
+    num_fiducial: 20
+    loc_lr: 1.0
+    model_name: large
+  Backbone:
+    name: ResNetRFL
+    use_cnt: True
+    use_seq: True
+  Neck:
+    name: RFAdaptor
+    use_v2s: True
+    use_s2v: True
+  Head:
+    name: RFLHead  
+    in_channels: 512
+    hidden_size: 256
+    batch_max_legnth: 25
+    out_channels: 38
+    use_cnt: True
+    use_seq: True
+
+Loss:
+  name: RFLLoss
+
+PostProcess:
+  name: RFLLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/ic15_data/
+    label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - RFLLabelEncode: # Class handling label
+      - RFLRecResizeImg:
+          image_shape: [1, 32, 100]
+          interpolation: 2
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+  loader:
+    shuffle: True
+    batch_size_per_card: 64
+    drop_last: True
+    num_workers: 8
+
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/ic15_data
+    label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - RFLLabelEncode: # Class handling label
+      - RFLRecResizeImg:
+          image_shape: [1, 32, 100]
+          interpolation: 2
+      - KeepKeys:
+          keep_keys: ['image', 'label', 'length', 'cnt_label'] # dataloader will return list in this order
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 256
+    num_workers: 8
diff --git a/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt b/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
new file mode 100644
index 000000000..091e962b2
--- /dev/null
+++ b/test_tipc/configs/rec_resnet_rfl/train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:rec_resnet_rfl
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:null
+Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
+Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./inference/rec_inference
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Global.checkpoints:
+norm_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+quant_export:null
+fpgm_export:null
+distill_export:null
+export1:null
+export2:null
+##
+train_model:./inference/rec_resnet_rfl_train/best_accuracy
+infer_export:tools/export_model.py -c test_tipc/configs/rec_resnet_rfl/rec_resnet_rfl.yml -o
+infer_quant:False
+inference:tools/infer/predict_rec.py --rec_image_shape="1,32,100" --rec_algorithm="RFL" --min_subgraph_size=5
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--rec_model_dir:
+--image_dir:./inference/rec_inference
+--save_log_path:./test/output/
+--benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[1,32,100]}]
diff --git a/tools/export_model.py b/tools/export_model.py
index 8610df83e..9c23060ee 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -99,7 +99,7 @@ def export_single_model(model,
         ]
         # print([None, 3, 32, 128])
         model = to_static(model, input_spec=other_shape)
-    elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
+    elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
         other_shape = [
             paddle.static.InputSpec(
                 shape=[None, 1, 32, 100], dtype="float32"),
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 176e2c68e..697e9da43 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -100,6 +100,12 @@ class TextRecognizer(object):
                 "use_space_char": args.use_space_char,
                 "rm_symbol": True
             }
+        elif self.rec_algorithm == 'RFL':
+            postprocess_params = {
+                'name': 'RFLLabelDecode',
+                "character_dict_path": None,
+                "use_space_char": args.use_space_char
+            }
         self.postprocess_op = build_post_process(postprocess_params)
         self.predictor, self.input_tensor, self.output_tensors, self.config = \
             utility.create_predictor(args, 'rec', logger)
@@ -143,6 +149,16 @@ class TextRecognizer(object):
             else:
                 norm_img = norm_img.astype(np.float32) / 128. - 1.
             return norm_img
+        elif self.rec_algorithm == 'RFL':
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+            resized_image = cv2.resize(
+                img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
+            resized_image = resized_image.astype('float32')
+            resized_image = resized_image / 255
+            resized_image = resized_image[np.newaxis, :]
+            resized_image -= 0.5
+            resized_image /= 0.5
+            return resized_image
 
         assert imgC == img.shape[2]
         imgW = int((imgH * max_wh_ratio))
diff --git a/tools/program.py b/tools/program.py
index 9117d51b9..129e926be 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -217,7 +217,7 @@ def train(config,
     use_srn = config['Architecture']['algorithm'] == "SRN"
     extra_input_models = [
         "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
-        "RobustScanner"
+        "RobustScanner", "RFL"
     ]
     extra_input = False
     if config['Architecture']['algorithm'] == 'Distillation':
@@ -625,7 +625,7 @@ def preprocess(is_train=False):
         'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
         'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
         'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
-        'Gestalt', 'SLANet', 'RobustScanner', 'CT'
+        'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL'
     ]
 
     if use_xpu: