add vl
parent
0401e5203e
commit
a3a095150e
|
@ -0,0 +1,106 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 8
|
||||
log_smooth_window: 200
|
||||
print_batch_step: 200
|
||||
save_model_dir: /paddle/backup/visionlan/LA_v2
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model: ./pretrained_model/LF_2_ocr
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: True
|
||||
infer_img: doc/imgs_words/en/word_2.png
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict36.txt
|
||||
max_text_length: &max_text_length 25
|
||||
training_step: &training_step LA
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_visionlan.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 20.0
|
||||
group_lr: true
|
||||
training_step: *training_step
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [6]
|
||||
values: [0.0001, 0.00001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: VisionLAN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet45
|
||||
strides: [2, 2, 2, 1, 1]
|
||||
Head:
|
||||
name: VLHead
|
||||
n_layers: 3
|
||||
n_position: 256
|
||||
n_dim: 512
|
||||
max_text_length: *max_text_length
|
||||
training_step: *training_step
|
||||
|
||||
Loss:
|
||||
name: VLLoss
|
||||
mode: *training_step
|
||||
weight_res: 0.5
|
||||
weight_mas: 0.5
|
||||
|
||||
PostProcess:
|
||||
name: VLLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
is_filter: true
|
||||
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- ABINetRecAug:
|
||||
- VLLabelEncode: # Class handling label
|
||||
- VLRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 220
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- VLLabelEncode: # Class handling label
|
||||
- VLRecResizeImg:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'label_res', 'label_sub', 'label_id', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 64
|
||||
num_workers: 4
|
||||
|
|
@ -99,12 +99,13 @@ class BaseRecLabelEncode(object):
|
|||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False):
|
||||
use_space_char=False,
|
||||
lower=False):
|
||||
|
||||
self.max_text_len = max_text_length
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.lower = False
|
||||
self.lower = lower
|
||||
|
||||
if character_dict_path is None:
|
||||
logger = get_logger()
|
||||
|
@ -1227,9 +1228,10 @@ class VLLabelEncode(BaseRecLabelEncode):
|
|||
max_text_length,
|
||||
character_dict_path=None,
|
||||
use_space_char=False,
|
||||
lower=True,
|
||||
**kwargs):
|
||||
super(VLLabelEncode, self).__init__(max_text_length,
|
||||
character_dict_path, use_space_char)
|
||||
super(VLLabelEncode, self).__init__(
|
||||
max_text_length, character_dict_path, use_space_char, lower)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label'] # original string
|
||||
|
|
|
@ -13,6 +13,5 @@
|
|||
# limitations under the License.
|
||||
|
||||
from .augment import tia_perspective, tia_distort, tia_stretch
|
||||
from .vl_aug import VLAug
|
||||
|
||||
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective', 'VLAug']
|
||||
__all__ = ['tia_distort', 'tia_stretch', 'tia_perspective']
|
||||
|
|
|
@ -1,460 +0,0 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import numbers
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from paddle.vision import transforms
|
||||
from paddle.vision.transforms import Compose
|
||||
|
||||
|
||||
def sample_asym(magnitude, size=None):
|
||||
return np.random.beta(1, 4, size) * magnitude
|
||||
|
||||
|
||||
def sample_sym(magnitude, size=None):
|
||||
return (np.random.beta(4, 4, size=size) - 0.5) * 2 * magnitude
|
||||
|
||||
|
||||
def sample_uniform(low, high, size=None):
|
||||
return np.random.uniform(low, high, size=size)
|
||||
|
||||
|
||||
def get_interpolation(type='random'):
|
||||
if type == 'random':
|
||||
choice = [
|
||||
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA
|
||||
]
|
||||
interpolation = choice[random.randint(0, len(choice) - 1)]
|
||||
elif type == 'nearest':
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
elif type == 'linear':
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
elif type == 'cubic':
|
||||
interpolation = cv2.INTER_CUBIC
|
||||
elif type == 'area':
|
||||
interpolation = cv2.INTER_AREA
|
||||
else:
|
||||
raise TypeError(
|
||||
'Interpolation types only nearest, linear, cubic, area are supported!'
|
||||
)
|
||||
return interpolation
|
||||
|
||||
|
||||
class CVRandomRotation(object):
|
||||
def __init__(self, degrees=15):
|
||||
assert isinstance(degrees,
|
||||
numbers.Number), "degree should be a single number."
|
||||
assert degrees >= 0, "degree must be positive."
|
||||
self.degrees = degrees
|
||||
|
||||
@staticmethod
|
||||
def get_params(degrees):
|
||||
return sample_sym(degrees)
|
||||
|
||||
def __call__(self, img):
|
||||
angle = self.get_params(self.degrees)
|
||||
src_h, src_w = img.shape[:2]
|
||||
M = cv2.getRotationMatrix2D(
|
||||
center=(src_w / 2, src_h / 2), angle=angle, scale=1.0)
|
||||
abs_cos, abs_sin = abs(M[0, 0]), abs(M[0, 1])
|
||||
dst_w = int(src_h * abs_sin + src_w * abs_cos)
|
||||
dst_h = int(src_h * abs_cos + src_w * abs_sin)
|
||||
M[0, 2] += (dst_w - src_w) / 2
|
||||
M[1, 2] += (dst_h - src_h) / 2
|
||||
|
||||
flags = get_interpolation()
|
||||
return cv2.warpAffine(
|
||||
img,
|
||||
M, (dst_w, dst_h),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
|
||||
class CVRandomAffine(object):
|
||||
def __init__(self, degrees, translate=None, scale=None, shear=None):
|
||||
assert isinstance(degrees,
|
||||
numbers.Number), "degree should be a single number."
|
||||
assert degrees >= 0, "degree must be positive."
|
||||
self.degrees = degrees
|
||||
|
||||
if translate is not None:
|
||||
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
|
||||
"translate should be a list or tuple and it must be of length 2."
|
||||
for t in translate:
|
||||
if not (0.0 <= t <= 1.0):
|
||||
raise ValueError(
|
||||
"translation values should be between 0 and 1")
|
||||
self.translate = translate
|
||||
|
||||
if scale is not None:
|
||||
assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
|
||||
"scale should be a list or tuple and it must be of length 2."
|
||||
for s in scale:
|
||||
if s <= 0:
|
||||
raise ValueError("scale values should be positive")
|
||||
self.scale = scale
|
||||
|
||||
if shear is not None:
|
||||
if isinstance(shear, numbers.Number):
|
||||
if shear < 0:
|
||||
raise ValueError(
|
||||
"If shear is a single number, it must be positive.")
|
||||
self.shear = [shear]
|
||||
else:
|
||||
assert isinstance(shear, (tuple, list)) and (len(shear) == 2), \
|
||||
"shear should be a list or tuple and it must be of length 2."
|
||||
self.shear = shear
|
||||
else:
|
||||
self.shear = shear
|
||||
|
||||
def _get_inverse_affine_matrix(self, center, angle, translate, scale,
|
||||
shear):
|
||||
from numpy import sin, cos, tan
|
||||
|
||||
if isinstance(shear, numbers.Number):
|
||||
shear = [shear, 0]
|
||||
|
||||
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
|
||||
raise ValueError(
|
||||
"Shear should be a single value or a tuple/list containing " +
|
||||
"two values. Got {}".format(shear))
|
||||
|
||||
rot = math.radians(angle)
|
||||
sx, sy = [math.radians(s) for s in shear]
|
||||
|
||||
cx, cy = center
|
||||
tx, ty = translate
|
||||
|
||||
# RSS without scaling
|
||||
a = cos(rot - sy) / cos(sy)
|
||||
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot)
|
||||
c = sin(rot - sy) / cos(sy)
|
||||
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot)
|
||||
|
||||
# Inverted rotation matrix with scale and shear
|
||||
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
|
||||
M = [d, -b, 0, -c, a, 0]
|
||||
M = [x / scale for x in M]
|
||||
|
||||
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
|
||||
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty)
|
||||
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty)
|
||||
|
||||
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
|
||||
M[2] += cx
|
||||
M[5] += cy
|
||||
return M
|
||||
|
||||
@staticmethod
|
||||
def get_params(degrees, translate, scale_ranges, shears, height):
|
||||
angle = sample_sym(degrees)
|
||||
if translate is not None:
|
||||
max_dx = translate[0] * height
|
||||
max_dy = translate[1] * height
|
||||
translations = (np.round(sample_sym(max_dx)),
|
||||
np.round(sample_sym(max_dy)))
|
||||
else:
|
||||
translations = (0, 0)
|
||||
|
||||
if scale_ranges is not None:
|
||||
scale = sample_uniform(scale_ranges[0], scale_ranges[1])
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if shears is not None:
|
||||
if len(shears) == 1:
|
||||
shear = [sample_sym(shears[0]), 0.]
|
||||
elif len(shears) == 2:
|
||||
shear = [sample_sym(shears[0]), sample_sym(shears[1])]
|
||||
else:
|
||||
shear = 0.0
|
||||
|
||||
return angle, translations, scale, shear
|
||||
|
||||
def __call__(self, img):
|
||||
src_h, src_w = img.shape[:2]
|
||||
angle, translate, scale, shear = self.get_params(
|
||||
self.degrees, self.translate, self.scale, self.shear, src_h)
|
||||
|
||||
M = self._get_inverse_affine_matrix((src_w / 2, src_h / 2), angle,
|
||||
(0, 0), scale, shear)
|
||||
M = np.array(M).reshape(2, 3)
|
||||
|
||||
startpoints = [(0, 0), (src_w - 1, 0), (src_w - 1, src_h - 1),
|
||||
(0, src_h - 1)]
|
||||
project = lambda x, y, a, b, c: int(a * x + b * y + c)
|
||||
endpoints = [(project(x, y, *M[0]), project(x, y, *M[1]))
|
||||
for x, y in startpoints]
|
||||
|
||||
rect = cv2.minAreaRect(np.array(endpoints))
|
||||
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
||||
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
||||
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
||||
|
||||
dst_w = int(max_x - min_x)
|
||||
dst_h = int(max_y - min_y)
|
||||
M[0, 2] += (dst_w - src_w) / 2
|
||||
M[1, 2] += (dst_h - src_h) / 2
|
||||
|
||||
# add translate
|
||||
dst_w += int(abs(translate[0]))
|
||||
dst_h += int(abs(translate[1]))
|
||||
if translate[0] < 0: M[0, 2] += abs(translate[0])
|
||||
if translate[1] < 0: M[1, 2] += abs(translate[1])
|
||||
|
||||
flags = get_interpolation()
|
||||
return cv2.warpAffine(
|
||||
img,
|
||||
M, (dst_w, dst_h),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
|
||||
|
||||
class CVRandomPerspective(object):
|
||||
def __init__(self, distortion=0.5):
|
||||
self.distortion = distortion
|
||||
|
||||
def get_params(self, width, height, distortion):
|
||||
offset_h = sample_asym(
|
||||
distortion * height / 2, size=4).astype(dtype=np.int)
|
||||
offset_w = sample_asym(
|
||||
distortion * width / 2, size=4).astype(dtype=np.int)
|
||||
topleft = (offset_w[0], offset_h[0])
|
||||
topright = (width - 1 - offset_w[1], offset_h[1])
|
||||
botright = (width - 1 - offset_w[2], height - 1 - offset_h[2])
|
||||
botleft = (offset_w[3], height - 1 - offset_h[3])
|
||||
|
||||
startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1),
|
||||
(0, height - 1)]
|
||||
endpoints = [topleft, topright, botright, botleft]
|
||||
return np.array(
|
||||
startpoints, dtype=np.float32), np.array(
|
||||
endpoints, dtype=np.float32)
|
||||
|
||||
def __call__(self, img):
|
||||
height, width = img.shape[:2]
|
||||
startpoints, endpoints = self.get_params(width, height, self.distortion)
|
||||
M = cv2.getPerspectiveTransform(startpoints, endpoints)
|
||||
|
||||
# TODO: more robust way to crop image
|
||||
rect = cv2.minAreaRect(endpoints)
|
||||
bbox = cv2.boxPoints(rect).astype(dtype=np.int)
|
||||
max_x, max_y = bbox[:, 0].max(), bbox[:, 1].max()
|
||||
min_x, min_y = bbox[:, 0].min(), bbox[:, 1].min()
|
||||
min_x, min_y = max(min_x, 0), max(min_y, 0)
|
||||
|
||||
flags = get_interpolation()
|
||||
img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (max_x, max_y),
|
||||
flags=flags,
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
img = img[min_y:, min_x:]
|
||||
return img
|
||||
|
||||
|
||||
class CVRescale(object):
|
||||
def __init__(self, factor=4, base_size=(128, 512)):
|
||||
""" Define image scales using gaussian pyramid and rescale image to target scale.
|
||||
|
||||
Args:
|
||||
factor: the decayed factor from base size, factor=4 keeps target scale by default.
|
||||
base_size: base size the build the bottom layer of pyramid
|
||||
"""
|
||||
if isinstance(factor, numbers.Number):
|
||||
self.factor = round(sample_uniform(0, factor))
|
||||
elif isinstance(factor, (tuple, list)) and len(factor) == 2:
|
||||
self.factor = round(sample_uniform(factor[0], factor[1]))
|
||||
else:
|
||||
raise Exception('factor must be number or list with length 2')
|
||||
# assert factor is valid
|
||||
self.base_h, self.base_w = base_size[:2]
|
||||
|
||||
def __call__(self, img):
|
||||
if self.factor == 0:
|
||||
return img
|
||||
src_h, src_w = img.shape[:2]
|
||||
cur_w, cur_h = self.base_w, self.base_h
|
||||
scale_img = cv2.resize(
|
||||
img, (cur_w, cur_h), interpolation=get_interpolation())
|
||||
for _ in range(np.int(self.factor)):
|
||||
scale_img = cv2.pyrDown(scale_img)
|
||||
scale_img = cv2.resize(
|
||||
scale_img, (src_w, src_h), interpolation=get_interpolation())
|
||||
return scale_img
|
||||
|
||||
|
||||
class CVGaussianNoise(object):
|
||||
def __init__(self, mean=0, var=20):
|
||||
self.mean = mean
|
||||
if isinstance(var, numbers.Number):
|
||||
self.var = max(int(sample_asym(var)), 1)
|
||||
elif isinstance(var, (tuple, list)) and len(var) == 2:
|
||||
self.var = int(sample_uniform(var[0], var[1]))
|
||||
else:
|
||||
raise Exception('degree must be number or list with length 2')
|
||||
|
||||
def __call__(self, img):
|
||||
noise = np.random.normal(self.mean, self.var**0.5, img.shape)
|
||||
img = np.clip(img + noise, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
class CVMotionBlur(object):
|
||||
def __init__(self, degrees=12, angle=90):
|
||||
if isinstance(degrees, numbers.Number):
|
||||
self.degree = max(int(sample_asym(degrees)), 1)
|
||||
elif isinstance(degrees, (tuple, list)) and len(degrees) == 2:
|
||||
self.degree = int(sample_uniform(degrees[0], degrees[1]))
|
||||
else:
|
||||
raise Exception('degree must be number or list with length 2')
|
||||
self.angle = sample_uniform(-angle, angle)
|
||||
|
||||
def __call__(self, img):
|
||||
M = cv2.getRotationMatrix2D((self.degree // 2, self.degree // 2),
|
||||
self.angle, 1)
|
||||
motion_blur_kernel = np.zeros((self.degree, self.degree))
|
||||
motion_blur_kernel[self.degree // 2, :] = 1
|
||||
motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M,
|
||||
(self.degree, self.degree))
|
||||
motion_blur_kernel = motion_blur_kernel / self.degree
|
||||
img = cv2.filter2D(img, -1, motion_blur_kernel)
|
||||
img = np.clip(img, 0, 255).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
class CVGeometry(object):
|
||||
def __init__(self,
|
||||
degrees=15,
|
||||
translate=(0.3, 0.3),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=0.5):
|
||||
self.p = p
|
||||
type_p = random.random()
|
||||
if type_p < 0.33:
|
||||
self.transforms = CVRandomRotation(degrees=degrees)
|
||||
elif type_p < 0.66:
|
||||
self.transforms = CVRandomAffine(
|
||||
degrees=degrees, translate=translate, scale=scale, shear=shear)
|
||||
else:
|
||||
self.transforms = CVRandomPerspective(distortion=distortion)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class CVDeterioration(object):
|
||||
def __init__(self, var, degrees, factor, p=0.5):
|
||||
self.p = p
|
||||
transforms = []
|
||||
if var is not None:
|
||||
transforms.append(CVGaussianNoise(var=var))
|
||||
if degrees is not None:
|
||||
transforms.append(CVMotionBlur(degrees=degrees))
|
||||
if factor is not None:
|
||||
transforms.append(CVRescale(factor=factor))
|
||||
|
||||
random.shuffle(transforms)
|
||||
transforms = Compose(transforms)
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class CVColorJitter(object):
|
||||
def __init__(self,
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=0.5):
|
||||
self.p = p
|
||||
self.transforms = transforms.ColorJitter(
|
||||
brightness=brightness,
|
||||
contrast=contrast,
|
||||
saturation=saturation,
|
||||
hue=hue)
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p:
|
||||
return self.transforms(img)
|
||||
else:
|
||||
return img
|
||||
|
||||
|
||||
class VLAug(object):
|
||||
def __init__(self,
|
||||
geometry_p=0.5,
|
||||
Deterioration_p=0.25,
|
||||
ColorJitter_p=0.25,
|
||||
**kwargs):
|
||||
self.Geometry = CVGeometry(
|
||||
degrees=45,
|
||||
translate=(0.0, 0.0),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=geometry_p)
|
||||
self.Deterioration = CVDeterioration(
|
||||
var=20, degrees=6, factor=4, p=Deterioration_p)
|
||||
self.ColorJitter = CVColorJitter(
|
||||
brightness=0.5,
|
||||
contrast=0.5,
|
||||
saturation=0.5,
|
||||
hue=0.1,
|
||||
p=ColorJitter_p)
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = self.Geometry(img)
|
||||
img = self.Deterioration(img)
|
||||
img = self.ColorJitter(img)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
geo = CVGeometry(
|
||||
degrees=45,
|
||||
translate=(0.0, 0.0),
|
||||
scale=(0.5, 2.),
|
||||
shear=(45, 15),
|
||||
distortion=0.5,
|
||||
p=1)
|
||||
det = CVDeterioration(var=20, degrees=6, factor=4, p=1)
|
||||
color = CVColorJitter(
|
||||
brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1, p=1)
|
||||
|
||||
img = np.ones((64, 256, 3))
|
||||
img = geo(img)
|
||||
img = det(img)
|
||||
img = color(img)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# print()
|
|
@ -0,0 +1,66 @@
|
|||
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class VLLoss(nn.Layer):
|
||||
def __init__(self, mode='LF_1', weight_res=0.5, weight_mas=0.5, **kwargs):
|
||||
super(VLLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean")
|
||||
assert mode in ['LF_1', 'LF_2', 'LA']
|
||||
self.mode = mode
|
||||
self.weight_res = weight_res
|
||||
self.weight_mas = weight_mas
|
||||
|
||||
def flatten_label(self, target):
|
||||
label_flatten = []
|
||||
label_length = []
|
||||
for i in range(0, target.shape[0]):
|
||||
cur_label = target[i].tolist()
|
||||
label_flatten += cur_label[:cur_label.index(0) + 1]
|
||||
label_length.append(cur_label.index(0) + 1)
|
||||
label_flatten = paddle.to_tensor(label_flatten, dtype='int64')
|
||||
label_length = paddle.to_tensor(label_length, dtype='int32')
|
||||
return (label_flatten, label_length)
|
||||
|
||||
def _flatten(self, sources, lengths):
|
||||
return paddle.concat([t[:l] for t, l in zip(sources, lengths)])
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
text_pre = predicts[0]
|
||||
target = batch[1].astype('int64')
|
||||
label_flatten, length = self.flatten_label(target)
|
||||
text_pre = self._flatten(text_pre, length)
|
||||
if self.mode == 'LF_1':
|
||||
loss = self.loss_func(text_pre, label_flatten)
|
||||
else:
|
||||
text_rem = predicts[1]
|
||||
text_mas = predicts[2]
|
||||
target_res = batch[2].astype('int64')
|
||||
target_sub = batch[3].astype('int64')
|
||||
label_flatten_res, length_res = self.flatten_label(target_res)
|
||||
label_flatten_sub, length_sub = self.flatten_label(target_sub)
|
||||
text_rem = self._flatten(text_rem, length_res)
|
||||
text_mas = self._flatten(text_mas, length_sub)
|
||||
loss_ori = self.loss_func(text_pre, label_flatten)
|
||||
loss_res = self.loss_func(text_rem, label_flatten_res)
|
||||
loss_mas = self.loss_func(text_mas, label_flatten_sub)
|
||||
loss = loss_ori + loss_res * self.weight_res + loss_mas * self.weight_mas
|
||||
return {'loss': loss}
|
|
@ -84,11 +84,15 @@ class BasicBlock(nn.Layer):
|
|||
|
||||
|
||||
class ResNet45(nn.Layer):
|
||||
def __init__(self, block=BasicBlock, layers=[3, 4, 6, 6, 3], in_channels=3):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
block=BasicBlock,
|
||||
layers=[3, 4, 6, 6, 3],
|
||||
strides=[2, 1, 2, 1, 1]):
|
||||
self.inplanes = 32
|
||||
super(ResNet45, self).__init__()
|
||||
self.conv1 = nn.Conv2D(
|
||||
3,
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
|
@ -98,18 +102,13 @@ class ResNet45(nn.Layer):
|
|||
self.bn1 = nn.BatchNorm2D(32)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
|
||||
self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
|
||||
self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
|
||||
self.layer1 = self._make_layer(block, 32, layers[0], stride=strides[0])
|
||||
self.layer2 = self._make_layer(block, 64, layers[1], stride=strides[1])
|
||||
self.layer3 = self._make_layer(block, 128, layers[2], stride=strides[2])
|
||||
self.layer4 = self._make_layer(block, 256, layers[3], stride=strides[3])
|
||||
self.layer5 = self._make_layer(block, 512, layers[4], stride=strides[4])
|
||||
self.out_channels = 512
|
||||
|
||||
# for m in self.modules():
|
||||
# if isinstance(m, nn.Conv2D):
|
||||
# n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
||||
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
|
@ -137,11 +136,9 @@ class ResNet45(nn.Layer):
|
|||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
# print(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
# print(x)
|
||||
x = self.layer4(x)
|
||||
x = self.layer5(x)
|
||||
return x
|
||||
|
|
|
@ -20,10 +20,6 @@ import paddle.nn as nn
|
|||
|
||||
import sys
|
||||
import math
|
||||
from paddle.nn.initializer import KaimingNormal, Constant
|
||||
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
|
@ -144,111 +140,4 @@ class ResNet_ASTER(nn.Layer):
|
|||
rnn_feat, _ = self.rnn(cnn_feat)
|
||||
return rnn_feat
|
||||
else:
|
||||
return cnn_feat
|
||||
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Block, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet45(nn.Layer):
|
||||
def __init__(self, in_channels=3, compress_layer=False):
|
||||
super(ResNet45, self).__init__()
|
||||
self.compress_layer = compress_layer
|
||||
|
||||
self.conv1_new = nn.Conv2D(
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
self.bn1 = nn.BatchNorm2D(32)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.inplanes = 32
|
||||
self.layer1 = self._make_layer(32, 3, [2, 2]) # [32, 128]
|
||||
self.layer2 = self._make_layer(64, 4, [2, 2]) # [16, 64]
|
||||
self.layer3 = self._make_layer(128, 6, [2, 2]) # [8, 32]
|
||||
self.layer4 = self._make_layer(256, 6, [1, 1]) # [8, 32]
|
||||
self.layer5 = self._make_layer(512, 3, [1, 1]) # [8, 32]
|
||||
|
||||
if self.compress_layer:
|
||||
self.layer6 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
512, 256, kernel_size=(3, 1), padding=(0, 0), stride=(1,
|
||||
1)),
|
||||
nn.BatchNorm(256),
|
||||
nn.ReLU())
|
||||
self.out_channels = 256
|
||||
else:
|
||||
self.out_channels = 512
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2D):
|
||||
KaimingNormal(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm):
|
||||
ones_(m.weight)
|
||||
zeros_(m.bias)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride):
|
||||
downsample = None
|
||||
if stride != [1, 1] or self.inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
|
||||
|
||||
layers = []
|
||||
layers.append(Block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Block(self.inplanes, planes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1_new(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x1 = self.layer1(x)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
if not self.compress_layer:
|
||||
return x5
|
||||
else:
|
||||
x6 = self.layer6(x5)
|
||||
return x6
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = ResNet45()
|
||||
x = paddle.rand([1, 3, 64, 256])
|
||||
x = paddle.to_tensor(x)
|
||||
print(x.shape)
|
||||
out = model(x)
|
||||
print(out.shape)
|
||||
return cnn_feat
|
|
@ -22,7 +22,7 @@ import paddle.nn as nn
|
|||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import Normal, XavierNormal
|
||||
import numpy as np
|
||||
from ppocr.modeling.backbones.rec_resnet_aster import ResNet45
|
||||
from ppocr.modeling.backbones.rec_resnet_45 import ResNet45
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
|
@ -442,14 +442,6 @@ class MLM_VRM(nn.Layer):
|
|||
if out_length[j] == 0 and tmp_result[j] == 0:
|
||||
out_length[j] = now_step + 1
|
||||
now_step += 1
|
||||
# while 0 in out_length and now_step < nsteps:
|
||||
# tmp_result = text_pre[now_step, :, :]
|
||||
# out_res[now_step] = tmp_result
|
||||
# tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
|
||||
# for j in range(b):
|
||||
# if out_length[j] == 0 and tmp_result[j] == 0:
|
||||
# out_length[j] = now_step + 1
|
||||
# now_step += 1
|
||||
for j in range(0, b):
|
||||
if int(out_length[j]) == 0:
|
||||
out_length[j] = nsteps
|
||||
|
|
|
@ -77,11 +77,62 @@ class Adam(object):
|
|||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
self.lazy_mode = lazy_mode
|
||||
self.group_lr = kwargs.get('group_lr', False)
|
||||
self.training_step = kwargs.get('training_step', None)
|
||||
|
||||
def __call__(self, model):
|
||||
train_params = [
|
||||
param for param in model.parameters() if param.trainable is True
|
||||
]
|
||||
if self.group_lr:
|
||||
if self.training_step == 'LF_2':
|
||||
import paddle
|
||||
if isinstance(model, paddle.fluid.dygraph.parallel.
|
||||
DataParallel): # multi gpu
|
||||
mlm = model._layers.head.MLM_VRM.MLM.parameters()
|
||||
pre_mlm_pp = model._layers.head.MLM_VRM.Prediction.pp_share.parameters(
|
||||
)
|
||||
pre_mlm_w = model._layers.head.MLM_VRM.Prediction.w_share.parameters(
|
||||
)
|
||||
else: # single gpu
|
||||
mlm = model.head.MLM_VRM.MLM.parameters()
|
||||
pre_mlm_pp = model.head.MLM_VRM.Prediction.pp_share.parameters(
|
||||
)
|
||||
pre_mlm_w = model.head.MLM_VRM.Prediction.w_share.parameters(
|
||||
)
|
||||
|
||||
total = []
|
||||
for param in mlm:
|
||||
total.append(id(param))
|
||||
for param in pre_mlm_pp:
|
||||
total.append(id(param))
|
||||
for param in pre_mlm_w:
|
||||
total.append(id(param))
|
||||
|
||||
group_base_params = [
|
||||
param for param in model.parameters() if id(param) in total
|
||||
]
|
||||
group_small_params = [
|
||||
param for param in model.parameters()
|
||||
if id(param) not in total
|
||||
]
|
||||
train_params = [{
|
||||
'params': group_base_params
|
||||
}, {
|
||||
'params': group_small_params,
|
||||
'learning_rate': self.learning_rate.values[0] * 0.1
|
||||
}]
|
||||
|
||||
else:
|
||||
print(
|
||||
'group lr currently only support VisionLAN in LF_2 training step'
|
||||
)
|
||||
train_params = [
|
||||
param for param in model.parameters()
|
||||
if param.trainable is True
|
||||
]
|
||||
else:
|
||||
train_params = [
|
||||
param for param in model.parameters() if param.trainable is True
|
||||
]
|
||||
|
||||
opt = optim.Adam(
|
||||
learning_rate=self.learning_rate,
|
||||
beta1=self.beta1,
|
||||
|
|
|
@ -27,8 +27,7 @@ class BaseRecLabelDecode(object):
|
|||
|
||||
self.character_str = []
|
||||
if character_dict_path is None:
|
||||
# self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
self.character_str = "abcdefghijklmnopqrstuvwxyz1234567890"
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
0
|
|
@ -60,7 +60,7 @@ def export_single_model(model,
|
|||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["SVTR", "VisionLAN"]:
|
||||
elif arch_config["algorithm"] == "SVTR":
|
||||
if arch_config["Head"]["name"] == 'MultiHead':
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
|
@ -97,6 +97,12 @@ def export_single_model(model,
|
|||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "VisionLAN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 64, 256], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(
|
||||
|
@ -217,4 +223,4 @@ def main():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
|
@ -366,6 +366,8 @@ class TextRecognizer(object):
|
|||
elif self.rec_algorithm == "VisionLAN":
|
||||
norm_img = self.resize_norm_img_vl(img_list[indices[ino]],
|
||||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "ABINet":
|
||||
norm_img = self.resize_norm_img_abinet(
|
||||
img_list[indices[ino]], self.rec_image_shape)
|
||||
|
|
Loading…
Reference in New Issue