Add kie image demo and docs. (#374)

* Add kie_image_demo.py

* Add kie demo docs

* Add brief instructions for kie image demo

* Add ann file field and return data for kie image demo

* Follow lint and import rules

* Fix bugs, reuse functions in KIEDataset, and use a new demo pic

* Add config-dir and fix indexing bug in ocr script

* [Feature] Improve ocr.py

1. Add box stitching back to ocr.py
2. Add config_dir which allows users to specify the default config path
3. Warn users when overriding parameters are set
4. Allow users to use customized checkpoint files

* Add docs for new ocr.py

* Add docs for merge

* Support kie in ocr.py

* Merged kie to ocr.py

* update docs, remove unsupported unvisual sdmgr

* Update mmocr/apis/inference.py

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>

* Apply suggestions from code review

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>

* fix linting

Co-authored-by: gaotongxiao <gaotongxiao@gmail.com>
Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
pull/395/head^2
Omkar Manjrekar 2021-08-04 11:50:13 +05:30 committed by GitHub
parent f24be6c614
commit 68ec3f5519
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 592 additions and 318 deletions

View File

@ -26,5 +26,15 @@ Please refer to [End2End Demo](docs/ocr_demo.md) for the tutorial of Text Detect
<img src="resources/demo_ocr_pred.jpg"/><br>
</div>
<br>
<br>
Please refer to [KIE End2End Demo](docs/kie_demo.md) for the tutorial of KIE end-to-end demo.
<div align="center">
<img src="resources/demo_kie_pred.jpeg"/><br>
</div>
<br>

BIN
demo/demo_kie.jpeg 100755

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 634 KiB

View File

@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch
```shell
python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow
```
*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.*
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface:
```python
@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True)
```
---
## Example 4: Text Detection + Recognition + Key Information Extraction
<div align="center">
<img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br>
</div>
<br>
**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization.
- CL interface:
```shell
python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow
```
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface:
```python
from mmocr.utils.ocr import MMOCR
# Load models into memory
ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
# Inference
results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
```
---
## API Arguments
The API has an extensive list of arguments that you can use. The following tables are for the python interface.
@ -91,16 +118,21 @@ The API has an extensive list of arguments that you can use. The following table
| Arguments | Type | Default | Description |
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
| `recog` | see [models](#models) | SAR | Text recognition algorithm |
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
| `config_dir` | str | configs/ | Path to the config directory where all the config files are located |
| `det_config` | str | None | Path to the custom config file of the selected det model |
| `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model |
| `recog_config` | str | None | Path to the custom config file of the selected recog model model |
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model |
| `recog_config` | str | None | Path to the custom config file of the selected recog model |
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model |
| `kie_config` | str | None | Path to the custom config file of the selected kie model |
| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model |
| `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' |
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`.
[1]: `kie` is only effective when both text detection and recognition models are specified.
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`.
### readtext():
@ -117,8 +149,8 @@ The API has an extensive list of arguments that you can use. The following table
| `details` | bool | False | Whether include the text boxes coordinates and confidence values |
| `imshow` | bool | False | Whether to show the result visualization on screen |
| `print_result` | bool | False | Whether to show the result for each image |
| `merge` | bool | False | Whether to merge neighboring boxes [2] |
| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes |
| `merge` | bool | False | Whether to merge neighboring boxes [2] |
| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes |
[1]: Make sure that the model is compatible with batch mode.
@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`)
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
**Key information extraction:**
| Name | Reference | `batch_mode` support |
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
---
## Additional info

View File

@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch
```shell
python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow
```
*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.*
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface:
```python
@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True)
```
---
## Example 4: Text Detection + Recognition + Key Information Extraction
<div align="center">
<img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br>
</div>
<br>
**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization.
- CL interface:
```shell
python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow
```
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
- Python interface:
```python
from mmocr.utils.ocr import MMOCR
# Load models into memory
ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
# Inference
results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
```
---
## API Arguments
The API has an extensive list of arguments that you can use. The following tables are for the python interface.
@ -91,16 +118,21 @@ The API has an extensive list of arguments that you can use. The following table
| Arguments | Type | Default | Description |
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
| `recog` | see [models](#models) | SAR | Text recognition algorithm |
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
| `config_dir` | str | configs/ | Path to the config directory where all the config files are located |
| `det_config` | str | None | Path to the custom config file of the selected det model |
| `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model |
| `recog_config` | str | None | Path to the custom config file of the selected recog model model |
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model |
| `recog_config` | str | None | Path to the custom config file of the selected recog model |
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model |
| `kie_config` | str | None | Path to the custom config file of the selected kie model |
| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model |
| `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' |
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`.
[1]: `kie` is only effective when both text detection and recognition models are specified.
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`.
### readtext():
@ -117,8 +149,8 @@ The API has an extensive list of arguments that you can use. The following table
| `details` | bool | False | Whether include the text boxes coordinates and confidence values |
| `imshow` | bool | False | Whether to show the result visualization on screen |
| `print_result` | bool | False | Whether to show the result for each image |
| `merge` | bool | False | Whether to merge neighboring boxes [2] |
| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes |
| `merge` | bool | False | Whether to merge neighboring boxes [2] |
| `merge_xdist` | float | 20 | The maximum x-axis distance to merge boxes |
[1]: Make sure that the model is compatible with batch mode.
@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`)
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
**Key information extraction:**
| Name | Reference | `batch_mode` support |
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
---
## Additional info

