From cf533b65c5acb9c1482d23c8c5201c8431d00d90 Mon Sep 17 00:00:00 2001
From: andyjpaddle <jiangkaitao@baidu.com>
Date: Tue, 19 Jul 2022 12:38:54 +0000
Subject: [PATCH] add vl

---
 ppocr/data/imaug/__init__.py                 |   3 +-
 ppocr/data/imaug/label_ops.py                |  63 ++-
 ppocr/data/imaug/rec_img_aug.py              |  35 ++
 ppocr/data/imaug/text_image_aug/__init__.py  |   3 +-
 ppocr/data/imaug/text_image_aug/vl_aug.py    | 460 +++++++++++++++++
 ppocr/losses/__init__.py                     |   4 +-
 ppocr/modeling/backbones/__init__.py         |   4 +-
 ppocr/modeling/backbones/rec_resnet_aster.py | 111 +++++
 ppocr/modeling/heads/__init__.py             |   3 +-
 ppocr/modeling/heads/rec_visionlan_head.py   | 498 +++++++++++++++++++
 ppocr/postprocess/__init__.py                |   4 +-
 ppocr/postprocess/rec_postprocess.py         |  70 ++-
 tools/eval.py                                |   2 +-
 tools/export_model.py                        |   2 +-
 tools/infer/predict_rec.py                   |  20 +
 tools/program.py                             |  42 +-
 16 files changed, 1297 insertions(+), 27 deletions(-)
 create mode 100644 ppocr/data/imaug/text_image_aug/vl_aug.py
 create mode 100644 ppocr/modeling/heads/rec_visionlan_head.py

diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index f0fd578f6..20719e022 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
 from .make_pse_gt import MakePseGt
 
 from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
-    SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
+    SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, VLRecResizeImg
+from .text_image_aug import VLAug
 from .ssl_img_aug import SSLRotateResize
 from .randaugment import RandAugment
 from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 02a5187da..304e190dc 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -23,6 +23,7 @@ import string
 from shapely.geometry import LineString, Point, Polygon
 import json
 import copy
+from random import sample
 
 from ppocr.utils.logging import get_logger
 
@@ -443,7 +444,9 @@ class KieLabelEncode(object):
             elif 'key_cls' in anno.keys():
                 labels.append(anno['key_cls'])
             else:
