PaddleOCR/ppocr/modeling/heads/rec_aster_head.py
xiaoting 6f677d609c
Fix seed export (#7502)
* fix seed export

* update doc for seed

* add stop_gredient for one_likes
2022-09-06 16:07:48 +08:00

390 lines
15 KiB
Python

# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/ayumiymk/aster.pytorch/blob/master/lib/models/attention_recognition_head.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle
from paddle import nn
from paddle.nn import functional as F
class AsterHead(nn.Layer):
def __init__(self,
in_channels,
out_channels,
sDim,
attDim,
max_len_labels,
time_step=25,
beam_width=5,
**kwargs):
super(AsterHead, self).__init__()
self.num_classes = out_channels
self.in_planes = in_channels
self.sDim = sDim
self.attDim = attDim
self.max_len_labels = max_len_labels
self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
attDim, max_len_labels)
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
self.eos = self.num_classes - 3
def forward(self, x, targets=None, embed=None):
return_dict = {}
embedding_vectors = self.embeder(x)
if self.training:
rec_targets, rec_lengths, _ = targets
rec_pred = self.decoder([x, rec_targets, rec_lengths],
embedding_vectors)
return_dict['rec_pred'] = rec_pred
return_dict['embedding_vectors'] = embedding_vectors
else:
rec_pred, rec_pred_scores = self.decoder.beam_search(
x, self.beam_width, self.eos, embedding_vectors)
rec_pred_scores.stop_gradient = True
rec_pred.stop_gradient = True
return_dict['rec_pred'] = rec_pred
return_dict['rec_pred_scores'] = rec_pred_scores
return_dict['embedding_vectors'] = embedding_vectors
return return_dict
class Embedding(nn.Layer):
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
super(Embedding, self).__init__()
self.in_timestep = in_timestep
self.in_planes = in_planes
self.embed_dim = embed_dim
self.mid_dim = mid_dim
self.eEmbed = nn.Linear(
in_timestep * in_planes,
self.embed_dim) # Embed encoder output to a word-embedding like
def forward(self, x):
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
x = self.eEmbed(x)
return x
class AttentionRecognitionHead(nn.Layer):
"""
input: [b x 16 x 64 x in_planes]
output: probability sequence: [b x T x num_classes]
"""
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
super(AttentionRecognitionHead, self).__init__()
self.num_classes = out_channels # this is the output classes. So it includes the <EOS>.
self.in_planes = in_channels
self.sDim = sDim
self.attDim = attDim
self.max_len_labels = max_len_labels
self.decoder = DecoderUnit(
sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
def forward(self, x, embed):
x, targets, lengths = x
batch_size = paddle.shape(x)[0]
# Decoder
state = self.decoder.get_initial_state(embed)
outputs = []
for i in range(max(lengths)):
if i == 0:
y_prev = paddle.full(
shape=[batch_size], fill_value=self.num_classes)
else:
y_prev = targets[:, i - 1]
output, state = self.decoder(x, state, y_prev)
outputs.append(output)
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
return outputs
def beam_search(self, x, beam_width, eos, embed):
def _inflate(tensor, times, dim):
repeat_dims = [1] * tensor.dim()
repeat_dims[dim] = times
output = paddle.tile(tensor, repeat_dims)
return output
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
batch_size, l, d = paddle.shape(x)
x = paddle.tile(
paddle.transpose(
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
inflated_encoder_feats = paddle.reshape(
paddle.transpose(
x, perm=[1, 0, 2, 3]), [-1, l, d])
# Initialize the decoder
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
pos_index = paddle.reshape(
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
# Initialize the scores
sequence_scores = paddle.full(
shape=[batch_size, beam_width], fill_value=-float('Inf'))
sequence_scores[:, 0] = 0.0
sequence_scores = paddle.reshape(
sequence_scores, shape=[batch_size * beam_width, 1])
# Initialize the input vector
y_prev = paddle.full(
shape=[batch_size * beam_width], fill_value=self.num_classes)
# Store decisions for backtracking
stored_scores = []
stored_predecessors = []
stored_emitted_symbols = []
for i in range(self.max_len_labels):
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
state = paddle.unsqueeze(state, axis=0)
log_softmax_output = paddle.nn.functional.log_softmax(
output, axis=1)
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
sequence_scores += log_softmax_output
scores, candidates = paddle.topk(
paddle.reshape(sequence_scores, [batch_size, -1]),
beam_width,
axis=1)
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
y_prev = paddle.reshape(
candidates % self.num_classes, shape=[batch_size, beam_width])
y_prev = paddle.reshape(y_prev, shape=[batch_size * beam_width])
sequence_scores = paddle.reshape(
scores, shape=[batch_size * beam_width, 1])
# Update fields for next timestep
pos_index = paddle.expand(pos_index, paddle.shape(candidates))
predecessors = paddle.cast(
candidates / self.num_classes + pos_index, dtype='int64')
predecessors = paddle.reshape(
predecessors, shape=[batch_size * beam_width, 1])
state = paddle.index_select(
state, index=predecessors.squeeze(), axis=1)
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
stored_scores.append(sequence_scores.clone())
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
eos_prev = paddle.full(paddle.shape(y_prev), fill_value=eos)
mask = eos_prev == y_prev
mask = paddle.cast(mask, 'int64')
mask = paddle.nonzero(mask)
if len(mask) > 0:
sequence_scores[:] = -float('inf')
sequence_scores = paddle.to_tensor(sequence_scores)
# Cache results for backtracking
stored_predecessors.append(predecessors)
y_prev = paddle.squeeze(y_prev)
stored_emitted_symbols.append(y_prev)
# Do backtracking to return the optimal values
# ====== backtrak ======#
# Initialize return variables given different types
p = []
# Placeholder for lengths of top-k sequences
l = paddle.full([batch_size, beam_width], self.max_len_labels)
# the last step output of the beams are not sorted
# thus they are sorted here
sorted_score, sorted_idx = paddle.topk(
paddle.reshape(
stored_scores[-1], shape=[batch_size, beam_width]),
beam_width)
# initialize the sequence scores with the sorted last step beam scores
s = sorted_score.clone()
batch_eos_found = paddle.zeros(
[batch_size], dtype='int32') # the number of EOS found
# in the backward loop below for each batch
t = self.max_len_labels - 1
# initialize the back pointer with the sorted order of the last step beams.
# add pos_index for indexing variable with b*k as the first dimension.
t_predecessors = paddle.reshape(
sorted_idx + pos_index.expand(paddle.shape(sorted_idx)),
shape=[batch_size * beam_width])
tmp_beam_width = beam_width
while t >= 0:
# Re-order the variables with the back pointer
current_symbol = paddle.index_select(
stored_emitted_symbols[t], index=t_predecessors, axis=0)
t_predecessors = paddle.index_select(
stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
eos_indices = stored_emitted_symbols[t] == eos
eos_indices = paddle.nonzero(eos_indices)
stored_predecessors_t = stored_predecessors[t]
stored_emitted_symbols_t = stored_emitted_symbols[t]
stored_scores_t = stored_scores[t]
t_plus = t + 1
if eos_indices.dim() > 0:
for j in range(eos_indices.shape[0] - 1, -1, -1):
# Indices of the EOS symbol for both variables
# with b*k as the first dimension, and b, k for
# the first two dimensions
idx = eos_indices[j]
b_idx = int(idx[0] / tmp_beam_width)
# The indices of the replacing position
# according to the replacement strategy noted above
res_k_idx = tmp_beam_width - (batch_eos_found[b_idx] %
tmp_beam_width) - 1
batch_eos_found[b_idx] += 1
res_idx = b_idx * tmp_beam_width + res_k_idx
# Replace the old information in return variables
# with the new ended sequence information
t_predecessors[res_idx] = stored_predecessors_t[idx[0]]
current_symbol[res_idx] = stored_emitted_symbols_t[idx[0]]
s[b_idx, res_k_idx] = stored_scores_t[idx[0], 0]
l[b_idx][res_k_idx] = t_plus
# record the back tracked results
p.append(current_symbol)
t -= 1
# Sort and re-order again as the added ended sequences may change
# the order (very unlikely)
s, re_sorted_idx = s.topk(beam_width)
for b_idx in range(batch_size):
tmp_tensor = paddle.full_like(l[b_idx], 0)
for k_idx in re_sorted_idx[b_idx]:
tmp_tensor[k_idx] = l[b_idx][k_idx]
l[b_idx] = tmp_tensor
re_sorted_idx = paddle.reshape(
re_sorted_idx + pos_index.expand(paddle.shape(re_sorted_idx)),
[batch_size * beam_width])
# Reverse the sequences and re-order at the same time
# It is reversed because the backtracking happens in reverse time order
reversed_p = p[::-1]
q = []
for step in reversed_p:
q.append(
paddle.reshape(
paddle.index_select(step, re_sorted_idx, 0),
shape=[batch_size, beam_width, -1]))
q = paddle.concat(q, -1)[:, 0, :]
return q, paddle.ones_like(q)
class AttentionUnit(nn.Layer):
def __init__(self, sDim, xDim, attDim):
super(AttentionUnit, self).__init__()
self.sDim = sDim
self.xDim = xDim
self.attDim = attDim
self.sEmbed = nn.Linear(sDim, attDim)
self.xEmbed = nn.Linear(xDim, attDim)
self.wEmbed = nn.Linear(attDim, 1)
def forward(self, x, sPrev):
batch_size, T, _ = x.shape # [b x T x xDim]
x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
xProj = self.xEmbed(x) # [(b x T) x attDim]
xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
sPrev = sPrev.squeeze(0)
sProj = self.sEmbed(sPrev) # [b x attDim]
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
sProj = paddle.expand(sProj,
[batch_size, T, self.attDim]) # [b x T x attDim]
sumTanh = paddle.tanh(sProj + xProj)
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
vProj = paddle.reshape(vProj, [batch_size, T])
alpha = F.softmax(
vProj, axis=1) # attention weights for each sample in the minibatch
return alpha
class DecoderUnit(nn.Layer):
def __init__(self, sDim, xDim, yDim, attDim):
super(DecoderUnit, self).__init__()
self.sDim = sDim
self.xDim = xDim
self.yDim = yDim
self.attDim = attDim
self.emdDim = attDim
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
self.tgt_embedding = nn.Embedding(
yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
std=0.01)) # the last is used for <BOS>
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
self.fc = nn.Linear(
sDim,
yDim,
weight_attr=nn.initializer.Normal(std=0.01),
bias_attr=nn.initializer.Constant(value=0))
self.embed_fc = nn.Linear(300, self.sDim)
def get_initial_state(self, embed, tile_times=1):
assert embed.shape[1] == 300
state = self.embed_fc(embed) # N * sDim
if tile_times != 1:
state = state.unsqueeze(1)
trans_state = paddle.transpose(state, perm=[1, 0, 2])
state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
trans_state = paddle.transpose(state, perm=[1, 0, 2])
state = paddle.reshape(trans_state, shape=[-1, self.sDim])
state = state.unsqueeze(0) # 1 * N * sDim
return state
def forward(self, x, sPrev, yPrev):
# x: feature sequence from the image decoder.
batch_size, T, _ = x.shape
alpha = self.attention_unit(x, sPrev)
context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
yPrev = paddle.cast(yPrev, dtype="int64")
yProj = self.tgt_embedding(yPrev)
concat_context = paddle.concat([yProj, context], 1)
sPrev = paddle.squeeze(sPrev, 0)
output, state = self.gru(concat_context, sPrev)
output = paddle.squeeze(output, axis=1)
output = self.fc(output)
return output, state