mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Save results to json file for kie. (#589)
* save results json for kie * update config * Fix KIE inference bug Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>pull/593/head
parent
f2b9ba93be
commit
98c5bff1e0
|
@ -25,7 +25,10 @@ test_pipeline = [
|
||||||
dict(
|
dict(
|
||||||
type='Collect',
|
type='Collect',
|
||||||
keys=['img', 'relations', 'texts', 'gt_bboxes'],
|
keys=['img', 'relations', 'texts', 'gt_bboxes'],
|
||||||
meta_keys=['img_norm_cfg', 'img_shape', 'ori_filename'])
|
meta_keys=[
|
||||||
|
'img_norm_cfg', 'img_shape', 'ori_filename', 'filename',
|
||||||
|
'ori_texts'
|
||||||
|
])
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset_type = 'KIEDataset'
|
dataset_type = 'KIEDataset'
|
||||||
|
|
|
@ -137,6 +137,8 @@ def model_inference(model,
|
||||||
img_prefix=None,
|
img_prefix=None,
|
||||||
ann_info=ann,
|
ann_info=ann,
|
||||||
bbox_fields=[])
|
bbox_fields=[])
|
||||||
|
if ann is not None:
|
||||||
|
data.update(dict(**ann))
|
||||||
|
|
||||||
# build the data pipeline
|
# build the data pipeline
|
||||||
data = test_pipeline(data)
|
data = test_pipeline(data)
|
||||||
|
|
|
@ -17,6 +17,38 @@ from mmocr.datasets import build_dataloader, build_dataset
|
||||||
from mmocr.models import build_detector
|
from mmocr.models import build_detector
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(model, img_meta, gt_bboxes, result, out_dir):
|
||||||
|
assert 'filename' in img_meta, ('Please add "filename" '
|
||||||
|
'to "meta_keys" in config.')
|
||||||
|
assert 'ori_texts' in img_meta, ('Please add "ori_texts" '
|
||||||
|
'to "meta_keys" in config.')
|
||||||
|
|
||||||
|
out_json_file = osp.join(out_dir,
|
||||||
|
osp.basename(img_meta['filename']) + '.json')
|
||||||
|
|
||||||
|
idx_to_cls = {}
|
||||||
|
if model.module.class_list is not None:
|
||||||
|
for line in mmcv.list_from_file(model.module.class_list):
|
||||||
|
class_idx, class_label = line.strip().split()
|
||||||
|
idx_to_cls[int(class_idx)] = class_label
|
||||||
|
|
||||||
|
json_result = [{
|
||||||
|
'text':
|
||||||
|
text,
|
||||||
|
'box':
|
||||||
|
box,
|
||||||
|
'pred':
|
||||||
|
idx_to_cls.get(
|
||||||
|
pred.argmax(-1).cpu().item(),
|
||||||
|
pred.argmax(-1).cpu().item()),
|
||||||
|
'conf':
|
||||||
|
pred.max(-1)[0].cpu().item()
|
||||||
|
} for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes,
|
||||||
|
result['nodes'])]
|
||||||
|
|
||||||
|
mmcv.dump(json_result, out_json_file)
|
||||||
|
|
||||||
|
|
||||||
def test(model, data_loader, show=False, out_dir=None):
|
def test(model, data_loader, show=False, out_dir=None):
|
||||||
model.eval()
|
model.eval()
|
||||||
results = []
|
results = []
|
||||||
|
@ -57,6 +89,10 @@ def test(model, data_loader, show=False, out_dir=None):
|
||||||
show=show,
|
show=show,
|
||||||
out_file=out_file)
|
out_file=out_file)
|
||||||
|
|
||||||
|
if out_dir:
|
||||||
|
save_results(model, img_meta, gt_bboxes[i], result[i],
|
||||||
|
out_dir)
|
||||||
|
|
||||||
for _ in range(batch_size):
|
for _ in range(batch_size):
|
||||||
prog_bar.update()
|
prog_bar.update()
|
||||||
return results
|
return results
|
||||||
|
@ -69,7 +105,8 @@ def parse_args():
|
||||||
parser.add_argument('checkpoint', help='Checkpoint file.')
|
parser.add_argument('checkpoint', help='Checkpoint file.')
|
||||||
parser.add_argument('--show', action='store_true', help='Show results.')
|
parser.add_argument('--show', action='store_true', help='Show results.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--show-dir', help='Directory where the output images will be saved.')
|
'--out-dir',
|
||||||
|
help='Directory where the output images and results will be saved.')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device',
|
'--device',
|
||||||
|
@ -84,10 +121,10 @@ def parse_args():
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
assert args.show or args.show_dir, ('Please specify at least one '
|
assert args.show or args.out_dir, ('Please specify at least one '
|
||||||
'operation (show the results / save )'
|
'operation (show the results / save )'
|
||||||
'the results with the argument '
|
'the results with the argument '
|
||||||
'"--show" or "--show-dir".')
|
'"--show" or "--out-dir".')
|
||||||
device = args.device
|
device = args.device
|
||||||
if device is not None:
|
if device is not None:
|
||||||
device = ast.literal_eval(f'[{device}]')
|
device = ast.literal_eval(f'[{device}]')
|
||||||
|
@ -117,7 +154,7 @@ def main():
|
||||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||||
|
|
||||||
model = MMDataParallel(model, device_ids=device)
|
model = MMDataParallel(model, device_ids=device)
|
||||||
test(model, data_loader, args.show, args.show_dir)
|
test(model, data_loader, args.show, args.out_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue