mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Save results to json file for kie. (#589)
* save results json for kie * update config * Fix KIE inference bug Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>pull/593/head
parent
f2b9ba93be
commit
98c5bff1e0
|
@ -25,7 +25,10 @@ test_pipeline = [
|
|||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'relations', 'texts', 'gt_bboxes'],
|
||||
meta_keys=['img_norm_cfg', 'img_shape', 'ori_filename'])
|
||||
meta_keys=[
|
||||
'img_norm_cfg', 'img_shape', 'ori_filename', 'filename',
|
||||
'ori_texts'
|
||||
])
|
||||
]
|
||||
|
||||
dataset_type = 'KIEDataset'
|
||||
|
|
|
@ -137,6 +137,8 @@ def model_inference(model,
|
|||
img_prefix=None,
|
||||
ann_info=ann,
|
||||
bbox_fields=[])
|
||||
if ann is not None:
|
||||
data.update(dict(**ann))
|
||||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
|
|
|
@ -17,6 +17,38 @@ from mmocr.datasets import build_dataloader, build_dataset
|
|||
from mmocr.models import build_detector
|
||||
|
||||
|
||||
def save_results(model, img_meta, gt_bboxes, result, out_dir):
|
||||
assert 'filename' in img_meta, ('Please add "filename" '
|
||||
'to "meta_keys" in config.')
|
||||
assert 'ori_texts' in img_meta, ('Please add "ori_texts" '
|
||||
'to "meta_keys" in config.')
|
||||
|
||||
out_json_file = osp.join(out_dir,
|
||||
osp.basename(img_meta['filename']) + '.json')
|
||||
|
||||
idx_to_cls = {}
|
||||
if model.module.class_list is not None:
|
||||
for line in mmcv.list_from_file(model.module.class_list):
|
||||
class_idx, class_label = line.strip().split()
|
||||
idx_to_cls[int(class_idx)] = class_label
|
||||
|
||||
json_result = [{
|
||||
'text':
|
||||
text,
|
||||
'box':
|
||||
box,
|
||||
'pred':
|
||||
idx_to_cls.get(
|
||||
pred.argmax(-1).cpu().item(),
|
||||
pred.argmax(-1).cpu().item()),
|
||||
'conf':
|
||||
pred.max(-1)[0].cpu().item()
|
||||
} for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes,
|
||||
result['nodes'])]
|
||||
|
||||
mmcv.dump(json_result, out_json_file)
|
||||
|
||||
|
||||
def test(model, data_loader, show=False, out_dir=None):
|
||||
model.eval()
|
||||
results = []
|
||||
|
@ -57,6 +89,10 @@ def test(model, data_loader, show=False, out_dir=None):
|
|||
show=show,
|
||||
out_file=out_file)
|
||||
|
||||
if out_dir:
|
||||
save_results(model, img_meta, gt_bboxes[i], result[i],
|
||||
out_dir)
|
||||
|
||||
for _ in range(batch_size):
|
||||
prog_bar.update()
|
||||
return results
|
||||
|
@ -69,7 +105,8 @@ def parse_args():
|
|||
parser.add_argument('checkpoint', help='Checkpoint file.')
|
||||
parser.add_argument('--show', action='store_true', help='Show results.')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='Directory where the output images will be saved.')
|
||||
'--out-dir',
|
||||
help='Directory where the output images and results will be saved.')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
|
@ -84,10 +121,10 @@ def parse_args():
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
assert args.show or args.show_dir, ('Please specify at least one '
|
||||
'operation (show the results / save )'
|
||||
'the results with the argument '
|
||||
'"--show" or "--show-dir".')
|
||||
assert args.show or args.out_dir, ('Please specify at least one '
|
||||
'operation (show the results / save )'
|
||||
'the results with the argument '
|
||||
'"--show" or "--out-dir".')
|
||||
device = args.device
|
||||
if device is not None:
|
||||
device = ast.literal_eval(f'[{device}]')
|
||||
|
@ -117,7 +154,7 @@ def main():
|
|||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
|
||||
model = MMDataParallel(model, device_ids=device)
|
||||
test(model, data_loader, args.show, args.show_dir)
|
||||
test(model, data_loader, args.show, args.out_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue