Fix seed export (#7502)
* fix seed export * update doc for seed * add stop_gredient for one_likespull/7340/merge
parent
8aa172d803
commit
6f677d609c
|
@ -50,7 +50,7 @@ Architecture:
|
|||
name: AsterHead # AttentionHead
|
||||
sDim: 512
|
||||
attDim: 512
|
||||
max_len_labels: 100
|
||||
max_len_labels: 20
|
||||
|
||||
Loss:
|
||||
name: AsterLoss
|
||||
|
|
|
@ -78,23 +78,34 @@ python3 tools/infer_rec.py -c configs/rec/rec_resnet_stn_bilstm_att.yml -o Globa
|
|||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
|
||||
coming soon
|
||||
首先将SEED文本识别训练过程中保存的模型,转换成inference model。( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_resnet_stn_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=seed_infer
|
||||
```
|
||||
|
||||
SEED文本识别模型推理,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --rec_model_dir=seed_infer --image_dir=doc/imgs_words_en/word_10.png --rec_algorithm="SEED" --rec_char_dict_path=ppocr/utils/EN_symbol_dict.txt --rec_image_shape="3,64,256" --use_space_char=False
|
||||
```
|
||||
|
||||
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++推理
|
||||
|
||||
coming soon
|
||||
暂不支持
|
||||
|
||||
<a name="4-3"></a>
|
||||
### 4.3 Serving服务化部署
|
||||
|
||||
coming soon
|
||||
暂不支持
|
||||
|
||||
<a name="4-4"></a>
|
||||
### 4.4 更多推理部署
|
||||
|
||||
coming soon
|
||||
暂不支持
|
||||
|
||||
<a name="5"></a>
|
||||
## 5. FAQ
|
||||
|
|
|
@ -77,7 +77,17 @@ python3 tools/infer_rec.py -c configs/rec/rec_resnet_stn_bilstm_att.yml -o Globa
|
|||
<a name="4-1"></a>
|
||||
### 4.1 Python Inference
|
||||
|
||||
Not support
|
||||
First, the model saved during the SEED text recognition training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) ), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/rec/rec_resnet_stn_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=seed_infer
|
||||
```
|
||||
|
||||
For SEED text recognition model inference, the following commands can be executed:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --rec_model_dir=seed_infer --image_dir=doc/imgs_words_en/word_10.png --rec_algorithm="SEED" --rec_char_dict_path=ppocr/utils/EN_symbol_dict.txt --rec_image_shape="3,64,256" --use_space_char=False
|
||||
```
|
||||
|
||||
<a name="4-2"></a>
|
||||
### 4.2 C++ Inference
|
||||
|
|
|
@ -62,10 +62,11 @@ class AsterHead(nn.Layer):
|
|||
else:
|
||||
rec_pred, rec_pred_scores = self.decoder.beam_search(
|
||||
x, self.beam_width, self.eos, embedding_vectors)
|
||||
rec_pred_scores.stop_gradient = True
|
||||
rec_pred.stop_gradient = True
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['rec_pred_scores'] = rec_pred_scores
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
|
@ -114,37 +115,13 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
|
||||
y_prev = targets[:, i - 1]
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
outputs.append(output)
|
||||
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
||||
|
||||
# inference stage.
|
||||
def sample(self, x):
|
||||
x, _, _ = x
|
||||
batch_size = x.size(0)
|
||||
# Decoder
|
||||
state = paddle.zeros([1, batch_size, self.sDim])
|
||||
|
||||
predicted_ids, predicted_scores = [], []
|
||||
for i in range(self.max_len_labels):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = predicted
|
||||
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
output = F.softmax(output, axis=1)
|
||||
score, predicted = output.max(1)
|
||||
predicted_ids.append(predicted.unsqueeze(1))
|
||||
predicted_scores.append(score.unsqueeze(1))
|
||||
predicted_ids = paddle.concat([predicted_ids, 1])
|
||||
predicted_scores = paddle.concat([predicted_scores, 1])
|
||||
# return predicted_ids.squeeze(), predicted_scores.squeeze()
|
||||
return predicted_ids, predicted_scores
|
||||
|
||||
def beam_search(self, x, beam_width, eos, embed):
|
||||
def _inflate(tensor, times, dim):
|
||||
repeat_dims = [1] * tensor.dim()
|
||||
|
@ -153,7 +130,7 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
return output
|
||||
|
||||
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
|
||||
batch_size, l, d = x.shape
|
||||
batch_size, l, d = paddle.shape(x)
|
||||
x = paddle.tile(
|
||||
paddle.transpose(
|
||||
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
|
||||
|
@ -166,21 +143,22 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
|
||||
pos_index = paddle.reshape(
|
||||
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
|
||||
|
||||
# Initialize the scores
|
||||
|
||||
sequence_scores = paddle.full(
|
||||
shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
|
||||
index = [i * beam_width for i in range(0, batch_size)]
|
||||
sequence_scores[index] = 0.0
|
||||
shape=[batch_size, beam_width], fill_value=-float('Inf'))
|
||||
sequence_scores[:, 0] = 0.0
|
||||
sequence_scores = paddle.reshape(
|
||||
sequence_scores, shape=[batch_size * beam_width, 1])
|
||||
|
||||
# Initialize the input vector
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size * beam_width], fill_value=self.num_classes)
|
||||
|
||||
# Store decisions for backtracking
|
||||
stored_scores = list()
|
||||
stored_predecessors = list()
|
||||
stored_emitted_symbols = list()
|
||||
stored_scores = []
|
||||
stored_predecessors = []
|
||||
stored_emitted_symbols = []
|
||||
|
||||
for i in range(self.max_len_labels):
|
||||
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
|
||||
|
@ -194,15 +172,16 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
paddle.reshape(sequence_scores, [batch_size, -1]),
|
||||
beam_width,
|
||||
axis=1)
|
||||
|
||||
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
|
||||
y_prev = paddle.reshape(
|
||||
candidates % self.num_classes, shape=[batch_size * beam_width])
|
||||
candidates % self.num_classes, shape=[batch_size, beam_width])
|
||||
y_prev = paddle.reshape(y_prev, shape=[batch_size * beam_width])
|
||||
sequence_scores = paddle.reshape(
|
||||
scores, shape=[batch_size * beam_width, 1])
|
||||
|
||||
# Update fields for next timestep
|
||||
pos_index = paddle.expand_as(pos_index, candidates)
|
||||
pos_index = paddle.expand(pos_index, paddle.shape(candidates))
|
||||
|
||||
predecessors = paddle.cast(
|
||||
candidates / self.num_classes + pos_index, dtype='int64')
|
||||
predecessors = paddle.reshape(
|
||||
|
@ -213,13 +192,13 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
|
||||
stored_scores.append(sequence_scores.clone())
|
||||
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
|
||||
eos_prev = paddle.full_like(y_prev, fill_value=eos)
|
||||
|
||||
eos_prev = paddle.full(paddle.shape(y_prev), fill_value=eos)
|
||||
mask = eos_prev == y_prev
|
||||
mask = paddle.cast(mask, 'int64')
|
||||
mask = paddle.nonzero(mask)
|
||||
if mask.dim() > 0:
|
||||
sequence_scores = sequence_scores.numpy()
|
||||
mask = mask.numpy()
|
||||
sequence_scores[mask] = -float('inf')
|
||||
if len(mask) > 0:
|
||||
sequence_scores[:] = -float('inf')
|
||||
sequence_scores = paddle.to_tensor(sequence_scores)
|
||||
|
||||
# Cache results for backtracking
|
||||
|
@ -228,11 +207,12 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
stored_emitted_symbols.append(y_prev)
|
||||
|
||||
# Do backtracking to return the optimal values
|
||||
#====== backtrak ======#
|
||||
# ====== backtrak ======#
|
||||
# Initialize return variables given different types
|
||||
p = list()
|
||||
l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
|
||||
] # Placeholder for lengths of top-k sequences
|
||||
p = []
|
||||
|
||||
# Placeholder for lengths of top-k sequences
|
||||
l = paddle.full([batch_size, beam_width], self.max_len_labels)
|
||||
|
||||
# the last step output of the beams are not sorted
|
||||
# thus they are sorted here
|
||||
|
@ -244,14 +224,18 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
# initialize the sequence scores with the sorted last step beam scores
|
||||
s = sorted_score.clone()
|
||||
|
||||
batch_eos_found = [0] * batch_size # the number of EOS found
|
||||
batch_eos_found = paddle.zeros(
|
||||
[batch_size], dtype='int32') # the number of EOS found
|
||||
# in the backward loop below for each batch
|
||||
t = self.max_len_labels - 1
|
||||
|
||||
# initialize the back pointer with the sorted order of the last step beams.
|
||||
# add pos_index for indexing variable with b*k as the first dimension.
|
||||
t_predecessors = paddle.reshape(
|
||||
sorted_idx + pos_index.expand_as(sorted_idx),
|
||||
sorted_idx + pos_index.expand(paddle.shape(sorted_idx)),
|
||||
shape=[batch_size * beam_width])
|
||||
|
||||
tmp_beam_width = beam_width
|
||||
while t >= 0:
|
||||
# Re-order the variables with the back pointer
|
||||
current_symbol = paddle.index_select(
|
||||
|
@ -261,26 +245,32 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
eos_indices = stored_emitted_symbols[t] == eos
|
||||
eos_indices = paddle.nonzero(eos_indices)
|
||||
|
||||
stored_predecessors_t = stored_predecessors[t]
|
||||
stored_emitted_symbols_t = stored_emitted_symbols[t]
|
||||
stored_scores_t = stored_scores[t]
|
||||
t_plus = t + 1
|
||||
|
||||
if eos_indices.dim() > 0:
|
||||
for i in range(eos_indices.shape[0] - 1, -1, -1):
|
||||
for j in range(eos_indices.shape[0] - 1, -1, -1):
|
||||
# Indices of the EOS symbol for both variables
|
||||
# with b*k as the first dimension, and b, k for
|
||||
# the first two dimensions
|
||||
idx = eos_indices[i]
|
||||
b_idx = int(idx[0] / beam_width)
|
||||
idx = eos_indices[j]
|
||||
b_idx = int(idx[0] / tmp_beam_width)
|
||||
# The indices of the replacing position
|
||||
# according to the replacement strategy noted above
|
||||
res_k_idx = beam_width - (batch_eos_found[b_idx] %
|
||||
beam_width) - 1
|
||||
res_k_idx = tmp_beam_width - (batch_eos_found[b_idx] %
|
||||
tmp_beam_width) - 1
|
||||
batch_eos_found[b_idx] += 1
|
||||
res_idx = b_idx * beam_width + res_k_idx
|
||||
res_idx = b_idx * tmp_beam_width + res_k_idx
|
||||
|
||||
# Replace the old information in return variables
|
||||
# with the new ended sequence information
|
||||
t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
|
||||
current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
|
||||
s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
|
||||
l[b_idx][res_k_idx] = t + 1
|
||||
|
||||
t_predecessors[res_idx] = stored_predecessors_t[idx[0]]
|
||||
current_symbol[res_idx] = stored_emitted_symbols_t[idx[0]]
|
||||
s[b_idx, res_k_idx] = stored_scores_t[idx[0], 0]
|
||||
l[b_idx][res_k_idx] = t_plus
|
||||
|
||||
# record the back tracked results
|
||||
p.append(current_symbol)
|
||||
|
@ -289,24 +279,30 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
# Sort and re-order again as the added ended sequences may change
|
||||
# the order (very unlikely)
|
||||
s, re_sorted_idx = s.topk(beam_width)
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
l[b_idx] = [
|
||||
l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
|
||||
]
|
||||
tmp_tensor = paddle.full_like(l[b_idx], 0)
|
||||
for k_idx in re_sorted_idx[b_idx]:
|
||||
tmp_tensor[k_idx] = l[b_idx][k_idx]
|
||||
l[b_idx] = tmp_tensor
|
||||
|
||||
re_sorted_idx = paddle.reshape(
|
||||
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
|
||||
re_sorted_idx + pos_index.expand(paddle.shape(re_sorted_idx)),
|
||||
[batch_size * beam_width])
|
||||
|
||||
# Reverse the sequences and re-order at the same time
|
||||
# It is reversed because the backtracking happens in reverse time order
|
||||
p = [
|
||||
reversed_p = p[::-1]
|
||||
|
||||
q = []
|
||||
for step in reversed_p:
|
||||
q.append(
|
||||
paddle.reshape(
|
||||
paddle.index_select(step, re_sorted_idx, 0),
|
||||
shape=[batch_size, beam_width, -1]) for step in reversed(p)
|
||||
]
|
||||
p = paddle.concat(p, -1)[:, 0, :]
|
||||
return p, paddle.ones_like(p)
|
||||
shape=[batch_size, beam_width, -1]))
|
||||
|
||||
q = paddle.concat(q, -1)[:, 0, :]
|
||||
return q, paddle.ones_like(q)
|
||||
|
||||
|
||||
class AttentionUnit(nn.Layer):
|
||||
|
@ -385,8 +381,8 @@ class DecoderUnit(nn.Layer):
|
|||
yProj = self.tgt_embedding(yPrev)
|
||||
|
||||
concat_context = paddle.concat([yProj, context], 1)
|
||||
concat_context = paddle.squeeze(concat_context, 1)
|
||||
sPrev = paddle.squeeze(sPrev, 0)
|
||||
|
||||
output, state = self.gru(concat_context, sPrev)
|
||||
output = paddle.squeeze(output, axis=1)
|
||||
output = self.fc(output)
|
||||
|
|
|
@ -307,6 +307,11 @@ class SEEDLabelDecode(BaseRecLabelDecode):
|
|||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
tmp = {}
|
||||
if isinstance(preds, list):
|
||||
tmp["rec_pred"] = preds[1]
|
||||
tmp["rec_pred_scores"] = preds[0]
|
||||
preds = tmp
|
||||
preds_idx = preds["rec_pred"]
|
||||
if isinstance(preds_idx, paddle.Tensor):
|
||||
preds_idx = preds_idx.numpy()
|
||||
|
|
|
@ -97,7 +97,6 @@ def export_single_model(model,
|
|||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 32, 128], dtype="float32"),
|
||||
]
|
||||
# print([None, 3, 32, 128])
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
|
||||
other_shape = [
|
||||
|
@ -115,18 +114,20 @@ def export_single_model(model,
|
|||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
|
||||
[
|
||||
shape=[None, 3, 48, 160], dtype="float32"), [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, ],
|
||||
dtype="float32"),
|
||||
shape=[None, ], dtype="float32"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, max_text_length],
|
||||
dtype="int64")
|
||||
shape=[None, max_text_length], dtype="int64")
|
||||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SEED":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 64, 256], dtype="float32")
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(
|
||||
|
|
|
@ -100,6 +100,12 @@ class TextRecognizer(object):
|
|||
"use_space_char": args.use_space_char,
|
||||
"rm_symbol": True
|
||||
}
|
||||
elif self.rec_algorithm == "SEED":
|
||||
postprocess_params = {
|
||||
'name': 'SEEDLabelDecode',
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -161,6 +167,7 @@ class TextRecognizer(object):
|
|||
if resized_w > self.rec_image_shape[2]:
|
||||
resized_w = self.rec_image_shape[2]
|
||||
imgW = self.rec_image_shape[2]
|
||||
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
|
@ -398,6 +405,11 @@ class TextRecognizer(object):
|
|||
img_list[indices[ino]], self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "SEED":
|
||||
norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
|
||||
self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "RobustScanner":
|
||||
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
||||
img_list[indices[ino]],
|
||||
|
|
|
@ -75,7 +75,6 @@ def main():
|
|||
'out_channels_list'] = out_channels_list
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
load_model(config, model)
|
||||
|
@ -97,7 +96,8 @@ def main():
|
|||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||
op[op_name][
|
||||
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
|
@ -136,7 +136,8 @@ def main():
|
|||
if config['Architecture']['algorithm'] == "RobustScanner":
|
||||
valid_ratio = np.expand_dims(batch[1], axis=0)
|
||||
word_positons = np.expand_dims(batch[2], axis=0)
|
||||
img_metas = [paddle.to_tensor(valid_ratio),
|
||||
img_metas = [
|
||||
paddle.to_tensor(valid_ratio),
|
||||
paddle.to_tensor(word_positons),
|
||||
]
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
|
|
Loading…
Reference in New Issue