View File

@ -30,7 +30,11 @@ def disable_text_recog_aug_test(cfg, set_types=None):
return cfg
def model_inference(model, imgs, batch_mode=False):
def model_inference(model,
imgs,
ann=None,
batch_mode=False,
return_data=False):
"""Inference image(s) with the detector.
Args:
@ -38,6 +42,8 @@ def model_inference(model, imgs, batch_mode=False):
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
batch_mode (bool): If True, use batch mode for inference.
ann (dict): Annotation info for key information extraction.
return_data: Return postprocessed data.
Returns:
result (dict): Predicted results.
"""
@ -75,10 +81,14 @@ def model_inference(model, imgs, batch_mode=False):
# prepare data
if is_ndarray:
# directly add img
data = dict(img=img)
data = dict(img=img, ann_info=ann, bbox_fields=[])
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
data = dict(
img_info=dict(filename=img),
img_prefix=None,
ann_info=ann,
bbox_fields=[])
# build the data pipeline
data = test_pipeline(data)
@ -111,6 +121,14 @@ def model_inference(model, imgs, batch_mode=False):
else:
data['img'] = data['img'].data
# for KIE models
if ann is not None:
data['relations'] = data['relations'].data[0]
data['gt_bboxes'] = data['gt_bboxes'].data[0]
data['texts'] = data['texts'].data[0]
data['img'] = data['img'][0]
data['img_metas'] = data['img_metas'][0]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
@ -125,9 +143,13 @@ def model_inference(model, imgs, batch_mode=False):
results = model(return_loss=False, rescale=True, **data)
if not is_batch:
return results[0]
if not return_data:
return results[0]
return results[0], datas[0]
else:
return results
if not return_data:
return results
return results, datas
def text_model_inference(model, input_sentence):

View File