-                raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.")
+                raise ValueError(
+                    "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
+                )
             edges.append(ann.get('edge', 0))
         ann_infos = dict(
             image=data['image'],
@@ -1044,3 +1047,61 @@ class MultiLabelEncode(BaseRecLabelEncode):
         data_out['label_sar'] = sar['label']
         data_out['length'] = ctc['length']
         return data_out
+
+
+class VLLabelEncode(BaseRecLabelEncode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self,
+                 max_text_length,
+                 character_dict_path=None,
+                 use_space_char=False,
+                 **kwargs):
+        super(VLLabelEncode, self).__init__(max_text_length,
+                                            character_dict_path, use_space_char)
+
+    def __call__(self, data):
+        text = data['label']  # original string
+        # generate occluded text
+        len_str = len(text)
+        if len_str <= 0:
+            return None
+        change_num = 1
+        order = list(range(len_str))
+        change_id = sample(order, change_num)[0]
+        label_sub = text[change_id]
+        if change_id == (len_str - 1):
+            label_res = text[:change_id]
+        elif change_id == 0:
+            label_res = text[1:]
+        else:
+            label_res = text[:change_id] + text[change_id + 1:]
+
+        data['label_res'] = label_res  # remaining string
+        data['label_sub'] = label_sub  # occluded character
+        data['label_id'] = change_id  # character index
+        # encode label
+        text = self.encode(text)
+        if text is None:
+            return None
+        text = [i + 1 for i in text]
+        data['length'] = np.array(len(text))
+        text = text + [0] * (self.max_text_len - len(text))
+        data['label'] = np.array(text)
+        label_res = self.encode(label_res)
+        label_sub = self.encode(label_sub)
+        if label_res is None:
+            label_res = []
+        else:
+            label_res = [i + 1 for i in label_res]
+        if label_sub is None:
+            label_sub = []
+        else:
+            label_sub = [i + 1 for i in label_sub]
+        data['length_res'] = np.array(len(label_res))
+        data['length_sub'] = np.array(len(label_sub))
+        label_res = label_res + [0] * (self.max_text_len - len(label_res))
+        label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
+        data['label_res'] = np.array(label_res)
+        data['label_sub'] = np.array(label_sub)
+        return data
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 32de2b3fc..18d57963f 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -213,6 +213,41 @@ class RecResizeImg(object):
         return data
 
 
+class VLRecResizeImg(object):
+    def __init__(self,
+                 image_shape,
+                 infer_mode=False,
+                 character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
+                 padding=True,
+                 **kwargs):
+        self.image_shape = image_shape
+        self.infer_mode = infer_mode
+        self.character_dict_path = character_dict_path
+        self.padding = padding
+
+    def __call__(self, data):
+        img = data['image']
+        if self.infer_mode and self.character_dict_path is not None:
+            norm_img, valid_ratio = resize_norm_img_chinese(img,
+                                                            self.image_shape)
+        else:
+            imgC, imgH, imgW = self.image_shape
+            resized_image = cv2.resize(
+                img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+            resized_w = imgW
+            resized_image = resized_image.astype('float32')
+            if self.image_shape[0] == 1:
+                resized_image = resized_image / 255
+                norm_img = resized_image[np.newaxis, :]
+            else:
+                norm_img = resized_image.transpose((2, 0, 1)) / 255
+            valid_ratio = min(1.0, float(resized_w / imgW))
+
+        data['image'] = norm_img
+        data['valid_ratio'] = valid_ratio
+        return data
+
+
 class SRNRecResizeImg(object):
     def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
         self.image_shape = image_shape
diff --git a/ppocr/data/imaug/text_image_aug/__init__.py b/ppocr/data/imaug/text_image_aug/__init__.py
index bca262638..ca108b287 100644
--- a/ppocr/data/imaug/text_image_aug/__init__.py
+++ b/ppocr/data/imaug/text_image_aug/__init__.py
@@ -13,5 +13,6 @@
 # limitations under the License.
 
 from .augment import tia_perspective, tia_distort, tia_stretch
+from .vl_aug import VLAug
 
-__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
+__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective', 'VLAug']
diff --git a/ppocr/data/imaug/text_image_aug/vl_aug.py b/ppocr/data/imaug/text_image_aug/vl_aug.py
new file mode 100644
index 000000000..50b066b15
--- /dev/null
+++ b/ppocr/data/imaug/text_image_aug/vl_aug.py
@@ -0,0 +1,460 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import numbers
+import random
+
+import cv2
+import numpy as np
+from PIL import Image
+from paddle.vision import transforms
+from paddle.vision.transforms import Compose
+
+
+def sample_asym(magnitude, size=None):
+    return np.random.beta(1, 4, size) * magnitude
+
+
+def sample_sym(magnitude, size=None):
+    return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
+
+
+def sample_uniform(low, high, size=None):
+    return np.random.uniform(low, high, size=size)
+
+
+def get_interpolation(type='random'):
+    if type == 'random':
+        choice = [
+            cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
+        ]
+        interpolation = choice[random.randint(0, len(choice) - 1)]
+    elif type == 'nearest':
+        interpolation = cv2.INTER_NEAREST
+    elif type == 'linear':
+        interpolation = cv2.INTER_LINEAR
+    elif type == 'cubic':
+        interpolation = cv2.INTER_CUBIC
+    elif type == 'area':
+        interpolation = cv2.INTER_AREA
+    else:
+        raise TypeError(
+            'Interpolation types only nearest, linear, cubic, area are supported!'
+        )
+    return interpolation
+
+
+class CVRandomRotation(object):
+    def __init__(self, degrees=15):
+        assert isinstance(degrees,
+                          numbers.Number), "degree should be a single number."
+        assert degrees >= 0, "degree must be positive."
+        self.degrees = degrees
+
+    @staticmethod
+    def get_params(degrees):
+        return sample_sym(degrees)
+
+    def __call__(self, img):
+        angle = self.get_params(self.degrees)
+        src_h, src_w = img.shape[:2]
+        M = cv2.getRotationMatrix2D(
+            center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
+        abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
+        dst_w = int(src_h * abs_sin + src_w * abs_cos)
+        dst_h = int(src_h * abs_cos + src_w * abs_sin)
+        M[0, 2] += (dst_w - src_w) / 2
+        M[1, 2] += (dst_h - src_h) / 2
+
+        flags = get_interpolation()
+        return cv2.warpAffine(
+            img,
+            M, (dst_w, dst_h),
+            flags=flags,
+            borderMode=cv2.BORDER_REPLICATE)
+
+
+class CVRandomAffine(object):
+    def __init__(self, degrees, translate=None, scale=None, shear=None):
+        assert isinstance(degrees,
+                          numbers.Number), "degree should be a single number."
+        assert degrees >= 0, "degree must be positive."
+        self.degrees = degrees
+
+        if translate is not None:
+            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+                "translate should be a list or tuple and it must be of length 2."
+            for t in translate:
+                if not (0.0 <= t <= 1.0):
+                    raise ValueError(
+                        "translation values should be between 0 and 1")
+        self.translate = translate
+
+        if scale is not None:
+            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
+                "scale should be a list or tuple and it must be of length 2."
+            for s in scale:
+                if s <= 0:
+                    raise ValueError("scale values should be positive")
+        self.scale = scale
+
+        if shear is not None:
+            if isinstance(shear, numbers.Number):
+                if shear < 0:
+                    raise ValueError(
+                        "If shear is a single number, it must be positive.")
+                self.shear = [shear]
+            else:
+                assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
+                    "shear should be a list or tuple and it must be of length 2."
+                self.shear = shear
+        else:
+            self.shear = shear
+
+    def _get_inverse_affine_matrix(self, center, angle, translate, scale,
+                                   shear):
+        from numpy import sin, cos, tan
+
+        if isinstance(shear, numbers.Number):
+            shear = [shear, 0]
+
+        if not isinstance(shear, (tuple, list)) and len(shear) == 2:
+            raise ValueError(
+                "Shear should be a single value or a tuple/list containing " +
+                "two values. Got {}".format(shear))
+
+        rot = math.radians(angle)
+        sx, sy = [math.radians(s) for s in shear]
+
+        cx, cy = center
+        tx, ty = translate
+
+        # RSS without scaling
+        a = cos(rot - sy) / cos(sy)
+        b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
+        c = sin(rot - sy) / cos(sy)
+        d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
+
+        # Inverted rotation matrix with scale and shear
+        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
+        M = [d, -b, 0, -c, a, 0]
+        M = [x / scale for x in M]
+
+        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
+        M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
+        M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
+
+        # Apply center translation: C * RSS^-1 * C^-1 * T^-1
+        M[2] += cx
+        M[5] += cy
+        return M
+
+    @staticmethod
+    def get_params(degrees, translate, scale_ranges, shears, height):
+        angle = sample_sym(degrees)
+        if translate is not None:
+            max_dx = translate[0] * height
+            max_dy = translate[1] * height
+            translations = (np.round(sample_sym(max_dx)),
+                            np.round(sample_sym(max_dy)))
+        else:
+            translations = (0, 0)
+
+        if scale_ranges is not None:
+            scale = sample_uniform(scale_ranges[0], scale_ranges[1])
+        else:
+            scale = 1.0
+
+        if shears is not None:
+            if len(shears) == 1:
+                shear = [sample_sym(shears[0]), 0.]
+            elif len(shears) == 2:
+                shear = [sample_sym(shears[0]), sample_sym(shears[1])]
+        else:
+            shear = 0.0
+
+        return angle, translations, scale, shear
+
+    def __call__(self, img):
+        src_h, src_w = img.shape[:2]
+        angle, translate, scale, shear = self.get_params(
+            self.degrees, self.translate, self.scale, self.shear, src_h)
+
+        M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
+                                            (0, 0), scale, shear)
+        M = np.array(M).reshape(2, 3)
+
+        startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
+                       (0, src_h - 1)]
+        project = lambda x, y, a, b, c: int(a * x + b * y + c)
+        endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
+                     for x, y in startpoints]
+
+        rect = cv2.minAreaRect(np.array(endpoints))
+        bbox = cv2.boxPoints(rect).astype(dtype=np.int)
+        max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
+        min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
+
+        dst_w = int(max_x - min_x)
+        dst_h = int(max_y - min_y)
+        M[0, 2] += (dst_w - src_w) / 2
+        M[1, 2] += (dst_h - src_h) / 2
+
+        # add translate
+        dst_w += int(abs(translate[0]))
+        dst_h += int(abs(translate[1]))
+        if translate[0] < 0: M[0, 2] += abs(translate[0])
+        if translate[1] < 0: M[1, 2] += abs(translate[1])
+
+        flags = get_interpolation()
+        return cv2.warpAffine(
+            img,
+            M, (dst_w, dst_h),
+            flags=flags,
+            borderMode=cv2.BORDER_REPLICATE)
+
+
+class CVRandomPerspective(object):
+    def __init__(self, distortion=0.5):
+        self.distortion = distortion
+
+    def get_params(self, width, height, distortion):
+        offset_h = sample_asym(
+            distortion * height / 2, size=4).astype(dtype=np.int)
+        offset_w = sample_asym(
+            distortion * width / 2, size=4).astype(dtype=np.int)
+        topleft = (offset_w[0], offset_h[0])
+        topright = (width - 1 - offset_w[1], offset_h[1])
+        botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
+        botleft = (offset_w[3], height - 1 - offset_h[3])
+
+        startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
+                       (0, height - 1)]
+        endpoints = [topleft, topright, botright, botleft]
+        return np.array(
+            startpoints, dtype=np.float32), np.array(
+                endpoints, dtype=np.float32)
+
+    def __call__(self, img):
+        height, width = img.shape[:2]
+        startpoints, endpoints = self.get_params(width, height, self.distortion)
+        M = cv2.getPerspectiveTransform(startpoints, endpoints)
+
+        # TODO: more robust way to crop image
+        rect = cv2.minAreaRect(endpoints)
+        bbox = cv2.boxPoints(rect).astype(dtype=np.int)
+        max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
+        min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
+        min_x, min_y = max(min_x, 0), max(min_y, 0)
+
+        flags = get_interpolation()
+        img = cv2.warpPerspective(
+            img,
+            M, (max_x, max_y),
+            flags=flags,
+            borderMode=cv2.BORDER_REPLICATE)
+        img = img[min_y:, min_x:]
+        return img
+
+
+class CVRescale(object):
+    def __init__(self, factor=4, base_size=(128, 512)):
+        """ Define image scales using gaussian pyramid and rescale image to target scale.
+        
+        Args:
+            factor: the decayed factor from base size, factor=4 keeps target scale by default.
+            base_size: base size the build the bottom layer of pyramid
+        """
+        if isinstance(factor, numbers.Number):
+            self.factor = round(sample_uniform(0, factor))
+        elif isinstance(factor, (tuple, list)) and len(factor) == 2:
+            self.factor = round(sample_uniform(factor[0], factor[1]))
+        else:
+            raise Exception('factor must be number or list with length 2')
+        # assert factor is valid
+        self.base_h, self.base_w = base_size[:2]
+
+    def __call__(self, img):
+        if self.factor == 0:
+            return img
+        src_h, src_w = img.shape[:2]
+        cur_w, cur_h = self.base_w, self.base_h
+        scale_img = cv2.resize(
+            img, (cur_w, cur_h), interpolation=get_interpolation())
+        for _ in range(np.int(self.factor)):
+            scale_img = cv2.pyrDown(scale_img)
+        scale_img = cv2.resize(
+            scale_img, (src_w, src_h), interpolation=get_interpolation())
+        return scale_img
+
+
+class CVGaussianNoise(object):
+    def __init__(self, mean=0, var=20):
+        self.mean = mean
+        if isinstance(var, numbers.Number):
+            self.var = max(int(sample_asym(var)), 1)
+        elif isinstance(var, (tuple, list)) and len(var) == 2:
+            self.var = int(sample_uniform(var[0], var[1]))
+        else:
+            raise Exception('degree must be number or list with length 2')
+
+    def __call__(self, img):
+        noise = np.random.normal(self.mean, self.var**0.5, img.shape)
+        img = np.clip(img + noise, 0, 255).astype(np.uint8)
+        return img
+
+
+class CVMotionBlur(object):
+    def __init__(self, degrees=12, angle=90):
+        if isinstance(degrees, numbers.Number):
+            self.degree = max(int(sample_asym(degrees)), 1)
+        elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
+            self.degree = int(sample_uniform(degrees[0], degrees[1]))
+        else:
+            raise Exception('degree must be number or list with length 2')
+        self.angle = sample_uniform(-angle, angle)
+
+    def __call__(self, img):
+        M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
+                                    self.angle, 1)
+        motion_blur_kernel = np.zeros((self.degree, self.degree))
+        motion_blur_kernel[self.degree // 2, :] = 1
+        motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
+                                            (self.degree, self.degree))
+        motion_blur_kernel = motion_blur_kernel / self.degree
+        img = cv2.filter2D(img, -1, motion_blur_kernel)
+        img = np.clip(img, 0, 255).astype(np.uint8)
+        return img
+
+
+class CVGeometry(object):
+    def __init__(self,
+                 degrees=15,
+                 translate=(0.3, 0.3),
+                 scale=(0.5, 2.),
+                 shear=(45, 15),
+                 distortion=0.5,
+                 p=0.5):
+        self.p = p
+        type_p = random.random()
+        if type_p < 0.33:
+            self.transforms = CVRandomRotation(degrees=degrees)
+        elif type_p < 0.66:
+            self.transforms = CVRandomAffine(
+                degrees=degrees, translate=translate, scale=scale, shear=shear)
+        else:
+            self.transforms = CVRandomPerspective(distortion=distortion)
+
+    def __call__(self, img):
+        if random.random() < self.p:
+            return self.transforms(img)
+        else:
+            return img
+
+
+class CVDeterioration(object):
+    def __init__(self, var, degrees, factor, p=0.5):
+        self.p = p
+        transforms = []
+        if var is not None:
+            transforms.append(CVGaussianNoise(var=var))
+        if degrees is not None:
+            transforms.append(CVMotionBlur(degrees=degrees))
+        if factor is not None:
+            transforms.append(CVRescale(factor=factor))
+
+        random.shuffle(transforms)
+        transforms = Compose(transforms)
+        self.transforms = transforms
+
+    def __call__(self, img):
+        if random.random() < self.p:
+            return self.transforms(img)
+        else:
+            return img
+
+
+class CVColorJitter(object):
+    def __init__(self,
+                 brightness=0.5,
+                 contrast=0.5,
+                 saturation=0.5,
+                 hue=0.1,
+                 p=0.5):
+        self.p = p
+        self.transforms = transforms.ColorJitter(
+            brightness=brightness,
+            contrast=contrast,
+            saturation=saturation,
+            hue=hue)
+
+    def __call__(self, img):
+        if random.random() < self.p:
+            return self.transforms(img)
+        else:
+            return img
+
+
+class VLAug(object):
+    def __init__(self,
+                 geometry_p=0.5,
+                 Deterioration_p=0.25,
+                 ColorJitter_p=0.25,
+                 **kwargs):
+        self.Geometry = CVGeometry(
+            degrees=45,
+            translate=(0.0, 0.0),
+            scale=(0.5, 2.),
+            shear=(45, 15),
+            distortion=0.5,
+            p=geometry_p)
+        self.Deterioration = CVDeterioration(
+            var=20, degrees=6, factor=4, p=Deterioration_p)
+        self.ColorJitter = CVColorJitter(
+            brightness=0.5,
+            contrast=0.5,
+            saturation=0.5,
+            hue=0.1,
+            p=ColorJitter_p)
+
+    def __call__(self, data):
+        img = data['image']
+        img = self.Geometry(img)
+        img = self.Deterioration(img)
+        img = self.ColorJitter(img)
+        data['image'] = img
+        return data
+
+
+if __name__ == '__main__':
+
+    geo = CVGeometry(
+        degrees=45,
+        translate=(0.0, 0.0),
+        scale=(0.5, 2.),
+        shear=(45, 15),
+        distortion=0.5,
+        p=1)
+    det = CVDeterioration(var=20, degrees=6, factor=4, p=1)
+    color = CVColorJitter(
+        brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=1)
+
+    img = np.ones((64, 256, 3))
+    img = geo(img)
+    img = det(img)
+    img = color(img)
+    # import pdb
+    # pdb.set_trace()
+    # print()
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index de8419b7c..e1e0635ae 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
 from .rec_aster_loss import AsterLoss
 from .rec_pren_loss import PRENLoss
 from .rec_multi_loss import MultiLoss
