diff --git a/demo/README.md b/demo/README.md index 4b513123..7b87d905 100644 --- a/demo/README.md +++ b/demo/README.md @@ -26,5 +26,15 @@ Please refer to [End2End Demo](docs/ocr_demo.md) for the tutorial of Text Detect <img src="resources/demo_ocr_pred.jpg"/><br> </div> +<br> +<br> + +Please refer to [KIE End2End Demo](docs/kie_demo.md) for the tutorial of KIE end-to-end demo. + +<div align="center"> + <img src="resources/demo_kie_pred.jpeg"/><br> + +</div> + <br> diff --git a/demo/demo_kie.jpeg b/demo/demo_kie.jpeg new file mode 100755 index 00000000..51014d8e Binary files /dev/null and b/demo/demo_kie.jpeg differ diff --git a/demo/resources/demo_kie_pred.png b/demo/resources/demo_kie_pred.png new file mode 100644 index 00000000..4a84d0c9 Binary files /dev/null and b/demo/resources/demo_kie_pred.png differ diff --git a/docs/demo.md b/docs/demo.md index 057de082..26653df0 100644 --- a/docs/demo.md +++ b/docs/demo.md @@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch ```shell python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow ``` -*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.* +*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. * - Python interface: ```python @@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True) ``` --- +## Example 4: Text Detection + Recognition + Key Information Extraction + +<div align="center"> + <img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br> +</div> +<br> + +**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization. + +- CL interface: +```shell +python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow +``` +*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. * + +- Python interface: +```python +from mmocr.utils.ocr import MMOCR + +# Load models into memory +ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR') + +# Inference +results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True) +``` +--- + ## API Arguments The API has an extensive list of arguments that you can use. The following tables are for the python interface. @@ -91,16 +118,21 @@ The API has an extensive list of arguments that you can use. The following table | Arguments | Type | Default | Description | | -------------- | --------------------- | ------------- | ----------------------------------------------------------- | -| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm | +| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm | | `recog` | see [models](#models) | SAR | Text recognition algorithm | +| `kie` [1] | see [models](#models) | None | Key information extraction algorithm | | `config_dir` | str | configs/ | Path to the config directory where all the config files are located | | `det_config` | str | None | Path to the custom config file of the selected det model | | `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model | -| `recog_config` | str | None | Path to the custom config file of the selected recog model model | -| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model | +| `recog_config` | str | None | Path to the custom config file of the selected recog model | +| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model | +| `kie_config` | str | None | Path to the custom config file of the selected kie model | +| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model | | `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' | -**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`. +[1]: `kie` is only effective when both text detection and recognition models are specified. + +**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`. ### readtext(): @@ -117,8 +149,8 @@ The API has an extensive list of arguments that you can use. The following table | `details` | bool | False | Whether include the text boxes coordinates and confidence values | | `imshow` | bool | False | Whether to show the result visualization on screen | | `print_result` | bool | False | Whether to show the result for each image | -| `merge` | bool | False | Whether to merge neighboring boxes [2] | -| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes | +| `merge` | bool | False | Whether to merge neighboring boxes [2] | +| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes | [1]: Make sure that the model is compatible with batch mode. @@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`) | SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: | | CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: | +**Key information extraction:** + +| Name | Reference | `batch_mode` support | +| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: | +| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: | --- ## Additional info diff --git a/docs_zh_CN/demo.md b/docs_zh_CN/demo.md index 057de082..26653df0 100644 --- a/docs_zh_CN/demo.md +++ b/docs_zh_CN/demo.md @@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch ```shell python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow ``` -*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.* +*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. * - Python interface: ```python @@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True) ``` --- +## Example 4: Text Detection + Recognition + Key Information Extraction + +<div align="center"> + <img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br> +</div> +<br> + +**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization. + +- CL interface: +```shell +python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow +``` +*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. * + +- Python interface: +```python +from mmocr.utils.ocr import MMOCR + +# Load models into memory +ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR') + +# Inference +results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True) +``` +--- + ## API Arguments The API has an extensive list of arguments that you can use. The following tables are for the python interface. @@ -91,16 +118,21 @@ The API has an extensive list of arguments that you can use. The following table | Arguments | Type | Default | Description | | -------------- | --------------------- | ------------- | ----------------------------------------------------------- | -| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm | +| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm | | `recog` | see [models](#models) | SAR | Text recognition algorithm | +| `kie` [1] | see [models](#models) | None | Key information extraction algorithm | | `config_dir` | str | configs/ | Path to the config directory where all the config files are located | | `det_config` | str | None | Path to the custom config file of the selected det model | | `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model | -| `recog_config` | str | None | Path to the custom config file of the selected recog model model | -| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model | +| `recog_config` | str | None | Path to the custom config file of the selected recog model | +| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model | +| `kie_config` | str | None | Path to the custom config file of the selected kie model | +| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model | | `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' | -**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`. +[1]: `kie` is only effective when both text detection and recognition models are specified. + +**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`. ### readtext(): @@ -117,8 +149,8 @@ The API has an extensive list of arguments that you can use. The following table | `details` | bool | False | Whether include the text boxes coordinates and confidence values | | `imshow` | bool | False | Whether to show the result visualization on screen | | `print_result` | bool | False | Whether to show the result for each image | -| `merge` | bool | False | Whether to merge neighboring boxes [2] | -| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes | +| `merge` | bool | False | Whether to merge neighboring boxes [2] | +| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes | [1]: Make sure that the model is compatible with batch mode. @@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`) | SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: | | CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: | +**Key information extraction:** + +| Name | Reference | `batch_mode` support | +| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: | +| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: | --- ## Additional info diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index 00ef7272..be979343 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -30,7 +30,11 @@ def disable_text_recog_aug_test(cfg, set_types=None): return cfg -def model_inference(model, imgs, batch_mode=False): +def model_inference(model, + imgs, + ann=None, + batch_mode=False, + return_data=False): """Inference image(s) with the detector. Args: @@ -38,6 +42,8 @@ def model_inference(model, imgs, batch_mode=False): imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): Either image files or loaded images. batch_mode (bool): If True, use batch mode for inference. + ann (dict): Annotation info for key information extraction. + return_data: Return postprocessed data. Returns: result (dict): Predicted results. """ @@ -75,10 +81,14 @@ def model_inference(model, imgs, batch_mode=False): # prepare data if is_ndarray: # directly add img - data = dict(img=img) + data = dict(img=img, ann_info=ann, bbox_fields=[]) else: # add information into dict - data = dict(img_info=dict(filename=img), img_prefix=None) + data = dict( + img_info=dict(filename=img), + img_prefix=None, + ann_info=ann, + bbox_fields=[]) # build the data pipeline data = test_pipeline(data) @@ -111,6 +121,14 @@ def model_inference(model, imgs, batch_mode=False): else: data['img'] = data['img'].data + # for KIE models + if ann is not None: + data['relations'] = data['relations'].data[0] + data['gt_bboxes'] = data['gt_bboxes'].data[0] + data['texts'] = data['texts'].data[0] + data['img'] = data['img'][0] + data['img_metas'] = data['img_metas'][0] + if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] @@ -125,9 +143,13 @@ def model_inference(model, imgs, batch_mode=False): results = model(return_loss=False, rescale=True, **data) if not is_batch: - return results[0] + if not return_data: + return results[0] + return results[0], datas[0] else: - return results + if not return_data: + return results + return results, datas def text_model_inference(model, input_sentence): diff --git a/mmocr/datasets/kie_dataset.py b/mmocr/datasets/kie_dataset.py index 4c959900..566e4eaa 100644 --- a/mmocr/datasets/kie_dataset.py +++ b/mmocr/datasets/kie_dataset.py @@ -1,4 +1,5 @@ import copy +import warnings from os import path as osp import numpy as np @@ -28,22 +29,28 @@ class KIEDataset(BaseDataset): """ def __init__(self, - ann_file, - loader, - dict_file, + ann_file=None, + loader=None, + dict_file=None, img_prefix='', pipeline=None, norm=10., directed=False, test_mode=True, **kwargs): - super().__init__( - ann_file, - loader, - pipeline, - img_prefix=img_prefix, - test_mode=test_mode) - assert osp.exists(dict_file) + if ann_file is None and loader is None: + warnings.warn( + 'KIEDataset is only initialized as a downstream demo task ' + 'of text detection and recognition ' + 'without an annotation file.', UserWarning) + else: + super().__init__( + ann_file, + loader, + pipeline, + img_prefix=img_prefix, + test_mode=test_mode) + assert osp.exists(dict_file) self.norm = norm self.directed = directed diff --git a/mmocr/utils/ocr.py b/mmocr/utils/ocr.py index 3130bac9..85d9f9a9 100644 --- a/mmocr/utils/ocr.py +++ b/mmocr/utils/ocr.py @@ -1,3 +1,4 @@ +import copy import os import warnings from argparse import ArgumentParser, Namespace @@ -5,288 +6,19 @@ from pathlib import Path import mmcv import numpy as np +import torch +from mmcv.image.misc import tensor2imgs +from mmcv.runner import load_checkpoint +from mmcv.utils.config import Config from mmdet.apis import init_detector from mmocr.apis.inference import model_inference from mmocr.core.visualize import det_recog_show_result +from mmocr.datasets.kie_dataset import KIEDataset from mmocr.datasets.pipelines.crop import crop_img +from mmocr.models import build_detector from mmocr.utils.box_util import stitch_boxes_into_lines - -textdet_models = { - 'DB_r18': { - 'config': 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', - 'ckpt': - 'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth' - }, - 'DB_r50': { - 'config': - 'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py', - 'ckpt': - 'dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth' - }, - 'DRRG': { - 'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py', - 'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth' - }, - 'FCE_IC15': { - 'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py', - 'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth' - }, - 'FCE_CTW_DCNv2': { - 'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', - 'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth' - }, - 'MaskRCNN_CTW': { - 'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', - 'ckpt': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth' - }, - 'MaskRCNN_IC15': { - 'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', - 'ckpt': - 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth' - }, - 'MaskRCNN_IC17': { - 'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py', - 'ckpt': - 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth' - }, - 'PANet_CTW': { - 'config': 'panet/panet_r18_fpem_ffm_600e_ctw1500.py', - 'ckpt': - 'panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth' - }, - 'PANet_IC15': { - 'config': 'panet/panet_r18_fpem_ffm_600e_icdar2015.py', - 'ckpt': - 'panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth' - }, - 'PS_CTW': { - 'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py', - 'ckpt': 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth' - }, - 'PS_IC15': { - 'config': 'psenet/psenet_r50_fpnf_600e_icdar2015.py', - 'ckpt': 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth' - }, - 'TextSnake': { - 'config': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py', - 'ckpt': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth' - } -} - -textrecog_models = { - 'CRNN': { - 'config': 'crnn/crnn_academic_dataset.py', - 'ckpt': 'crnn/crnn_academic-a723a1c5.pth' - }, - 'SAR': { - 'config': 'sar/sar_r31_parallel_decoder_academic.py', - 'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth' - }, - 'NRTR_1/16-1/8': { - 'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py', - 'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth' - }, - 'NRTR_1/8-1/4': { - 'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py', - 'ckpt': 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth' - }, - 'RobustScanner': { - 'config': 'robust_scanner/robustscanner_r31_academic.py', - 'ckpt': 'robust_scanner/robustscanner_r31_academic-5f05874f.pth' - }, - 'SEG': { - 'config': 'seg/seg_r31_1by16_fpnocr_academic.py', - 'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth' - }, - 'CRNN_TPS': { - 'config': 'tps/crnn_tps_academic_dataset.py', - 'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth' - } -} - - -# Post processing function for end2end ocr -def det_recog_pp(args, result): - final_results = [] - for arr, output, export, det_recog_result in zip(args.arrays, args.output, - args.export, result): - if output or args.imshow: - res_img = det_recog_show_result( - arr, det_recog_result, out_file=output) - if args.imshow: - mmcv.imshow(res_img, 'inference results') - if not args.details: - simple_res = {} - simple_res['filename'] = det_recog_result['filename'] - simple_res['text'] = [ - x['text'] for x in det_recog_result['result'] - ] - final_result = simple_res - else: - final_result = det_recog_result - if export: - mmcv.dump(final_result, export, indent=4) - if args.print_result: - print(final_result, end='\n\n') - final_results.append(final_result) - return final_results - - -# Post processing function for separate det/recog inference -def single_pp(args, result, model): - for arr, output, export, res in zip(args.arrays, args.output, args.export, - result): - if export: - mmcv.dump(res, export, indent=4) - if output or args.imshow: - res_img = model.show_result(arr, res, out_file=output) - if args.imshow: - mmcv.imshow(res_img, 'inference results') - if args.print_result: - print(res, end='\n\n') - return result - - -# End2end ocr inference pipeline -def det_and_recog_inference(args, det_model, recog_model): - end2end_res = [] - # Find bounding boxes in the images (text detection) - det_result = single_inference(det_model, args.arrays, args.batch_mode, - args.det_batch_size) - bboxes_list = [res['boundary_result'] for res in det_result] - - # For each bounding box, the image is cropped and sent to the recognition - # model either one by one or all together depending on the batch_mode - for filename, arr, bboxes in zip(args.filenames, args.arrays, bboxes_list): - img_e2e_res = {} - img_e2e_res['filename'] = filename - img_e2e_res['result'] = [] - box_imgs = [] - for bbox in bboxes: - box_res = {} - box_res['box'] = [round(x) for x in bbox[:-1]] - box_res['box_score'] = float(bbox[-1]) - box = bbox[:8] - if len(bbox) > 9: - min_x = min(bbox[0:-1:2]) - min_y = min(bbox[1:-1:2]) - max_x = max(bbox[0:-1:2]) - max_y = max(bbox[1:-1:2]) - box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y] - box_img = crop_img(arr, box) - if args.batch_mode: - box_imgs.append(box_img) - else: - recog_result = model_inference(recog_model, box_img) - text = recog_result['text'] - text_score = recog_result['score'] - if isinstance(text_score, list): - text_score = sum(text_score) / max(1, len(text)) - box_res['text'] = text - box_res['text_score'] = text_score - img_e2e_res['result'].append(box_res) - - if args.batch_mode: - recog_results = single_inference(recog_model, box_imgs, True, - args.recog_batch_size) - for i, recog_result in enumerate(recog_results): - text = recog_result['text'] - text_score = recog_result['score'] - if isinstance(text_score, (list, tuple)): - text_score = sum(text_score) / max(1, len(text)) - img_e2e_res['result'][i]['text'] = text - img_e2e_res['result'][i]['text_score'] = text_score - - if args.merge: - img_e2e_res['result'] = stitch_boxes_into_lines( - img_e2e_res['result'], args.merge_xdist, 0.5) - - end2end_res.append(img_e2e_res) - return end2end_res - - -# Separate det/recog inference pipeline -def single_inference(model, arrays, batch_mode, batch_size): - result = [] - if batch_mode: - if batch_size == 0: - result = model_inference(model, arrays, batch_mode=True) - else: - n = batch_size - arr_chunks = [arrays[i:i + n] for i in range(0, len(arrays), n)] - for chunk in arr_chunks: - result.extend(model_inference(model, chunk, batch_mode=True)) - else: - for arr in arrays: - result.append(model_inference(model, arr, batch_mode=False)) - return result - - -# Arguments pre-processing function -def args_processing(args): - # Check if the input is a list/tuple that - # contains only np arrays or strings - if isinstance(args.img, (list, tuple)): - img_list = args.img - if not all([isinstance(x, (np.ndarray, str)) for x in args.img]): - raise AssertionError('Images must be strings or numpy arrays') - - # Create a list of the images - if isinstance(args.img, str): - img_path = Path(args.img) - if img_path.is_dir(): - img_list = [str(x) for x in img_path.glob('*')] - else: - img_list = [str(img_path)] - elif isinstance(args.img, np.ndarray): - img_list = [args.img] - - # Read all image(s) in advance to reduce wasted time - # re-reading the images for vizualisation output - args.arrays = [mmcv.imread(x) for x in img_list] - - # Create a list of filenames (used for output images and result files) - if isinstance(img_list[0], str): - args.filenames = [str(Path(x).stem) for x in img_list] - else: - args.filenames = [str(x) for x in range(len(img_list))] - - # If given an output argument, create a list of output image filenames - num_res = len(img_list) - if args.output: - output_path = Path(args.output) - if output_path.is_dir(): - args.output = [ - str(output_path / f'out_{x}.png') for x in args.filenames - ] - else: - args.output = [str(args.output)] - if args.batch_mode: - raise AssertionError( - 'Output of multiple images inference must be a directory') - else: - args.output = [None] * num_res - - # If given an export argument, create a list of - # result filenames for each image - if args.export: - export_path = Path(args.export) - args.export = [ - str(export_path / f'out_{x}.{args.export_format}') - for x in args.filenames - ] - else: - args.export = [None] * num_res - - return args - - -# Create an inference pipeline with parsed arguments -def main(): - args = parse_args() - ocr = MMOCR(**vars(args)) - ocr.readtext(**vars(args)) +from mmocr.utils.fileio import list_from_file # Parse CLI arguments @@ -333,6 +65,23 @@ def parse_args(): default='', help='Path to the custom checkpoint file of the selected recog model. ' 'It overrides the settings in recog') + parser.add_argument( + '--kie', + type=str, + default='', + help='Pretrained key information extraction algorithm') + parser.add_argument( + '--kie-config', + type=str, + default='', + help='Path to the custom config file of the selected kie model. It' + 'overrides the settings in kie') + parser.add_argument( + '--kie-ckpt', + type=str, + default='', + help='Path to the custom checkpoint file of the selected kie model. ' + 'It overrides the settings in kie') parser.add_argument( '--config-dir', type=str, @@ -418,11 +167,138 @@ class MMOCR: recog='SEG', recog_config='', recog_ckpt='', + kie='', + kie_config='', + kie_ckpt='', config_dir=os.path.join(str(Path.cwd()), 'configs/'), device='cuda:0', **kwargs): + + textdet_models = { + 'DB_r18': { + 'config': + 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py', + 'ckpt': + 'dbnet/' + 'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth' + }, + 'DB_r50': { + 'config': + 'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py', + 'ckpt': + 'dbnet/' + 'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth' + }, + 'DRRG': { + 'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py', + 'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth' + }, + 'FCE_IC15': { + 'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py', + 'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth' + }, + 'FCE_CTW_DCNv2': { + 'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py', + 'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth' + }, + 'MaskRCNN_CTW': { + 'config': + 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py', + 'ckpt': + 'maskrcnn/' + 'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth' + }, + 'MaskRCNN_IC15': { + 'config': + 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py', + 'ckpt': + 'maskrcnn/' + 'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth' + }, + 'MaskRCNN_IC17': { + 'config': + 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py', + 'ckpt': + 'maskrcnn/' + 'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth' + }, + 'PANet_CTW': { + 'config': + 'panet/panet_r18_fpem_ffm_600e_ctw1500.py', + 'ckpt': + 'panet/' + 'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth' + }, + 'PANet_IC15': { + 'config': + 'panet/panet_r18_fpem_ffm_600e_icdar2015.py', + 'ckpt': + 'panet/' + 'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth' + }, + 'PS_CTW': { + 'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py', + 'ckpt': + 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth' + }, + 'PS_IC15': { + 'config': + 'psenet/psenet_r50_fpnf_600e_icdar2015.py', + 'ckpt': + 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth' + }, + 'TextSnake': { + 'config': + 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py', + 'ckpt': + 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth' + } + } + + textrecog_models = { + 'CRNN': { + 'config': 'crnn/crnn_academic_dataset.py', + 'ckpt': 'crnn/crnn_academic-a723a1c5.pth' + }, + 'SAR': { + 'config': 'sar/sar_r31_parallel_decoder_academic.py', + 'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth' + }, + 'NRTR_1/16-1/8': { + 'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py', + 'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth' + }, + 'NRTR_1/8-1/4': { + 'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py', + 'ckpt': + 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth' + }, + 'RobustScanner': { + 'config': 'robust_scanner/robustscanner_r31_academic.py', + 'ckpt': + 'robust_scanner/robustscanner_r31_academic-5f05874f.pth' + }, + 'SEG': { + 'config': 'seg/seg_r31_1by16_fpnocr_academic.py', + 'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth' + }, + 'CRNN_TPS': { + 'config': 'tps/crnn_tps_academic_dataset.py', + 'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth' + } + } + + kie_models = { + 'SDMGR': { + 'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py', + 'ckpt': + 'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth' + } + } + self.td = det self.tr = recog + self.kie = kie self.device = device # Check if the det/recog model choice is valid @@ -432,7 +308,12 @@ class MMOCR: elif self.tr and self.tr not in textrecog_models: raise ValueError(self.tr, 'is not a supported text recognition algorithm') + elif self.kie and self.kie not in kie_models: + raise ValueError( + self.kie, 'is not a supported key information extraction' + ' algorithm') + self.detect_model = None if self.td: # Build detection model if not det_config: @@ -444,9 +325,8 @@ class MMOCR: self.detect_model = init_detector( det_config, det_ckpt, device=self.device) - else: - self.detect_model = None + self.recog_model = None if self.tr: # Build recognition model if not recog_config: @@ -454,13 +334,27 @@ class MMOCR: config_dir, 'textrecog/', textrecog_models[self.tr]['config']) if not recog_ckpt: - recog_ckpt = 'https://download.openmmlab.com/mmocr/' - 'textrecog/' + textrecog_models[self.tr]['ckpt'] + recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \ + 'textrecog/' + textrecog_models[self.tr]['ckpt'] self.recog_model = init_detector( recog_config, recog_ckpt, device=self.device) - else: - self.recog_model = None + + self.kie_model = None + if self.kie: + # Build key information extraction model + if not kie_config: + kie_config = os.path.join(config_dir, 'kie/', + kie_models[self.kie]['config']) + if not kie_ckpt: + kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \ + 'kie/' + kie_models[self.kie]['ckpt'] + + kie_cfg = Config.fromfile(kie_config) + self.kie_model = build_detector( + kie_cfg.model, test_cfg=kie_cfg.get('test_cfg')) + self.kie_model.cfg = kie_cfg + load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device) # Attribute check for model in list(filter(None, [self.recog_model, self.detect_model])): @@ -490,25 +384,292 @@ class MMOCR: args = Namespace(**args) # Input and output arguments processing - args = args_processing(args) + self._args_processing(args) + self.args = args pp_result = None # Send args and models to the MMOCR model inference API # and call post-processing functions for the output if self.detect_model and self.recog_model: - det_recog_result = det_and_recog_inference(args, self.detect_model, - self.recog_model) - pp_result = det_recog_pp(args, det_recog_result) + det_recog_result = self.det_recog_kie_inference( + self.detect_model, self.recog_model, kie_model=self.kie_model) + pp_result = self.det_recog_pp(det_recog_result) else: for model in list( filter(None, [self.recog_model, self.detect_model])): - result = single_inference(model, args.arrays, args.batch_mode, - args.single_batch_size) - pp_result = single_pp(args, result, model) + result = self.single_inference(model, args.arrays, + args.batch_mode, + args.single_batch_size) + pp_result = self.single_pp(args, result, model) return pp_result + # Post processing function for end2end ocr + def det_recog_pp(self, result): + final_results = [] + args = self.args + for arr, output, export, det_recog_result in zip( + args.arrays, args.output, args.export, result): + if output or args.imshow: + if self.kie_model: + res_img = det_recog_show_result(arr, det_recog_result) + else: + res_img = det_recog_show_result( + arr, det_recog_result, out_file=output) + if args.imshow and not self.kie_model: + mmcv.imshow(res_img, 'inference results') + if not args.details: + simple_res = {} + simple_res['filename'] = det_recog_result['filename'] + simple_res['text'] = [ + x['text'] for x in det_recog_result['result'] + ] + final_result = simple_res + else: + final_result = det_recog_result + if export: + mmcv.dump(final_result, export, indent=4) + if args.print_result: + print(final_result, end='\n\n') + final_results.append(final_result) + return final_results + + # Post processing function for separate det/recog inference + def single_pp(self, result, model): + for arr, output, export, res in zip(self.args.arrays, self.args.output, + self.args.export, result): + if export: + mmcv.dump(res, export, indent=4) + if output or self.args.imshow: + res_img = model.show_result(arr, res, out_file=output) + if self.args.imshow: + mmcv.imshow(res_img, 'inference results') + if self.args.print_result: + print(res, end='\n\n') + return result + + def generate_kie_labels(self, result, boxes, class_list): + idx_to_cls = {} + if class_list is not None: + for line in list_from_file(class_list): + class_idx, class_label = line.strip().split() + idx_to_cls[class_idx] = class_label + + max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + labels = [] + for i in range(len(boxes)): + pred_label = str(node_pred_label[i]) + if pred_label in idx_to_cls: + pred_label = idx_to_cls[pred_label] + pred_score = node_pred_score[i] + labels.append((pred_label, pred_score)) + return labels + + def visualize_kie_output(self, + model, + data, + result, + out_file=None, + show=False): + """Visualizes KIE output.""" + img_tensor = data['img'].data + img_meta = data['img_metas'].data + gt_bboxes = data['gt_bboxes'].data.numpy().tolist() + img = tensor2imgs(img_tensor.unsqueeze(0), + **img_meta['img_norm_cfg'])[0] + h, w, _ = img_meta['img_shape'] + img_show = img[:h, :w, :] + model.show_result( + img_show, result, gt_bboxes, show=show, out_file=out_file) + + # End2end ocr inference pipeline + def det_recog_kie_inference(self, det_model, recog_model, kie_model=None): + end2end_res = [] + # Find bounding boxes in the images (text detection) + det_result = self.single_inference(det_model, self.args.arrays, + self.args.batch_mode, + self.args.det_batch_size) + bboxes_list = [res['boundary_result'] for res in det_result] + + if kie_model: + kie_dataset = KIEDataset( + dict_file=kie_model.cfg.data.test.dict_file) + + # For each bounding box, the image is cropped and + # sent to the recognition model either one by one + # or all together depending on the batch_mode + for filename, arr, bboxes, out_file in zip(self.args.filenames, + self.args.arrays, + bboxes_list, + self.args.output): + img_e2e_res = {} + img_e2e_res['filename'] = filename + img_e2e_res['result'] = [] + box_imgs = [] + for bbox in bboxes: + box_res = {} + box_res['box'] = [round(x) for x in bbox[:-1]] + box_res['box_score'] = float(bbox[-1]) + box = bbox[:8] + if len(bbox) > 9: + min_x = min(bbox[0:-1:2]) + min_y = min(bbox[1:-1:2]) + max_x = max(bbox[0:-1:2]) + max_y = max(bbox[1:-1:2]) + box = [ + min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y + ] + box_img = crop_img(arr, box) + if self.args.batch_mode: + box_imgs.append(box_img) + else: + recog_result = model_inference(recog_model, box_img) + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, list): + text_score = sum(text_score) / max(1, len(text)) + box_res['text'] = text + box_res['text_score'] = text_score + img_e2e_res['result'].append(box_res) + + if self.args.batch_mode: + recog_results = self.single_inference( + recog_model, box_imgs, True, self.args.recog_batch_size) + for i, recog_result in enumerate(recog_results): + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, (list, tuple)): + text_score = sum(text_score) / max(1, len(text)) + img_e2e_res['result'][i]['text'] = text + img_e2e_res['result'][i]['text_score'] = text_score + + if self.args.merge: + img_e2e_res['result'] = stitch_boxes_into_lines( + img_e2e_res['result'], self.args.merge_xdist, 0.5) + + if kie_model: + annotations = copy.deepcopy(img_e2e_res['result']) + # Customized for kie_dataset, which + # assumes that boxes are represented by only 4 points + for i, ann in enumerate(annotations): + min_x = min(ann['box'][::2]) + min_y = min(ann['box'][1::2]) + max_x = max(ann['box'][::2]) + max_y = max(ann['box'][1::2]) + annotations[i]['box'] = [ + min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y + ] + ann_info = kie_dataset._parse_anno_info(annotations) + kie_result, data = model_inference( + kie_model, + arr, + ann=ann_info, + return_data=True, + batch_mode=self.args.batch_mode) + # visualize KIE results + self.visualize_kie_output( + kie_model, + data, + kie_result, + out_file=out_file, + show=self.args.imshow) + gt_bboxes = data['gt_bboxes'].data.numpy().tolist() + labels = self.generate_kie_labels(kie_result, gt_bboxes, + kie_model.class_list) + for i in range(len(gt_bboxes)): + img_e2e_res['result'][i]['label'] = labels[i][0] + img_e2e_res['result'][i]['label_score'] = labels[i][1] + + end2end_res.append(img_e2e_res) + return end2end_res + + # Separate det/recog inference pipeline + def single_inference(self, model, arrays, batch_mode, batch_size): + result = [] + if batch_mode: + if batch_size == 0: + result = model_inference(model, arrays, batch_mode=True) + else: + n = batch_size + arr_chunks = [ + arrays[i:i + n] for i in range(0, len(arrays), n) + ] + for chunk in arr_chunks: + result.extend( + model_inference(model, chunk, batch_mode=True)) + else: + for arr in arrays: + result.append(model_inference(model, arr, batch_mode=False)) + return result + + # Arguments pre-processing function + def _args_processing(self, args): + # Check if the input is a list/tuple that + # contains only np arrays or strings + if isinstance(args.img, (list, tuple)): + img_list = args.img + if not all([isinstance(x, (np.ndarray, str)) for x in args.img]): + raise AssertionError('Images must be strings or numpy arrays') + + # Create a list of the images + if isinstance(args.img, str): + img_path = Path(args.img) + if img_path.is_dir(): + img_list = [str(x) for x in img_path.glob('*')] + else: + img_list = [str(img_path)] + elif isinstance(args.img, np.ndarray): + img_list = [args.img] + + # Read all image(s) in advance to reduce wasted time + # re-reading the images for vizualisation output + args.arrays = [mmcv.imread(x) for x in img_list] + + # Create a list of filenames (used for output images and result files) + if isinstance(img_list[0], str): + args.filenames = [str(Path(x).stem) for x in img_list] + else: + args.filenames = [str(x) for x in range(len(img_list))] + + # If given an output argument, create a list of output image filenames + num_res = len(img_list) + if args.output: + output_path = Path(args.output) + if output_path.is_dir(): + args.output = [ + str(output_path / f'out_{x}.png') for x in args.filenames + ] + else: + args.output = [str(args.output)] + if args.batch_mode: + raise AssertionError('Output of multiple images inference' + ' must be a directory') + else: + args.output = [None] * num_res + + # If given an export argument, create a list of + # result filenames for each image + if args.export: + export_path = Path(args.export) + args.export = [ + str(export_path / f'out_{x}.{args.export_format}') + for x in args.filenames + ] + else: + args.export = [None] * num_res + + return args + + +# Create an inference pipeline with parsed arguments +def main(): + args = parse_args() + ocr = MMOCR(**vars(args)) + ocr.readtext(**vars(args)) + if __name__ == '__main__': main()