mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Keep E2E Inferencer output simple (#1559)
parent
79a4b2042c
commit
f4940de2a4
|
@ -244,7 +244,7 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
if not get_datasample:
|
||||
results = []
|
||||
for pred in preds:
|
||||
result = self._pred2dict(pred)
|
||||
result = self.pred2dict(pred)
|
||||
results.append(result)
|
||||
if not is_batch:
|
||||
results = results[0]
|
||||
|
@ -257,7 +257,7 @@ class BaseMMOCRInferencer(BaseInferencer):
|
|||
return results
|
||||
return results, imgs
|
||||
|
||||
def _pred2dict(self, data_sample: InstanceData) -> Dict:
|
||||
def pred2dict(self, data_sample: InstanceData) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary.
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ class KIEInferencer(BaseMMOCRInferencer):
|
|||
|
||||
return results
|
||||
|
||||
def _pred2dict(self, data_sample: KIEDataSample) -> Dict:
|
||||
def pred2dict(self, data_sample: KIEDataSample) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary. It's better to contain only basic data elements such as
|
||||
strings and numbers in order to guarantee it's json-serializable.
|
||||
|
|
|
@ -185,26 +185,27 @@ class MMOCRInferencer(BaseMMOCRInferencer):
|
|||
for i, rec_pred in enumerate(preds['rec']):
|
||||
result = dict(rec_texts=[], rec_scores=[])
|
||||
for rec_pred_instance in rec_pred:
|
||||
pred = rec_pred_instance.pred_text
|
||||
result['rec_texts'].append(pred.item)
|
||||
result['rec_scores'].append(pred.score)
|
||||
rec_dict_res = self.textrec_inferencer.pred2dict(
|
||||
rec_pred_instance)
|
||||
result['rec_texts'].append(rec_dict_res['text'])
|
||||
result['rec_scores'].append(rec_dict_res['scores'])
|
||||
results[i].update(result)
|
||||
if 'det' in self.mode:
|
||||
for i, det_pred in enumerate(preds['det']):
|
||||
det_pred_instances = det_pred.pred_instances
|
||||
det_dict_res = self.textdet_inferencer.pred2dict(det_pred)
|
||||
results[i].update(
|
||||
dict(
|
||||
det_polygons=det_pred_instances['polygons'],
|
||||
det_scores=det_pred_instances['scores']))
|
||||
det_polygons=det_dict_res['polygons'],
|
||||
det_scores=det_dict_res['scores']))
|
||||
if 'kie' in self.mode:
|
||||
for i, kie_pred in enumerate(preds['kie']):
|
||||
kie_pred_instances = kie_pred.pred_instances
|
||||
kie_dict_res = self.kie_inferencer.pred2dict(kie_pred)
|
||||
results[i].update(
|
||||
dict(
|
||||
kie_labels=kie_pred_instances['labels'],
|
||||
kie_scores=kie_pred_instances['scores']),
|
||||
kie_edge_scores=kie_pred_instances['edge_scores'],
|
||||
kie_edge_labels=kie_pred_instances['edge_labels'])
|
||||
kie_labels=kie_dict_res['labels'],
|
||||
kie_scores=kie_dict_res['scores']),
|
||||
kie_edge_scores=kie_dict_res['edge_scores'],
|
||||
kie_edge_labels=kie_dict_res['edge_labels'])
|
||||
|
||||
if not is_batch:
|
||||
results = results[0]
|
||||
|
|
|
@ -7,7 +7,7 @@ from .base_mmocr_inferencer import BaseMMOCRInferencer
|
|||
|
||||
class TextDetInferencer(BaseMMOCRInferencer):
|
||||
|
||||
def _pred2dict(self, data_sample: TextDetDataSample) -> Dict:
|
||||
def pred2dict(self, data_sample: TextDetDataSample) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary. It's better to contain only basic data elements such as
|
||||
strings and numbers in order to guarantee it's json-serializable.
|
||||
|
|
|
@ -9,7 +9,7 @@ from .base_mmocr_inferencer import BaseMMOCRInferencer
|
|||
|
||||
class TextRecInferencer(BaseMMOCRInferencer):
|
||||
|
||||
def _pred2dict(self, data_sample: TextRecogDataSample) -> Dict:
|
||||
def pred2dict(self, data_sample: TextRecogDataSample) -> Dict:
|
||||
"""Extract elements necessary to represent a prediction into a
|
||||
dictionary. It's better to contain only basic data elements such as
|
||||
strings and numbers in order to guarantee it's json-serializable.
|
||||
|
|
Loading…
Reference in New Issue