diff --git a/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py index 8e20c8fd..f073064a 100644 --- a/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py +++ b/configs/kie/sdmgr/sdmgr_unet16_60e_wildreceipt.py @@ -25,7 +25,10 @@ test_pipeline = [ dict( type='Collect', keys=['img', 'relations', 'texts', 'gt_bboxes'], - meta_keys=['img_norm_cfg', 'img_shape', 'ori_filename']) + meta_keys=[ + 'img_norm_cfg', 'img_shape', 'ori_filename', 'filename', + 'ori_texts' + ]) ] dataset_type = 'KIEDataset' diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index 5fc4733e..2cd9c654 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -137,6 +137,8 @@ def model_inference(model, img_prefix=None, ann_info=ann, bbox_fields=[]) + if ann is not None: + data.update(dict(**ann)) # build the data pipeline data = test_pipeline(data) diff --git a/tools/kie_test_imgs.py b/tools/kie_test_imgs.py index effb2bd0..caabc5d5 100755 --- a/tools/kie_test_imgs.py +++ b/tools/kie_test_imgs.py @@ -17,6 +17,38 @@ from mmocr.datasets import build_dataloader, build_dataset from mmocr.models import build_detector +def save_results(model, img_meta, gt_bboxes, result, out_dir): + assert 'filename' in img_meta, ('Please add "filename" ' + 'to "meta_keys" in config.') + assert 'ori_texts' in img_meta, ('Please add "ori_texts" ' + 'to "meta_keys" in config.') + + out_json_file = osp.join(out_dir, + osp.basename(img_meta['filename']) + '.json') + + idx_to_cls = {} + if model.module.class_list is not None: + for line in mmcv.list_from_file(model.module.class_list): + class_idx, class_label = line.strip().split() + idx_to_cls[int(class_idx)] = class_label + + json_result = [{ + 'text': + text, + 'box': + box, + 'pred': + idx_to_cls.get( + pred.argmax(-1).cpu().item(), + pred.argmax(-1).cpu().item()), + 'conf': + pred.max(-1)[0].cpu().item() + } for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes, + result['nodes'])] + + mmcv.dump(json_result, out_json_file) + + def test(model, data_loader, show=False, out_dir=None): model.eval() results = [] @@ -57,6 +89,10 @@ def test(model, data_loader, show=False, out_dir=None): show=show, out_file=out_file) + if out_dir: + save_results(model, img_meta, gt_bboxes[i], result[i], + out_dir) + for _ in range(batch_size): prog_bar.update() return results @@ -69,7 +105,8 @@ def parse_args(): parser.add_argument('checkpoint', help='Checkpoint file.') parser.add_argument('--show', action='store_true', help='Show results.') parser.add_argument( - '--show-dir', help='Directory where the output images will be saved.') + '--out-dir', + help='Directory where the output images and results will be saved.') parser.add_argument('--local_rank', type=int, default=0) parser.add_argument( '--device', @@ -84,10 +121,10 @@ def parse_args(): def main(): args = parse_args() - assert args.show or args.show_dir, ('Please specify at least one ' - 'operation (show the results / save )' - 'the results with the argument ' - '"--show" or "--show-dir".') + assert args.show or args.out_dir, ('Please specify at least one ' + 'operation (show the results / save )' + 'the results with the argument ' + '"--show" or "--out-dir".') device = args.device if device is not None: device = ast.literal_eval(f'[{device}]') @@ -117,7 +154,7 @@ def main(): load_checkpoint(model, args.checkpoint, map_location='cpu') model = MMDataParallel(model, device_ids=device) - test(model, data_loader, args.show, args.show_dir) + test(model, data_loader, args.show, args.out_dir) if __name__ == '__main__':