fix conflicts
commit
8308f33274
|
@ -0,0 +1,106 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 400
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/seed
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 100
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
eval_filter: True
|
||||
save_res_path: ./output/rec/predicts_seed.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adadelta
|
||||
weight_deacy: 0.0
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [4,5,8]
|
||||
values: [1.0, 0.1, 0.01]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 2.0e-05
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: seed
|
||||
algorithm: ASTER
|
||||
Transform:
|
||||
name: STN_ON
|
||||
tps_inputsize: [32, 64]
|
||||
tps_outputsize: [32, 100]
|
||||
num_control_points: 20
|
||||
tps_margins: [0.05,0.05]
|
||||
stn_activation: none
|
||||
Backbone:
|
||||
name: ResNet_ASTER
|
||||
Head:
|
||||
name: AsterHead # AttentionHead
|
||||
sDim: 512
|
||||
attDim: 512
|
||||
max_len_labels: 100
|
||||
|
||||
Loss:
|
||||
name: AsterLoss
|
||||
|
||||
PostProcess:
|
||||
name: SEEDLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
is_filter: True
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- Fasttext:
|
||||
path: "./cc.en.300.bin"
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- SEEDResize:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 6
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- SEEDResize:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: True
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, SEEDResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
|
|
|
@ -106,6 +106,7 @@ class BaseRecLabelEncode(object):
|
|||
self.max_text_len = max_text_length
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.unknown = "UNKNOWN"
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -174,6 +175,7 @@ class NRTRLabelEncode(BaseRecLabelEncode):
|
|||
super(NRTRLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
|
@ -185,10 +187,12 @@ class NRTRLabelEncode(BaseRecLabelEncode):
|
|||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
||||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class CTCLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -337,6 +341,39 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
return idx
|
||||
|
||||
|
||||
class SEEDLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SEEDLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text)) + 1 # conclue eos
|
||||
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
|
||||
)
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
||||
class SRNLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -416,7 +453,6 @@ class TableLabelEncode(object):
|
|||
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\r\n")
|
||||
list_character.append(character)
|
||||
|
@ -588,7 +624,7 @@ class SARLabelEncode(BaseRecLabelEncode):
|
|||
data['length'] = np.array(len(text))
|
||||
target = [self.start_idx] + text + [self.end_idx]
|
||||
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
||||
|
||||
|
||||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
|
|
|
@ -23,6 +23,7 @@ import sys
|
|||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import fasttext
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
|
@ -83,12 +84,13 @@ class NRTRDecodeImage(object):
|
|||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
@ -133,6 +135,17 @@ class ToCHWImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class Fasttext(object):
|
||||
def __init__(self, path="None", **kwargs):
|
||||
self.fast_model = fasttext.load_model(path)
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
fast_label = self.fast_model[label]
|
||||
data['fast_label'] = fast_label
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
|
|
@ -82,6 +82,18 @@ class RecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class SEEDResize(object):
|
||||
def __init__(self, image_shape, infer_mode=False, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_no_padding_img(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
@ -109,7 +121,8 @@ class SARRecResizeImg(object):
|
|||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio)
|
||||
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
|
||||
img, self.image_shape, self.width_downsample_ratio)
|
||||
data['image'] = norm_img
|
||||
data['resized_shape'] = resize_shape
|
||||
data['pad_shape'] = pad_shape
|
||||
|
@ -175,6 +188,17 @@ def resize_norm_img(img, image_shape):
|
|||
return padding_im
|
||||
|
||||
|
||||
def resize_no_padding_img(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
|
||||
def resize_norm_img_chinese(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
|
|
|
@ -42,10 +42,14 @@ from .combined_loss import CombinedLoss
|
|||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
from .rec_aster_loss import AsterLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss'
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss',
|
||||
'SARLoss', 'AsterLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class CosineEmbeddingLoss(nn.Layer):
|
||||
def __init__(self, margin=0.):
|
||||
super(CosineEmbeddingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.epsilon = 1e-12
|
||||
|
||||
def forward(self, x1, x2, target):
|
||||
similarity = paddle.fluid.layers.reduce_sum(
|
||||
x1 * x2, dim=-1) / (paddle.norm(
|
||||
x1, axis=-1) * paddle.norm(
|
||||
x2, axis=-1) + self.epsilon)
|
||||
one_list = paddle.full_like(target, fill_value=1)
|
||||
out = paddle.fluid.layers.reduce_mean(
|
||||
paddle.where(
|
||||
paddle.equal(target, one_list), 1. - similarity,
|
||||
paddle.maximum(
|
||||
paddle.zeros_like(similarity), similarity - self.margin)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AsterLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
weight=None,
|
||||
size_average=True,
|
||||
ignore_index=-100,
|
||||
sequence_normalize=False,
|
||||
sample_normalize=True,
|
||||
**kwargs):
|
||||
super(AsterLoss, self).__init__()
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
self.ignore_index = ignore_index
|
||||
self.sequence_normalize = sequence_normalize
|
||||
self.sample_normalize = sample_normalize
|
||||
self.loss_sem = CosineEmbeddingLoss()
|
||||
self.is_cosin_loss = True
|
||||
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
sem_target = batch[3].astype('float32')
|
||||
embedding_vectors = predicts['embedding_vectors']
|
||||
rec_pred = predicts['rec_pred']
|
||||
|
||||
if not self.is_cosin_loss:
|
||||
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
|
||||
else:
|
||||
label_target = paddle.ones([embedding_vectors.shape[0]])
|
||||
sem_loss = paddle.sum(
|
||||
self.loss_sem(embedding_vectors, sem_target, label_target))
|
||||
|
||||
# rec loss
|
||||
batch_size, def_max_length = targets.shape[0], targets.shape[1]
|
||||
|
||||
mask = paddle.zeros([batch_size, def_max_length])
|
||||
for i in range(batch_size):
|
||||
mask[i, :label_lengths[i]] = 1
|
||||
mask = paddle.cast(mask, "float32")
|
||||
max_length = max(label_lengths)
|
||||
assert max_length == rec_pred.shape[1]
|
||||
targets = targets[:, :max_length]
|
||||
mask = mask[:, :max_length]
|
||||
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
|
||||
input = nn.functional.log_softmax(rec_pred, axis=1)
|
||||
targets = paddle.reshape(targets, [-1, 1])
|
||||
mask = paddle.reshape(mask, [-1, 1])
|
||||
output = -paddle.index_sample(input, index=targets) * mask
|
||||
output = paddle.sum(output)
|
||||
if self.sequence_normalize:
|
||||
output = output / paddle.sum(mask)
|
||||
if self.sample_normalize:
|
||||
output = output / batch_size
|
||||
|
||||
loss = output + sem_loss * 0.1
|
||||
return {'loss': loss}
|
|
@ -13,13 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import Levenshtein
|
||||
import string
|
||||
|
||||
|
||||
class RecMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.is_filter = is_filter
|
||||
self.reset()
|
||||
|
||||
def _normalize_text(self, text):
|
||||
text = ''.join(
|
||||
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
||||
return text.lower()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
|
@ -28,6 +35,9 @@ class RecMetric(object):
|
|||
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
||||
pred = pred.replace(" ", "")
|
||||
target = target.replace(" ", "")
|
||||
if self.is_filter:
|
||||
pred = self._normalize_text(pred)
|
||||
target = self._normalize_text(target)
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
if pred == target:
|
||||
|
@ -57,4 +67,3 @@ class RecMetric(object):
|
|||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ def build_backbone(config, model_type):
|
|||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_resnet_31 import ResNet31
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31"
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
@ -38,6 +39,9 @@ def build_backbone(config, model_type):
|
|||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
elif model_type == "seed":
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
support_dict = ["ResNet_ASTER"]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2D(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2D(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
|
||||
# [n_position]
|
||||
positions = paddle.arange(0, n_position)
|
||||
# [feat_dim]
|
||||
dim_range = paddle.arange(0, feat_dim)
|
||||
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
|
||||
# [n_position, feat_dim]
|
||||
angles = paddle.unsqueeze(
|
||||
positions, axis=1) / paddle.unsqueeze(
|
||||
dim_range, axis=0)
|
||||
angles = paddle.cast(angles, "float32")
|
||||
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
|
||||
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
|
||||
return angles
|
||||
|
||||
|
||||
class AsterBlock(nn.Layer):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(AsterBlock, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ASTER(nn.Layer):
|
||||
"""For aster or crnn"""
|
||||
|
||||
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
|
||||
super(ResNet_ASTER, self).__init__()
|
||||
self.with_lstm = with_lstm
|
||||
self.n_group = n_group
|
||||
|
||||
self.layer0 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(32),
|
||||
nn.ReLU())
|
||||
|
||||
self.inplanes = 32
|
||||
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
|
||||
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
|
||||
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
|
||||
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
|
||||
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
|
||||
|
||||
if with_lstm:
|
||||
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
|
||||
self.out_channels = 2 * 256
|
||||
else:
|
||||
self.out_channels = 512
|
||||
|
||||
def _make_layer(self, planes, blocks, stride):
|
||||
downsample = None
|
||||
if stride != [1, 1] or self.inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
|
||||
|
||||
layers = []
|
||||
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(AsterBlock(self.inplanes, planes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.layer0(x)
|
||||
x1 = self.layer1(x0)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
cnn_feat = x5.squeeze(2) # [N, c, w]
|
||||
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
|
||||
if self.with_lstm:
|
||||
rnn_feat, _ = self.rnn(cnn_feat)
|
||||
return rnn_feat
|
||||
else:
|
||||
return cnn_feat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = paddle.randn([3, 3, 32, 100])
|
||||
net = ResNet_ASTER()
|
||||
encoder_feat = net(x)
|
||||
print(encoder_feat.shape)
|
|
@ -28,12 +28,14 @@ def build_head(config):
|
|||
from .rec_srn_head import SRNHead
|
||||
from .rec_nrtr_head import Transformer
|
||||
from .rec_sar_head import SARHead
|
||||
from .rec_aster_head import AsterHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead'
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead', 'SARHead', 'Transformer',
|
||||
'AsterHead', 'SARHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,390 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
class AsterHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
sDim,
|
||||
attDim,
|
||||
max_len_labels,
|
||||
time_step=25,
|
||||
beam_width=5,
|
||||
**kwargs):
|
||||
super(AsterHead, self).__init__()
|
||||
self.num_classes = out_channels
|
||||
self.in_planes = in_channels
|
||||
self.sDim = sDim
|
||||
self.attDim = attDim
|
||||
self.max_len_labels = max_len_labels
|
||||
self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
|
||||
attDim, max_len_labels)
|
||||
self.time_step = time_step
|
||||
self.embeder = Embedding(self.time_step, in_channels)
|
||||
self.beam_width = beam_width
|
||||
self.eos = self.num_classes - 1
|
||||
|
||||
def forward(self, x, targets=None, embed=None):
|
||||
return_dict = {}
|
||||
embedding_vectors = self.embeder(x)
|
||||
|
||||
if self.training:
|
||||
rec_targets, rec_lengths, _ = targets
|
||||
rec_pred = self.decoder([x, rec_targets, rec_lengths],
|
||||
embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
else:
|
||||
rec_pred, rec_pred_scores = self.decoder.beam_search(
|
||||
x, self.beam_width, self.eos, embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['rec_pred_scores'] = rec_pred_scores
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
class Embedding(nn.Layer):
|
||||
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
|
||||
super(Embedding, self).__init__()
|
||||
self.in_timestep = in_timestep
|
||||
self.in_planes = in_planes
|
||||
self.embed_dim = embed_dim
|
||||
self.mid_dim = mid_dim
|
||||
self.eEmbed = nn.Linear(
|
||||
in_timestep * in_planes,
|
||||
self.embed_dim) # Embed encoder output to a word-embedding like
|
||||
|
||||
def forward(self, x):
|
||||
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
|
||||
x = self.eEmbed(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRecognitionHead(nn.Layer):
|
||||
"""
|
||||
input: [b x 16 x 64 x in_planes]
|
||||
output: probability sequence: [b x T x num_classes]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
|
||||
super(AttentionRecognitionHead, self).__init__()
|
||||
self.num_classes = out_channels # this is the output classes. So it includes the <EOS>.
|
||||
self.in_planes = in_channels
|
||||
self.sDim = sDim
|
||||
self.attDim = attDim
|
||||
self.max_len_labels = max_len_labels
|
||||
|
||||
self.decoder = DecoderUnit(
|
||||
sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
|
||||
|
||||
def forward(self, x, embed):
|
||||
x, targets, lengths = x
|
||||
batch_size = paddle.shape(x)[0]
|
||||
# Decoder
|
||||
state = self.decoder.get_initial_state(embed)
|
||||
outputs = []
|
||||
for i in range(max(lengths)):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = targets[:, i - 1]
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
outputs.append(output)
|
||||
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
||||
|
||||
# inference stage.
|
||||
def sample(self, x):
|
||||
x, _, _ = x
|
||||
batch_size = x.size(0)
|
||||
# Decoder
|
||||
state = paddle.zeros([1, batch_size, self.sDim])
|
||||
|
||||
predicted_ids, predicted_scores = [], []
|
||||
for i in range(self.max_len_labels):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = predicted
|
||||
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
output = F.softmax(output, axis=1)
|
||||
score, predicted = output.max(1)
|
||||
predicted_ids.append(predicted.unsqueeze(1))
|
||||
predicted_scores.append(score.unsqueeze(1))
|
||||
predicted_ids = paddle.concat([predicted_ids, 1])
|
||||
predicted_scores = paddle.concat([predicted_scores, 1])
|
||||
# return predicted_ids.squeeze(), predicted_scores.squeeze()
|
||||
return predicted_ids, predicted_scores
|
||||
|
||||
def beam_search(self, x, beam_width, eos, embed):
|
||||
def _inflate(tensor, times, dim):
|
||||
repeat_dims = [1] * tensor.dim()
|
||||
repeat_dims[dim] = times
|
||||
output = paddle.tile(tensor, repeat_dims)
|
||||
return output
|
||||
|
||||
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
|
||||
batch_size, l, d = x.shape
|
||||
# inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC
|
||||
x = paddle.tile(
|
||||
paddle.transpose(
|
||||
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
|
||||
inflated_encoder_feats = paddle.reshape(
|
||||
paddle.transpose(
|
||||
x, perm=[1, 0, 2, 3]), [-1, l, d])
|
||||
|
||||
# Initialize the decoder
|
||||
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
|
||||
|
||||
pos_index = paddle.reshape(
|
||||
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
|
||||
|
||||
# Initialize the scores
|
||||
sequence_scores = paddle.full(
|
||||
shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
|
||||
index = [i * beam_width for i in range(0, batch_size)]
|
||||
sequence_scores[index] = 0.0
|
||||
|
||||
# Initialize the input vector
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size * beam_width], fill_value=self.num_classes)
|
||||
|
||||
# Store decisions for backtracking
|
||||
stored_scores = list()
|
||||
stored_predecessors = list()
|
||||
stored_emitted_symbols = list()
|
||||
|
||||
for i in range(self.max_len_labels):
|
||||
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
|
||||
state = paddle.unsqueeze(state, axis=0)
|
||||
log_softmax_output = paddle.nn.functional.log_softmax(
|
||||
output, axis=1)
|
||||
|
||||
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
|
||||
sequence_scores += log_softmax_output
|
||||
scores, candidates = paddle.topk(
|
||||
paddle.reshape(sequence_scores, [batch_size, -1]),
|
||||
beam_width,
|
||||
axis=1)
|
||||
|
||||
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
|
||||
y_prev = paddle.reshape(
|
||||
candidates % self.num_classes, shape=[batch_size * beam_width])
|
||||
sequence_scores = paddle.reshape(
|
||||
scores, shape=[batch_size * beam_width, 1])
|
||||
|
||||
# Update fields for next timestep
|
||||
pos_index = paddle.expand_as(pos_index, candidates)
|
||||
predecessors = paddle.cast(
|
||||
candidates / self.num_classes + pos_index, dtype='int64')
|
||||
predecessors = paddle.reshape(
|
||||
predecessors, shape=[batch_size * beam_width, 1])
|
||||
state = paddle.index_select(
|
||||
state, index=predecessors.squeeze(), axis=1)
|
||||
|
||||
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
|
||||
stored_scores.append(sequence_scores.clone())
|
||||
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
|
||||
eos_prev = paddle.full_like(y_prev, fill_value=eos)
|
||||
mask = eos_prev == y_prev
|
||||
mask = paddle.nonzero(mask)
|
||||
if mask.dim() > 0:
|
||||
sequence_scores = sequence_scores.numpy()
|
||||
mask = mask.numpy()
|
||||
sequence_scores[mask] = -float('inf')
|
||||
sequence_scores = paddle.to_tensor(sequence_scores)
|
||||
|
||||
# Cache results for backtracking
|
||||
stored_predecessors.append(predecessors)
|
||||
y_prev = paddle.squeeze(y_prev)
|
||||
stored_emitted_symbols.append(y_prev)
|
||||
|
||||
# Do backtracking to return the optimal values
|
||||
#====== backtrak ======#
|
||||
# Initialize return variables given different types
|
||||
p = list()
|
||||
l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
|
||||
] # Placeholder for lengths of top-k sequences
|
||||
|
||||
# the last step output of the beams are not sorted
|
||||
# thus they are sorted here
|
||||
sorted_score, sorted_idx = paddle.topk(
|
||||
paddle.reshape(
|
||||
stored_scores[-1], shape=[batch_size, beam_width]),
|
||||
beam_width)
|
||||
|
||||
# initialize the sequence scores with the sorted last step beam scores
|
||||
s = sorted_score.clone()
|
||||
|
||||
batch_eos_found = [0] * batch_size # the number of EOS found
|
||||
# in the backward loop below for each batch
|
||||
t = self.max_len_labels - 1
|
||||
# initialize the back pointer with the sorted order of the last step beams.
|
||||
# add pos_index for indexing variable with b*k as the first dimension.
|
||||
t_predecessors = paddle.reshape(
|
||||
sorted_idx + pos_index.expand_as(sorted_idx),
|
||||
shape=[batch_size * beam_width])
|
||||
while t >= 0:
|
||||
# Re-order the variables with the back pointer
|
||||
current_symbol = paddle.index_select(
|
||||
stored_emitted_symbols[t], index=t_predecessors, axis=0)
|
||||
t_predecessors = paddle.index_select(
|
||||
stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
|
||||
eos_indices = stored_emitted_symbols[t] == eos
|
||||
eos_indices = paddle.nonzero(eos_indices)
|
||||
|
||||
if eos_indices.dim() > 0:
|
||||
for i in range(eos_indices.shape[0] - 1, -1, -1):
|
||||
# Indices of the EOS symbol for both variables
|
||||
# with b*k as the first dimension, and b, k for
|
||||
# the first two dimensions
|
||||
idx = eos_indices[i]
|
||||
b_idx = int(idx[0] / beam_width)
|
||||
# The indices of the replacing position
|
||||
# according to the replacement strategy noted above
|
||||
res_k_idx = beam_width - (batch_eos_found[b_idx] %
|
||||
beam_width) - 1
|
||||
batch_eos_found[b_idx] += 1
|
||||
res_idx = b_idx * beam_width + res_k_idx
|
||||
|
||||
# Replace the old information in return variables
|
||||
# with the new ended sequence information
|
||||
t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
|
||||
current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
|
||||
s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
|
||||
l[b_idx][res_k_idx] = t + 1
|
||||
|
||||
# record the back tracked results
|
||||
p.append(current_symbol)
|
||||
t -= 1
|
||||
|
||||
# Sort and re-order again as the added ended sequences may change
|
||||
# the order (very unlikely)
|
||||
s, re_sorted_idx = s.topk(beam_width)
|
||||
for b_idx in range(batch_size):
|
||||
l[b_idx] = [
|
||||
l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
|
||||
]
|
||||
|
||||
re_sorted_idx = paddle.reshape(
|
||||
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
|
||||
[batch_size * beam_width])
|
||||
|
||||
# Reverse the sequences and re-order at the same time
|
||||
# It is reversed because the backtracking happens in reverse time order
|
||||
p = [
|
||||
paddle.reshape(
|
||||
paddle.index_select(step, re_sorted_idx, 0),
|
||||
shape=[batch_size, beam_width, -1]) for step in reversed(p)
|
||||
]
|
||||
p = paddle.concat(p, -1)[:, 0, :]
|
||||
return p, paddle.ones_like(p)
|
||||
|
||||
|
||||
class AttentionUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, attDim):
|
||||
super(AttentionUnit, self).__init__()
|
||||
|
||||
self.sDim = sDim
|
||||
self.xDim = xDim
|
||||
self.attDim = attDim
|
||||
|
||||
self.sEmbed = nn.Linear(sDim, attDim)
|
||||
self.xEmbed = nn.Linear(xDim, attDim)
|
||||
self.wEmbed = nn.Linear(attDim, 1)
|
||||
|
||||
def forward(self, x, sPrev):
|
||||
batch_size, T, _ = x.shape # [b x T x xDim]
|
||||
x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
|
||||
xProj = self.xEmbed(x) # [(b x T) x attDim]
|
||||
xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
|
||||
|
||||
sPrev = sPrev.squeeze(0)
|
||||
sProj = self.sEmbed(sPrev) # [b x attDim]
|
||||
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
|
||||
sProj = paddle.expand(sProj,
|
||||
[batch_size, T, self.attDim]) # [b x T x attDim]
|
||||
|
||||
sumTanh = paddle.tanh(sProj + xProj)
|
||||
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
|
||||
|
||||
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
|
||||
vProj = paddle.reshape(vProj, [batch_size, T])
|
||||
alpha = F.softmax(
|
||||
vProj, axis=1) # attention weights for each sample in the minibatch
|
||||
return alpha
|
||||
|
||||
|
||||
class DecoderUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, yDim, attDim):
|
||||
super(DecoderUnit, self).__init__()
|
||||
self.sDim = sDim
|
||||
self.xDim = xDim
|
||||
self.yDim = yDim
|
||||
self.attDim = attDim
|
||||
self.emdDim = attDim
|
||||
|
||||
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
|
||||
self.tgt_embedding = nn.Embedding(
|
||||
yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
|
||||
std=0.01)) # the last is used for <BOS>
|
||||
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
|
||||
self.fc = nn.Linear(
|
||||
sDim,
|
||||
yDim,
|
||||
weight_attr=nn.initializer.Normal(std=0.01),
|
||||
bias_attr=nn.initializer.Constant(value=0))
|
||||
self.embed_fc = nn.Linear(300, self.sDim)
|
||||
|
||||
def get_initial_state(self, embed, tile_times=1):
|
||||
assert embed.shape[1] == 300
|
||||
state = self.embed_fc(embed) # N * sDim
|
||||
if tile_times != 1:
|
||||
state = state.unsqueeze(1)
|
||||
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
||||
state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
|
||||
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
||||
state = paddle.reshape(trans_state, shape=[-1, self.sDim])
|
||||
state = state.unsqueeze(0) # 1 * N * sDim
|
||||
return state
|
||||
|
||||
def forward(self, x, sPrev, yPrev):
|
||||
# x: feature sequence from the image decoder.
|
||||
batch_size, T, _ = x.shape
|
||||
alpha = self.attention_unit(x, sPrev)
|
||||
context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
|
||||
yPrev = paddle.cast(yPrev, dtype="int64")
|
||||
yProj = self.tgt_embedding(yPrev)
|
||||
|
||||
concat_context = paddle.concat([yProj, context], 1)
|
||||
concat_context = paddle.squeeze(concat_context, 1)
|
||||
sPrev = paddle.squeeze(sPrev, 0)
|
||||
output, state = self.gru(concat_context, sPrev)
|
||||
output = paddle.squeeze(output, axis=1)
|
||||
output = self.fc(output)
|
||||
return output, state
|
|
@ -17,8 +17,9 @@ __all__ = ['build_transform']
|
|||
|
||||
def build_transform(config):
|
||||
from .tps import TPS
|
||||
from .tps import STN_ON
|
||||
|
||||
support_dict = ['TPS']
|
||||
support_dict = ['TPS', 'STN_ON']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
def conv3x3_block(in_channels, out_channels, stride=1):
|
||||
n = 3 * 3 * out_channels
|
||||
w = math.sqrt(2. / n)
|
||||
conv_layer = nn.Conv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
weight_attr=nn.initializer.Normal(
|
||||
mean=0.0, std=w),
|
||||
bias_attr=nn.initializer.Constant(0))
|
||||
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
|
||||
return block
|
||||
|
||||
|
||||
class STN(nn.Layer):
|
||||
def __init__(self, in_channels, num_ctrlpoints, activation='none'):
|
||||
super(STN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_ctrlpoints = num_ctrlpoints
|
||||
self.activation = activation
|
||||
self.stn_convnet = nn.Sequential(
|
||||
conv3x3_block(in_channels, 32), #32x64
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(32, 64), #16x32
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(64, 128), # 8*16
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(128, 256), # 4*8
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(256, 256), # 2*4,
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(256, 256)) # 1*2
|
||||
self.stn_fc1 = nn.Sequential(
|
||||
nn.Linear(
|
||||
2 * 256,
|
||||
512,
|
||||
weight_attr=nn.initializer.Normal(0, 0.001),
|
||||
bias_attr=nn.initializer.Constant(0)),
|
||||
nn.BatchNorm1D(512),
|
||||
nn.ReLU())
|
||||
fc2_bias = self.init_stn()
|
||||
self.stn_fc2 = nn.Linear(
|
||||
512,
|
||||
num_ctrlpoints * 2,
|
||||
weight_attr=nn.initializer.Constant(0.0),
|
||||
bias_attr=nn.initializer.Assign(fc2_bias))
|
||||
|
||||
def init_stn(self):
|
||||
margin = 0.01
|
||||
sampling_num_per_side = int(self.num_ctrlpoints / 2)
|
||||
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
|
||||
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
||||
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
ctrl_points = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
|
||||
if self.activation == 'none':
|
||||
pass
|
||||
elif self.activation == 'sigmoid':
|
||||
ctrl_points = -np.log(1. / ctrl_points - 1.)
|
||||
ctrl_points = paddle.to_tensor(ctrl_points)
|
||||
fc2_bias = paddle.reshape(
|
||||
ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
|
||||
return fc2_bias
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stn_convnet(x)
|
||||
batch_size, _, h, w = x.shape
|
||||
x = paddle.reshape(x, shape=(batch_size, -1))
|
||||
img_feat = self.stn_fc1(x)
|
||||
x = self.stn_fc2(0.1 * img_feat)
|
||||
if self.activation == 'sigmoid':
|
||||
x = F.sigmoid(x)
|
||||
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
|
||||
return img_feat, x
|
|
@ -22,6 +22,9 @@ from paddle import nn, ParamAttr
|
|||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
from .stn import STN
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
|
@ -231,7 +234,8 @@ class GridGenerator(nn.Layer):
|
|||
""" Return inv_delta_C which is needed to calculate T """
|
||||
F = self.F
|
||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = paddle.norm(
|
||||
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||
delta_C = paddle.concat( # F+3 x F+3
|
||||
[
|
||||
|
@ -301,3 +305,25 @@ class TPS(nn.Layer):
|
|||
[-1, image.shape[2], image.shape[3], 2])
|
||||
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
|
||||
return batch_I_r
|
||||
|
||||
|
||||
class STN_ON(nn.Layer):
|
||||
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
|
||||
num_control_points, tps_margins, stn_activation):
|
||||
super(STN_ON, self).__init__()
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
self.stn_head = STN(in_channels=in_channels,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation=stn_activation)
|
||||
self.tps_inputsize = tps_inputsize
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, image):
|
||||
stn_input = paddle.nn.functional.interpolate(
|
||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
|
||||
def grid_sample(input, grid, canvas=None):
|
||||
input.stop_gradient = False
|
||||
output = F.grid_sample(input, grid)
|
||||
if canvas is None:
|
||||
return output
|
||||
else:
|
||||
input_mask = paddle.ones(shape=input.shape)
|
||||
output_mask = F.grid_sample(input_mask, grid)
|
||||
padded_output = output * output_mask + canvas * (1 - output_mask)
|
||||
return padded_output
|
||||
|
||||
|
||||
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
||||
def compute_partial_repr(input_points, control_points):
|
||||
N = input_points.shape[0]
|
||||
M = control_points.shape[0]
|
||||
pairwise_diff = paddle.reshape(
|
||||
input_points, shape=[N, 1, 2]) - paddle.reshape(
|
||||
control_points, shape=[1, M, 2])
|
||||
# original implementation, very slow
|
||||
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
||||
pairwise_diff_square = pairwise_diff * pairwise_diff
|
||||
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
|
||||
1]
|
||||
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
|
||||
# fix numerical error for 0 * log(0), substitute all nan with 0
|
||||
mask = repr_matrix != repr_matrix
|
||||
repr_matrix[mask] = 0
|
||||
return repr_matrix
|
||||
|
||||
|
||||
# output_ctrl_pts are specified, according to our task.
|
||||
def build_output_control_points(num_control_points, margins):
|
||||
margin_x, margin_y = margins
|
||||
num_ctrl_pts_per_side = num_control_points // 2
|
||||
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
||||
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
||||
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
|
||||
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
|
||||
output_ctrl_pts_arr = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
||||
output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
|
||||
return output_ctrl_pts
|
||||
|
||||
|
||||
class TPSSpatialTransformer(nn.Layer):
|
||||
def __init__(self,
|
||||
output_image_size=None,
|
||||
num_control_points=None,
|
||||
margins=None):
|
||||
super(TPSSpatialTransformer, self).__init__()
|
||||
self.output_image_size = output_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.margins = margins
|
||||
|
||||
self.target_height, self.target_width = output_image_size
|
||||
target_control_points = build_output_control_points(num_control_points,
|
||||
margins)
|
||||
N = num_control_points
|
||||
# N = N - 4
|
||||
|
||||
# create padded kernel matrix
|
||||
forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
|
||||
target_control_partial_repr = compute_partial_repr(
|
||||
target_control_points, target_control_points)
|
||||
target_control_partial_repr = paddle.cast(target_control_partial_repr,
|
||||
forward_kernel.dtype)
|
||||
forward_kernel[:N, :N] = target_control_partial_repr
|
||||
forward_kernel[:N, -3] = 1
|
||||
forward_kernel[-3, :N] = 1
|
||||
target_control_points = paddle.cast(target_control_points,
|
||||
forward_kernel.dtype)
|
||||
forward_kernel[:N, -2:] = target_control_points
|
||||
forward_kernel[-2:, :N] = paddle.transpose(
|
||||
target_control_points, perm=[1, 0])
|
||||
# compute inverse matrix
|
||||
inverse_kernel = paddle.inverse(forward_kernel)
|
||||
|
||||
# create target cordinate matrix
|
||||
HW = self.target_height * self.target_width
|
||||
target_coordinate = list(
|
||||
itertools.product(
|
||||
range(self.target_height), range(self.target_width)))
|
||||
target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
|
||||
Y, X = paddle.split(
|
||||
target_coordinate, target_coordinate.shape[1], axis=1)
|
||||
#Y, X = target_coordinate.split(1, dim = 1)
|
||||
Y = Y / (self.target_height - 1)
|
||||
X = X / (self.target_width - 1)
|
||||
target_coordinate = paddle.concat(
|
||||
[X, Y], axis=1) # convert from (y, x) to (x, y)
|
||||
target_coordinate_partial_repr = compute_partial_repr(
|
||||
target_coordinate, target_control_points)
|
||||
target_coordinate_repr = paddle.concat(
|
||||
[
|
||||
target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
|
||||
target_coordinate
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# register precomputed matrices
|
||||
self.inverse_kernel = inverse_kernel
|
||||
self.padding_matrix = paddle.zeros(shape=[3, 2])
|
||||
self.target_coordinate_repr = target_coordinate_repr
|
||||
self.target_control_points = target_control_points
|
||||
|
||||
def forward(self, input, source_control_points):
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.shape[1] == self.num_control_points
|
||||
assert source_control_points.shape[2] == 2
|
||||
#batch_size = source_control_points.shape[0]
|
||||
batch_size = paddle.shape(source_control_points)[0]
|
||||
|
||||
self.padding_matrix = paddle.expand(
|
||||
self.padding_matrix, shape=[batch_size, 3, 2])
|
||||
Y = paddle.concat([source_control_points, self.padding_matrix], 1)
|
||||
mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
|
||||
source_coordinate = paddle.matmul(self.target_coordinate_repr,
|
||||
mapping_matrix)
|
||||
|
||||
grid = paddle.reshape(
|
||||
source_coordinate,
|
||||
shape=[-1, self.target_height, self.target_width, 2])
|
||||
grid = paddle.clip(grid, 0,
|
||||
1) # the source_control_points may be out of [0, 1].
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
|
@ -127,3 +127,34 @@ class RMSProp(object):
|
|||
grad_clip=self.grad_clip,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
||||
|
||||
class Adadelta(object):
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
epsilon=1e-08,
|
||||
rho=0.95,
|
||||
parameter_list=None,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
self.learning_rate = learning_rate
|
||||
self.epsilon = epsilon
|
||||
self.rho = rho
|
||||
self.parameter_list = parameter_list
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
|
||||
def __call__(self, parameters):
|
||||
opt = optim.Adadelta(
|
||||
learning_rate=self.learning_rate,
|
||||
epsilon=self.epsilon,
|
||||
rho=self.rho,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip=self.grad_clip,
|
||||
name=self.name,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
|
|
@ -24,17 +24,19 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
|
||||
TableLabelDecode, SARLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode'
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -303,6 +303,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
return idx
|
||||
|
||||
|
||||
class SEEDLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SEEDLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
end_idx = self.get_beg_end_flag_idx("eos")
|
||||
return [end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "sos":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "eos":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
||||
return idx
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
[end_idx] = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
"""
|
||||
text = self.decode(text)
|
||||
if label is None:
|
||||
return text
|
||||
else:
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
preds_idx = preds["rec_pred"]
|
||||
if isinstance(preds_idx, paddle.Tensor):
|
||||
preds_idx = preds_idx.numpy()
|
||||
if "rec_pred_scores" in preds:
|
||||
preds_idx = preds["rec_pred"]
|
||||
preds_prob = preds["rec_pred_scores"]
|
||||
else:
|
||||
preds_idx = preds["rec_pred"].argmax(axis=2)
|
||||
preds_prob = preds["rec_pred"].max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
|
||||
class SRNLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -188,10 +188,12 @@ def train(config,
|
|||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||
use_sar = config['Architecture']['algorithm'] == 'SAR'
|
||||
use_seed = config['Architecture']['algorithm'] == 'SEED'
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
model_type = None
|
||||
algorithm = config['Architecture']['algorithm']
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
@ -215,7 +217,7 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table' or use_nrtr or use_sar:
|
||||
if use_srn or model_type == 'table' or use_nrtr or use_sar or use_seed:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -402,7 +404,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR'
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'ASTER'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue