Add kie image demo and docs. (#374)

* Add kie_image_demo.py

* Add kie demo docs

* Add brief instructions for kie image demo

* Add ann file field and return data for kie image demo

* Follow lint and import rules

* Fix bugs, reuse functions in KIEDataset, and use a new demo pic

* Add config-dir and fix indexing bug in ocr script

* [Feature] Improve ocr.py

1. Add box stitching back to ocr.py
2. Add config_dir which allows users to specify the default config path
3. Warn users when overriding parameters are set
4. Allow users to use customized checkpoint files

* Add docs for new ocr.py

* Add docs for merge

* Support kie in ocr.py

* Merged kie to ocr.py

* update docs, remove unsupported unvisual sdmgr

* Update mmocr/apis/inference.py

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>

* Apply suggestions from code review

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>

* fix linting

Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>
Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
pull/395/head^2
Omkar Manjrekar 2021-08-04 11:50:13 +05:30 committed by GitHub
parent f24be6c614
commit 68ec3f5519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 592 additions and 318 deletions

View File

@ -26,5 +26,15 @@ Please refer to [End2End Demo](docs/ocr_demo.md) for the tutorial of Text Detect
<img src="resources/demo_ocr_pred.jpg"/><br> <img src="resources/demo_ocr_pred.jpg"/><br>
</div> </div>
<br>
<br>
Please refer to [KIE End2End Demo](docs/kie_demo.md) for the tutorial of KIE end-to-end demo.
<div align="center">
<img src="resources/demo_kie_pred.jpeg"/><br>
</div>
<br> <br>

BIN
demo/demo_kie.jpeg 100755

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 634 KiB

View File

@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch
```shell ```shell
python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow
``` ```
*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.* *Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface: - Python interface:
```python ```python
@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True)
``` ```
--- ---
## Example 4: Text Detection + Recognition + Key Information Extraction
<div align="center">
<img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br>
</div>
<br>
**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization.
- CL interface:
```shell
python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow
```
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface:
```python
from mmocr.utils.ocr import MMOCR
# Load models into memory
ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
# Inference
results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
```
---
## API Arguments ## API Arguments
The API has an extensive list of arguments that you can use. The following tables are for the python interface. The API has an extensive list of arguments that you can use. The following tables are for the python interface.
@ -91,16 +118,21 @@ The API has an extensive list of arguments that you can use. The following table
| Arguments | Type | Default | Description | | Arguments | Type | Default | Description |
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- | | -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm | | `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
| `recog` | see [models](#models) | SAR | Text recognition algorithm | | `recog` | see [models](#models) | SAR | Text recognition algorithm |
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
| `config_dir` | str | configs/ | Path to the config directory where all the config files are located | | `config_dir` | str | configs/ | Path to the config directory where all the config files are located |
| `det_config` | str | None | Path to the custom config file of the selected det model | | `det_config` | str | None | Path to the custom config file of the selected det model |
| `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model | | `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model |
| `recog_config` | str | None | Path to the custom config file of the selected recog model model | | `recog_config` | str | None | Path to the custom config file of the selected recog model |
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model | | `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model |
| `kie_config` | str | None | Path to the custom config file of the selected kie model |
| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model |
| `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' | | `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' |
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`. [1]: `kie` is only effective when both text detection and recognition models are specified.
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`.
### readtext(): ### readtext():
@ -117,8 +149,8 @@ The API has an extensive list of arguments that you can use. The following table
| `details` | bool | False | Whether include the text boxes coordinates and confidence values | | `details` | bool | False | Whether include the text boxes coordinates and confidence values |
| `imshow` | bool | False | Whether to show the result visualization on screen | | `imshow` | bool | False | Whether to show the result visualization on screen |
| `print_result` | bool | False | Whether to show the result for each image | | `print_result` | bool | False | Whether to show the result for each image |
| `merge` | bool | False | Whether to merge neighboring boxes [2] | | `merge` | bool | False | Whether to merge neighboring boxes [2] |
| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes | | `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes |
[1]: Make sure that the model is compatible with batch mode. [1]: Make sure that the model is compatible with batch mode.
@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`)
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: | | SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: | | CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
**Key information extraction:**
| Name | Reference | `batch_mode` support |
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
--- ---
## Additional info ## Additional info

