[Fix] Keep E2E Inferencer output simple (#1559)

pull/1567/head
Tong Gao 2022-12-06 16:47:31 +08:00 committed by GitHub
parent 79a4b2042c
commit f4940de2a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 16 deletions

View File

@ -244,7 +244,7 @@ class BaseMMOCRInferencer(BaseInferencer):
if not get_datasample:
results = []
for pred in preds:
result = self._pred2dict(pred)
result = self.pred2dict(pred)
results.append(result)
if not is_batch:
results = results[0]
@ -257,7 +257,7 @@ class BaseMMOCRInferencer(BaseInferencer):
return results
return results, imgs
def _pred2dict(self, data_sample: InstanceData) -> Dict:
def pred2dict(self, data_sample: InstanceData) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary.

View File

@ -174,7 +174,7 @@ class KIEInferencer(BaseMMOCRInferencer):
return results
def _pred2dict(self, data_sample: KIEDataSample) -> Dict:
def pred2dict(self, data_sample: KIEDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.

View File

@ -185,26 +185,27 @@ class MMOCRInferencer(BaseMMOCRInferencer):
for i, rec_pred in enumerate(preds['rec']):
result = dict(rec_texts=[], rec_scores=[])
for rec_pred_instance in rec_pred:
pred = rec_pred_instance.pred_text
result['rec_texts'].append(pred.item)
result['rec_scores'].append(pred.score)
rec_dict_res = self.textrec_inferencer.pred2dict(
rec_pred_instance)
result['rec_texts'].append(rec_dict_res['text'])
result['rec_scores'].append(rec_dict_res['scores'])
results[i].update(result)
if 'det' in self.mode:
for i, det_pred in enumerate(preds['det']):
det_pred_instances = det_pred.pred_instances
det_dict_res = self.textdet_inferencer.pred2dict(det_pred)
results[i].update(
dict(
det_polygons=det_pred_instances['polygons'],
det_scores=det_pred_instances['scores']))
det_polygons=det_dict_res['polygons'],
det_scores=det_dict_res['scores']))
if 'kie' in self.mode:
for i, kie_pred in enumerate(preds['kie']):
kie_pred_instances = kie_pred.pred_instances
kie_dict_res = self.kie_inferencer.pred2dict(kie_pred)
results[i].update(
dict(
kie_labels=kie_pred_instances['labels'],
kie_scores=kie_pred_instances['scores']),
kie_edge_scores=kie_pred_instances['edge_scores'],
kie_edge_labels=kie_pred_instances['edge_labels'])
kie_labels=kie_dict_res['labels'],
kie_scores=kie_dict_res['scores']),
kie_edge_scores=kie_dict_res['edge_scores'],
kie_edge_labels=kie_dict_res['edge_labels'])
if not is_batch:
results = results[0]

View File

@ -7,7 +7,7 @@ from .base_mmocr_inferencer import BaseMMOCRInferencer
class TextDetInferencer(BaseMMOCRInferencer):
def _pred2dict(self, data_sample: TextDetDataSample) -> Dict:
def pred2dict(self, data_sample: TextDetDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.

View File

@ -9,7 +9,7 @@ from .base_mmocr_inferencer import BaseMMOCRInferencer
class TextRecInferencer(BaseMMOCRInferencer):
def _pred2dict(self, data_sample: TextRecogDataSample) -> Dict:
def pred2dict(self, data_sample: TextRecogDataSample) -> Dict:
"""Extract elements necessary to represent a prediction into a
dictionary. It's better to contain only basic data elements such as
strings and numbers in order to guarantee it's json-serializable.