添加RFL CNT分支infer支持
parent
3f8602c169
commit
c459b72565
|
@ -103,7 +103,6 @@ class RFLHead(nn.Layer):
|
||||||
else:
|
else:
|
||||||
seq_outputs = self.seq_head(seq_inputs, None,
|
seq_outputs = self.seq_head(seq_inputs, None,
|
||||||
self.batch_max_legnth)
|
self.batch_max_legnth)
|
||||||
|
return cnt_outputs, seq_outputs
|
||||||
else:
|
else:
|
||||||
seq_outputs = None
|
return cnt_outputs
|
||||||
|
|
||||||
return cnt_outputs, seq_outputs
|
|
||||||
|
|
|
@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
||||||
return result_list
|
return result_list
|
||||||
|
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
cnt_pred, preds = preds
|
if len(preds) == 2:
|
||||||
if preds is not None:
|
cnt_pred, preds = preds
|
||||||
|
|
||||||
if isinstance(preds, paddle.Tensor):
|
if isinstance(preds, paddle.Tensor):
|
||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
preds_idx = preds.argmax(axis=2)
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
||||||
return text, label
|
return text, label
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
cnt_pred = preds
|
||||||
|
if isinstance(cnt_pred, paddle.Tensor):
|
||||||
|
cnt_pred = cnt_pred.numpy()
|
||||||
cnt_length = []
|
cnt_length = []
|
||||||
for lens in cnt_pred:
|
for lens in cnt_pred:
|
||||||
length = round(paddle.sum(lens).item())
|
length = round(np.sum(lens))
|
||||||
cnt_length.append(length)
|
cnt_length.append(length)
|
||||||
if label is None:
|
if label is None:
|
||||||
return cnt_length
|
return cnt_length
|
||||||
|
|
|
@ -97,7 +97,8 @@ def main():
|
||||||
elif config['Architecture']['algorithm'] == "SAR":
|
elif config['Architecture']['algorithm'] == "SAR":
|
||||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||||
elif config['Architecture']['algorithm'] == "RobustScanner":
|
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
op[op_name][
|
||||||
|
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||||
else:
|
else:
|
||||||
op[op_name]['keep_keys'] = ['image']
|
op[op_name]['keep_keys'] = ['image']
|
||||||
transforms.append(op)
|
transforms.append(op)
|
||||||
|
@ -136,9 +137,10 @@ def main():
|
||||||
if config['Architecture']['algorithm'] == "RobustScanner":
|
if config['Architecture']['algorithm'] == "RobustScanner":
|
||||||
valid_ratio = np.expand_dims(batch[1], axis=0)
|
valid_ratio = np.expand_dims(batch[1], axis=0)
|
||||||
word_positons = np.expand_dims(batch[2], axis=0)
|
word_positons = np.expand_dims(batch[2], axis=0)
|
||||||
img_metas = [paddle.to_tensor(valid_ratio),
|
img_metas = [
|
||||||
paddle.to_tensor(word_positons),
|
paddle.to_tensor(valid_ratio),
|
||||||
]
|
paddle.to_tensor(word_positons),
|
||||||
|
]
|
||||||
images = np.expand_dims(batch[0], axis=0)
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
if config['Architecture']['algorithm'] == "SRN":
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
@ -160,6 +162,10 @@ def main():
|
||||||
"score": float(post_result[key][0][1]),
|
"score": float(post_result[key][0][1]),
|
||||||
}
|
}
|
||||||
info = json.dumps(rec_info, ensure_ascii=False)
|
info = json.dumps(rec_info, ensure_ascii=False)
|
||||||
|
elif isinstance(post_result, list) and isinstance(post_result[0],
|
||||||
|
int):
|
||||||
|
# for RFLearning CNT branch
|
||||||
|
info = str(post_result[0])
|
||||||
else:
|
else:
|
||||||
if len(post_result[0]) >= 2:
|
if len(post_result[0]) >= 2:
|
||||||
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
||||||
|
|
Loading…
Reference in New Issue