[third-party] Fix the issue of inference errors with KE models in ONNX format (#14138)
* fix inference KIE model using onnx model * fix code style * fix onnx inputs compatiblility with det and rec * fix code stylepull/14147/head
parent
d3d7e85883
commit
58e876d38d
|
@ -40,6 +40,7 @@ logger = get_logger()
|
||||||
|
|
||||||
class SerPredictor(object):
|
class SerPredictor(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
self.args = args
|
||||||
self.ocr_engine = PaddleOCR(
|
self.ocr_engine = PaddleOCR(
|
||||||
use_angle_cls=args.use_angle_cls,
|
use_angle_cls=args.use_angle_cls,
|
||||||
det_model_dir=args.det_model_dir,
|
det_model_dir=args.det_model_dir,
|
||||||
|
@ -113,7 +114,12 @@ class SerPredictor(object):
|
||||||
data[idx] = np.expand_dims(data[idx], axis=0)
|
data[idx] = np.expand_dims(data[idx], axis=0)
|
||||||
else:
|
else:
|
||||||
data[idx] = [data[idx]]
|
data[idx] = [data[idx]]
|
||||||
|
if self.args.use_onnx:
|
||||||
|
input_tensor = {
|
||||||
|
name: data[idx] for idx, name in enumerate(self.input_tensor)
|
||||||
|
}
|
||||||
|
self.output_tensors = self.predictor.run(None, input_tensor)
|
||||||
|
else:
|
||||||
for idx in range(len(self.input_tensor)):
|
for idx in range(len(self.input_tensor)):
|
||||||
self.input_tensor[idx].copy_from_cpu(data[idx])
|
self.input_tensor[idx].copy_from_cpu(data[idx])
|
||||||
|
|
||||||
|
@ -121,7 +127,9 @@ class SerPredictor(object):
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for output_tensor in self.output_tensors:
|
for output_tensor in self.output_tensors:
|
||||||
output = output_tensor.copy_to_cpu()
|
output = (
|
||||||
|
output_tensor if self.args.use_onnx else output_tensor.copy_to_cpu()
|
||||||
|
)
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
preds = outputs[0]
|
preds = outputs[0]
|
||||||
|
|
||||||
|
|
|
@ -221,7 +221,13 @@ def create_predictor(args, mode, logger):
|
||||||
providers=["CPUExecutionProvider"],
|
providers=["CPUExecutionProvider"],
|
||||||
sess_options=sess_options,
|
sess_options=sess_options,
|
||||||
)
|
)
|
||||||
return sess, sess.get_inputs()[0], None, None
|
inputs = sess.get_inputs()
|
||||||
|
return (
|
||||||
|
sess,
|
||||||
|
inputs[0] if len(inputs) == 1 else [vo.name for vo in inputs],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
file_names = ["model", "inference"]
|
file_names = ["model", "inference"]
|
||||||
|
|
Loading…
Reference in New Issue