View File

@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch
```shell ```shell
python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow
``` ```
*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.* *Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface: - Python interface:
```python ```python
@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True)
``` ```
--- ---
## Example 4: Text Detection + Recognition + Key Information Extraction
<div align="center">
<img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br>
</div>
<br>
**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization.
- CL interface:
```shell
python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow
```
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface:
```python
from mmocr.utils.ocr import MMOCR
# Load models into memory
ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
# Inference
results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
```
---
## API Arguments ## API Arguments
The API has an extensive list of arguments that you can use. The following tables are for the python interface. The API has an extensive list of arguments that you can use. The following tables are for the python interface.
@ -91,16 +118,21 @@ The API has an extensive list of arguments that you can use. The following table
| Arguments | Type | Default | Description | | Arguments | Type | Default | Description |
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- | | -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm | | `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
| `recog` | see [models](#models) | SAR | Text recognition algorithm | | `recog` | see [models](#models) | SAR | Text recognition algorithm |
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
| `config_dir` | str | configs/ | Path to the config directory where all the config files are located | | `config_dir` | str | configs/ | Path to the config directory where all the config files are located |
| `det_config` | str | None | Path to the custom config file of the selected det model | | `det_config` | str | None | Path to the custom config file of the selected det model |
| `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model | | `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model |
| `recog_config` | str | None | Path to the custom config file of the selected recog model model | | `recog_config` | str | None | Path to the custom config file of the selected recog model |
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model | | `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model |
| `kie_config` | str | None | Path to the custom config file of the selected kie model |
| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model |
| `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' | | `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' |
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`. [1]: `kie` is only effective when both text detection and recognition models are specified.
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`.
### readtext(): ### readtext():
@ -117,8 +149,8 @@ The API has an extensive list of arguments that you can use. The following table
| `details` | bool | False | Whether include the text boxes coordinates and confidence values | | `details` | bool | False | Whether include the text boxes coordinates and confidence values |
| `imshow` | bool | False | Whether to show the result visualization on screen | | `imshow` | bool | False | Whether to show the result visualization on screen |
| `print_result` | bool | False | Whether to show the result for each image | | `print_result` | bool | False | Whether to show the result for each image |
| `merge` | bool | False | Whether to merge neighboring boxes [2] | | `merge` | bool | False | Whether to merge neighboring boxes [2] |
| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes | | `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes |
[1]: Make sure that the model is compatible with batch mode. [1]: Make sure that the model is compatible with batch mode.
@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`)
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: | | SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: | | CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
**Key information extraction:**
| Name | Reference | `batch_mode` support |
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
--- ---
## Additional info ## Additional info

View File

@ -30,7 +30,11 @@ def disable_text_recog_aug_test(cfg, set_types=None):
return cfg return cfg
def model_inference(model, imgs, batch_mode=False): def model_inference(model,
imgs,
ann=None,
batch_mode=False,
return_data=False):
"""Inference image(s) with the detector. """Inference image(s) with the detector.
Args: Args:
@ -38,6 +42,8 @@ def model_inference(model, imgs, batch_mode=False):
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images. Either image files or loaded images.
batch_mode (bool): If True, use batch mode for inference. batch_mode (bool): If True, use batch mode for inference.
ann (dict): Annotation info for key information extraction.
return_data: Return postprocessed data.
Returns: Returns:
result (dict): Predicted results. result (dict): Predicted results.
""" """
@ -75,10 +81,14 @@ def model_inference(model, imgs, batch_mode=False):
# prepare data # prepare data
if is_ndarray: if is_ndarray:
# directly add img # directly add img
data = dict(img=img) data = dict(img=img, ann_info=ann, bbox_fields=[])
else: else:
# add information into dict # add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None) data = dict(
img_info=dict(filename=img),
img_prefix=None,
ann_info=ann,
bbox_fields=[])
# build the data pipeline # build the data pipeline
data = test_pipeline(data) data = test_pipeline(data)
@ -111,6 +121,14 @@ def model_inference(model, imgs, batch_mode=False):
else: else:
data['img'] = data['img'].data data['img'] = data['img'].data
# for KIE models
if ann is not None:
data['relations'] = data['relations'].data[0]
data['gt_bboxes'] = data['gt_bboxes'].data[0]
data['texts'] = data['texts'].data[0]
data['img'] = data['img'][0]
data['img_metas'] = data['img_metas'][0]
if next(model.parameters()).is_cuda: if next(model.parameters()).is_cuda:
# scatter to specified GPU # scatter to specified GPU
data = scatter(data, [device])[0] data = scatter(data, [device])[0]
@ -125,9 +143,13 @@ def model_inference(model, imgs, batch_mode=False):
results = model(return_loss=False, rescale=True, **data) results = model(return_loss=False, rescale=True, **data)
if not is_batch: if not is_batch:
return results[0] if not return_data:
return results[0]
return results[0], datas[0]
else: else:
return results if not return_data:
return results
return results, datas
def text_model_inference(model, input_sentence): def text_model_inference(model, input_sentence):

