428 lines
14 KiB
Python
428 lines
14 KiB
Python
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This code is refer from:
|
|
https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/loss/transformer_english_decomposition.py
|
|
"""
|
|
import copy
|
|
import math
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
|
|
def subsequent_mask(size):
|
|
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
|
Unmasked positions are filled with float(0.0).
|
|
"""
|
|
mask = paddle.ones([1, size, size], dtype='float32')
|
|
mask_inf = paddle.triu(
|
|
paddle.full(
|
|
shape=[1, size, size], dtype='float32', fill_value='-inf'),
|
|
diagonal=1)
|
|
mask = mask + mask_inf
|
|
padding_mask = paddle.equal(mask, paddle.to_tensor(1, dtype=mask.dtype))
|
|
return padding_mask
|
|
|
|
|
|
def clones(module, N):
|
|
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
|
|
|
|
|
|
def masked_fill(x, mask, value):
|
|
y = paddle.full(x.shape, value, x.dtype)
|
|
return paddle.where(mask, y, x)
|
|
|
|
|
|
def attention(query, key, value, mask=None, dropout=None, attention_map=None):
|
|
d_k = query.shape[-1]
|
|
scores = paddle.matmul(query,
|
|
paddle.transpose(key, [0, 1, 3, 2])) / math.sqrt(d_k)
|
|
|
|
if mask is not None:
|
|
scores = masked_fill(scores, mask == 0, float('-inf'))
|
|
else:
|
|
pass
|
|
|
|
p_attn = F.softmax(scores, axis=-1)
|
|
|
|
if dropout is not None:
|
|
p_attn = dropout(p_attn)
|
|
return paddle.matmul(p_attn, value), p_attn
|
|
|
|
|
|
class MultiHeadedAttention(nn.Layer):
|
|
def __init__(self, h, d_model, dropout=0.1, compress_attention=False):
|
|
super(MultiHeadedAttention, self).__init__()
|
|
assert d_model % h == 0
|
|
self.d_k = d_model // h
|
|
self.h = h
|
|
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
|
self.attn = None
|
|
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
|
|
self.compress_attention = compress_attention
|
|
self.compress_attention_linear = nn.Linear(h, 1)
|
|
|
|
def forward(self, query, key, value, mask=None, attention_map=None):
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(1)
|
|
nbatches = query.shape[0]
|
|
|
|
query, key, value = \
|
|
[paddle.transpose(l(x).reshape([nbatches, -1, self.h, self.d_k]), [0,2,1,3])
|
|
for l, x in zip(self.linears, (query, key, value))]
|
|
|
|
x, attention_map = attention(
|
|
query,
|
|
key,
|
|
value,
|
|
mask=mask,
|
|
dropout=self.dropout,
|
|
attention_map=attention_map)
|
|
|
|
x = paddle.reshape(
|
|
paddle.transpose(x, [0, 2, 1, 3]),
|
|
[nbatches, -1, self.h * self.d_k])
|
|
|
|
return self.linears[-1](x), attention_map
|
|
|
|
|
|
class ResNet(nn.Layer):
|
|
def __init__(self, num_in, block, layers):
|
|
super(ResNet, self).__init__()
|
|
|
|
self.conv1 = nn.Conv2D(num_in, 64, kernel_size=3, stride=1, padding=1)
|
|
self.bn1 = nn.BatchNorm2D(64, use_global_stats=True)
|
|
self.relu1 = nn.ReLU()
|
|
self.pool = nn.MaxPool2D((2, 2), (2, 2))
|
|
|
|
self.conv2 = nn.Conv2D(64, 128, kernel_size=3, stride=1, padding=1)
|
|
self.bn2 = nn.BatchNorm2D(128, use_global_stats=True)
|
|
self.relu2 = nn.ReLU()
|
|
|
|
self.layer1_pool = nn.MaxPool2D((2, 2), (2, 2))
|
|
self.layer1 = self._make_layer(block, 128, 256, layers[0])
|
|
self.layer1_conv = nn.Conv2D(256, 256, 3, 1, 1)
|
|
self.layer1_bn = nn.BatchNorm2D(256, use_global_stats=True)
|
|
self.layer1_relu = nn.ReLU()
|
|
|
|
self.layer2_pool = nn.MaxPool2D((2, 2), (2, 2))
|
|
self.layer2 = self._make_layer(block, 256, 256, layers[1])
|
|
self.layer2_conv = nn.Conv2D(256, 256, 3, 1, 1)
|
|
self.layer2_bn = nn.BatchNorm2D(256, use_global_stats=True)
|
|
self.layer2_relu = nn.ReLU()
|
|
|
|
self.layer3_pool = nn.MaxPool2D((2, 2), (2, 2))
|
|
self.layer3 = self._make_layer(block, 256, 512, layers[2])
|
|
self.layer3_conv = nn.Conv2D(512, 512, 3, 1, 1)
|
|
self.layer3_bn = nn.BatchNorm2D(512, use_global_stats=True)
|
|
self.layer3_relu = nn.ReLU()
|
|
|
|
self.layer4_pool = nn.MaxPool2D((2, 2), (2, 2))
|
|
self.layer4 = self._make_layer(block, 512, 512, layers[3])
|
|
self.layer4_conv2 = nn.Conv2D(512, 1024, 3, 1, 1)
|
|
self.layer4_conv2_bn = nn.BatchNorm2D(1024, use_global_stats=True)
|
|
self.layer4_conv2_relu = nn.ReLU()
|
|
|
|
def _make_layer(self, block, inplanes, planes, blocks):
|
|
|
|
if inplanes != planes:
|
|
downsample = nn.Sequential(
|
|
nn.Conv2D(inplanes, planes, 3, 1, 1),
|
|
nn.BatchNorm2D(
|
|
planes, use_global_stats=True), )
|
|
else:
|
|
downsample = None
|
|
layers = []
|
|
layers.append(block(inplanes, planes, downsample))
|
|
for i in range(1, blocks):
|
|
layers.append(block(planes, planes, downsample=None))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu1(x)
|
|
x = self.pool(x)
|
|
|
|
x = self.conv2(x)
|
|
x = self.bn2(x)
|
|
x = self.relu2(x)
|
|
|
|
x = self.layer1_pool(x)
|
|
x = self.layer1(x)
|
|
x = self.layer1_conv(x)
|
|
x = self.layer1_bn(x)
|
|
x = self.layer1_relu(x)
|
|
|
|
x = self.layer2(x)
|
|
x = self.layer2_conv(x)
|
|
x = self.layer2_bn(x)
|
|
x = self.layer2_relu(x)
|
|
|
|
x = self.layer3(x)
|
|
x = self.layer3_conv(x)
|
|
x = self.layer3_bn(x)
|
|
x = self.layer3_relu(x)
|
|
|
|
x = self.layer4(x)
|
|
x = self.layer4_conv2(x)
|
|
x = self.layer4_conv2_bn(x)
|
|
x = self.layer4_conv2_relu(x)
|
|
|
|
return x
|
|
|
|
|
|
class Bottleneck(nn.Layer):
|
|
def __init__(self, input_dim):
|
|
super(Bottleneck, self).__init__()
|
|
self.conv1 = nn.Conv2D(input_dim, input_dim, 1)
|
|
self.bn1 = nn.BatchNorm2D(input_dim, use_global_stats=True)
|
|
self.relu = nn.ReLU()
|
|
|
|
self.conv2 = nn.Conv2D(input_dim, input_dim, 3, 1, 1)
|
|
self.bn2 = nn.BatchNorm2D(input_dim, use_global_stats=True)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class PositionalEncoding(nn.Layer):
|
|
"Implement the PE function."
|
|
|
|
def __init__(self, dropout, dim, max_len=5000):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.dropout = nn.Dropout(p=dropout, mode="downscale_in_infer")
|
|
|
|
pe = paddle.zeros([max_len, dim])
|
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
|
div_term = paddle.exp(
|
|
paddle.arange(0, dim, 2).astype('float32') *
|
|
(-math.log(10000.0) / dim))
|
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
|
pe = paddle.unsqueeze(pe, 0)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x):
|
|
x = x + self.pe[:, :x.shape[1]]
|
|
return self.dropout(x)
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Layer):
|
|
"Implements FFN equation."
|
|
|
|
def __init__(self, d_model, d_ff, dropout=0.1):
|
|
super(PositionwiseFeedForward, self).__init__()
|
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
self.w_2 = nn.Linear(d_ff, d_model)
|
|
self.dropout = nn.Dropout(dropout, mode="downscale_in_infer")
|
|
|
|
def forward(self, x):
|
|
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
|
|
|
|
|
class Generator(nn.Layer):
|
|
"Define standard linear + softmax generation step."
|
|
|
|
def __init__(self, d_model, vocab):
|
|
super(Generator, self).__init__()
|
|
self.proj = nn.Linear(d_model, vocab)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
out = self.proj(x)
|
|
return out
|
|
|
|
|
|
class Embeddings(nn.Layer):
|
|
def __init__(self, d_model, vocab):
|
|
super(Embeddings, self).__init__()
|
|
self.lut = nn.Embedding(vocab, d_model)
|
|
self.d_model = d_model
|
|
|
|
def forward(self, x):
|
|
embed = self.lut(x) * math.sqrt(self.d_model)
|
|
return embed
|
|
|
|
|
|
class LayerNorm(nn.Layer):
|
|
"Construct a layernorm module (See citation for details)."
|
|
|
|
def __init__(self, features, eps=1e-6):
|
|
super(LayerNorm, self).__init__()
|
|
self.a_2 = self.create_parameter(
|
|
shape=[features],
|
|
default_initializer=paddle.nn.initializer.Constant(1.0))
|
|
self.b_2 = self.create_parameter(
|
|
shape=[features],
|
|
default_initializer=paddle.nn.initializer.Constant(0.0))
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
mean = x.mean(-1, keepdim=True)
|
|
std = x.std(-1, keepdim=True)
|
|
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
|
|
|
|
|
class Decoder(nn.Layer):
|
|
def __init__(self):
|
|
super(Decoder, self).__init__()
|
|
|
|
self.mask_multihead = MultiHeadedAttention(
|
|
h=16, d_model=1024, dropout=0.1)
|
|
self.mul_layernorm1 = LayerNorm(1024)
|
|
|
|
self.multihead = MultiHeadedAttention(h=16, d_model=1024, dropout=0.1)
|
|
self.mul_layernorm2 = LayerNorm(1024)
|
|
|
|
self.pff = PositionwiseFeedForward(1024, 2048)
|
|
self.mul_layernorm3 = LayerNorm(1024)
|
|
|
|
def forward(self, text, conv_feature, attention_map=None):
|
|
text_max_length = text.shape[1]
|
|
mask = subsequent_mask(text_max_length)
|
|
result = text
|
|
result = self.mul_layernorm1(result + self.mask_multihead(
|
|
text, text, text, mask=mask)[0])
|
|
b, c, h, w = conv_feature.shape
|
|
conv_feature = paddle.transpose(
|
|
conv_feature.reshape([b, c, h * w]), [0, 2, 1])
|
|
word_image_align, attention_map = self.multihead(
|
|
result,
|
|
conv_feature,
|
|
conv_feature,
|
|
mask=None,
|
|
attention_map=attention_map)
|
|
result = self.mul_layernorm2(result + word_image_align)
|
|
result = self.mul_layernorm3(result + self.pff(result))
|
|
|
|
return result, attention_map
|
|
|
|
|
|
class BasicBlock(nn.Layer):
|
|
def __init__(self, inplanes, planes, downsample):
|
|
super(BasicBlock, self).__init__()
|
|
self.conv1 = nn.Conv2D(
|
|
inplanes, planes, kernel_size=3, stride=1, padding=1)
|
|
self.bn1 = nn.BatchNorm2D(planes, use_global_stats=True)
|
|
self.relu = nn.ReLU()
|
|
self.conv2 = nn.Conv2D(
|
|
planes, planes, kernel_size=3, stride=1, padding=1)
|
|
self.bn2 = nn.BatchNorm2D(planes, use_global_stats=True)
|
|
self.downsample = downsample
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample != None:
|
|
residual = self.downsample(residual)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
class Encoder(nn.Layer):
|
|
def __init__(self):
|
|
super(Encoder, self).__init__()
|
|
self.cnn = ResNet(num_in=1, block=BasicBlock, layers=[1, 2, 5, 3])
|
|
|
|
def forward(self, input):
|
|
conv_result = self.cnn(input)
|
|
return conv_result
|
|
|
|
|
|
class Transformer(nn.Layer):
|
|
def __init__(self, in_channels=1, alphabet='0123456789'):
|
|
super(Transformer, self).__init__()
|
|
self.alphabet = alphabet
|
|
word_n_class = self.get_alphabet_len()
|
|
self.embedding_word_with_upperword = Embeddings(512, word_n_class)
|
|
self.pe = PositionalEncoding(dim=512, dropout=0.1, max_len=5000)
|
|
|
|
self.encoder = Encoder()
|
|
self.decoder = Decoder()
|
|
self.generator_word_with_upperword = Generator(1024, word_n_class)
|
|
|
|
for p in self.parameters():
|
|
if p.dim() > 1:
|
|
nn.initializer.XavierNormal(p)
|
|
|
|
def get_alphabet_len(self):
|
|
return len(self.alphabet)
|
|
|
|
def forward(self, image, text_length, text_input, attention_map=None):
|
|
if image.shape[1] == 3:
|
|
R = image[:, 0:1, :, :]
|
|
G = image[:, 1:2, :, :]
|
|
B = image[:, 2:3, :, :]
|
|
image = 0.299 * R + 0.587 * G + 0.114 * B
|
|
|
|
conv_feature = self.encoder(image) # batch, 1024, 8, 32
|
|
max_length = max(text_length)
|
|
text_input = text_input[:, :max_length]
|
|
|
|
text_embedding = self.embedding_word_with_upperword(
|
|
text_input) # batch, text_max_length, 512
|
|
postion_embedding = self.pe(
|
|
paddle.zeros(text_embedding.shape)) # batch, text_max_length, 512
|
|
text_input_with_pe = paddle.concat([text_embedding, postion_embedding],
|
|
2) # batch, text_max_length, 1024
|
|
batch, seq_len, _ = text_input_with_pe.shape
|
|
|
|
text_input_with_pe, word_attention_map = self.decoder(
|
|
text_input_with_pe, conv_feature)
|
|
|
|
word_decoder_result = self.generator_word_with_upperword(
|
|
text_input_with_pe)
|
|
|
|
if self.training:
|
|
total_length = paddle.sum(text_length)
|
|
probs_res = paddle.zeros([total_length, self.get_alphabet_len()])
|
|
start = 0
|
|
|
|
for index, length in enumerate(text_length):
|
|
length = int(length.numpy())
|
|
probs_res[start:start + length, :] = word_decoder_result[
|
|
index, 0:0 + length, :]
|
|
|
|
start = start + length
|
|
|
|
return probs_res, word_attention_map, None
|
|
else:
|
|
return word_decoder_result
|