添加RFL CNT分支infer支持
parent
3f8602c169
commit
c459b72565
|
@ -103,7 +103,6 @@ class RFLHead(nn.Layer):
|
|||
else:
|
||||
seq_outputs = self.seq_head(seq_inputs, None,
|
||||
self.batch_max_legnth)
|
||||
return cnt_outputs, seq_outputs
|
||||
else:
|
||||
seq_outputs = None
|
||||
|
||||
return cnt_outputs, seq_outputs
|
||||
return cnt_outputs
|
||||
|
|
|
@ -287,9 +287,8 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
|||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
cnt_pred, preds = preds
|
||||
if preds is not None:
|
||||
|
||||
if len(preds) == 2:
|
||||
cnt_pred, preds = preds
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
|
@ -302,9 +301,12 @@ class RFLLabelDecode(BaseRecLabelDecode):
|
|||
return text, label
|
||||
|
||||
else:
|
||||
cnt_pred = preds
|
||||
if isinstance(cnt_pred, paddle.Tensor):
|
||||
cnt_pred = cnt_pred.numpy()
|
||||
cnt_length = []
|
||||
for lens in cnt_pred:
|
||||
length = round(paddle.sum(lens).item())
|
||||
length = round(np.sum(lens))
|
||||
cnt_length.append(length)
|
||||
if label is None:
|
||||
return cnt_length
|
||||
|
|
|
@ -97,7 +97,8 @@ def main():
|
|||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
|
||||
elif config['Architecture']['algorithm'] == "RobustScanner":
|
||||
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||
op[op_name][
|
||||
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
|
@ -136,9 +137,10 @@ def main():
|
|||
if config['Architecture']['algorithm'] == "RobustScanner":
|
||||
valid_ratio = np.expand_dims(batch[1], axis=0)
|
||||
word_positons = np.expand_dims(batch[2], axis=0)
|
||||
img_metas = [paddle.to_tensor(valid_ratio),
|
||||
paddle.to_tensor(word_positons),
|
||||
]
|
||||
img_metas = [
|
||||
paddle.to_tensor(valid_ratio),
|
||||
paddle.to_tensor(word_positons),
|
||||
]
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
|
@ -160,6 +162,10 @@ def main():
|
|||
"score": float(post_result[key][0][1]),
|
||||
}
|
||||
info = json.dumps(rec_info, ensure_ascii=False)
|
||||
elif isinstance(post_result, list) and isinstance(post_result[0],
|
||||
int):
|
||||
# for RFLearning CNT branch
|
||||
info = str(post_result[0])
|
||||
else:
|
||||
if len(post_result[0]) >= 2:
|
||||
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
||||
|
|
Loading…
Reference in New Issue