add vl
parent
05a983054e
commit
cf533b65c5
|
@ -23,7 +23,8 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
|||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, VLRecResizeImg
|
||||
from .text_image_aug import VLAug
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -23,6 +23,7 @@ import string
|
|||
from shapely.geometry import LineString, Point, Polygon
|
||||
import json
|
||||
import copy
|
||||
from random import sample
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
@ -443,7 +444,9 @@ class KieLabelEncode(object):
|
|||
elif 'key_cls' in anno.keys():
|
||||
labels.append(anno['key_cls'])
|
||||
else:
|
||||
raise ValueError("Cannot found 'key_cls' in ann.keys(), please check your training annotation.")
|
||||
raise ValueError(
|
||||
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
|
||||
)
|
||||
edges.append(ann.get('edge', 0))
|
||||
ann_infos = dict(
|
||||
image=data['image'],
|
||||
|
@ -1044,3 +1047,61 @@ class MultiLabelEncode(BaseRecLabelEncode):
|
|||
data_out['label_sar'] = sar['label']
|
||||
data_out['length'] = ctc['length']
|
||||
return data_out
|
||||
|
||||
|
||||
class VLLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(VLLabelEncode, self).__init__(max_text_length,
|
||||
character_dict_path, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label'] # original string
|
||||
# generate occluded text
|
||||
len_str = len(text)
|
||||
if len_str <= 0:
|
||||
return None
|
||||
change_num = 1
|
||||
order = list(range(len_str))
|
||||
change_id = sample(order, change_num)[0]
|
||||
label_sub = text[change_id]
|
||||
if change_id == (len_str - 1):
|
||||
label_res = text[:change_id]
|
||||
elif change_id == 0:
|
||||
label_res = text[1:]
|
||||
else:
|
||||
label_res = text[:change_id] + text[change_id + 1:]
|
||||
|
||||
data['label_res'] = label_res # remaining string
|
||||
data['label_sub'] = label_sub # occluded character
|
||||
data['label_id'] = change_id # character index
|
||||
# encode label
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
text = [i + 1 for i in text]
|
||||
data['length'] = np.array(len(text))
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
label_res = self.encode(label_res)
|
||||
label_sub = self.encode(label_sub)
|
||||
if label_res is None:
|
||||
label_res = []
|
||||
else:
|
||||
label_res = [i + 1 for i in label_res]
|
||||
if label_sub is None:
|
||||
label_sub = []
|
||||
else:
|
||||
label_sub = [i + 1 for i in label_sub]
|
||||
data['length_res'] = np.array(len(label_res))
|
||||
data['length_sub'] = np.array(len(label_sub))
|
||||
label_res = label_res + [0] * (self.max_text_len - len(label_res))
|
||||
label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
|
||||
data['label_res'] = np.array(label_res)
|
||||
data['label_sub'] = np.array(label_sub)
|
||||
return data
|
||||
|
|
|
@ -213,6 +213,41 @@ class RecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class VLRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
infer_mode=False,
|
||||
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
|
||||
padding=True,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
self.character_dict_path = character_dict_path
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if self.infer_mode and self.character_dict_path is not None:
|
||||
norm_img, valid_ratio = resize_norm_img_chinese(img,
|
||||
self.image_shape)
|
||||
else:
|
||||
imgC, imgH, imgW = self.image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
resized_image = resized_image.astype('float32')
|
||||
if self.image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
norm_img = resized_image[np.newaxis, :]
|
||||
else:
|
||||
norm_img = resized_image.transpose((2, 0, 1)) / 255
|
||||
valid_ratio = min(1.0, float(resized_w / imgW))
|
||||
|
||||
data['image'] = norm_img
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
|
|
@ -13,5 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .augment import tia_perspective, tia_distort, tia_stretch
|
||||
from .vl_aug import VLAug
|
||||
|
||||
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
|
||||
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective', 'VLAug']
|
||||
|
|
|
@ -0,0 +1,460 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from paddle.vision import transforms
|
||||
from paddle.vision.transforms import Compose
|
||||
|
||||
|
||||
def sample_asym(magnitude, size=None):
|
||||
return np.random.beta(1, 4, size) * magnitude
|
||||
|
||||
|
||||
def sample_sym(magnitude, size=None):
|
||||
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
|
||||
|
||||
|
||||
def sample_uniform(low, high, size=None):
|
||||
return np.random.uniform(low, high, size=size)
|
||||
|
||||
|
||||
def get_interpolation(type='random'):
|
||||
if type == 'random':
|
||||
choice = [
|
||||
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
|
||||
]
|
||||
interpolation = choice[random.randint(0, len(choice) - 1)]
|
||||
elif type == 'nearest':
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif type == 'linear':
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
elif type == 'cubic':
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
elif type == 'area':
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
raise TypeError(
|
||||
'Interpolation types only nearest, linear, cubic, area are supported!'
|
||||
)
|
||||
return interpolation
|
||||
|
||||
|
||||
class CVRandomRotation(object):
|
||||
def __init__(self, degrees=15):
|
||||
assert isinstance(degrees,
|
||||
numbers.Number), "degree should be a single number."
|
||||
assert degrees >= 0, "degree must be positive."
|
||||
self.degrees = degrees
|
||||
|
||||
@staticmethod
|
||||
def get_params(degrees):
|
||||
return sample_sym(degrees)
|
||||
|
||||
def __call__(self, img):
|
||||
angle = self.get_params(self.degrees)
|
||||
src_h, src_w = img.shape[:2]
|
||||
M = cv2.getRotationMatrix2D(
|
||||
center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
|
||||
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
|
||||
dst_w = int(src_h * abs_sin + src_w * abs_cos)
|
||||
dst_h = int(src_h * abs_cos + src_w * abs_sin)
|
||||
M[0, 2] += (dst_w - src_w) / 2
|
||||
M[1, 2] += (dst_h - src_h) / 2
|
||||
|
||||
flags = get_interpolation()
|
||||
return cv2.warpAffine(
|
||||
img,
|
||||
M, (dst_w, dst_h),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
|
||||
class CVRandomAffine(object):
|
||||
def __init__(self, degrees, translate=None, scale=None, shear=None):
|
||||
assert isinstance(degrees,
|
||||
numbers.Number), "degree should be a single number."
|
||||
assert degrees >= 0, "degree must be positive."
|
||||
self.degrees = degrees
|
||||
|
||||
if translate is not None:
|
||||
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
||||
"translate should be a list or tuple and it must be of length 2."
|
||||
for t in translate:
|
||||
if not (0.0 <= t <= 1.0):
|
||||
raise ValueError(
|
||||
"translation values should be between 0 and 1")
|
||||
self.translate = translate
|
||||
|
||||
if scale is not None:
|
||||
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
||||
"scale should be a list or tuple and it must be of length 2."
|
||||
for s in scale:
|
||||
if s <= 0:
|
||||
raise ValueError("scale values should be positive")
|
||||
self.scale = scale
|
||||
|
||||
if shear is not None:
|
||||
if isinstance(shear, numbers.Number):
|
||||
if shear < 0:
|
||||
raise ValueError(
|
||||
"If shear is a single number, it must be positive.")
|
||||
self.shear = [shear]
|
||||
else:
|
||||
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
|
||||
"shear should be a list or tuple and it must be of length 2."
|
||||
self.shear = shear
|
||||
else:
|
||||
self.shear = shear
|
||||
|
||||
def _get_inverse_affine_matrix(self, center, angle, translate, scale,
|
||||
shear):
|
||||
from numpy import sin, cos, tan
|
||||
|
||||
if isinstance(shear, numbers.Number):
|
||||
shear = [shear, 0]
|
||||
|
||||
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
|
||||
raise ValueError(
|
||||
"Shear should be a single value or a tuple/list containing " +
|
||||
"two values. Got {}".format(shear))
|
||||
|
||||
rot = math.radians(angle)
|
||||
sx, sy = [math.radians(s) for s in shear]
|
||||
|
||||
cx, cy = center
|
||||
tx, ty = translate
|
||||
|
||||
# RSS without scaling
|
||||
a = cos(rot - sy) / cos(sy)
|
||||
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
|
||||
c = sin(rot - sy) / cos(sy)
|
||||
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
|
||||
|
||||
# Inverted rotation matrix with scale and shear
|
||||
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
|
||||
M = [d, -b, 0, -c, a, 0]
|
||||
M = [x / scale for x in M]
|
||||
|
||||
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
||||
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
|
||||
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
|
||||
|
||||
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
||||
M[2] += cx
|
||||
M[5] += cy
|
||||
return M
|
||||
|
||||
@staticmethod
|
||||
def get_params(degrees, translate, scale_ranges, shears, height):
|
||||
angle = sample_sym(degrees)
|
||||
if translate is not None:
|
||||
max_dx = translate[0] * height
|
||||
max_dy = translate[1] * height
|
||||
translations = (np.round(sample_sym(max_dx)),
|
||||
np.round(sample_sym(max_dy)))
|
||||
else:
|
||||
translations = (0, 0)
|
||||
|
||||
if scale_ranges is not None:
|
||||
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if shears is not None:
|
||||
if len(shears) == 1:
|
||||
shear = [sample_sym(shears[0]), 0.]
|
||||
elif len(shears) == 2:
|
||||
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
|
||||
else:
|
||||
shear = 0.0
|
||||
|
||||
return angle, translations, scale, shear
|
||||
|
||||
def __call__(self, img):
|
||||
src_h, src_w = img.shape[:2]
|
||||
angle, translate, scale, shear = self.get_params(
|
||||
self.degrees, self.translate, self.scale, self.shear, src_h)
|
||||
|
||||
M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
|
||||
(0, 0), scale, shear)
|
||||
M = np.array(M).reshape(2, 3)
|
||||
|
||||
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
|
||||
(0, src_h - 1)]
|
||||
project = lambda x, y, a, b, c: int(a * x + b * y + c)
|
||||
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
|
||||
for x, y in startpoints]
|
||||
|
||||
rect = cv2.minAreaRect(np.array(endpoints))
|
||||
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
||||
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
||||
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
||||
|
||||
dst_w = int(max_x - min_x)
|
||||
dst_h = int(max_y - min_y)
|
||||
M[0, 2] += (dst_w - src_w) / 2
|
||||
M[1, 2] += (dst_h - src_h) / 2
|
||||
|
||||
# add translate
|
||||
dst_w += int(abs(translate[0]))
|
||||
dst_h += int(abs(translate[1]))
|
||||
if translate[0] < 0: M[0, 2] += abs(translate[0])
|
||||
if translate[1] < 0: M[1, 2] += abs(translate[1])
|
||||
|
||||
flags = get_interpolation()
|
||||
return cv2.warpAffine(
|
||||
img,
|
||||
M, (dst_w, dst_h),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
|
||||
class CVRandomPerspective(object):
|
||||
def __init__(self, distortion=0.5):
|
||||
self.distortion = distortion
|
||||
|
||||
def get_params(self, width, height, distortion):
|
||||
offset_h = sample_asym(
|
||||
distortion * height / 2, size=4).astype(dtype=np.int)
|
||||
offset_w = sample_asym(
|
||||
distortion * width / 2, size=4).astype(dtype=np.int)
|
||||
topleft = (offset_w[0], offset_h[0])
|
||||
topright = (width - 1 - offset_w[1], offset_h[1])
|
||||
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
|
||||
botleft = (offset_w[3], height - 1 - offset_h[3])
|
||||
|
||||
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
|
||||
(0, height - 1)]
|
||||
endpoints = [topleft, topright, botright, botleft]
|
||||
return np.array(
|
||||
startpoints, dtype=np.float32), np.array(
|
||||
endpoints, dtype=np.float32)
|
||||
|
||||
def __call__(self, img):
|
||||
height, width = img.shape[:2]
|
||||
startpoints, endpoints = self.get_params(width, height, self.distortion)
|
||||
M = cv2.getPerspectiveTransform(startpoints, endpoints)
|
||||
|
||||
# TODO: more robust way to crop image
|
||||
rect = cv2.minAreaRect(endpoints)
|
||||
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
||||
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
||||
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
||||
min_x, min_y = max(min_x, 0), max(min_y, 0)
|
||||
|
||||
flags = get_interpolation()
|
||||
img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (max_x, max_y),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
img = img[min_y:, min_x:]
|
||||
return img
|
||||
|
||||
|
||||
class CVRescale(object):
|
||||
def __init__(self, factor=4, base_size=(128, 512)):
|
||||
""" Define image scales using gaussian pyramid and rescale image to target scale.
|
||||
|
||||
Args:
|
||||
factor: the decayed factor from base size, factor=4 keeps target scale by default.
|
||||
base_size: base size the build the bottom layer of pyramid
|
||||
"""
|
||||
if isinstance(factor, numbers.Number):
|
||||
self.factor = round(sample_uniform(0, factor))
|
||||
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
|
||||
self.factor = round(sample_uniform(factor[0], factor[1]))
|
||||
else:
|
||||
raise Exception('factor must be number or list with length 2')
|
||||
# assert factor is valid
|
||||
self.base_h, self.base_w = base_size[:2]
|
||||
|
||||
def __call__(self, img):
|
||||
if self.factor == 0:
|
||||
return img
|
||||
src_h, src_w = img.shape[:2]
|
||||
cur_w, cur_h = self.base_w, self.base_h
|
||||
scale_img = cv2.resize(
|
||||
img, (cur_w, cur_h), interpolation=get_interpolation())
|
||||
for _ in range(np.int(self.factor)):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = cv2.resize(
|
||||
scale_img, (src_w, src_h), interpolation=get_interpolation())
|
||||
return scale_img
|
||||
|
||||
|
||||
class CVGaussianNoise(object):
|
||||
def __init__(self, mean=0, var=20):
|
||||
self.mean = mean
|
||||
if isinstance(var, numbers.Number):
|
||||
self.var = max(int(sample_asym(var)), 1)
|
||||
elif isinstance(var, (tuple, list)) and len(var) == 2:
|
||||
self.var = int(sample_uniform(var[0], var[1]))
|
||||
else:
|
||||
raise Exception('degree must be number or list with length 2')
|
||||
|
||||
def __call__(self, img):
|
||||
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
|
||||
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
class CVMotionBlur(object):
|
||||
def __init__(self, degrees=12, angle=90):
|
||||
if isinstance(degrees, numbers.Number):
|
||||
self.degree = max(int(sample_asym(degrees)), 1)
|
||||
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
||||
self.degree = int(sample_uniform(degrees[0], degrees[1]))
|
||||
else:
|
||||
raise Exception('degree must be number or list with length 2')
|
||||
self.angle = sample_uniform(-angle, angle)
|
||||
|
||||
def __call__(self, img):
|
||||
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
|
||||
self.angle, 1)
|
||||
motion_blur_kernel = np.zeros((self.degree, self.degree))
|
||||
motion_blur_kernel[self.degree // 2, :] = 1
|
||||
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
|
||||
(self.degree, self.degree))
|
||||
motion_blur_kernel = motion_blur_kernel / self.degree
|
||||
img = cv2.filter2D(img, -1, motion_blur_kernel)
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
class CVGeometry(object):
|
||||
def __init__(self,
|
||||
degrees=15,
|
||||
translate=(0.3, 0.3),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=0.5):
|
||||
self.p = p
|
||||
type_p = random.random()
|
||||
if type_p < 0.33:
|
||||
self.transforms = CVRandomRotation(degrees=degrees)
|
||||
elif type_p < 0.66:
|
||||
self.transforms = CVRandomAffine(
|
||||
degrees=degrees, translate=translate, scale=scale, shear=shear)
|
||||
else:
|
||||
self.transforms = CVRandomPerspective(distortion=distortion)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class CVDeterioration(object):
|
||||
def __init__(self, var, degrees, factor, p=0.5):
|
||||
self.p = p
|
||||
transforms = []
|
||||
if var is not None:
|
||||
transforms.append(CVGaussianNoise(var=var))
|
||||
if degrees is not None:
|
||||
transforms.append(CVMotionBlur(degrees=degrees))
|
||||
if factor is not None:
|
||||
transforms.append(CVRescale(factor=factor))
|
||||
|
||||
random.shuffle(transforms)
|
||||
transforms = Compose(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class CVColorJitter(object):
|
||||
def __init__(self,
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.p = p
|
||||
self.transforms = transforms.ColorJitter(
|
||||
brightness=brightness,
|
||||
contrast=contrast,
|
||||
saturation=saturation,
|
||||
hue=hue)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class VLAug(object):
|
||||
def __init__(self,
|
||||
geometry_p=0.5,
|
||||
Deterioration_p=0.25,
|
||||
ColorJitter_p=0.25,
|
||||
**kwargs):
|
||||
self.Geometry = CVGeometry(
|
||||
degrees=45,
|
||||
translate=(0.0, 0.0),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=geometry_p)
|
||||
self.Deterioration = CVDeterioration(
|
||||
var=20, degrees=6, factor=4, p=Deterioration_p)
|
||||
self.ColorJitter = CVColorJitter(
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=ColorJitter_p)
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = self.Geometry(img)
|
||||
img = self.Deterioration(img)
|
||||
img = self.ColorJitter(img)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
geo = CVGeometry(
|
||||
degrees=45,
|
||||
translate=(0.0, 0.0),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=1)
|
||||
det = CVDeterioration(var=20, degrees=6, factor=4, p=1)
|
||||
color = CVColorJitter(
|
||||
brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=1)
|
||||
|
||||
img = np.ones((64, 256, 3))
|
||||
img = geo(img)
|
||||
img = det(img)
|
||||
img = color(img)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# print()
|
|
@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
|
|||
from .rec_aster_loss import AsterLoss
|
||||
from .rec_pren_loss import PRENLoss
|
||||
from .rec_multi_loss import MultiLoss
|
||||
from .rec_vl_loss import VLLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -61,7 +62,8 @@ def build_loss(config):
|
|||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
|
||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'NRTRLoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'VLLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -28,14 +28,14 @@ def build_backbone(config, model_type):
|
|||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_resnet_31 import ResNet31
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
from .rec_resnet_aster import ResNet_ASTER, ResNet45
|
||||
from .rec_micronet import MicroNet
|
||||
from .rec_efficientb3_pren import EfficientNetb3_PREN
|
||||
from .rec_svtrnet import SVTRNet
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31", "ResNet_ASTER", 'MicroNet', 'EfficientNetb3_PREN',
|
||||
'SVTRNet'
|
||||
'SVTRNet', 'ResNet45'
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -20,6 +20,10 @@ import paddle.nn as nn
|
|||
|
||||
import sys
|
||||
import math
|
||||
from paddle.nn.initializer import KaimingNormal, Constant
|
||||
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
|
@ -141,3 +145,110 @@ class ResNet_ASTER(nn.Layer):
|
|||
return rnn_feat
|
||||
else:
|
||||
return cnn_feat
|
||||
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Block, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet45(nn.Layer):
|
||||
def __init__(self, in_channels=3, compress_layer=False):
|
||||
super(ResNet45, self).__init__()
|
||||
self.compress_layer = compress_layer
|
||||
|
||||
self.conv1_new = nn.Conv2D(
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm2D(32)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.inplanes = 32
|
||||
self.layer1 = self._make_layer(32, 3, [2, 2]) # [32, 128]
|
||||
self.layer2 = self._make_layer(64, 4, [2, 2]) # [16, 64]
|
||||
self.layer3 = self._make_layer(128, 6, [2, 2]) # [8, 32]
|
||||
self.layer4 = self._make_layer(256, 6, [1, 1]) # [8, 32]
|
||||
self.layer5 = self._make_layer(512, 3, [1, 1]) # [8, 32]
|
||||
|
||||
if self.compress_layer:
|
||||
self.layer6 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
512, 256, kernel_size=(3, 1), padding=(0, 0), stride=(1,
|
||||
1)),
|
||||
nn.BatchNorm(256),
|
||||
nn.ReLU())
|
||||
self.out_channels = 256
|
||||
else:
|
||||
self.out_channels = 512
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2D):
|
||||
KaimingNormal(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm):
|
||||
ones_(m.weight)
|
||||
zeros_(m.bias)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride):
|
||||
downsample = None
|
||||
if stride != [1, 1] or self.inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
|
||||
|
||||
layers = []
|
||||
layers.append(Block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Block(self.inplanes, planes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1_new(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x1 = self.layer1(x)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
if not self.compress_layer:
|
||||
return x5
|
||||
else:
|
||||
x6 = self.layer6(x5)
|
||||
return x6
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = ResNet45()
|
||||
x = paddle.rand([1, 3, 64, 256])
|
||||
x = paddle.to_tensor(x)
|
||||
print(x.shape)
|
||||
out = model(x)
|
||||
print(out.shape)
|
||||
|
|
|
@ -33,6 +33,7 @@ def build_head(config):
|
|||
from .rec_aster_head import AsterHead
|
||||
from .rec_pren_head import PRENHead
|
||||
from .rec_multi_head import MultiHead
|
||||
from .rec_visionlan_head import VLHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
|
@ -46,7 +47,7 @@ def build_head(config):
|
|||
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
|
||||
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
|
||||
'MultiHead'
|
||||
'MultiHead', 'VLHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,498 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import Normal, XavierNormal
|
||||
import numpy as np
|
||||
from ppocr.modeling.backbones.rec_resnet_aster import ResNet45
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
def __init__(self, d_hid, n_position=200):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.register_buffer(
|
||||
'pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
||||
|
||||
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
||||
''' Sinusoid position encoding table '''
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
sinusoid_table = paddle.to_tensor(sinusoid_table, dtype='float32')
|
||||
sinusoid_table = paddle.unsqueeze(sinusoid_table, axis=0)
|
||||
return sinusoid_table
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pos_table[:, :x.shape[1]].clone().detach()
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Layer):
|
||||
"Scaled Dot-Product Attention"
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super(ScaledDotProductAttention, self).__init__()
|
||||
self.temperature = temperature
|
||||
self.dropout = nn.Dropout(attn_dropout)
|
||||
self.softmax = nn.Softmax(axis=2)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
k = paddle.transpose(k, perm=[0, 2, 1])
|
||||
attn = paddle.bmm(q, k)
|
||||
attn = attn / self.temperature
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -1e9)
|
||||
if mask.dim() == 3:
|
||||
mask = paddle.unsqueeze(mask, axis=1)
|
||||
elif mask.dim() == 2:
|
||||
mask = paddle.unsqueeze(mask, axis=1)
|
||||
mask = paddle.unsqueeze(mask, axis=1)
|
||||
repeat_times = [
|
||||
attn.shape[1] // mask.shape[1], attn.shape[2] // mask.shape[2]
|
||||
]
|
||||
mask = paddle.tile(mask, [1, repeat_times[0], repeat_times[1], 1])
|
||||
attn[mask == 0] = -1e9
|
||||
attn = self.softmax(attn)
|
||||
attn = self.dropout(attn)
|
||||
output = paddle.bmm(attn, v)
|
||||
return output
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Layer):
|
||||
" Multi-Head Attention module"
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
self.w_qs = nn.Linear(
|
||||
d_model,
|
||||
n_head * d_k,
|
||||
weight_attr=ParamAttr(initializer=Normal(
|
||||
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
|
||||
self.w_ks = nn.Linear(
|
||||
d_model,
|
||||
n_head * d_k,
|
||||
weight_attr=ParamAttr(initializer=Normal(
|
||||
mean=0, std=np.sqrt(2.0 / (d_model + d_k)))))
|
||||
self.w_vs = nn.Linear(
|
||||
d_model,
|
||||
n_head * d_v,
|
||||
weight_attr=ParamAttr(initializer=Normal(
|
||||
mean=0, std=np.sqrt(2.0 / (d_model + d_v)))))
|
||||
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_k,
|
||||
0.5))
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
self.fc = nn.Linear(
|
||||
n_head * d_v,
|
||||
d_model,
|
||||
weight_attr=ParamAttr(initializer=XavierNormal()))
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, _ = q.shape
|
||||
sz_b, len_k, _ = k.shape
|
||||
sz_b, len_v, _ = v.shape
|
||||
residual = q
|
||||
|
||||
q = self.w_qs(q)
|
||||
q = paddle.reshape(
|
||||
q, shape=[-1, len_q, n_head, d_k]) # 4*21*512 ---- 4*21*8*64
|
||||
k = self.w_ks(k)
|
||||
k = paddle.reshape(k, shape=[-1, len_k, n_head, d_k])
|
||||
v = self.w_vs(v)
|
||||
v = paddle.reshape(v, shape=[-1, len_v, n_head, d_v])
|
||||
|
||||
q = paddle.transpose(q, perm=[2, 0, 1, 3])
|
||||
q = paddle.reshape(q, shape=[-1, len_q, d_k]) # (n*b) x lq x dk
|
||||
k = paddle.transpose(k, perm=[2, 0, 1, 3])
|
||||
k = paddle.reshape(k, shape=[-1, len_k, d_k]) # (n*b) x lk x dk
|
||||
v = paddle.transpose(v, perm=[2, 0, 1, 3])
|
||||
v = paddle.reshape(v, shape=[-1, len_v, d_v]) # (n*b) x lv x dv
|
||||
|
||||
mask = paddle.tile(
|
||||
mask,
|
||||
[n_head, 1, 1]) if mask is not None else None # (n*b) x .. x ..
|
||||
output = self.attention(q, k, v, mask=mask)
|
||||
output = paddle.reshape(output, shape=[n_head, -1, len_q, d_v])
|
||||
output = paddle.transpose(output, perm=[1, 2, 0, 3])
|
||||
output = paddle.reshape(
|
||||
output, shape=[-1, len_q, n_head * d_v]) # b x lq x (n*dv)
|
||||
output = self.dropout(self.fc(output))
|
||||
output = self.layer_norm(output + residual)
|
||||
return output
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Layer):
|
||||
def __init__(self, d_in, d_hid, dropout=0.1):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = nn.Conv1D(d_in, d_hid, 1) # position-wise
|
||||
self.w_2 = nn.Conv1D(d_hid, d_in, 1) # position-wise
|
||||
self.layer_norm = nn.LayerNorm(d_in)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = self.w_2(F.relu(self.w_1(x)))
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = self.dropout(x)
|
||||
x = self.layer_norm(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderLayer(nn.Layer):
|
||||
''' Compose with two layers '''
|
||||
|
||||
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.slf_attn = MultiHeadAttention(
|
||||
n_head, d_model, d_k, d_v, dropout=dropout)
|
||||
self.pos_ffn = PositionwiseFeedForward(
|
||||
d_model, d_inner, dropout=dropout)
|
||||
|
||||
def forward(self, enc_input, slf_attn_mask=None):
|
||||
enc_output = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
enc_output = self.pos_ffn(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class Transformer_Encoder(nn.Layer):
|
||||
def __init__(self,
|
||||
n_layers=2,
|
||||
n_head=8,
|
||||
d_word_vec=512,
|
||||
d_k=64,
|
||||
d_v=64,
|
||||
d_model=512,
|
||||
d_inner=2048,
|
||||
dropout=0.1,
|
||||
n_position=256):
|
||||
super(Transformer_Encoder, self).__init__()
|
||||
self.position_enc = PositionalEncoding(
|
||||
d_word_vec, n_position=n_position)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.layer_stack = nn.LayerList([
|
||||
EncoderLayer(
|
||||
d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
||||
for _ in range(n_layers)
|
||||
])
|
||||
self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-6)
|
||||
|
||||
def forward(self, enc_output, src_mask, return_attns=False):
|
||||
enc_output = self.dropout(
|
||||
self.position_enc(enc_output)) # position embeding
|
||||
for enc_layer in self.layer_stack:
|
||||
enc_output = enc_layer(enc_output, slf_attn_mask=src_mask)
|
||||
enc_output = self.layer_norm(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class PP_layer(nn.Layer):
|
||||
def __init__(self, n_dim=512, N_max_character=25, n_position=256):
|
||||
|
||||
super(PP_layer, self).__init__()
|
||||
self.character_len = N_max_character
|
||||
self.f0_embedding = nn.Embedding(N_max_character, n_dim)
|
||||
self.w0 = nn.Linear(N_max_character, n_position)
|
||||
self.wv = nn.Linear(n_dim, n_dim)
|
||||
self.we = nn.Linear(n_dim, N_max_character)
|
||||
self.active = nn.Tanh()
|
||||
self.softmax = nn.Softmax(axis=2)
|
||||
|
||||
def forward(self, enc_output):
|
||||
# enc_output: b,256,512
|
||||
reading_order = paddle.arange(self.character_len, dtype='int64')
|
||||
reading_order = reading_order.unsqueeze(0).expand(
|
||||
[enc_output.shape[0], -1]) # (S,) -> (B, S)
|
||||
reading_order = self.f0_embedding(reading_order) # b,25,512
|
||||
|
||||
# calculate attention
|
||||
reading_order = paddle.transpose(reading_order, perm=[0, 2, 1])
|
||||
t = self.w0(reading_order) # b,512,256
|
||||
t = self.active(
|
||||
paddle.transpose(
|
||||
t, perm=[0, 2, 1]) + self.wv(enc_output)) # b,256,512
|
||||
t = self.we(t) # b,256,25
|
||||
t = self.softmax(paddle.transpose(t, perm=[0, 2, 1])) # b,25,256
|
||||
g_output = paddle.bmm(t, enc_output) # b,25,512
|
||||
return g_output
|
||||
|
||||
|
||||
class Prediction(nn.Layer):
|
||||
def __init__(self,
|
||||
n_dim=512,
|
||||
n_position=256,
|
||||
N_max_character=25,
|
||||
n_class=37):
|
||||
super(Prediction, self).__init__()
|
||||
self.pp = PP_layer(
|
||||
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
|
||||
self.pp_share = PP_layer(
|
||||
n_dim=n_dim, N_max_character=N_max_character, n_position=n_position)
|
||||
self.w_vrm = nn.Linear(n_dim, n_class) # output layer
|
||||
self.w_share = nn.Linear(n_dim, n_class) # output layer
|
||||
self.nclass = n_class
|
||||
|
||||
def forward(self, cnn_feature, f_res, f_sub, train_mode=False,
|
||||
use_mlm=True):
|
||||
if train_mode:
|
||||
if not use_mlm:
|
||||
g_output = self.pp(cnn_feature) # b,25,512
|
||||
g_output = self.w_vrm(g_output)
|
||||
f_res = 0
|
||||
f_sub = 0
|
||||
return g_output, f_res, f_sub
|
||||
g_output = self.pp(cnn_feature) # b,25,512
|
||||
f_res = self.pp_share(f_res)
|
||||
f_sub = self.pp_share(f_sub)
|
||||
g_output = self.w_vrm(g_output)
|
||||
f_res = self.w_share(f_res)
|
||||
f_sub = self.w_share(f_sub)
|
||||
return g_output, f_res, f_sub
|
||||
else:
|
||||
g_output = self.pp(cnn_feature) # b,25,512
|
||||
g_output = self.w_vrm(g_output)
|
||||
return g_output
|
||||
|
||||
|
||||
class MLM(nn.Layer):
|
||||
"Architecture of MLM"
|
||||
|
||||
def __init__(self, n_dim=512, n_position=256, max_text_length=25):
|
||||
super(MLM, self).__init__()
|
||||
self.MLM_SequenceModeling_mask = Transformer_Encoder(
|
||||
n_layers=2, n_position=n_position)
|
||||
self.MLM_SequenceModeling_WCL = Transformer_Encoder(
|
||||
n_layers=1, n_position=n_position)
|
||||
self.pos_embedding = nn.Embedding(max_text_length, n_dim)
|
||||
self.w0_linear = nn.Linear(1, n_position)
|
||||
self.wv = nn.Linear(n_dim, n_dim)
|
||||
self.active = nn.Tanh()
|
||||
self.we = nn.Linear(n_dim, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x, label_pos):
|
||||
# transformer unit for generating mask_c
|
||||
feature_v_seq = self.MLM_SequenceModeling_mask(x, src_mask=None)
|
||||
# position embedding layer
|
||||
label_pos = paddle.to_tensor(label_pos, dtype='int64')
|
||||
pos_emb = self.pos_embedding(label_pos)
|
||||
pos_emb = self.w0_linear(paddle.unsqueeze(pos_emb, axis=2))
|
||||
pos_emb = paddle.transpose(pos_emb, perm=[0, 2, 1])
|
||||
# fusion position embedding with features V & generate mask_c
|
||||
att_map_sub = self.active(pos_emb + self.wv(feature_v_seq))
|
||||
att_map_sub = self.we(att_map_sub) # b,256,1
|
||||
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
|
||||
att_map_sub = self.sigmoid(att_map_sub) # b,1,256
|
||||
# WCL
|
||||
## generate inputs for WCL
|
||||
att_map_sub = paddle.transpose(att_map_sub, perm=[0, 2, 1])
|
||||
f_res = x * (1 - att_map_sub) # second path with remaining string
|
||||
f_sub = x * att_map_sub # first path with occluded character
|
||||
## transformer units in WCL
|
||||
f_res = self.MLM_SequenceModeling_WCL(f_res, src_mask=None)
|
||||
f_sub = self.MLM_SequenceModeling_WCL(f_sub, src_mask=None)
|
||||
return f_res, f_sub, att_map_sub
|
||||
|
||||
|
||||
def trans_1d_2d(x):
|
||||
b, w_h, c = x.shape # b, 256, 512
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = paddle.reshape(x, [-1, c, 32, 8])
|
||||
x = paddle.transpose(x, perm=[0, 1, 3, 2]) # [b, c, 8, 32]
|
||||
return x
|
||||
|
||||
|
||||
class MLM_VRM(nn.Layer):
|
||||
"""
|
||||
MLM+VRM, MLM is only used in training.
|
||||
ratio controls the occluded number in a batch.
|
||||
The pipeline of VisionLAN in testing is very concise with only a backbone + sequence modeling(transformer unit) + prediction layer(pp layer).
|
||||
x: input image
|
||||
label_pos: character index
|
||||
training_step: LF or LA process
|
||||
output
|
||||
text_pre: prediction of VRM
|
||||
test_rem: prediction of remaining string in MLM
|
||||
text_mas: prediction of occluded character in MLM
|
||||
mask_c_show: visualization of Mask_c
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layers=3,
|
||||
n_position=256,
|
||||
n_dim=512,
|
||||
max_text_length=25,
|
||||
nclass=37):
|
||||
super(MLM_VRM, self).__init__()
|
||||
self.MLM = MLM(n_dim=n_dim,
|
||||
n_position=n_position,
|
||||
max_text_length=max_text_length)
|
||||
self.SequenceModeling = Transformer_Encoder(
|
||||
n_layers=n_layers, n_position=n_position)
|
||||
self.Prediction = Prediction(
|
||||
n_dim=n_dim,
|
||||
n_position=n_position,
|
||||
N_max_character=max_text_length +
|
||||
1, # N_max_character = 1 eos + 25 characters
|
||||
n_class=nclass)
|
||||
self.nclass = nclass
|
||||
self.max_text_length = max_text_length
|
||||
|
||||
def forward(self, x, label_pos, training_step, train_mode=False):
|
||||
b, c, h, w = x.shape
|
||||
nT = self.max_text_length
|
||||
x = paddle.transpose(x, perm=[0, 1, 3, 2])
|
||||
x = paddle.reshape(x, [-1, c, h * w])
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
if train_mode:
|
||||
if training_step == 'LF_1':
|
||||
f_res = 0
|
||||
f_sub = 0
|
||||
x = self.SequenceModeling(x, src_mask=None)
|
||||
text_pre, test_rem, text_mas = self.Prediction(
|
||||
x, f_res, f_sub, train_mode=True, use_mlm=False)
|
||||
return text_pre, text_pre, text_pre, text_pre
|
||||
elif training_step == 'LF_2':
|
||||
# MLM
|
||||
f_res, f_sub, mask_c = self.MLM(x, label_pos)
|
||||
x = self.SequenceModeling(x, src_mask=None)
|
||||
text_pre, test_rem, text_mas = self.Prediction(
|
||||
x, f_res, f_sub, train_mode=True)
|
||||
mask_c_show = trans_1d_2d(mask_c)
|
||||
return text_pre, test_rem, text_mas, mask_c_show
|
||||
elif training_step == 'LA':
|
||||
# MLM
|
||||
f_res, f_sub, mask_c = self.MLM(x, label_pos)
|
||||
## use the mask_c (1 for occluded character and 0 for remaining characters) to occlude input
|
||||
## ratio controls the occluded number in a batch
|
||||
character_mask = paddle.zeros_like(mask_c)
|
||||
|
||||
ratio = b // 2
|
||||
if ratio >= 1:
|
||||
with paddle.no_grad():
|
||||
character_mask[0:ratio, :, :] = mask_c[0:ratio, :, :]
|
||||
else:
|
||||
character_mask = mask_c
|
||||
x = x * (1 - character_mask)
|
||||
# VRM
|
||||
## transformer unit for VRM
|
||||
x = self.SequenceModeling(x, src_mask=None)
|
||||
## prediction layer for MLM and VSR
|
||||
text_pre, test_rem, text_mas = self.Prediction(
|
||||
x, f_res, f_sub, train_mode=True)
|
||||
mask_c_show = trans_1d_2d(mask_c)
|
||||
return text_pre, test_rem, text_mas, mask_c_show
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else: # VRM is only used in the testing stage
|
||||
f_res = 0
|
||||
f_sub = 0
|
||||
contextual_feature = self.SequenceModeling(x, src_mask=None)
|
||||
text_pre = self.Prediction(
|
||||
contextual_feature,
|
||||
f_res,
|
||||
f_sub,
|
||||
train_mode=False,
|
||||
use_mlm=False)
|
||||
text_pre = paddle.transpose(
|
||||
text_pre, perm=[1, 0, 2]) # (26, b, 37))
|
||||
lenText = nT
|
||||
nsteps = nT
|
||||
out_res = paddle.zeros(
|
||||
shape=[lenText, b, self.nclass], dtype=x.dtype) # (25, b, 37)
|
||||
out_length = paddle.zeros(shape=[b], dtype=x.dtype)
|
||||
now_step = 0
|
||||
for _ in range(nsteps):
|
||||
if 0 in out_length and now_step < nsteps:
|
||||
tmp_result = text_pre[now_step, :, :]
|
||||
out_res[now_step] = tmp_result
|
||||
tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
|
||||
for j in range(b):
|
||||
if out_length[j] == 0 and tmp_result[j] == 0:
|
||||
out_length[j] = now_step + 1
|
||||
now_step += 1
|
||||
# while 0 in out_length and now_step < nsteps:
|
||||
# tmp_result = text_pre[now_step, :, :]
|
||||
# out_res[now_step] = tmp_result
|
||||
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
|
||||
# for j in range(b):
|
||||
# if out_length[j] == 0 and tmp_result[j] == 0:
|
||||
# out_length[j] = now_step + 1
|
||||
# now_step += 1
|
||||
for j in range(0, b):
|
||||
if int(out_length[j]) == 0:
|
||||
out_length[j] = nsteps
|
||||
start = 0
|
||||
output = paddle.zeros(
|
||||
shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
|
||||
for i in range(0, b):
|
||||
cur_length = int(out_length[i])
|
||||
output[start:start + cur_length] = out_res[0:cur_length, i, :]
|
||||
start += cur_length
|
||||
return output, out_length
|
||||
|
||||
|
||||
class VLHead(nn.Layer):
|
||||
"""
|
||||
Architecture of VisionLAN
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=36,
|
||||
n_layers=3,
|
||||
n_position=256,
|
||||
n_dim=512,
|
||||
max_text_length=25,
|
||||
training_step='LA'):
|
||||
super(VLHead, self).__init__()
|
||||
self.MLM_VRM = MLM_VRM(
|
||||
n_layers=n_layers,
|
||||
n_position=n_position,
|
||||
n_dim=n_dim,
|
||||
max_text_length=max_text_length,
|
||||
nclass=out_channels + 1)
|
||||
self.training_step = training_step
|
||||
|
||||
def forward(self, feat, targets=None):
|
||||
|
||||
if self.training:
|
||||
label_pos = targets[-2]
|
||||
text_pre, test_rem, text_mas, mask_map = self.MLM_VRM(
|
||||
feat, label_pos, self.training_step, train_mode=True)
|
||||
return text_pre, test_rem, text_mas, mask_map
|
||||
else:
|
||||
output, out_length = self.MLM_VRM(
|
||||
feat, targets, self.training_step, train_mode=False)
|
||||
return output, out_length
|
|
@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
|
|||
from .fce_postprocess import FCEPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode
|
||||
SEEDLabelDecode, PRENLabelDecode, VLLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
||||
|
@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
|
||||
'DistillationSARLabelDecode'
|
||||
'DistillationSARLabelDecode', 'VLLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -27,7 +27,8 @@ class BaseRecLabelDecode(object):
|
|||
|
||||
self.character_str = []
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
# self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
self.character_str = "abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
|
@ -752,3 +753,70 @@ class PRENLabelDecode(BaseRecLabelDecode):
|
|||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
|
||||
class VLLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id - 1]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, length=None, *args, **kwargs):
|
||||
if len(preds) == 2: # eval mode
|
||||
net_out, length = preds
|
||||
else: # train mode
|
||||
net_out = preds[0]
|
||||
length = length
|
||||
net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
|
||||
text = []
|
||||
if not isinstance(net_out, paddle.Tensor):
|
||||
net_out = paddle.to_tensor(net_out, dtype='float32')
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
net_out = F.softmax(net_out, axis=1)
|
||||
for i in range(0, length.shape[0]):
|
||||
preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||
) + length[i])].topk(1)[1][:, 0].tolist()
|
||||
preds_text = ''.join([
|
||||
self.character[idx - 1]
|
||||
if idx > 0 and idx <= len(self.character) else ''
|
||||
for idx in preds_idx
|
||||
])
|
||||
preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
|
||||
) + length[i])].topk(1)[0][:, 0]
|
||||
preds_prob = paddle.exp(
|
||||
paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
|
||||
text.append((preds_text, preds_prob))
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
|
|
@ -73,7 +73,7 @@ def main():
|
|||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
|
|
@ -55,7 +55,7 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
|
|||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SVTR":
|
||||
elif arch_config["algorithm"] in ["SVTR", "VisionLAN"]:
|
||||
if arch_config["Head"]["name"] == 'MultiHead':
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
|
|
|
@ -69,6 +69,12 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "VisionLAN":
|
||||
postprocess_params = {
|
||||
'name': 'VLLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -143,6 +149,15 @@ class TextRecognizer(object):
|
|||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_vl(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
|
@ -300,6 +315,11 @@ class TextRecognizer(object):
|
|||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "VisionLAN":
|
||||
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
|
||||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
|
|
|
@ -207,7 +207,7 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"]
|
||||
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN"]
|
||||
extra_input = False
|
||||
if config['Architecture']['algorithm'] == 'Distillation':
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
@ -249,7 +249,6 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
|
||||
# use amp
|
||||
if scaler:
|
||||
with paddle.amp.auto_cast():
|
||||
|
@ -264,7 +263,6 @@ def train(config,
|
|||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
|
||||
|
@ -286,6 +284,9 @@ def train(config,
|
|||
]: # for multi head loss
|
||||
post_result = post_process_class(
|
||||
preds['ctc'], batch[1]) # for CTC head out
|
||||
elif config['Loss']['name'] in ['VLLoss']:
|
||||
post_result = post_process_class(preds, batch[1],
|
||||
batch[-1])
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
eval_class(post_result, batch)
|
||||
|
@ -307,7 +308,8 @@ def train(config,
|
|||
train_stats.update(stats)
|
||||
|
||||
if log_writer is not None and dist.get_rank() == 0:
|
||||
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics=train_stats.get(), prefix="TRAIN", step=global_step)
|
||||
|
||||
if dist.get_rank() == 0 and (
|
||||
(global_step > 0 and global_step % print_batch_step == 0) or
|
||||
|
@ -354,7 +356,8 @@ def train(config,
|
|||
|
||||
# logger metric
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics=cur_metric, prefix="EVAL", step=global_step)
|
||||
|
||||
if cur_metric[main_indicator] >= best_model_dict[
|
||||
main_indicator]:
|
||||
|
@ -377,11 +380,18 @@ def train(config,
|
|||
logger.info(best_str)
|
||||
# logger best metric
|
||||
if log_writer is not None:
|
||||
log_writer.log_metrics(metrics={
|
||||
"best_{}".format(main_indicator): best_model_dict[main_indicator]
|
||||
}, prefix="EVAL", step=global_step)
|
||||
log_writer.log_metrics(
|
||||
metrics={
|
||||
"best_{}".format(main_indicator):
|
||||
best_model_dict[main_indicator]
|
||||
},
|
||||
prefix="EVAL",
|
||||
step=global_step)
|
||||
|
||||
log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict)
|
||||
log_writer.log_model(
|
||||
is_best=True,
|
||||
prefix="best_accuracy",
|
||||
metadata=best_model_dict)
|
||||
|
||||
reader_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
|
@ -413,7 +423,8 @@ def train(config,
|
|||
epoch=epoch,
|
||||
global_step=global_step)
|
||||
if log_writer is not None:
|
||||
log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
log_writer.log_model(
|
||||
is_best=False, prefix='iter_epoch_{}'.format(epoch))
|
||||
|
||||
best_str = 'best metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||
|
@ -451,7 +462,6 @@ def eval(model,
|
|||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
batch_numpy = []
|
||||
for item in batch:
|
||||
if isinstance(item, paddle.Tensor):
|
||||
|
@ -564,7 +574,8 @@ def preprocess(is_train=False):
|
|||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
|
||||
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
|
||||
'VisionLAN'
|
||||
]
|
||||
|
||||
if use_xpu:
|
||||
|
@ -583,9 +594,10 @@ def preprocess(is_train=False):
|
|||
if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
log_writer = VDLLogger(save_model_dir)
|
||||
log_writer = VDLLogger(vdl_writer_path)
|
||||
loggers.append(log_writer)
|
||||
if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config:
|
||||
if ('use_wandb' in config['Global'] and
|
||||
config['Global']['use_wandb']) or 'wandb' in config:
|
||||
save_dir = config['Global']['save_model_dir']
|
||||
wandb_writer_path = "{}/wandb".format(save_dir)
|
||||
if "wandb" in config:
|
||||
|
|
Loading…
Reference in New Issue