@ -1,4 +1,5 @@
import copy
import warnings
from os import path as osp
import numpy as np
@ -28,22 +29,28 @@ class KIEDataset(BaseDataset):
"""
def __init__(self,
ann_file,
loader,
dict_file,
ann_file=None,
loader=None,
dict_file=None,
img_prefix='',
pipeline=None,
norm=10.,
directed=False,
test_mode=True,
**kwargs):
super().__init__(
ann_file,
loader,
pipeline,
img_prefix=img_prefix,
test_mode=test_mode)
assert osp.exists(dict_file)
if ann_file is None and loader is None:
warnings.warn(
'KIEDataset is only initialized as a downstream demo task '
'of text detection and recognition '
'without an annotation file.', UserWarning)
else:
super().__init__(
ann_file,
loader,
pipeline,
img_prefix=img_prefix,
test_mode=test_mode)
assert osp.exists(dict_file)
self.norm = norm
self.directed = directed

View File

@ -1,3 +1,4 @@
import copy
import os
import warnings
from argparse import ArgumentParser, Namespace
@ -5,288 +6,19 @@ from pathlib import Path
import mmcv
import numpy as np
import torch
from mmcv.image.misc import tensor2imgs
from mmcv.runner import load_checkpoint
from mmcv.utils.config import Config
from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.kie_dataset import KIEDataset
from mmocr.datasets.pipelines.crop import crop_img
from mmocr.models import build_detector
from mmocr.utils.box_util import stitch_boxes_into_lines
textdet_models = {
'DB_r18': {
'config': 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
},
'DB_r50': {
'config':
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
},
'DRRG': {
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
},
'FCE_IC15': {
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
},
'FCE_CTW_DCNv2': {
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
},
'MaskRCNN_CTW': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
'ckpt': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
},
'MaskRCNN_IC15': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
'ckpt':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
},
'MaskRCNN_IC17': {
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
'ckpt':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
},
'PANet_CTW': {
'config': 'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
'ckpt':
'panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
},
'PANet_IC15': {
'config': 'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
'ckpt': 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
},
'PS_IC15': {
'config': 'psenet/psenet_r50_fpnf_600e_icdar2015.py',
'ckpt': 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
},
'TextSnake': {
'config': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
}
}
textrecog_models = {
'CRNN': {
'config': 'crnn/crnn_academic_dataset.py',
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
},
'SAR': {
'config': 'sar/sar_r31_parallel_decoder_academic.py',
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt': 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',
'ckpt': 'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
}
}
# Post processing function for end2end ocr
def det_recog_pp(args, result):
final_results = []
for arr, output, export, det_recog_result in zip(args.arrays, args.output,
args.export, result):
if output or args.imshow:
res_img = det_recog_show_result(
arr, det_recog_result, out_file=output)
if args.imshow:
mmcv.imshow(res_img, 'inference results')
if not args.details:
simple_res = {}
simple_res['filename'] = det_recog_result['filename']
simple_res['text'] = [
x['text'] for x in det_recog_result['result']
]
final_result = simple_res
else:
final_result = det_recog_result
if export:
mmcv.dump(final_result, export, indent=4)
if args.print_result:
print(final_result, end='\n\n')
final_results.append(final_result)
return final_results
# Post processing function for separate det/recog inference
def single_pp(args, result, model):
for arr, output, export, res in zip(args.arrays, args.output, args.export,
result):
if export:
mmcv.dump(res, export, indent=4)
if output or args.imshow:
res_img = model.show_result(arr, res, out_file=output)
if args.imshow:
mmcv.imshow(res_img, 'inference results')
if args.print_result:
print(res, end='\n\n')
return result
# End2end ocr inference pipeline
def det_and_recog_inference(args, det_model, recog_model):
end2end_res = []
# Find bounding boxes in the images (text detection)
det_result = single_inference(det_model, args.arrays, args.batch_mode,
args.det_batch_size)
bboxes_list = [res['boundary_result'] for res in det_result]
# For each bounding box, the image is cropped and sent to the recognition
# model either one by one or all together depending on the batch_mode
for filename, arr, bboxes in zip(args.filenames, args.arrays, bboxes_list):
img_e2e_res = {}
img_e2e_res['filename'] = filename
img_e2e_res['result'] = []
box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
box_img = crop_img(arr, box)
if args.batch_mode:
box_imgs.append(box_img)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
img_e2e_res['result'].append(box_res)
if args.batch_mode:
recog_results = single_inference(recog_model, box_imgs, True,
args.recog_batch_size)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, (list, tuple)):
text_score = sum(text_score) / max(1, len(text))
img_e2e_res['result'][i]['text'] = text
img_e2e_res['result'][i]['text_score'] = text_score
if args.merge:
img_e2e_res['result'] = stitch_boxes_into_lines(
img_e2e_res['result'], args.merge_xdist, 0.5)
end2end_res.append(img_e2e_res)
return end2end_res
# Separate det/recog inference pipeline
def single_inference(model, arrays, batch_mode, batch_size):
result = []
if batch_mode:
if batch_size == 0:
result = model_inference(model, arrays, batch_mode=True)
else:
n = batch_size
arr_chunks = [arrays[i:i + n] for i in range(0, len(arrays), n)]
for chunk in arr_chunks:
result.extend(model_inference(model, chunk, batch_mode=True))
else:
for arr in arrays:
result.append(model_inference(model, arr, batch_mode=False))
return result
# Arguments pre-processing function
def args_processing(args):
# Check if the input is a list/tuple that
# contains only np arrays or strings
if isinstance(args.img, (list, tuple)):
img_list = args.img
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
raise AssertionError('Images must be strings or numpy arrays')
# Create a list of the images
if isinstance(args.img, str):
img_path = Path(args.img)
if img_path.is_dir():
img_list = [str(x) for x in img_path.glob('*')]
else:
img_list = [str(img_path)]
elif isinstance(args.img, np.ndarray):
img_list = [args.img]
# Read all image(s) in advance to reduce wasted time
# re-reading the images for vizualisation output
args.arrays = [mmcv.imread(x) for x in img_list]
# Create a list of filenames (used for output images and result files)
if isinstance(img_list[0], str):
args.filenames = [str(Path(x).stem) for x in img_list]
else:
args.filenames = [str(x) for x in range(len(img_list))]
# If given an output argument, create a list of output image filenames
num_res = len(img_list)
if args.output:
output_path = Path(args.output)
if output_path.is_dir():
args.output = [
str(output_path / f'out_{x}.png') for x in args.filenames
]
else:
args.output = [str(args.output)]
if args.batch_mode:
raise AssertionError(
'Output of multiple images inference must be a directory')
else:
args.output = [None] * num_res
# If given an export argument, create a list of
# result filenames for each image
if args.export:
export_path = Path(args.export)
args.export = [
str(export_path / f'out_{x}.{args.export_format}')
for x in args.filenames
]
else:
args.export = [None] * num_res
return args
# Create an inference pipeline with parsed arguments
def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))
from mmocr.utils.fileio import list_from_file
# Parse CLI arguments
@ -333,6 +65,23 @@ def parse_args():
default='',
help='Path to the custom checkpoint file of the selected recog model. '
'It overrides the settings in recog')
parser.add_argument(
'--kie',
type=str,
default='',
help='Pretrained key information extraction algorithm')
parser.add_argument(
'--kie-config',
type=str,
default='',
help='Path to the custom config file of the selected kie model. It'
'overrides the settings in kie')
parser.add_argument(
'--kie-ckpt',
type=str,
default='',
help='Path to the custom checkpoint file of the selected kie model. '
'It overrides the settings in kie')
parser.add_argument(
'--config-dir',
type=str,
@ -418,11 +167,138 @@ class MMOCR:
recog='SEG',
recog_config='',
recog_ckpt='',
kie='',
kie_config='',
kie_ckpt='',
config_dir=os.path.join(str(Path.cwd()), 'configs/'),
device='cuda:0',
**kwargs):
textdet_models = {
'DB_r18': {
'config':
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
},
'DB_r50': {
'config':
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
'ckpt':
'dbnet/'
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
},
'DRRG': {
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
},
'FCE_IC15': {
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
},
'FCE_CTW_DCNv2': {
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
},
'MaskRCNN_CTW': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
},
'MaskRCNN_IC15': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
},
'MaskRCNN_IC17': {
'config':
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
'ckpt':
'maskrcnn/'
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
},
'PANet_CTW': {
'config':
'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
},
'PANet_IC15': {
'config':
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
'ckpt':
'panet/'
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
},
'PS_CTW': {
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
'ckpt':
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
},
'PS_IC15': {
'config':
'psenet/psenet_r50_fpnf_600e_icdar2015.py',
'ckpt':
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
},
'TextSnake': {
'config':
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
'ckpt':
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
}
}
textrecog_models = {
'CRNN': {
'config': 'crnn/crnn_academic_dataset.py',
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
},
'SAR': {
'config': 'sar/sar_r31_parallel_decoder_academic.py',
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
},
'NRTR_1/16-1/8': {
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
},
'NRTR_1/8-1/4': {
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
'ckpt':
'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
},
'RobustScanner': {
'config': 'robust_scanner/robustscanner_r31_academic.py',
'ckpt':
'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
},
'SEG': {
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
},
'CRNN_TPS': {
'config': 'tps/crnn_tps_academic_dataset.py',
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
}
}
kie_models = {
'SDMGR': {
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
'ckpt':
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
}
}
self.td = det
self.tr = recog
self.kie = kie
self.device = device
# Check if the det/recog model choice is valid
@ -432,7 +308,12 @@ class MMOCR:
elif self.tr and self.tr not in textrecog_models:
raise ValueError(self.tr,
'is not a supported text recognition algorithm')
elif self.kie and self.kie not in kie_models:
raise ValueError(
self.kie, 'is not a supported key information extraction'
' algorithm')
self.detect_model = None
if self.td:
# Build detection model
if not det_config:
@ -444,9 +325,8 @@ class MMOCR:
self.detect_model = init_detector(
det_config, det_ckpt, device=self.device)
else:
self.detect_model = None
self.recog_model = None
if self.tr:
# Build recognition model
if not recog_config:
@ -454,13 +334,27 @@ class MMOCR:
config_dir, 'textrecog/',
textrecog_models[self.tr]['config'])
if not recog_ckpt:
recog_ckpt = 'https://download.openmmlab.com/mmocr/'
'textrecog/' + textrecog_models[self.tr]['ckpt']
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'textrecog/' + textrecog_models[self.tr]['ckpt']
self.recog_model = init_detector(
recog_config, recog_ckpt, device=self.device)
else:
self.recog_model = None
self.kie_model = None
if self.kie:
# Build key information extraction model
if not kie_config:
kie_config = os.path.join(config_dir, 'kie/',
kie_models[self.kie]['config'])
if not kie_ckpt:
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
'kie/' + kie_models[self.kie]['ckpt']
kie_cfg = Config.fromfile(kie_config)
self.kie_model = build_detector(
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
self.kie_model.cfg = kie_cfg
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
# Attribute check
for model in list(filter(None, [self.recog_model, self.detect_model])):
@ -490,25 +384,292 @@ class MMOCR:
args = Namespace(**args)
# Input and output arguments processing
args = args_processing(args)
self._args_processing(args)
self.args = args
pp_result = None
# Send args and models to the MMOCR model inference API
# and call post-processing functions for the output
if self.detect_model and self.recog_model:
det_recog_result = det_and_recog_inference(args, self.detect_model,
self.recog_model)
pp_result = det_recog_pp(args, det_recog_result)
det_recog_result = self.det_recog_kie_inference(
self.detect_model, self.recog_model, kie_model=self.kie_model)
pp_result = self.det_recog_pp(det_recog_result)
else:
for model in list(
filter(None, [self.recog_model, self.detect_model])):
result = single_inference(model, args.arrays, args.batch_mode,
args.single_batch_size)
pp_result = single_pp(args, result, model)
result = self.single_inference(model, args.arrays,
args.batch_mode,
args.single_batch_size)
pp_result = self.single_pp(args, result, model)
return pp_result
# Post processing function for end2end ocr
def det_recog_pp(self, result):
final_results = []
args = self.args
for arr, output, export, det_recog_result in zip(
args.arrays, args.output, args.export, result):
if output or args.imshow:
if self.kie_model:
res_img = det_recog_show_result(arr, det_recog_result)
else:
res_img = det_recog_show_result(
arr, det_recog_result, out_file=output)
if args.imshow and not self.kie_model:
mmcv.imshow(res_img, 'inference results')
if not args.details:
simple_res = {}
simple_res['filename'] = det_recog_result['filename']
simple_res['text'] = [
x['text'] for x in det_recog_result['result']
]
final_result = simple_res
else:
final_result = det_recog_result
if export:
mmcv.dump(final_result, export, indent=4)
if args.print_result:
print(final_result, end='\n\n')
final_results.append(final_result)
return final_results
# Post processing function for separate det/recog inference
def single_pp(self, result, model):
for arr, output, export, res in zip(self.args.arrays, self.args.output,
self.args.export, result):
if export:
mmcv.dump(res, export, indent=4)
if output or self.args.imshow:
res_img = model.show_result(arr, res, out_file=output)
if self.args.imshow:
mmcv.imshow(res_img, 'inference results')
if self.args.print_result:
print(res, end='\n\n')
return result
def generate_kie_labels(self, result, boxes, class_list):
idx_to_cls = {}
if class_list is not None:
for line in list_from_file(class_list):
class_idx, class_label = line.strip().split()
idx_to_cls[class_idx] = class_label
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
labels = []
for i in range(len(boxes)):
pred_label = str(node_pred_label[i])
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = node_pred_score[i]
labels.append((pred_label, pred_score))
return labels
def visualize_kie_output(self,
model,
data,
result,
out_file=None,
show=False):
"""Visualizes KIE output."""
img_tensor = data['img'].data
img_meta = data['img_metas'].data
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
img = tensor2imgs(img_tensor.unsqueeze(0),
**img_meta['img_norm_cfg'])[0]
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
model.show_result(
img_show, result, gt_bboxes, show=show, out_file=out_file)
# End2end ocr inference pipeline
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
end2end_res = []
# Find bounding boxes in the images (text detection)
det_result = self.single_inference(det_model, self.args.arrays,
self.args.batch_mode,
self.args.det_batch_size)
bboxes_list = [res['boundary_result'] for res in det_result]
if kie_model:
kie_dataset = KIEDataset(
dict_file=kie_model.cfg.data.test.dict_file)
# For each bounding box, the image is cropped and
# sent to the recognition model either one by one
# or all together depending on the batch_mode
for filename, arr, bboxes, out_file in zip(self.args.filenames,
self.args.arrays,
bboxes_list,
self.args.output):
img_e2e_res = {}
img_e2e_res['filename'] = filename
img_e2e_res['result'] = []
box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
]
box_img = crop_img(arr, box)
if self.args.batch_mode:
box_imgs.append(box_img)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
img_e2e_res['result'].append(box_res)
if self.args.batch_mode:
recog_results = self.single_inference(
recog_model, box_imgs, True, self.args.recog_batch_size)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, (list, tuple)):
text_score = sum(text_score) / max(1, len(text))
img_e2e_res['result'][i]['text'] = text
img_e2e_res['result'][i]['text_score'] = text_score
if self.args.merge:
img_e2e_res['result'] = stitch_boxes_into_lines(
img_e2e_res['result'], self.args.merge_xdist, 0.5)
if kie_model:
annotations = copy.deepcopy(img_e2e_res['result'])
# Customized for kie_dataset, which
# assumes that boxes are represented by only 4 points
for i, ann in enumerate(annotations):
min_x = min(ann['box'][::2])
min_y = min(ann['box'][1::2])
max_x = max(ann['box'][::2])
max_y = max(ann['box'][1::2])
annotations[i]['box'] = [
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
]
ann_info = kie_dataset._parse_anno_info(annotations)
kie_result, data = model_inference(
kie_model,
arr,
ann=ann_info,
return_data=True,
batch_mode=self.args.batch_mode)
# visualize KIE results
self.visualize_kie_output(
kie_model,
data,
kie_result,
out_file=out_file,
show=self.args.imshow)
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
labels = self.generate_kie_labels(kie_result, gt_bboxes,
kie_model.class_list)
for i in range(len(gt_bboxes)):
img_e2e_res['result'][i]['label'] = labels[i][0]
img_e2e_res['result'][i]['label_score'] = labels[i][1]
end2end_res.append(img_e2e_res)
return end2end_res
# Separate det/recog inference pipeline
def single_inference(self, model, arrays, batch_mode, batch_size):
result = []
if batch_mode:
if batch_size == 0:
result = model_inference(model, arrays, batch_mode=True)
else:
n = batch_size
arr_chunks = [
arrays[i:i + n] for i in range(0, len(arrays), n)
]
for chunk in arr_chunks:
result.extend(
model_inference(model, chunk, batch_mode=True))
else:
for arr in arrays:
result.append(model_inference(model, arr, batch_mode=False))
return result
# Arguments pre-processing function
def _args_processing(self, args):
# Check if the input is a list/tuple that
# contains only np arrays or strings
if isinstance(args.img, (list, tuple)):
img_list = args.img
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
raise AssertionError('Images must be strings or numpy arrays')
# Create a list of the images
if isinstance(args.img, str):
img_path = Path(args.img)
if img_path.is_dir():
img_list = [str(x) for x in img_path.glob('*')]
else:
img_list = [str(img_path)]
elif isinstance(args.img, np.ndarray):
img_list = [args.img]
# Read all image(s) in advance to reduce wasted time
# re-reading the images for vizualisation output
args.arrays = [mmcv.imread(x) for x in img_list]
# Create a list of filenames (used for output images and result files)
if isinstance(img_list[0], str):
args.filenames = [str(Path(x).stem) for x in img_list]
else:
args.filenames = [str(x) for x in range(len(img_list))]
# If given an output argument, create a list of output image filenames
num_res = len(img_list)
if args.output:
output_path = Path(args.output)
if output_path.is_dir():
args.output = [
str(output_path / f'out_{x}.png') for x in args.filenames
]
else:
args.output = [str(args.output)]
if args.batch_mode:
raise AssertionError('Output of multiple images inference'
' must be a directory')
else:
args.output = [None] * num_res
# If given an export argument, create a list of
# result filenames for each image
if args.export:
export_path = Path(args.export)
args.export = [
str(export_path / f'out_{x}.{args.export_format}')
for x in args.filenames
]
else:
args.export = [None] * num_res
return args
# Create an inference pipeline with parsed arguments
def main():
args = parse_args()
ocr = MMOCR(**vars(args))
ocr.readtext(**vars(args))
if __name__ == '__main__':
main()