[Feature] Save results to json file for kie. ()

* save results json for kie

* update config

* Fix KIE inference bug

Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>
pull/593/head
Hongbin Sun 2021-11-15 20:28:13 +08:00 committed by GitHub
parent f2b9ba93be
commit 98c5bff1e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 7 deletions

View File

@ -25,7 +25,10 @@ test_pipeline = [
dict(
type='Collect',
keys=['img', 'relations', 'texts', 'gt_bboxes'],
meta_keys=['img_norm_cfg', 'img_shape', 'ori_filename'])
meta_keys=[
'img_norm_cfg', 'img_shape', 'ori_filename', 'filename',
'ori_texts'
])
]
dataset_type = 'KIEDataset'

View File

@ -137,6 +137,8 @@ def model_inference(model,
img_prefix=None,
ann_info=ann,
bbox_fields=[])
if ann is not None:
data.update(dict(**ann))
# build the data pipeline
data = test_pipeline(data)

View File

@ -17,6 +17,38 @@ from mmocr.datasets import build_dataloader, build_dataset
from mmocr.models import build_detector
def save_results(model, img_meta, gt_bboxes, result, out_dir):
assert 'filename' in img_meta, ('Please add "filename" '
'to "meta_keys" in config.')
assert 'ori_texts' in img_meta, ('Please add "ori_texts" '
'to "meta_keys" in config.')
out_json_file = osp.join(out_dir,
osp.basename(img_meta['filename']) + '.json')
idx_to_cls = {}
if model.module.class_list is not None:
for line in mmcv.list_from_file(model.module.class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[int(class_idx)] = class_label
json_result = [{
'text':
text,
'box':
box,
'pred':
idx_to_cls.get(
pred.argmax(-1).cpu().item(),
pred.argmax(-1).cpu().item()),
'conf':
pred.max(-1)[0].cpu().item()
} for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes,
result['nodes'])]
mmcv.dump(json_result, out_json_file)
def test(model, data_loader, show=False, out_dir=None):
model.eval()
results = []
@ -57,6 +89,10 @@ def test(model, data_loader, show=False, out_dir=None):
show=show,
out_file=out_file)
if out_dir:
save_results(model, img_meta, gt_bboxes[i], result[i],
out_dir)
for _ in range(batch_size):
prog_bar.update()
return results
@ -69,7 +105,8 @@ def parse_args():
parser.add_argument('checkpoint', help='Checkpoint file.')
parser.add_argument('--show', action='store_true', help='Show results.')
parser.add_argument(
'--show-dir', help='Directory where the output images will be saved.')
'--out-dir',
help='Directory where the output images and results will be saved.')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
@ -84,10 +121,10 @@ def parse_args():
def main():
args = parse_args()
assert args.show or args.show_dir, ('Please specify at least one '
'operation (show the results / save )'
'the results with the argument '
'"--show" or "--show-dir".')
assert args.show or args.out_dir, ('Please specify at least one '
'operation (show the results / save )'
'the results with the argument '
'"--show" or "--out-dir".')
device = args.device
if device is not None:
device = ast.literal_eval(f'[{device}]')
@ -117,7 +154,7 @@ def main():
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=device)
test(model, data_loader, args.show, args.show_dir)
test(model, data_loader, args.show, args.out_dir)
if __name__ == '__main__':