411 lines
14 KiB
Python
411 lines
14 KiB
Python
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This code is refer from:
|
|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
|
|
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
import paddle
|
|
from paddle import ParamAttr
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
|
|
|
|
class SAREncoder(nn.Layer):
|
|
"""
|
|
Args:
|
|
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
|
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
|
|
enc_gru (bool): If True, use GRU, else LSTM in encoder.
|
|
d_model (int): Dim of channels from backbone.
|
|
d_enc (int): Dim of encoder RNN layer.
|
|
mask (bool): If True, mask padding in RNN sequence.
|
|
"""
|
|
|
|
def __init__(self,
|
|
enc_bi_rnn=False,
|
|
enc_drop_rnn=0.1,
|
|
enc_gru=False,
|
|
d_model=512,
|
|
d_enc=512,
|
|
mask=True,
|
|
**kwargs):
|
|
super().__init__()
|
|
assert isinstance(enc_bi_rnn, bool)
|
|
assert isinstance(enc_drop_rnn, (int, float))
|
|
assert 0 <= enc_drop_rnn < 1.0
|
|
assert isinstance(enc_gru, bool)
|
|
assert isinstance(d_model, int)
|
|
assert isinstance(d_enc, int)
|
|
assert isinstance(mask, bool)
|
|
|
|
self.enc_bi_rnn = enc_bi_rnn
|
|
self.enc_drop_rnn = enc_drop_rnn
|
|
self.mask = mask
|
|
|
|
# LSTM Encoder
|
|
if enc_bi_rnn:
|
|
direction = 'bidirectional'
|
|
else:
|
|
direction = 'forward'
|
|
kwargs = dict(
|
|
input_size=d_model,
|
|
hidden_size=d_enc,
|
|
num_layers=2,
|
|
time_major=False,
|
|
dropout=enc_drop_rnn,
|
|
direction=direction)
|
|
if enc_gru:
|
|
self.rnn_encoder = nn.GRU(**kwargs)
|
|
else:
|
|
self.rnn_encoder = nn.LSTM(**kwargs)
|
|
|
|
# global feature transformation
|
|
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
|
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
|
|
|
def forward(self, feat, img_metas=None):
|
|
if img_metas is not None:
|
|
assert len(img_metas[0]) == paddle.shape(feat)[0]
|
|
|
|
valid_ratios = None
|
|
if img_metas is not None and self.mask:
|
|
valid_ratios = img_metas[-1]
|
|
|
|
h_feat = feat.shape[2] # bsz c h w
|
|
feat_v = F.max_pool2d(
|
|
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
|
|
feat_v = feat_v.squeeze(2) # bsz * C * W
|
|
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
|
|
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
|
|
|
if valid_ratios is not None:
|
|
valid_hf = []
|
|
T = paddle.shape(holistic_feat)[1]
|
|
for i in range(paddle.shape(valid_ratios)[0]):
|
|
valid_step = paddle.minimum(
|
|
T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1
|
|
valid_hf.append(holistic_feat[i, valid_step, :])
|
|
valid_hf = paddle.stack(valid_hf, axis=0)
|
|
else:
|
|
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
|
holistic_feat = self.linear(valid_hf) # bsz * C
|
|
|
|
return holistic_feat
|
|
|
|
|
|
class BaseDecoder(nn.Layer):
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
|
|
def forward_train(self, feat, out_enc, targets, img_metas):
|
|
raise NotImplementedError
|
|
|
|
def forward_test(self, feat, out_enc, img_metas):
|
|
raise NotImplementedError
|
|
|
|
def forward(self,
|
|
feat,
|
|
out_enc,
|
|
label=None,
|
|
img_metas=None,
|
|
train_mode=True):
|
|
self.train_mode = train_mode
|
|
|
|
if train_mode:
|
|
return self.forward_train(feat, out_enc, label, img_metas)
|
|
return self.forward_test(feat, out_enc, img_metas)
|
|
|
|
|
|
class ParallelSARDecoder(BaseDecoder):
|
|
"""
|
|
Args:
|
|
out_channels (int): Output class number.
|
|
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
|
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
|
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
|
dec_gru (bool): If True, use GRU, else LSTM in decoder.
|
|
d_model (int): Dim of channels from backbone.
|
|
d_enc (int): Dim of encoder RNN layer.
|
|
d_k (int): Dim of channels of attention module.
|
|
pred_dropout (float): Dropout probability of prediction layer.
|
|
max_seq_len (int): Maximum sequence length for decoding.
|
|
mask (bool): If True, mask padding in feature map.
|
|
start_idx (int): Index of start token.
|
|
padding_idx (int): Index of padding token.
|
|
pred_concat (bool): If True, concat glimpse feature from
|
|
attention with holistic feature and hidden state.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
out_channels, # 90 + unknown + start + padding
|
|
enc_bi_rnn=False,
|
|
dec_bi_rnn=False,
|
|
dec_drop_rnn=0.0,
|
|
dec_gru=False,
|
|
d_model=512,
|
|
d_enc=512,
|
|
d_k=64,
|
|
pred_dropout=0.1,
|
|
max_text_length=30,
|
|
mask=True,
|
|
pred_concat=True,
|
|
**kwargs):
|
|
super().__init__()
|
|
|
|
self.num_classes = out_channels
|
|
self.enc_bi_rnn = enc_bi_rnn
|
|
self.d_k = d_k
|
|
self.start_idx = out_channels - 2
|
|
self.padding_idx = out_channels - 1
|
|
self.max_seq_len = max_text_length
|
|
self.mask = mask
|
|
self.pred_concat = pred_concat
|
|
|
|
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
|
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
|
|
|
|
# 2D attention layer
|
|
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
|
|
self.conv3x3_1 = nn.Conv2D(
|
|
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
|
self.conv1x1_2 = nn.Linear(d_k, 1)
|
|
|
|
# Decoder RNN layer
|
|
if dec_bi_rnn:
|
|
direction = 'bidirectional'
|
|
else:
|
|
direction = 'forward'
|
|
|
|
kwargs = dict(
|
|
input_size=encoder_rnn_out_size,
|
|
hidden_size=encoder_rnn_out_size,
|
|
num_layers=2,
|
|
time_major=False,
|
|
dropout=dec_drop_rnn,
|
|
direction=direction)
|
|
if dec_gru:
|
|
self.rnn_decoder = nn.GRU(**kwargs)
|
|
else:
|
|
self.rnn_decoder = nn.LSTM(**kwargs)
|
|
|
|
# Decoder input embedding
|
|
self.embedding = nn.Embedding(
|
|
self.num_classes,
|
|
encoder_rnn_out_size,
|
|
padding_idx=self.padding_idx)
|
|
|
|
# Prediction layer
|
|
self.pred_dropout = nn.Dropout(pred_dropout)
|
|
pred_num_classes = self.num_classes - 1
|
|
if pred_concat:
|
|
fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
|
|
else:
|
|
fc_in_channel = d_model
|
|
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
|
|
|
|
def _2d_attention(self,
|
|
decoder_input,
|
|
feat,
|
|
holistic_feat,
|
|
valid_ratios=None):
|
|
|
|
y = self.rnn_decoder(decoder_input)[0]
|
|
# y: bsz * (seq_len + 1) * hidden_size
|
|
|
|
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
|
bsz, seq_len, attn_size = attn_query.shape
|
|
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
|
|
# (bsz, seq_len + 1, attn_size, 1, 1)
|
|
|
|
attn_key = self.conv3x3_1(feat)
|
|
# bsz * attn_size * h * w
|
|
attn_key = attn_key.unsqueeze(1)
|
|
# bsz * 1 * attn_size * h * w
|
|
|
|
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
|
|
|
|
# bsz * (seq_len + 1) * attn_size * h * w
|
|
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
|
|
# bsz * (seq_len + 1) * h * w * attn_size
|
|
attn_weight = self.conv1x1_2(attn_weight)
|
|
# bsz * (seq_len + 1) * h * w * 1
|
|
bsz, T, h, w, c = paddle.shape(attn_weight)
|
|
assert c == 1
|
|
|
|
if valid_ratios is not None:
|
|
# cal mask of attention weight
|
|
for i in range(paddle.shape(valid_ratios)[0]):
|
|
valid_width = paddle.minimum(
|
|
w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
|
|
if valid_width < w:
|
|
attn_weight[i, :, :, valid_width:, :] = float('-inf')
|
|
|
|
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
|
|
attn_weight = F.softmax(attn_weight, axis=-1)
|
|
|
|
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
|
|
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
|
|
# attn_weight: bsz * T * c * h * w
|
|
# feat: bsz * c * h * w
|
|
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
|
|
(3, 4),
|
|
keepdim=False)
|
|
# bsz * (seq_len + 1) * C
|
|
|
|
# Linear transformation
|
|
if self.pred_concat:
|
|
hf_c = holistic_feat.shape[-1]
|
|
holistic_feat = paddle.expand(
|
|
holistic_feat, shape=[bsz, seq_len, hf_c])
|
|
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
|
|
else:
|
|
y = self.prediction(attn_feat)
|
|
# bsz * (seq_len + 1) * num_classes
|
|
if self.train_mode:
|
|
y = self.pred_dropout(y)
|
|
|
|
return y
|
|
|
|
def forward_train(self, feat, out_enc, label, img_metas):
|
|
'''
|
|
img_metas: [label, valid_ratio]
|
|
'''
|
|
if img_metas is not None:
|
|
assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]
|
|
|
|
valid_ratios = None
|
|
if img_metas is not None and self.mask:
|
|
valid_ratios = img_metas[-1]
|
|
|
|
lab_embedding = self.embedding(label)
|
|
# bsz * seq_len * emb_dim
|
|
out_enc = out_enc.unsqueeze(1)
|
|
# bsz * 1 * emb_dim
|
|
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
|
|
# bsz * (seq_len + 1) * C
|
|
out_dec = self._2d_attention(
|
|
in_dec, feat, out_enc, valid_ratios=valid_ratios)
|
|
|
|
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
|
|
|
def forward_test(self, feat, out_enc, img_metas):
|
|
if img_metas is not None:
|
|
assert len(img_metas[0]) == feat.shape[0]
|
|
|
|
valid_ratios = None
|
|
if img_metas is not None and self.mask:
|
|
valid_ratios = img_metas[-1]
|
|
|
|
seq_len = self.max_seq_len
|
|
bsz = feat.shape[0]
|
|
start_token = paddle.full(
|
|
(bsz, ), fill_value=self.start_idx, dtype='int64')
|
|
# bsz
|
|
start_token = self.embedding(start_token)
|
|
# bsz * emb_dim
|
|
emb_dim = start_token.shape[1]
|
|
start_token = start_token.unsqueeze(1)
|
|
start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
|
|
# bsz * seq_len * emb_dim
|
|
out_enc = out_enc.unsqueeze(1)
|
|
# bsz * 1 * emb_dim
|
|
decoder_input = paddle.concat((out_enc, start_token), axis=1)
|
|
# bsz * (seq_len + 1) * emb_dim
|
|
|
|
outputs = []
|
|
for i in range(1, seq_len + 1):
|
|
decoder_output = self._2d_attention(
|
|
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
|
char_output = decoder_output[:, i, :] # bsz * num_classes
|
|
char_output = F.softmax(char_output, -1)
|
|
outputs.append(char_output)
|
|
max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
|
|
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
|
if i < seq_len:
|
|
decoder_input[:, i + 1, :] = char_embedding
|
|
|
|
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
|
|
|
|
return outputs
|
|
|
|
|
|
class SARHead(nn.Layer):
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
enc_dim=512,
|
|
max_text_length=30,
|
|
enc_bi_rnn=False,
|
|
enc_drop_rnn=0.1,
|
|
enc_gru=False,
|
|
dec_bi_rnn=False,
|
|
dec_drop_rnn=0.0,
|
|
dec_gru=False,
|
|
d_k=512,
|
|
pred_dropout=0.1,
|
|
pred_concat=True,
|
|
**kwargs):
|
|
super(SARHead, self).__init__()
|
|
|
|
# encoder module
|
|
self.encoder = SAREncoder(
|
|
enc_bi_rnn=enc_bi_rnn,
|
|
enc_drop_rnn=enc_drop_rnn,
|
|
enc_gru=enc_gru,
|
|
d_model=in_channels,
|
|
d_enc=enc_dim)
|
|
|
|
# decoder module
|
|
self.decoder = ParallelSARDecoder(
|
|
out_channels=out_channels,
|
|
enc_bi_rnn=enc_bi_rnn,
|
|
dec_bi_rnn=dec_bi_rnn,
|
|
dec_drop_rnn=dec_drop_rnn,
|
|
dec_gru=dec_gru,
|
|
d_model=in_channels,
|
|
d_enc=enc_dim,
|
|
d_k=d_k,
|
|
pred_dropout=pred_dropout,
|
|
max_text_length=max_text_length,
|
|
pred_concat=pred_concat)
|
|
|
|
def forward(self, feat, targets=None):
|
|
'''
|
|
img_metas: [label, valid_ratio]
|
|
'''
|
|
holistic_feat = self.encoder(feat, targets) # bsz c
|
|
|
|
if self.training:
|
|
label = targets[0] # label
|
|
final_out = self.decoder(
|
|
feat, holistic_feat, label, img_metas=targets)
|
|
else:
|
|
final_out = self.decoder(
|
|
feat,
|
|
holistic_feat,
|
|
label=None,
|
|
img_metas=targets,
|
|
train_mode=False)
|
|
# (bsz, seq_len, num_classes)
|
|
|
|
return final_out
|