+from .rec_vl_loss import VLLoss
 
 # cls loss
 from .cls_loss import ClsLoss
@@ -61,7 +62,8 @@ def build_loss(config):
         'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
         'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
         'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
-        'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
+        'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
+        'VLLoss'
     ]
     config = copy.deepcopy(config)
     module_name = config.pop('name')
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 072d6e0f8..bd6dc5ce1 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -28,14 +28,14 @@ def build_backbone(config, model_type):
         from .rec_mv1_enhance import MobileNetV1Enhance
         from .rec_nrtr_mtb import MTB
         from .rec_resnet_31 import ResNet31
-        from .rec_resnet_aster import ResNet_ASTER
+        from .rec_resnet_aster import ResNet_ASTER, ResNet45
         from .rec_micronet import MicroNet
         from .rec_efficientb3_pren import EfficientNetb3_PREN
         from .rec_svtrnet import SVTRNet
         support_dict = [
             'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
             "ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
-            'SVTRNet'
+            'SVTRNet', 'ResNet45'
         ]
     elif model_type == "e2e":
         from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py
index 6a2710dfa..a59c2da20 100644
--- a/ppocr/modeling/backbones/rec_resnet_aster.py
+++ b/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -20,6 +20,10 @@ import paddle.nn as nn
 
 import sys
 import math