View File

@ -1,4 +1,5 @@
import copy import copy
import warnings
from os import path as osp from os import path as osp
import numpy as np import numpy as np
@ -28,22 +29,28 @@ class KIEDataset(BaseDataset):
""" """
def __init__(self, def __init__(self,
ann_file, ann_file=None,
loader, loader=None,
dict_file, dict_file=None,
img_prefix='', img_prefix='',
pipeline=None, pipeline=None,
norm=10., norm=10.,
directed=False, directed=False,
test_mode=True, test_mode=True,
**kwargs): **kwargs):
super().__init__( if ann_file is None and loader is None:
ann_file, warnings.warn(
loader, 'KIEDataset is only initialized as a downstream demo task '
pipeline, 'of text detection and recognition '
img_prefix=img_prefix, 'without an annotation file.', UserWarning)
test_mode=test_mode) else:
assert osp.exists(dict_file) super().__init__(
ann_file,
loader,
pipeline,
img_prefix=img_prefix,
test_mode=test_mode)
assert osp.exists(dict_file)
self.norm = norm self.norm = norm
self.directed = directed self.directed = directed

View File

@ -1,3 +1,4 @@
import copy
import os import os
import warnings import warnings
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
@ -5,288 +6,19 @@ from pathlib import Path
import mmcv import mmcv
import numpy as np import numpy as np
import torch
from mmcv.image.misc import tensor2imgs
from mmcv.runner import load_checkpoint
from mmcv.utils.config import Config
from mmdet.apis import init_detector from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.kie_dataset import KIEDataset
from mmocr.datasets.pipelines.crop import crop_img from mmocr.datasets.pipelines.crop import crop_img
from mmocr.models import build_detector
from mmocr.utils.box_util import stitch_boxes_into_lines from mmocr.utils.box_util import stitch_boxes_into_lines
from mmocr.utils.fileio import list_from_file
textdet_models = {
'DB_r18': {
'config': 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
},
'DB_r50': {
'config':
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
},
'DRRG': {
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
},
'FCE_IC15': {
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
},
'FCE_CTW_DCNv2': {
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
},
'MaskRCNN_CTW': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
'ckpt': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
},
'MaskRCNN_IC15': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
'ckpt':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
},
'MaskRCNN_IC17': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
'ckpt':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
},
'PANet_CTW': {
'config': 'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
'ckpt':
'panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
},
'PANet_IC15': {
'config': 'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
'ckpt': 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
},
'PS_IC15': {
'config': 'psenet/psenet_r50_fpnf_600e_icdar2015.py',
'ckpt': 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
},
'TextSnake': {
'config': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
}
}
textrecog_models = {
'CRNN': {
'config': 'crnn/crnn_academic_dataset.py',
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
},
'SAR': {
'config': 'sar/sar_r31_parallel_decoder_academic.py',
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt': 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',
'ckpt': 'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
}
}
# Post processing function for end2end ocr
def det_recog_pp(args, result):
final_results = []
for arr, output, export, det_recog_result in zip(args.arrays, args.output,
args.export, result):
if output or args.imshow:
res_img = det_recog_show_result(
arr, det_recog_result, out_file=output)
if args.imshow:
mmcv.imshow(res_img, 'inference results')
if not args.details:
simple_res = {}
simple_res['filename'] = det_recog_result['filename']
simple_res['text'] = [
x['text'] for x in det_recog_result['result']
]
final_result = simple_res
else:
final_result = det_recog_result
if export:
mmcv.dump(final_result, export, indent=4)
if args.print_result:
print(final_result, end='\n\n')
final_results.append(final_result)
return final_results
# Post processing function for separate det/recog inference
def single_pp(args, result, model):
for arr, output, export, res in zip(args.arrays, args.output, args.export,
result):
if export:
mmcv.dump(res, export, indent=4)
if output or args.imshow:
res_img = model.show_result(arr, res, out_file=output)
if args.imshow:
mmcv.imshow(res_img, 'inference results')
if args.print_result:
print(res, end='\n\n')
return result
# End2end ocr inference pipeline
def det_and_recog_inference(args, det_model, recog_model):
end2end_res = []
# Find bounding boxes in the images (text detection)
det_result = single_inference(det_model, args.arrays, args.batch_mode,
args.det_batch_size)
bboxes_list = [res['boundary_result'] for res in det_result]
# For each bounding box, the image is cropped and sent to the recognition
# model either one by one or all together depending on the batch_mode
for filename, arr, bboxes in zip(args.filenames, args.arrays, bboxes_list):
img_e2e_res = {}
img_e2e_res['filename'] = filename
img_e2e_res['result'] = []
box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
box_img = crop_img(arr, box)
if args.batch_mode:
box_imgs.append(box_img)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
img_e2e_res['result'].append(box_res)
if args.batch_mode:
recog_results = single_inference(recog_model, box_imgs, True,
args.recog_batch_size)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, (list, tuple)):
text_score = sum(text_score) / max(1, len(text))
img_e2e_res['result'][i]['text'] = text
img_e2e_res['result'][i]['text_score'] = text_score
if args.merge:
img_e2e_res['result'] = stitch_boxes_into_lines(
img_e2e_res['result'], args.merge_xdist, 0.5)
end2end_res.append(img_e2e_res)
return end2end_res
# Separate det/recog inference pipeline
def single_inference(model, arrays, batch_mode, batch_size):
result = []
if batch_mode:
if batch_size == 0:
result = model_inference(model, arrays, batch_mode=True)
else:
n = batch_size
arr_chunks = [arrays[i:i + n] for i in range(0, len(arrays), n)]
for chunk in arr_chunks:
result.extend(model_inference(model, chunk, batch_mode=True))
else:
for arr in arrays:
result.append(model_inference(model, arr, batch_mode=False))
return result
# Arguments pre-processing function
def args_processing(args):
# Check if the input is a list/tuple that
# contains only np arrays or strings
if isinstance(args.img, (list, tuple)):
img_list = args.img
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
raise AssertionError('Images must be strings or numpy arrays')
# Create a list of the images
if isinstance(args.img, str):
img_path = Path(args.img)
if img_path.is_dir():
img_list = [str(x) for x in img_path.glob('*')]
else:
img_list = [str(img_path)]
elif isinstance(args.img, np.ndarray):
img_list = [args.img]
# Read all image(s) in advance to reduce wasted time
# re-reading the images for vizualisation output
args.arrays = [mmcv.imread(x) for x in img_list]
# Create a list of filenames (used for output images and result files)
if isinstance(img_list[0], str):
args.filenames = [str(Path(x).stem) for x in img_list]
else:
args.filenames = [str(x) for x in range(len(img_list))]
# If given an output argument, create a list of output image filenames
num_res = len(img_list)
if args.output:
output_path = Path(args.output)
if output_path.is_dir():
args.output = [
str(output_path / f'out_{x}.png') for x in args.filenames
]
else:
args.output = [str(args.output)]
if args.batch_mode:
raise AssertionError(
'Output of multiple images inference must be a directory')
else:
args.output = [None] * num_res
# If given an export argument, create a list of
# result filenames for each image
if args.export:
export_path = Path(args.export)
args.export = [
str(export_path / f'out_{x}.{args.export_format}')
for x in args.filenames
]
else:
args.export = [None] * num_res
return args
# Create an inference pipeline with parsed arguments
def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))
# Parse CLI arguments # Parse CLI arguments
@ -333,6 +65,23 @@ def parse_args():
default='', default='',
help='Path to the custom checkpoint file of the selected recog model. ' help='Path to the custom checkpoint file of the selected recog model. '
'It overrides the settings in recog') 'It overrides the settings in recog')
parser.add_argument(
'--kie',
type=str,
default='',
help='Pretrained key information extraction algorithm')
parser.add_argument(
'--kie-config',
type=str,
default='',
help='Path to the custom config file of the selected kie model. It'
'overrides the settings in kie')
parser.add_argument(
'--kie-ckpt',
type=str,
default='',
help='Path to the custom checkpoint file of the selected kie model. '
'It overrides the settings in kie')
parser.add_argument( parser.add_argument(
'--config-dir', '--config-dir',
type=str, type=str,
@ -418,11 +167,138 @@ class MMOCR:
recog='SEG', recog='SEG',
recog_config='', recog_config='',
recog_ckpt='', recog_ckpt='',
kie='',
kie_config='',
kie_ckpt='',
config_dir=os.path.join(str(Path.cwd()), 'configs/'), config_dir=os.path.join(str(Path.cwd()), 'configs/'),
device='cuda:0', device='cuda:0',
**kwargs): **kwargs):
textdet_models = {
'DB_r18': {
'config':
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
},
'DB_r50': {
'config':
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
},
'DRRG': {
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
},
'FCE_IC15': {
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
},
'FCE_CTW_DCNv2': {
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
},
'MaskRCNN_CTW': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
},
'MaskRCNN_IC15': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
},
'MaskRCNN_IC17': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
},
'PANet_CTW': {
'config':
'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
},
'PANet_IC15': {
'config':
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
'ckpt':
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
},
'PS_IC15': {
'config':
'psenet/psenet_r50_fpnf_600e_icdar2015.py',
'ckpt':
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
},
'TextSnake': {
'config':
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
'ckpt':
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
}
}
textrecog_models = {
'CRNN': {
'config': 'crnn/crnn_academic_dataset.py',
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
},
'SAR': {
'config': 'sar/sar_r31_parallel_decoder_academic.py',
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt':
'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',
'ckpt':
'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
}
}
kie_models = {
'SDMGR': {
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
'ckpt':
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
}
}
self.td = det self.td = det
self.tr = recog self.tr = recog
self.kie = kie
self.device = device self.device = device
# Check if the det/recog model choice is valid # Check if the det/recog model choice is valid
@ -432,7 +308,12 @@ class MMOCR:
elif self.tr and self.tr not in textrecog_models: elif self.tr and self.tr not in textrecog_models:
raise ValueError(self.tr, raise ValueError(self.tr,
'is not a supported text recognition algorithm') 'is not a supported text recognition algorithm')
elif self.kie and self.kie not in kie_models:
raise ValueError(
self.kie, 'is not a supported key information extraction'
' algorithm')
self.detect_model = None
if self.td: if self.td:
# Build detection model # Build detection model
if not det_config: if not det_config:
@ -444,9 +325,8 @@ class MMOCR:
self.detect_model = init_detector( self.detect_model = init_detector(
det_config, det_ckpt, device=self.device) det_config, det_ckpt, device=self.device)
else:
self.detect_model = None
self.recog_model = None
if self.tr: if self.tr:
# Build recognition model # Build recognition model
if not recog_config: if not recog_config:
@ -454,13 +334,27 @@ class MMOCR:
config_dir, 'textrecog/', config_dir, 'textrecog/',
textrecog_models[self.tr]['config']) textrecog_models[self.tr]['config'])
if not recog_ckpt: if not recog_ckpt:
recog_ckpt = 'https://download.openmmlab.com/mmocr/' recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'textrecog/' + textrecog_models[self.tr]['ckpt'] 'textrecog/' + textrecog_models[self.tr]['ckpt']
self.recog_model = init_detector( self.recog_model = init_detector(
recog_config, recog_ckpt, device=self.device) recog_config, recog_ckpt, device=self.device)
else:
self.recog_model = None self.kie_model = None
if self.kie:
# Build key information extraction model
if not kie_config:
kie_config = os.path.join(config_dir, 'kie/',
kie_models[self.kie]['config'])
if not kie_ckpt:
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'kie/' + kie_models[self.kie]['ckpt']
kie_cfg = Config.fromfile(kie_config)
self.kie_model = build_detector(
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
self.kie_model.cfg = kie_cfg
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
# Attribute check # Attribute check
for model in list(filter(None, [self.recog_model, self.detect_model])): for model in list(filter(None, [self.recog_model, self.detect_model])):
@ -490,25 +384,292 @@ class MMOCR:
args = Namespace(**args) args = Namespace(**args)
# Input and output arguments processing # Input and output arguments processing
args = args_processing(args) self._args_processing(args)
self.args = args
pp_result = None pp_result = None
# Send args and models to the MMOCR model inference API # Send args and models to the MMOCR model inference API
# and call post-processing functions for the output # and call post-processing functions for the output
if self.detect_model and self.recog_model: if self.detect_model and self.recog_model:
det_recog_result = det_and_recog_inference(args, self.detect_model, det_recog_result = self.det_recog_kie_inference(
self.recog_model) self.detect_model, self.recog_model, kie_model=self.kie_model)
pp_result = det_recog_pp(args, det_recog_result) pp_result = self.det_recog_pp(det_recog_result)
else: else:
for model in list( for model in list(
filter(None, [self.recog_model, self.detect_model])): filter(None, [self.recog_model, self.detect_model])):
result = single_inference(model, args.arrays, args.batch_mode, result = self.single_inference(model, args.arrays,
args.single_batch_size) args.batch_mode,
pp_result = single_pp(args, result, model) args.single_batch_size)
pp_result = self.single_pp(args, result, model)
return pp_result return pp_result
# Post processing function for end2end ocr
def det_recog_pp(self, result):
final_results = []
args = self.args
for arr, output, export, det_recog_result in zip(
args.arrays, args.output, args.export, result):
if output or args.imshow:
if self.kie_model:
res_img = det_recog_show_result(arr, det_recog_result)
else:
res_img = det_recog_show_result(
arr, det_recog_result, out_file=output)
if args.imshow and not self.kie_model:
mmcv.imshow(res_img, 'inference results')
if not args.details:
simple_res = {}
simple_res['filename'] = det_recog_result['filename']
simple_res['text'] = [
x['text'] for x in det_recog_result['result']
]
final_result = simple_res
else:
final_result = det_recog_result
if export:
mmcv.dump(final_result, export, indent=4)
if args.print_result:
print(final_result, end='\n\n')
final_results.append(final_result)
return final_results
# Post processing function for separate det/recog inference
def single_pp(self, result, model):
for arr, output, export, res in zip(self.args.arrays, self.args.output,
self.args.export, result):
if export:
mmcv.dump(res, export, indent=4)
if output or self.args.imshow:
res_img = model.show_result(arr, res, out_file=output)
if self.args.imshow:
mmcv.imshow(res_img, 'inference results')
if self.args.print_result:
print(res, end='\n\n')
return result
def generate_kie_labels(self, result, boxes, class_list):
idx_to_cls = {}
if class_list is not None:
for line in list_from_file(class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[class_idx] = class_label
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
labels = []
for i in range(len(boxes)):
pred_label = str(node_pred_label[i])
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = node_pred_score[i]
labels.append((pred_label, pred_score))
return labels
def visualize_kie_output(self,
model,
data,
result,
out_file=None,
show=False):
"""Visualizes KIE output."""
img_tensor = data['img'].data
img_meta = data['img_metas'].data
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
img = tensor2imgs(img_tensor.unsqueeze(0),
**img_meta['img_norm_cfg'])[0]
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
model.show_result(
img_show, result, gt_bboxes, show=show, out_file=out_file)
# End2end ocr inference pipeline
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
end2end_res = []
# Find bounding boxes in the images (text detection)
det_result = self.single_inference(det_model, self.args.arrays,
self.args.batch_mode,
self.args.det_batch_size)
bboxes_list = [res['boundary_result'] for res in det_result]
if kie_model:
kie_dataset = KIEDataset(
dict_file=kie_model.cfg.data.test.dict_file)
# For each bounding box, the image is cropped and
# sent to the recognition model either one by one
# or all together depending on the batch_mode
for filename, arr, bboxes, out_file in zip(self.args.filenames,
self.args.arrays,
bboxes_list,
self.args.output):
img_e2e_res = {}
img_e2e_res['filename'] = filename
img_e2e_res['result'] = []
box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
]
box_img = crop_img(arr, box)
if self.args.batch_mode:
box_imgs.append(box_img)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
img_e2e_res['result'].append(box_res)
if self.args.batch_mode:
recog_results = self.single_inference(
recog_model, box_imgs, True, self.args.recog_batch_size)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, (list, tuple)):
text_score = sum(text_score) / max(1, len(text))
img_e2e_res['result'][i]['text'] = text
img_e2e_res['result'][i]['text_score'] = text_score
if self.args.merge:
img_e2e_res['result'] = stitch_boxes_into_lines(
img_e2e_res['result'], self.args.merge_xdist, 0.5)
if kie_model:
annotations = copy.deepcopy(img_e2e_res['result'])
# Customized for kie_dataset, which
# assumes that boxes are represented by only 4 points
for i, ann in enumerate(annotations):
min_x = min(ann['box'][::2])
min_y = min(ann['box'][1::2])
max_x = max(ann['box'][::2])
max_y = max(ann['box'][1::2])
annotations[i]['box'] = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
]
ann_info = kie_dataset._parse_anno_info(annotations)
kie_result, data = model_inference(
kie_model,
arr,
ann=ann_info,
return_data=True,
batch_mode=self.args.batch_mode)
# visualize KIE results
self.visualize_kie_output(
kie_model,
data,
kie_result,
out_file=out_file,
show=self.args.imshow)
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
labels = self.generate_kie_labels(kie_result, gt_bboxes,
kie_model.class_list)
for i in range(len(gt_bboxes)):
img_e2e_res['result'][i]['label'] = labels[i][0]
img_e2e_res['result'][i]['label_score'] = labels[i][1]
end2end_res.append(img_e2e_res)
return end2end_res
# Separate det/recog inference pipeline
def single_inference(self, model, arrays, batch_mode, batch_size):
result = []
if batch_mode:
if batch_size == 0:
result = model_inference(model, arrays, batch_mode=True)
else:
n = batch_size
arr_chunks = [
arrays[i:i + n] for i in range(0, len(arrays), n)
]
for chunk in arr_chunks:
result.extend(
model_inference(model, chunk, batch_mode=True))
else:
for arr in arrays:
result.append(model_inference(model, arr, batch_mode=False))
return result
# Arguments pre-processing function
def _args_processing(self, args):
# Check if the input is a list/tuple that
# contains only np arrays or strings
if isinstance(args.img, (list, tuple)):
img_list = args.img
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
raise AssertionError('Images must be strings or numpy arrays')
# Create a list of the images
if isinstance(args.img, str):
img_path = Path(args.img)
if img_path.is_dir():
img_list = [str(x) for x in img_path.glob('*')]
else:
img_list = [str(img_path)]
elif isinstance(args.img, np.ndarray):
img_list = [args.img]
# Read all image(s) in advance to reduce wasted time
# re-reading the images for vizualisation output
args.arrays = [mmcv.imread(x) for x in img_list]
# Create a list of filenames (used for output images and result files)
if isinstance(img_list[0], str):
args.filenames = [str(Path(x).stem) for x in img_list]
else:
args.filenames = [str(x) for x in range(len(img_list))]
# If given an output argument, create a list of output image filenames
num_res = len(img_list)
if args.output:
output_path = Path(args.output)
if output_path.is_dir():
args.output = [
str(output_path / f'out_{x}.png') for x in args.filenames
]
else:
args.output = [str(args.output)]
if args.batch_mode:
raise AssertionError('Output of multiple images inference'
' must be a directory')
else:
args.output = [None] * num_res
# If given an export argument, create a list of
# result filenames for each image
if args.export:
export_path = Path(args.export)
args.export = [
str(export_path / f'out_{x}.{args.export_format}')
for x in args.filenames
]
else:
args.export = [None] * num_res
return args
# Create an inference pipeline with parsed arguments
def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))
if __name__ == '__main__': if __name__ == '__main__':
main() main()