+from paddle.nn.initializer import KaimingNormal, Constant
+
+zeros_ = Constant(value=0.)
+ones_ = Constant(value=1.)
 
 
 def conv3x3(in_planes, out_planes, stride=1):
@@ -141,3 +145,110 @@ class ResNet_ASTER(nn.Layer):
             return rnn_feat
         else:
             return cnn_feat
+
+
+class Block(nn.Layer):
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Block, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2D(planes)
+        self.relu = nn.ReLU()
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2D(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+        out += residual
+        out = self.relu(out)
+        return out
+
+
+class ResNet45(nn.Layer):
+    def __init__(self, in_channels=3, compress_layer=False):
+        super(ResNet45, self).__init__()
+        self.compress_layer = compress_layer
+
+        self.conv1_new = nn.Conv2D(
+            in_channels,
+            32,
+            kernel_size=(3, 3),
+            stride=1,
+            padding=1,
+            bias_attr=False)
+        self.bn1 = nn.BatchNorm2D(32)
+        self.relu = nn.ReLU()
+
+        self.inplanes = 32
+        self.layer1 = self._make_layer(32, 3, [2, 2])  # [32, 128]
+        self.layer2 = self._make_layer(64, 4, [2, 2])  # [16, 64]
+        self.layer3 = self._make_layer(128, 6, [2, 2])  # [8, 32]
+        self.layer4 = self._make_layer(256, 6, [1, 1])  # [8, 32]
+        self.layer5 = self._make_layer(512, 3, [1, 1])  # [8, 32]
+
+        if self.compress_layer:
+            self.layer6 = nn.Sequential(
+                nn.Conv2D(
+                    512, 256, kernel_size=(3, 1), padding=(0, 0), stride=(1,
+                                                                          1)),
+                nn.BatchNorm(256),
+                nn.ReLU())
+            self.out_channels = 256
+        else:
+            self.out_channels = 512
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Conv2D):
+            KaimingNormal(m.weight)
+        elif isinstance(m, nn.BatchNorm):
+            ones_(m.weight)
+            zeros_(m.bias)
+
+    def _make_layer(self, planes, blocks, stride):
+        downsample = None
+        if stride != [1, 1] or self.inplanes != planes:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
+
+        layers = []
+        layers.append(Block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes
+        for _ in range(1, blocks):
+            layers.append(Block(self.inplanes, planes))
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1_new(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x1 = self.layer1(x)
+        x2 = self.layer2(x1)
+        x3 = self.layer3(x2)
+        x4 = self.layer4(x3)
+        x5 = self.layer5(x4)
+
+        if not self.compress_layer:
+            return x5
+        else:
+            x6 = self.layer6(x5)
+            return x6
+
+
+if __name__ == '__main__':
+    model = ResNet45()
+    x = paddle.rand([1, 3, 64, 256])
+    x = paddle.to_tensor(x)
+    print(x.shape)
+    out = model(x)
+    print(out.shape)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 1670ea38e..37ad6bd6f 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -33,6 +33,7 @@ def build_head(config):
     from .rec_aster_head import AsterHead
     from .rec_pren_head import PRENHead
     from .rec_multi_head import MultiHead
+    from .rec_visionlan_head import VLHead
 
     # cls head
     from .cls_head import ClsHead
@@ -46,7 +47,7 @@ def build_head(config):
         'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
         'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
         'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
-        'MultiHead'
+        'MultiHead', 'VLHead'
     ]
 
     #table head
diff --git a/ppocr/modeling/heads/rec_visionlan_head.py b/ppocr/modeling/heads/rec_visionlan_head.py
new file mode 100644
index 000000000..a5d605982
--- /dev/null
+++ b/ppocr/modeling/heads/rec_visionlan_head.py
@@ -0,0 +1,498 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn.initializer import Normal, XavierNormal
+import numpy as np
+from ppocr.modeling.backbones.rec_resnet_aster import ResNet45
+
+
+class PositionalEncoding(nn.Layer):
+    def __init__(self, d_hid, n_position=200):
+        super(PositionalEncoding, self).__init__()
+        self.register_buffer(
+            'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
+
+    def _get_sinusoid_encoding_table(self, n_position, d_hid):
+        ''' Sinusoid position encoding table '''
+
+        def get_position_angle_vec(position):
+            return [
+                position / np.power(10000, 2 * (hid_j // 2) / d_hid)
+                for hid_j in range(d_hid)
+            ]
+
+        sinusoid_table = np.array(
+            [get_position_angle_vec(pos_i) for pos_i in range(n_position)])
+        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
+        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1
+        sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
+        sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
+        return sinusoid_table
+
+    def forward(self, x):
+        return x + self.pos_table[:, :x.shape[1]].clone().detach()
+
+
+class ScaledDotProductAttention(nn.Layer):
+    "Scaled Dot-Product Attention"
+
+    def __init__(self, temperature, attn_dropout=0.1):
+        super(ScaledDotProductAttention, self).__init__()
+        self.temperature = temperature
+        self.dropout = nn.Dropout(attn_dropout)
+        self.softmax = nn.Softmax(axis=2)
+
+    def forward(self, q, k, v, mask=None):
+        k = paddle.transpose(k, perm=[0, 2, 1])
+        attn = paddle.bmm(q, k)
+        attn = attn / self.temperature
+        if mask is not None:
+            attn = attn.masked_fill(mask, -1e9)
+            if mask.dim() == 3:
+                mask = paddle.unsqueeze(mask, axis=1)
+            elif mask.dim() == 2:
+                mask = paddle.unsqueeze(mask, axis=1)
+                mask = paddle.unsqueeze(mask, axis=1)
+            repeat_times = [
+                attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
+            ]
+            mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
+            attn[mask == 0] = -1e9
+        attn = self.softmax(attn)
+        attn = self.dropout(attn)
+        output = paddle.bmm(attn, v)
+        return output
+
+
+class MultiHeadAttention(nn.Layer):
+    " Multi-Head Attention module"
+
+    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
+        super(MultiHeadAttention, self).__init__()
+        self.n_head = n_head
+        self.d_k = d_k
+        self.d_v = d_v
+        self.w_qs = nn.Linear(
+            d_model,
+            n_head * d_k,
+            weight_attr=ParamAttr(initializer=Normal(
+                mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+        self.w_ks = nn.Linear(
+            d_model,
+            n_head * d_k,
+            weight_attr=ParamAttr(initializer=Normal(
+                mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
+        self.w_vs = nn.Linear(
+            d_model,
+            n_head * d_v,
+            weight_attr=ParamAttr(initializer=Normal(
+                mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
+
+        self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
+                                                                        0.5))
+        self.layer_norm = nn.LayerNorm(d_model)
+        self.fc = nn.Linear(
+            n_head * d_v,
+            d_model,
+            weight_attr=ParamAttr(initializer=XavierNormal()))
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, q, k, v, mask=None):
+        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+        sz_b, len_q, _ = q.shape
+        sz_b, len_k, _ = k.shape
+        sz_b, len_v, _ = v.shape
+        residual = q
+
+        q = self.w_qs(q)
+        q = paddle.reshape(
+            q, shape=[-1, len_q, n_head, d_k])  # 4*21*512 ---- 4*21*8*64
+        k = self.w_ks(k)
+        k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
+        v = self.w_vs(v)
+        v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
+
+        q = paddle.transpose(q, perm=[2, 0, 1, 3])
+        q = paddle.reshape(q, shape=[-1, len_q, d_k])  # (n*b) x lq x dk
+        k = paddle.transpose(k, perm=[2, 0, 1, 3])
+        k = paddle.reshape(k, shape=[-1, len_k, d_k])  # (n*b) x lk x dk
+        v = paddle.transpose(v, perm=[2, 0, 1, 3])
+        v = paddle.reshape(v, shape=[-1, len_v, d_v])  # (n*b) x lv x dv
+
+        mask = paddle.tile(
+            mask,
+            [n_head, 1, 1]) if mask is not None else None  # (n*b) x .. x ..
+        output = self.attention(q, k, v, mask=mask)
+        output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
+        output = paddle.transpose(output, perm=[1, 2, 0, 3])
+        output = paddle.reshape(
+            output, shape=[-1, len_q, n_head * d_v])  # b x lq x (n*dv)
+        output = self.dropout(self.fc(output))
+        output = self.layer_norm(output + residual)
+        return output
+
+
+class PositionwiseFeedForward(nn.Layer):
+    def __init__(self, d_in, d_hid, dropout=0.1):
+        super(PositionwiseFeedForward, self).__init__()
+        self.w_1 = nn.Conv1D(d_in, d_hid, 1)  # position-wise
+        self.w_2 = nn.Conv1D(d_hid, d_in, 1)  # position-wise
+        self.layer_norm = nn.LayerNorm(d_in)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        residual = x
+        x = paddle.transpose(x, perm=[0, 2, 1])
+        x = self.w_2(F.relu(self.w_1(x)))
+        x = paddle.transpose(x, perm=[0, 2, 1])
+        x = self.dropout(x)
+        x = self.layer_norm(x + residual)
+        return x
+
+
+class EncoderLayer(nn.Layer):
+    ''' Compose with two layers '''
+
+    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
+        super(EncoderLayer, self).__init__()
+        self.slf_attn = MultiHeadAttention(
+            n_head, d_model, d_k, d_v, dropout=dropout)
+        self.pos_ffn = PositionwiseFeedForward(
+            d_model, d_inner, dropout=dropout)
+
+    def forward(self, enc_input, slf_attn_mask=None):
+        enc_output = self.slf_attn(
+            enc_input, enc_input, enc_input, mask=slf_attn_mask)
+        enc_output = self.pos_ffn(enc_output)
+        return enc_output
+
+
+class Transformer_Encoder(nn.Layer):
+    def __init__(self,
+                 n_layers=2,
+                 n_head=8,
+                 d_word_vec=512,
+                 d_k=64,
+                 d_v=64,
+                 d_model=512,
+                 d_inner=2048,
+                 dropout=0.1,
+                 n_position=256):
+        super(Transformer_Encoder, self).__init__()
+        self.position_enc = PositionalEncoding(
+            d_word_vec, n_position=n_position)
+        self.dropout = nn.Dropout(p=dropout)
+        self.layer_stack = nn.LayerList([
+            EncoderLayer(
+                d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
+            for _ in range(n_layers)
+        ])
+        self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
+
+    def forward(self, enc_output, src_mask, return_attns=False):
+        enc_output = self.dropout(
+            self.position_enc(enc_output))  # position embeding
+        for enc_layer in self.layer_stack:
+            enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
+        enc_output = self.layer_norm(enc_output)
+        return enc_output
+
+
+class PP_layer(nn.Layer):
+    def __init__(self, n_dim=512, N_max_character=25, n_position=256):
+
+        super(PP_layer, self).__init__()
+        self.character_len = N_max_character
+        self.f0_embedding = nn.Embedding(N_max_character, n_dim)
+        self.w0 = nn.Linear(N_max_character, n_position)
+        self.wv = nn.Linear(n_dim, n_dim)
+        self.we = nn.Linear(n_dim, N_max_character)
+        self.active = nn.Tanh()
+        self.softmax = nn.Softmax(axis=2)
+
+    def forward(self, enc_output):
+        # enc_output: b,256,512
+        reading_order = paddle.arange(self.character_len, dtype='int64')
+        reading_order = reading_order.unsqueeze(0).expand(
+            [enc_output.shape[0], -1])  # (S,) -> (B, S)
+        reading_order = self.f0_embedding(reading_order)  # b,25,512
+
+        # calculate attention
+        reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
+        t = self.w0(reading_order)  # b,512,256
+        t = self.active(
+            paddle.transpose(
+                t, perm=[0, 2, 1]) + self.wv(enc_output))  # b,256,512
+        t = self.we(t)  # b,256,25
+        t = self.softmax(paddle.transpose(t, perm=[0, 2, 1]))  # b,25,256
+        g_output = paddle.bmm(t, enc_output)  # b,25,512
+        return g_output
+
+
+class Prediction(nn.Layer):
+    def __init__(self,
+                 n_dim=512,
+                 n_position=256,
+                 N_max_character=25,
+                 n_class=37):
+        super(Prediction, self).__init__()
+        self.pp = PP_layer(
+            n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+        self.pp_share = PP_layer(
+            n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
+        self.w_vrm = nn.Linear(n_dim, n_class)  # output layer
+        self.w_share = nn.Linear(n_dim, n_class)  # output layer
+        self.nclass = n_class
+
+    def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
+                use_mlm=True):
+        if train_mode:
+            if not use_mlm:
+                g_output = self.pp(cnn_feature)  # b,25,512
+                g_output = self.w_vrm(g_output)
+                f_res = 0
+                f_sub = 0
+                return g_output, f_res, f_sub
+            g_output = self.pp(cnn_feature)  # b,25,512
+            f_res = self.pp_share(f_res)
+            f_sub = self.pp_share(f_sub)
+            g_output = self.w_vrm(g_output)
+            f_res = self.w_share(f_res)
+            f_sub = self.w_share(f_sub)
+            return g_output, f_res, f_sub
+        else:
+            g_output = self.pp(cnn_feature)  # b,25,512
+            g_output = self.w_vrm(g_output)
+            return g_output
+
+
+class MLM(nn.Layer):
+    "Architecture of MLM"
+
+    def __init__(self, n_dim=512, n_position=256, max_text_length=25):
+        super(MLM, self).__init__()
+        self.MLM_SequenceModeling_mask = Transformer_Encoder(
+            n_layers=2, n_position=n_position)
+        self.MLM_SequenceModeling_WCL = Transformer_Encoder(
+            n_layers=1, n_position=n_position)
+        self.pos_embedding = nn.Embedding(max_text_length, n_dim)
+        self.w0_linear = nn.Linear(1, n_position)
+        self.wv = nn.Linear(n_dim, n_dim)
+        self.active = nn.Tanh()
+        self.we = nn.Linear(n_dim, 1)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x, label_pos):
+        # transformer unit for generating mask_c
+        feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
+        # position embedding layer
+        label_pos = paddle.to_tensor(label_pos, dtype='int64')
+        pos_emb = self.pos_embedding(label_pos)
+        pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
+        pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
+        # fusion position embedding with features V & generate mask_c
+        att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
+        att_map_sub = self.we(att_map_sub)  # b,256,1
+        att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
+        att_map_sub = self.sigmoid(att_map_sub)  # b,1,256
+        # WCL
+        ## generate inputs for WCL
+        att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
+        f_res = x * (1 - att_map_sub)  # second path with remaining string
+        f_sub = x * att_map_sub  # first path with occluded character
+        ## transformer units in WCL
+        f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
+        f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
+        return f_res, f_sub, att_map_sub
+
+
+def trans_1d_2d(x):
+    b, w_h, c = x.shape  # b, 256, 512
+    x = paddle.transpose(x, perm=[0, 2, 1])
+    x = paddle.reshape(x, [-1, c, 32, 8])
+    x = paddle.transpose(x, perm=[0, 1, 3, 2])  # [b, c, 8, 32]
+    return x
+
+
+class MLM_VRM(nn.Layer):
+    """
+    MLM+VRM, MLM is only used in training.
+    ratio controls the occluded number in a batch.
+    The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
+    x: input image
+    label_pos: character index
+    training_step: LF or LA process
+    output
+    text_pre: prediction of VRM
+    test_rem: prediction of remaining string in MLM
+    text_mas: prediction of occluded character in MLM
+    mask_c_show: visualization of Mask_c
+    """
+
+    def __init__(self,
+                 n_layers=3,
+                 n_position=256,
+                 n_dim=512,
+                 max_text_length=25,
+                 nclass=37):
+        super(MLM_VRM, self).__init__()
+        self.MLM = MLM(n_dim=n_dim,
+                       n_position=n_position,
+                       max_text_length=max_text_length)
+        self.SequenceModeling = Transformer_Encoder(
+            n_layers=n_layers, n_position=n_position)
+        self.Prediction = Prediction(
+            n_dim=n_dim,
+            n_position=n_position,
+            N_max_character=max_text_length +
+            1,  # N_max_character = 1 eos + 25 characters
+            n_class=nclass)
+        self.nclass = nclass
+        self.max_text_length = max_text_length
+
+    def forward(self, x, label_pos, training_step, train_mode=False):
+        b, c, h, w = x.shape
+        nT = self.max_text_length
+        x = paddle.transpose(x, perm=[0, 1, 3, 2])
+        x = paddle.reshape(x, [-1, c, h * w])
+        x = paddle.transpose(x, perm=[0, 2, 1])
+        if train_mode:
+            if training_step == 'LF_1':
+                f_res = 0
+                f_sub = 0
+                x = self.SequenceModeling(x, src_mask=None)
+                text_pre, test_rem, text_mas = self.Prediction(
+                    x, f_res, f_sub, train_mode=True, use_mlm=False)
+                return text_pre, text_pre, text_pre, text_pre
+            elif training_step == 'LF_2':
+                # MLM
+                f_res, f_sub, mask_c = self.MLM(x, label_pos)
+                x = self.SequenceModeling(x, src_mask=None)
+                text_pre, test_rem, text_mas = self.Prediction(
+                    x, f_res, f_sub, train_mode=True)
+                mask_c_show = trans_1d_2d(mask_c)
+                return text_pre, test_rem, text_mas, mask_c_show
+            elif training_step == 'LA':
+                # MLM
+                f_res, f_sub, mask_c = self.MLM(x, label_pos)
+                ## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
+                ## ratio controls the occluded number in a batch
+                character_mask = paddle.zeros_like(mask_c)
+
+                ratio = b // 2
+                if ratio >= 1:
+                    with paddle.no_grad():
+                        character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
+                else:
+                    character_mask = mask_c
+                x = x * (1 - character_mask)
+                # VRM
+                ## transformer unit for VRM
+                x = self.SequenceModeling(x, src_mask=None)
+                ## prediction layer for MLM and VSR
+                text_pre, test_rem, text_mas = self.Prediction(
+                    x, f_res, f_sub, train_mode=True)
+                mask_c_show = trans_1d_2d(mask_c)
+                return text_pre, test_rem, text_mas, mask_c_show
+            else:
+                raise NotImplementedError
+        else:  # VRM is only used in the testing stage
+            f_res = 0
+            f_sub = 0
+            contextual_feature = self.SequenceModeling(x, src_mask=None)
+            text_pre = self.Prediction(
+                contextual_feature,
+                f_res,
+                f_sub,
+                train_mode=False,
+                use_mlm=False)
+            text_pre = paddle.transpose(
+                text_pre, perm=[1, 0, 2])  # (26, b, 37))
+            lenText = nT
+            nsteps = nT
+            out_res = paddle.zeros(
+                shape=[lenText, b, self.nclass], dtype=x.dtype)  # (25, b, 37)
+            out_length = paddle.zeros(shape=[b], dtype=x.dtype)
+            now_step = 0
+            for _ in range(nsteps):
+                if 0 in out_length and now_step < nsteps:
+                    tmp_result = text_pre[now_step, :, :]
+                    out_res[now_step] = tmp_result
+                    tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
+                    for j in range(b):
+                        if out_length[j] == 0 and tmp_result[j] == 0:
+                            out_length[j] = now_step + 1
+                    now_step += 1
+            # while 0 in out_length and now_step < nsteps:
+            #     tmp_result = text_pre[now_step, :, :]
+            #     out_res[now_step] = tmp_result
+            #     tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
+            #     for j in range(b):
+            #         if out_length[j] == 0 and tmp_result[j] == 0:
+            #             out_length[j] = now_step + 1
+            #     now_step += 1
+            for j in range(0, b):
+                if int(out_length[j]) == 0:
+                    out_length[j] = nsteps
+            start = 0
+            output = paddle.zeros(
+                shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
+            for i in range(0, b):
+                cur_length = int(out_length[i])
+                output[start:start + cur_length] = out_res[0:cur_length, i, :]
+                start += cur_length
+            return output, out_length
+
+
+class VLHead(nn.Layer):
+    """
+    Architecture of VisionLAN
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels=36,
+                 n_layers=3,
+                 n_position=256,
+                 n_dim=512,
+                 max_text_length=25,
+                 training_step='LA'):
+        super(VLHead, self).__init__()
+        self.MLM_VRM = MLM_VRM(
+            n_layers=n_layers,
+            n_position=n_position,
+            n_dim=n_dim,
+            max_text_length=max_text_length,
+            nclass=out_channels + 1)
+        self.training_step = training_step
+
+    def forward(self, feat, targets=None):
+
+        if self.training:
+            label_pos = targets[-2]
+            text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
+                feat, label_pos, self.training_step, train_mode=True)
+            return text_pre, test_rem, text_mas, mask_map
+        else:
+            output, out_length = self.MLM_VRM(
+                feat, targets, self.training_step, train_mode=False)
+            return output, out_length
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index f50b5f1c5..a22b79960 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
 from .fce_postprocess import FCEPostProcess
 from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
     DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
-    SEEDLabelDecode, PRENLabelDecode
+    SEEDLabelDecode, PRENLabelDecode, VLLabelDecode
 from .cls_postprocess import ClsPostProcess
 from .pg_postprocess import PGPostProcess
 from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
@@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
         'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
         'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
         'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
-        'DistillationSARLabelDecode'
+        'DistillationSARLabelDecode', 'VLLabelDecode'
     ]
 
     if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index bf0fd890b..e2434cd72 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -27,7 +27,8 @@ class BaseRecLabelDecode(object):
 
         self.character_str = []
         if character_dict_path is None:
-            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
+            # self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
+            self.character_str = "abcdefghijklmnopqrstuvwxyz1234567890"
             dict_character = list(self.character_str)
         else:
             with open(character_dict_path, "rb") as fin:
@@ -752,3 +753,70 @@ class PRENLabelDecode(BaseRecLabelDecode):
             return text
         label = self.decode(label)
         return text, label
+
+
+class VLLabelDecode(BaseRecLabelDecode):
+    """ Convert between text-label and text-index """
+
+    def __init__(self, character_dict_path=None, use_space_char=False,
+                 **kwargs):
+        super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
+
+    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+        """ convert text-index into text-label. """
+        result_list = []
+        ignored_tokens = self.get_ignored_tokens()
+        batch_size = len(text_index)
+        for batch_idx in range(batch_size):
+            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
+            if is_remove_duplicate:
+                selection[1:] = text_index[batch_idx][1:] != text_index[
+                    batch_idx][:-1]
+            for ignored_token in ignored_tokens:
+                selection &= text_index[batch_idx] != ignored_token
+
+            char_list = [
+                self.character[text_id - 1]
+                for text_id in text_index[batch_idx][selection]
+            ]
+            if text_prob is not None:
+                conf_list = text_prob[batch_idx][selection]
+            else:
+                conf_list = [1] * len(selection)
+            if len(conf_list) == 0:
+                conf_list = [0]
+
+            text = ''.join(char_list)
+            result_list.append((text, np.mean(conf_list).tolist()))
+        return result_list
+
+    def __call__(self, preds, label=None, length=None, *args, **kwargs):
+        if len(preds) == 2:  # eval mode
+            net_out, length = preds
+        else:  # train mode
+            net_out = preds[0]
+            length = length
+            net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
+        text = []
+        if not isinstance(net_out, paddle.Tensor):
+            net_out = paddle.to_tensor(net_out, dtype='float32')
+        # import pdb 
+        # pdb.set_trace()
+        net_out = F.softmax(net_out, axis=1)
+        for i in range(0, length.shape[0]):
+            preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
+            ) + length[i])].topk(1)[1][:, 0].tolist()
+            preds_text = ''.join([
+                self.character[idx - 1]
+                if idx > 0 and idx <= len(self.character) else ''
+                for idx in preds_idx
+            ])
+            preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
+            ) + length[i])].topk(1)[0][:, 0]
+            preds_prob = paddle.exp(
+                paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
+            text.append((preds_text, preds_prob))
+        if label is None:
+            return text
+        label = self.decode(label)
+        return text, label
diff --git a/tools/eval.py b/tools/eval.py
index cab283343..2fc53488e 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -73,7 +73,7 @@ def main():
             config['Architecture']["Head"]['out_channels'] = char_num
 
     model = build_model(config['Architecture'])
-    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
     extra_input = False
     if config['Architecture']['algorithm'] == 'Distillation':
         for key in config['Architecture']["Models"]:
diff --git a/tools/export_model.py b/tools/export_model.py
index c0cbcd361..5d17410aa 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -55,7 +55,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
                 shape=[None, 3, 48, 160], dtype="float32"),
         ]
         model = to_static(model, input_spec=other_shape)
-    elif arch_config["algorithm"] == "SVTR":
+    elif arch_config["algorithm"] in ["SVTR", "VisionLAN"]:
         if arch_config["Head"]["name"] == 'MultiHead':
             other_shape = [
                 paddle.static.InputSpec(
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 3664ef2ca..cdfc984ce 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -69,6 +69,12 @@ class TextRecognizer(object):
                 "character_dict_path": args.rec_char_dict_path,
                 "use_space_char": args.use_space_char
             }
+        elif self.rec_algorithm == "VisionLAN":
+            postprocess_params = {
+                'name': 'VLLabelDecode',
+                "character_dict_path": args.rec_char_dict_path,
+                "use_space_char": args.use_space_char
+            }
         self.postprocess_op = build_post_process(postprocess_params)
         self.predictor, self.input_tensor, self.output_tensors, self.config = \
             utility.create_predictor(args, 'rec', logger)
@@ -143,6 +149,15 @@ class TextRecognizer(object):
         resized_image /= 0.5
         return resized_image
 
+    def resize_norm_img_vl(self, img, image_shape):
+
+        imgC, imgH, imgW = image_shape
+        resized_image = cv2.resize(
+            img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
+        resized_image = resized_image.astype('float32')
+        resized_image = resized_image.transpose((2, 0, 1)) / 255
+        return resized_image
+
     def resize_norm_img_srn(self, img, image_shape):
         imgC, imgH, imgW = image_shape
 
@@ -300,6 +315,11 @@ class TextRecognizer(object):
                                                          self.rec_image_shape)
                     norm_img = norm_img[np.newaxis, :]
                     norm_img_batch.append(norm_img)
+                elif self.rec_algorithm == "VisionLAN":
+                    norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
+                                                       self.rec_image_shape)
+                    norm_img = norm_img[np.newaxis, :]
+                    norm_img_batch.append(norm_img)
                 else:
                     norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                     max_wh_ratio)
diff --git a/tools/program.py b/tools/program.py
index aa0d2698c..bf774fd48 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -207,7 +207,7 @@ def train(config,
     model.train()
 
     use_srn = config['Architecture']['algorithm'] == "SRN"
-    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
+    extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
     extra_input = False
     if config['Architecture']['algorithm'] == 'Distillation':
         for key in config['Architecture']["Models"]:
@@ -249,7 +249,6 @@ def train(config,
             images = batch[0]
             if use_srn:
                 model_average = True
-
             # use amp
             if scaler:
                 with paddle.amp.auto_cast():
@@ -264,7 +263,6 @@ def train(config,
                     preds = model(batch)
                 else:
                     preds = model(images)
-
             loss = loss_class(preds, batch)
             avg_loss = loss['loss']
 
@@ -286,6 +284,9 @@ def train(config,
                                                   ]:  # for multi head loss
                         post_result = post_process_class(
                             preds['ctc'], batch[1])  # for CTC head out
+                    elif config['Loss']['name'] in ['VLLoss']:
+                        post_result = post_process_class(preds, batch[1],
+                                                         batch[-1])
                     else:
                         post_result = post_process_class(preds, batch[1])
                     eval_class(post_result, batch)
@@ -307,7 +308,8 @@ def train(config,
             train_stats.update(stats)
 
             if log_writer is not None and dist.get_rank() == 0:
-                log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
+                log_writer.log_metrics(
+                    metrics=train_stats.get(), prefix="TRAIN", step=global_step)
 
             if dist.get_rank() == 0 and (
                 (global_step > 0 and global_step % print_batch_step == 0) or
@@ -354,7 +356,8 @@ def train(config,
 
                 # logger metric
                 if log_writer is not None:
-                    log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
+                    log_writer.log_metrics(
+                        metrics=cur_metric, prefix="EVAL", step=global_step)
 
                 if cur_metric[main_indicator] >= best_model_dict[
                         main_indicator]:
@@ -377,11 +380,18 @@ def train(config,
                 logger.info(best_str)
                 # logger best metric
                 if log_writer is not None:
-                    log_writer.log_metrics(metrics={
-                        "best_{}".format(main_indicator): best_model_dict[main_indicator]
-                        }, prefix="EVAL", step=global_step)
-                    
-                    log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
+                    log_writer.log_metrics(
+                        metrics={
+                            "best_{}".format(main_indicator):
+                            best_model_dict[main_indicator]
+                        },
+                        prefix="EVAL",
+                        step=global_step)
+
+                    log_writer.log_model(
+                        is_best=True,
+                        prefix="best_accuracy",
+                        metadata=best_model_dict)
 
             reader_start = time.time()
         if dist.get_rank() == 0:
@@ -413,7 +423,8 @@ def train(config,
                 epoch=epoch,
                 global_step=global_step)
             if log_writer is not None:
-                log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
+                log_writer.log_model(
+                    is_best=False, prefix='iter_epoch_{}'.format(epoch))
 
     best_str = 'best metric, {}'.format(', '.join(
         ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
@@ -451,7 +462,6 @@ def eval(model,
                 preds = model(batch)
             else:
                 preds = model(images)
-
             batch_numpy = []
             for item in batch:
                 if isinstance(item, paddle.Tensor):
@@ -564,7 +574,8 @@ def preprocess(is_train=False):
     assert alg in [
         'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
         'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
-        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
+        'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
+        'VisionLAN'
     ]
 
     if use_xpu:
@@ -583,9 +594,10 @@ def preprocess(is_train=False):
     if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
         save_model_dir = config['Global']['save_model_dir']
         vdl_writer_path = '{}/vdl/'.format(save_model_dir)
-        log_writer = VDLLogger(save_model_dir)
+        log_writer = VDLLogger(vdl_writer_path)
         loggers.append(log_writer)
-    if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
+    if ('use_wandb' in config['Global'] and
+            config['Global']['use_wandb']) or 'wandb' in config:
         save_dir = config['Global']['save_model_dir']
         wandb_writer_path = "{}/wandb".format(save_dir)
         if "wandb" in config: