mirror of https://github.com/open-mmlab/mmocr.git
Add kie image demo and docs. (#374)
* Add kie_image_demo.py * Add kie demo docs * Add brief instructions for kie image demo * Add ann file field and return data for kie image demo * Follow lint and import rules * Fix bugs, reuse functions in KIEDataset, and use a new demo pic * Add config-dir and fix indexing bug in ocr script * [Feature] Improve ocr.py 1. Add box stitching back to ocr.py 2. Add config_dir which allows users to specify the default config path 3. Warn users when overriding parameters are set 4. Allow users to use customized checkpoint files * Add docs for new ocr.py * Add docs for merge * Support kie in ocr.py * Merged kie to ocr.py * update docs, remove unsupported unvisual sdmgr * Update mmocr/apis/inference.py Co-authored-by: Hongbin Sun <hongbin306@gmail.com> * Apply suggestions from code review Co-authored-by: Hongbin Sun <hongbin306@gmail.com> * fix linting Co-authored-by: gaotongxiao <gaotongxiao@gmail.com> Co-authored-by: Hongbin Sun <hongbin306@gmail.com>pull/395/head^2
parent
f24be6c614
commit
68ec3f5519
|
@ -26,5 +26,15 @@ Please refer to [End2End Demo](docs/ocr_demo.md) for the tutorial of Text Detect
|
|||
<img src="resources/demo_ocr_pred.jpg"/><br>
|
||||
|
||||
</div>
|
||||
<br>
|
||||
<br>
|
||||
|
||||
Please refer to [KIE End2End Demo](docs/kie_demo.md) for the tutorial of KIE end-to-end demo.
|
||||
|
||||
<div align="center">
|
||||
<img src="resources/demo_kie_pred.jpeg"/><br>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<br>
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 120 KiB |
Binary file not shown.
After Width: | Height: | Size: 634 KiB |
45
docs/demo.md
45
docs/demo.md
|
@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch
|
|||
```shell
|
||||
python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow
|
||||
```
|
||||
*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.*
|
||||
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
|
||||
|
||||
- Python interface:
|
||||
```python
|
||||
|
@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True)
|
|||
```
|
||||
---
|
||||
|
||||
## Example 4: Text Detection + Recognition + Key Information Extraction
|
||||
|
||||
<div align="center">
|
||||
<img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization.
|
||||
|
||||
- CL interface:
|
||||
```shell
|
||||
python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow
|
||||
```
|
||||
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
|
||||
|
||||
- Python interface:
|
||||
```python
|
||||
from mmocr.utils.ocr import MMOCR
|
||||
|
||||
# Load models into memory
|
||||
ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
|
||||
|
||||
# Inference
|
||||
results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
|
||||
```
|
||||
---
|
||||
|
||||
## API Arguments
|
||||
The API has an extensive list of arguments that you can use. The following tables are for the python interface.
|
||||
|
||||
|
@ -93,14 +120,19 @@ The API has an extensive list of arguments that you can use. The following table
|
|||
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
|
||||
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
|
||||
| `recog` | see [models](#models) | SAR | Text recognition algorithm |
|
||||
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
|
||||
| `config_dir` | str | configs/ | Path to the config directory where all the config files are located |
|
||||
| `det_config` | str | None | Path to the custom config file of the selected det model |
|
||||
| `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model |
|
||||
| `recog_config` | str | None | Path to the custom config file of the selected recog model model |
|
||||
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model |
|
||||
| `recog_config` | str | None | Path to the custom config file of the selected recog model |
|
||||
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model |
|
||||
| `kie_config` | str | None | Path to the custom config file of the selected kie model |
|
||||
| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model |
|
||||
| `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' |
|
||||
|
||||
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`.
|
||||
[1]: `kie` is only effective when both text detection and recognition models are specified.
|
||||
|
||||
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`.
|
||||
|
||||
### readtext():
|
||||
|
||||
|
@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`)
|
|||
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
|
||||
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
|
||||
|
||||
**Key information extraction:**
|
||||
|
||||
| Name | Reference | `batch_mode` support |
|
||||
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
|
||||
---
|
||||
|
||||
## Additional info
|
||||
|
|
|
@ -70,7 +70,7 @@ results = ocr.readtext(%INPUT_FOLDER_PATH%, output = %OUTPUT_FOLDER_PATH%, batch
|
|||
```shell
|
||||
python mmocr/utils/ocr.py demo/demo_text_ocr.jpg --print-result --imshow
|
||||
```
|
||||
*Note: When calling the script from the command line, the `configs` folder must be in the current working directory.*
|
||||
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
|
||||
|
||||
- Python interface:
|
||||
```python
|
||||
|
@ -84,6 +84,33 @@ results = ocr.readtext('demo/demo_text_ocr.jpg', print_result=True, imshow=True)
|
|||
```
|
||||
---
|
||||
|
||||
## Example 4: Text Detection + Recognition + Key Information Extraction
|
||||
|
||||
<div align="center">
|
||||
<img src="https://raw.githubusercontent.com/open-mmlab/mmocr/main/demo/resources/demo_kie_pred.png"/><br>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
**Instruction:** Perform end-to-end ocr (det + recog) inference first with PS_CTW detection model and SAR recognition model, then run KIE inference with SDMGR model on the ocr result and show the visualization.
|
||||
|
||||
- CL interface:
|
||||
```shell
|
||||
python mmocr/utils/ocr.py demo/demo_kie.jpeg --det PS_CTW --recog SAR --kie SDMGR --print-result --imshow
|
||||
```
|
||||
*Note: When calling the script from the command line, the script assumes configs are saved in the `configs/` folder. User can customize the directory by specifying the value of `config_dir`. *
|
||||
|
||||
- Python interface:
|
||||
```python
|
||||
from mmocr.utils.ocr import MMOCR
|
||||
|
||||
# Load models into memory
|
||||
ocr = MMOCR(det='PS_CTW', recog='SAR', kie='SDMGR')
|
||||
|
||||
# Inference
|
||||
results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
|
||||
```
|
||||
---
|
||||
|
||||
## API Arguments
|
||||
The API has an extensive list of arguments that you can use. The following tables are for the python interface.
|
||||
|
||||
|
@ -93,14 +120,19 @@ The API has an extensive list of arguments that you can use. The following table
|
|||
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
|
||||
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
|
||||
| `recog` | see [models](#models) | SAR | Text recognition algorithm |
|
||||
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
|
||||
| `config_dir` | str | configs/ | Path to the config directory where all the config files are located |
|
||||
| `det_config` | str | None | Path to the custom config file of the selected det model |
|
||||
| `det_ckpt` | str | None | Path to the custom checkpoint file of the selected det model |
|
||||
| `recog_config` | str | None | Path to the custom config file of the selected recog model model |
|
||||
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model model |
|
||||
| `recog_config` | str | None | Path to the custom config file of the selected recog model |
|
||||
| `recog_ckpt` | str | None | Path to the custom checkpoint file of the selected recog model |
|
||||
| `kie_config` | str | None | Path to the custom config file of the selected kie model |
|
||||
| `kie_ckpt` | str | None | Path to the custom checkpoint file of the selected kie model |
|
||||
| `device` | str | cuda:0 | Device used for inference: 'cuda:0' or 'cpu' |
|
||||
|
||||
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to setting `*_config` and `*_ckpt` as default values. However, manually specifying `*_config` and `*_ckpt` will always override default values set by `det` and/or `recog`.
|
||||
[1]: `kie` is only effective when both text detection and recognition models are specified.
|
||||
|
||||
**Note:** User can use default pretrained models by specifying `det` and/or `recog`, which is equivalent to specifying their corresponding `*_config` and `*_ckpt`. However, manually specifying `*_config` and `*_ckpt` will always override values set by `det` and/or `recog`. Similar rules also apply to `kie`, `kie_config` and `kie_ckpt`.
|
||||
|
||||
### readtext():
|
||||
|
||||
|
@ -165,6 +197,11 @@ means that `batch_mode` and `print_result` are set to `True`)
|
|||
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
|
||||
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
|
||||
|
||||
**Key information extraction:**
|
||||
|
||||
| Name | Reference | `batch_mode` support |
|
||||
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
|
||||
---
|
||||
|
||||
## Additional info
|
||||
|
|
|
@ -30,7 +30,11 @@ def disable_text_recog_aug_test(cfg, set_types=None):
|
|||
return cfg
|
||||
|
||||
|
||||
def model_inference(model, imgs, batch_mode=False):
|
||||
def model_inference(model,
|
||||
imgs,
|
||||
ann=None,
|
||||
batch_mode=False,
|
||||
return_data=False):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
Args:
|
||||
|
@ -38,6 +42,8 @@ def model_inference(model, imgs, batch_mode=False):
|
|||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
batch_mode (bool): If True, use batch mode for inference.
|
||||
ann (dict): Annotation info for key information extraction.
|
||||
return_data: Return postprocessed data.
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
|
@ -75,10 +81,14 @@ def model_inference(model, imgs, batch_mode=False):
|
|||
# prepare data
|
||||
if is_ndarray:
|
||||
# directly add img
|
||||
data = dict(img=img)
|
||||
data = dict(img=img, ann_info=ann, bbox_fields=[])
|
||||
else:
|
||||
# add information into dict
|
||||
data = dict(img_info=dict(filename=img), img_prefix=None)
|
||||
data = dict(
|
||||
img_info=dict(filename=img),
|
||||
img_prefix=None,
|
||||
ann_info=ann,
|
||||
bbox_fields=[])
|
||||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
|
@ -111,6 +121,14 @@ def model_inference(model, imgs, batch_mode=False):
|
|||
else:
|
||||
data['img'] = data['img'].data
|
||||
|
||||
# for KIE models
|
||||
if ann is not None:
|
||||
data['relations'] = data['relations'].data[0]
|
||||
data['gt_bboxes'] = data['gt_bboxes'].data[0]
|
||||
data['texts'] = data['texts'].data[0]
|
||||
data['img'] = data['img'][0]
|
||||
data['img_metas'] = data['img_metas'][0]
|
||||
|
||||
if next(model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [device])[0]
|
||||
|
@ -125,9 +143,13 @@ def model_inference(model, imgs, batch_mode=False):
|
|||
results = model(return_loss=False, rescale=True, **data)
|
||||
|
||||
if not is_batch:
|
||||
if not return_data:
|
||||
return results[0]
|
||||
return results[0], datas[0]
|
||||
else:
|
||||
if not return_data:
|
||||
return results
|
||||
return results, datas
|
||||
|
||||
|
||||
def text_model_inference(model, input_sentence):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import warnings
|
||||
from os import path as osp
|
||||
|
||||
import numpy as np
|
||||
|
@ -28,15 +29,21 @@ class KIEDataset(BaseDataset):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
ann_file,
|
||||
loader,
|
||||
dict_file,
|
||||
ann_file=None,
|
||||
loader=None,
|
||||
dict_file=None,
|
||||
img_prefix='',
|
||||
pipeline=None,
|
||||
norm=10.,
|
||||
directed=False,
|
||||
test_mode=True,
|
||||
**kwargs):
|
||||
if ann_file is None and loader is None:
|
||||
warnings.warn(
|
||||
'KIEDataset is only initialized as a downstream demo task '
|
||||
'of text detection and recognition '
|
||||
'without an annotation file.', UserWarning)
|
||||
else:
|
||||
super().__init__(
|
||||
ann_file,
|
||||
loader,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
@ -5,288 +6,19 @@ from pathlib import Path
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.image.misc import tensor2imgs
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.utils.config import Config
|
||||
from mmdet.apis import init_detector
|
||||
|
||||
from mmocr.apis.inference import model_inference
|
||||
from mmocr.core.visualize import det_recog_show_result
|
||||
from mmocr.datasets.kie_dataset import KIEDataset
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.utils.box_util import stitch_boxes_into_lines
|
||||
|
||||
textdet_models = {
|
||||
'DB_r18': {
|
||||
'config': 'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
|
||||
},
|
||||
'DB_r50': {
|
||||
'config':
|
||||
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
|
||||
},
|
||||
'DRRG': {
|
||||
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
|
||||
},
|
||||
'FCE_IC15': {
|
||||
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
|
||||
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
|
||||
},
|
||||
'FCE_CTW_DCNv2': {
|
||||
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
|
||||
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
|
||||
},
|
||||
'MaskRCNN_CTW': {
|
||||
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
|
||||
'ckpt': 'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
|
||||
},
|
||||
'MaskRCNN_IC15': {
|
||||
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
|
||||
'ckpt':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
|
||||
},
|
||||
'MaskRCNN_IC17': {
|
||||
'config': 'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
|
||||
'ckpt':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
|
||||
},
|
||||
'PANet_CTW': {
|
||||
'config': 'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'panet/panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
|
||||
},
|
||||
'PANet_IC15': {
|
||||
'config': 'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'panet/panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
|
||||
},
|
||||
'PS_CTW': {
|
||||
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
|
||||
'ckpt': 'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
|
||||
},
|
||||
'PS_IC15': {
|
||||
'config': 'psenet/psenet_r50_fpnf_600e_icdar2015.py',
|
||||
'ckpt': 'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
|
||||
},
|
||||
'TextSnake': {
|
||||
'config': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt': 'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
|
||||
}
|
||||
}
|
||||
|
||||
textrecog_models = {
|
||||
'CRNN': {
|
||||
'config': 'crnn/crnn_academic_dataset.py',
|
||||
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
|
||||
},
|
||||
'SAR': {
|
||||
'config': 'sar/sar_r31_parallel_decoder_academic.py',
|
||||
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
|
||||
},
|
||||
'NRTR_1/16-1/8': {
|
||||
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
|
||||
},
|
||||
'NRTR_1/8-1/4': {
|
||||
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
|
||||
'ckpt': 'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
|
||||
},
|
||||
'RobustScanner': {
|
||||
'config': 'robust_scanner/robustscanner_r31_academic.py',
|
||||
'ckpt': 'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
|
||||
},
|
||||
'SEG': {
|
||||
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
|
||||
},
|
||||
'CRNN_TPS': {
|
||||
'config': 'tps/crnn_tps_academic_dataset.py',
|
||||
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Post processing function for end2end ocr
|
||||
def det_recog_pp(args, result):
|
||||
final_results = []
|
||||
for arr, output, export, det_recog_result in zip(args.arrays, args.output,
|
||||
args.export, result):
|
||||
if output or args.imshow:
|
||||
res_img = det_recog_show_result(
|
||||
arr, det_recog_result, out_file=output)
|
||||
if args.imshow:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
if not args.details:
|
||||
simple_res = {}
|
||||
simple_res['filename'] = det_recog_result['filename']
|
||||
simple_res['text'] = [
|
||||
x['text'] for x in det_recog_result['result']
|
||||
]
|
||||
final_result = simple_res
|
||||
else:
|
||||
final_result = det_recog_result
|
||||
if export:
|
||||
mmcv.dump(final_result, export, indent=4)
|
||||
if args.print_result:
|
||||
print(final_result, end='\n\n')
|
||||
final_results.append(final_result)
|
||||
return final_results
|
||||
|
||||
|
||||
# Post processing function for separate det/recog inference
|
||||
def single_pp(args, result, model):
|
||||
for arr, output, export, res in zip(args.arrays, args.output, args.export,
|
||||
result):
|
||||
if export:
|
||||
mmcv.dump(res, export, indent=4)
|
||||
if output or args.imshow:
|
||||
res_img = model.show_result(arr, res, out_file=output)
|
||||
if args.imshow:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
if args.print_result:
|
||||
print(res, end='\n\n')
|
||||
return result
|
||||
|
||||
|
||||
# End2end ocr inference pipeline
|
||||
def det_and_recog_inference(args, det_model, recog_model):
|
||||
end2end_res = []
|
||||
# Find bounding boxes in the images (text detection)
|
||||
det_result = single_inference(det_model, args.arrays, args.batch_mode,
|
||||
args.det_batch_size)
|
||||
bboxes_list = [res['boundary_result'] for res in det_result]
|
||||
|
||||
# For each bounding box, the image is cropped and sent to the recognition
|
||||
# model either one by one or all together depending on the batch_mode
|
||||
for filename, arr, bboxes in zip(args.filenames, args.arrays, bboxes_list):
|
||||
img_e2e_res = {}
|
||||
img_e2e_res['filename'] = filename
|
||||
img_e2e_res['result'] = []
|
||||
box_imgs = []
|
||||
for bbox in bboxes:
|
||||
box_res = {}
|
||||
box_res['box'] = [round(x) for x in bbox[:-1]]
|
||||
box_res['box_score'] = float(bbox[-1])
|
||||
box = bbox[:8]
|
||||
if len(bbox) > 9:
|
||||
min_x = min(bbox[0:-1:2])
|
||||
min_y = min(bbox[1:-1:2])
|
||||
max_x = max(bbox[0:-1:2])
|
||||
max_y = max(bbox[1:-1:2])
|
||||
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
|
||||
box_img = crop_img(arr, box)
|
||||
if args.batch_mode:
|
||||
box_imgs.append(box_img)
|
||||
else:
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
box_res['text'] = text
|
||||
box_res['text_score'] = text_score
|
||||
img_e2e_res['result'].append(box_res)
|
||||
|
||||
if args.batch_mode:
|
||||
recog_results = single_inference(recog_model, box_imgs, True,
|
||||
args.recog_batch_size)
|
||||
for i, recog_result in enumerate(recog_results):
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, (list, tuple)):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
img_e2e_res['result'][i]['text'] = text
|
||||
img_e2e_res['result'][i]['text_score'] = text_score
|
||||
|
||||
if args.merge:
|
||||
img_e2e_res['result'] = stitch_boxes_into_lines(
|
||||
img_e2e_res['result'], args.merge_xdist, 0.5)
|
||||
|
||||
end2end_res.append(img_e2e_res)
|
||||
return end2end_res
|
||||
|
||||
|
||||
# Separate det/recog inference pipeline
|
||||
def single_inference(model, arrays, batch_mode, batch_size):
|
||||
result = []
|
||||
if batch_mode:
|
||||
if batch_size == 0:
|
||||
result = model_inference(model, arrays, batch_mode=True)
|
||||
else:
|
||||
n = batch_size
|
||||
arr_chunks = [arrays[i:i + n] for i in range(0, len(arrays), n)]
|
||||
for chunk in arr_chunks:
|
||||
result.extend(model_inference(model, chunk, batch_mode=True))
|
||||
else:
|
||||
for arr in arrays:
|
||||
result.append(model_inference(model, arr, batch_mode=False))
|
||||
return result
|
||||
|
||||
|
||||
# Arguments pre-processing function
|
||||
def args_processing(args):
|
||||
# Check if the input is a list/tuple that
|
||||
# contains only np arrays or strings
|
||||
if isinstance(args.img, (list, tuple)):
|
||||
img_list = args.img
|
||||
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
|
||||
raise AssertionError('Images must be strings or numpy arrays')
|
||||
|
||||
# Create a list of the images
|
||||
if isinstance(args.img, str):
|
||||
img_path = Path(args.img)
|
||||
if img_path.is_dir():
|
||||
img_list = [str(x) for x in img_path.glob('*')]
|
||||
else:
|
||||
img_list = [str(img_path)]
|
||||
elif isinstance(args.img, np.ndarray):
|
||||
img_list = [args.img]
|
||||
|
||||
# Read all image(s) in advance to reduce wasted time
|
||||
# re-reading the images for vizualisation output
|
||||
args.arrays = [mmcv.imread(x) for x in img_list]
|
||||
|
||||
# Create a list of filenames (used for output images and result files)
|
||||
if isinstance(img_list[0], str):
|
||||
args.filenames = [str(Path(x).stem) for x in img_list]
|
||||
else:
|
||||
args.filenames = [str(x) for x in range(len(img_list))]
|
||||
|
||||
# If given an output argument, create a list of output image filenames
|
||||
num_res = len(img_list)
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
if output_path.is_dir():
|
||||
args.output = [
|
||||
str(output_path / f'out_{x}.png') for x in args.filenames
|
||||
]
|
||||
else:
|
||||
args.output = [str(args.output)]
|
||||
if args.batch_mode:
|
||||
raise AssertionError(
|
||||
'Output of multiple images inference must be a directory')
|
||||
else:
|
||||
args.output = [None] * num_res
|
||||
|
||||
# If given an export argument, create a list of
|
||||
# result filenames for each image
|
||||
if args.export:
|
||||
export_path = Path(args.export)
|
||||
args.export = [
|
||||
str(export_path / f'out_{x}.{args.export_format}')
|
||||
for x in args.filenames
|
||||
]
|
||||
else:
|
||||
args.export = [None] * num_res
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# Create an inference pipeline with parsed arguments
|
||||
def main():
|
||||
args = parse_args()
|
||||
ocr = MMOCR(**vars(args))
|
||||
ocr.readtext(**vars(args))
|
||||
from mmocr.utils.fileio import list_from_file
|
||||
|
||||
|
||||
# Parse CLI arguments
|
||||
|
@ -333,6 +65,23 @@ def parse_args():
|
|||
default='',
|
||||
help='Path to the custom checkpoint file of the selected recog model. '
|
||||
'It overrides the settings in recog')
|
||||
parser.add_argument(
|
||||
'--kie',
|
||||
type=str,
|
||||
default='',
|
||||
help='Pretrained key information extraction algorithm')
|
||||
parser.add_argument(
|
||||
'--kie-config',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom config file of the selected kie model. It'
|
||||
'overrides the settings in kie')
|
||||
parser.add_argument(
|
||||
'--kie-ckpt',
|
||||
type=str,
|
||||
default='',
|
||||
help='Path to the custom checkpoint file of the selected kie model. '
|
||||
'It overrides the settings in kie')
|
||||
parser.add_argument(
|
||||
'--config-dir',
|
||||
type=str,
|
||||
|
@ -418,11 +167,138 @@ class MMOCR:
|
|||
recog='SEG',
|
||||
recog_config='',
|
||||
recog_ckpt='',
|
||||
kie='',
|
||||
kie_config='',
|
||||
kie_ckpt='',
|
||||
config_dir=os.path.join(str(Path.cwd()), 'configs/'),
|
||||
device='cuda:0',
|
||||
**kwargs):
|
||||
|
||||
textdet_models = {
|
||||
'DB_r18': {
|
||||
'config':
|
||||
'dbnet/dbnet_r18_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/'
|
||||
'dbnet_r18_fpnc_sbn_1200e_icdar2015_20210329-ba3ab597.pth'
|
||||
},
|
||||
'DB_r50': {
|
||||
'config':
|
||||
'dbnet/dbnet_r50dcnv2_fpnc_1200e_icdar2015.py',
|
||||
'ckpt':
|
||||
'dbnet/'
|
||||
'dbnet_r50dcnv2_fpnc_sbn_1200e_icdar2015_20210325-91cef9af.pth'
|
||||
},
|
||||
'DRRG': {
|
||||
'config': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt': 'drrg/drrg_r50_fpn_unet_1200e_ctw1500-1abf4f67.pth'
|
||||
},
|
||||
'FCE_IC15': {
|
||||
'config': 'fcenet/fcenet_r50_fpn_1500e_icdar2015.py',
|
||||
'ckpt': 'fcenet/fcenet_r50_fpn_1500e_icdar2015-d435c061.pth'
|
||||
},
|
||||
'FCE_CTW_DCNv2': {
|
||||
'config': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500.py',
|
||||
'ckpt': 'fcenet/fcenet_r50dcnv2_fpn_1500e_ctw1500-05d740bb.pth'
|
||||
},
|
||||
'MaskRCNN_CTW': {
|
||||
'config':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_ctw1500.py',
|
||||
'ckpt':
|
||||
'maskrcnn/'
|
||||
'mask_rcnn_r50_fpn_160e_ctw1500_20210219-96497a76.pth'
|
||||
},
|
||||
'MaskRCNN_IC15': {
|
||||
'config':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2015.py',
|
||||
'ckpt':
|
||||
'maskrcnn/'
|
||||
'mask_rcnn_r50_fpn_160e_icdar2015_20210219-8eb340a3.pth'
|
||||
},
|
||||
'MaskRCNN_IC17': {
|
||||
'config':
|
||||
'maskrcnn/mask_rcnn_r50_fpn_160e_icdar2017.py',
|
||||
'ckpt':
|
||||
'maskrcnn/'
|
||||
'mask_rcnn_r50_fpn_160e_icdar2017_20210218-c6ec3ebb.pth'
|
||||
},
|
||||
'PANet_CTW': {
|
||||
'config':
|
||||
'panet/panet_r18_fpem_ffm_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'panet/'
|
||||
'panet_r18_fpem_ffm_sbn_600e_ctw1500_20210219-3b3a9aa3.pth'
|
||||
},
|
||||
'PANet_IC15': {
|
||||
'config':
|
||||
'panet/panet_r18_fpem_ffm_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'panet/'
|
||||
'panet_r18_fpem_ffm_sbn_600e_icdar2015_20210219-42dbe46a.pth'
|
||||
},
|
||||
'PS_CTW': {
|
||||
'config': 'psenet/psenet_r50_fpnf_600e_ctw1500.py',
|
||||
'ckpt':
|
||||
'psenet/psenet_r50_fpnf_600e_ctw1500_20210401-216fed50.pth'
|
||||
},
|
||||
'PS_IC15': {
|
||||
'config':
|
||||
'psenet/psenet_r50_fpnf_600e_icdar2015.py',
|
||||
'ckpt':
|
||||
'psenet/psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth'
|
||||
},
|
||||
'TextSnake': {
|
||||
'config':
|
||||
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
|
||||
}
|
||||
}
|
||||
|
||||
textrecog_models = {
|
||||
'CRNN': {
|
||||
'config': 'crnn/crnn_academic_dataset.py',
|
||||
'ckpt': 'crnn/crnn_academic-a723a1c5.pth'
|
||||
},
|
||||
'SAR': {
|
||||
'config': 'sar/sar_r31_parallel_decoder_academic.py',
|
||||
'ckpt': 'sar/sar_r31_parallel_decoder_academic-dba3a4a3.pth'
|
||||
},
|
||||
'NRTR_1/16-1/8': {
|
||||
'config': 'nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'ckpt': 'nrtr/nrtr_r31_academic_20210406-954db95e.pth'
|
||||
},
|
||||
'NRTR_1/8-1/4': {
|
||||
'config': 'nrtr/nrtr_r31_1by8_1by4_academic.py',
|
||||
'ckpt':
|
||||
'nrtr/nrtr_r31_1by8_1by4_academic_20210406-ce16e7cc.pth'
|
||||
},
|
||||
'RobustScanner': {
|
||||
'config': 'robust_scanner/robustscanner_r31_academic.py',
|
||||
'ckpt':
|
||||
'robust_scanner/robustscanner_r31_academic-5f05874f.pth'
|
||||
},
|
||||
'SEG': {
|
||||
'config': 'seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'ckpt': 'seg/seg_r31_1by16_fpnocr_academic-72235b11.pth'
|
||||
},
|
||||
'CRNN_TPS': {
|
||||
'config': 'tps/crnn_tps_academic_dataset.py',
|
||||
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
|
||||
}
|
||||
}
|
||||
|
||||
kie_models = {
|
||||
'SDMGR': {
|
||||
'config': 'sdmgr/sdmgr_unet16_60e_wildreceipt.py',
|
||||
'ckpt':
|
||||
'sdmgr/sdmgr_unet16_60e_wildreceipt_20210520-7489e6de.pth'
|
||||
}
|
||||
}
|
||||
|
||||
self.td = det
|
||||
self.tr = recog
|
||||
self.kie = kie
|
||||
self.device = device
|
||||
|
||||
# Check if the det/recog model choice is valid
|
||||
|
@ -432,7 +308,12 @@ class MMOCR:
|
|||
elif self.tr and self.tr not in textrecog_models:
|
||||
raise ValueError(self.tr,
|
||||
'is not a supported text recognition algorithm')
|
||||
elif self.kie and self.kie not in kie_models:
|
||||
raise ValueError(
|
||||
self.kie, 'is not a supported key information extraction'
|
||||
' algorithm')
|
||||
|
||||
self.detect_model = None
|
||||
if self.td:
|
||||
# Build detection model
|
||||
if not det_config:
|
||||
|
@ -444,9 +325,8 @@ class MMOCR:
|
|||
|
||||
self.detect_model = init_detector(
|
||||
det_config, det_ckpt, device=self.device)
|
||||
else:
|
||||
self.detect_model = None
|
||||
|
||||
self.recog_model = None
|
||||
if self.tr:
|
||||
# Build recognition model
|
||||
if not recog_config:
|
||||
|
@ -454,13 +334,27 @@ class MMOCR:
|
|||
config_dir, 'textrecog/',
|
||||
textrecog_models[self.tr]['config'])
|
||||
if not recog_ckpt:
|
||||
recog_ckpt = 'https://download.openmmlab.com/mmocr/'
|
||||
recog_ckpt = 'https://download.openmmlab.com/mmocr/' + \
|
||||
'textrecog/' + textrecog_models[self.tr]['ckpt']
|
||||
|
||||
self.recog_model = init_detector(
|
||||
recog_config, recog_ckpt, device=self.device)
|
||||
else:
|
||||
self.recog_model = None
|
||||
|
||||
self.kie_model = None
|
||||
if self.kie:
|
||||
# Build key information extraction model
|
||||
if not kie_config:
|
||||
kie_config = os.path.join(config_dir, 'kie/',
|
||||
kie_models[self.kie]['config'])
|
||||
if not kie_ckpt:
|
||||
kie_ckpt = 'https://download.openmmlab.com/mmocr/' + \
|
||||
'kie/' + kie_models[self.kie]['ckpt']
|
||||
|
||||
kie_cfg = Config.fromfile(kie_config)
|
||||
self.kie_model = build_detector(
|
||||
kie_cfg.model, test_cfg=kie_cfg.get('test_cfg'))
|
||||
self.kie_model.cfg = kie_cfg
|
||||
load_checkpoint(self.kie_model, kie_ckpt, map_location=self.device)
|
||||
|
||||
# Attribute check
|
||||
for model in list(filter(None, [self.recog_model, self.detect_model])):
|
||||
|
@ -490,25 +384,292 @@ class MMOCR:
|
|||
args = Namespace(**args)
|
||||
|
||||
# Input and output arguments processing
|
||||
args = args_processing(args)
|
||||
self._args_processing(args)
|
||||
self.args = args
|
||||
|
||||
pp_result = None
|
||||
|
||||
# Send args and models to the MMOCR model inference API
|
||||
# and call post-processing functions for the output
|
||||
if self.detect_model and self.recog_model:
|
||||
det_recog_result = det_and_recog_inference(args, self.detect_model,
|
||||
self.recog_model)
|
||||
pp_result = det_recog_pp(args, det_recog_result)
|
||||
det_recog_result = self.det_recog_kie_inference(
|
||||
self.detect_model, self.recog_model, kie_model=self.kie_model)
|
||||
pp_result = self.det_recog_pp(det_recog_result)
|
||||
else:
|
||||
for model in list(
|
||||
filter(None, [self.recog_model, self.detect_model])):
|
||||
result = single_inference(model, args.arrays, args.batch_mode,
|
||||
result = self.single_inference(model, args.arrays,
|
||||
args.batch_mode,
|
||||
args.single_batch_size)
|
||||
pp_result = single_pp(args, result, model)
|
||||
pp_result = self.single_pp(args, result, model)
|
||||
|
||||
return pp_result
|
||||
|
||||
# Post processing function for end2end ocr
|
||||
def det_recog_pp(self, result):
|
||||
final_results = []
|
||||
args = self.args
|
||||
for arr, output, export, det_recog_result in zip(
|
||||
args.arrays, args.output, args.export, result):
|
||||
if output or args.imshow:
|
||||
if self.kie_model:
|
||||
res_img = det_recog_show_result(arr, det_recog_result)
|
||||
else:
|
||||
res_img = det_recog_show_result(
|
||||
arr, det_recog_result, out_file=output)
|
||||
if args.imshow and not self.kie_model:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
if not args.details:
|
||||
simple_res = {}
|
||||
simple_res['filename'] = det_recog_result['filename']
|
||||
simple_res['text'] = [
|
||||
x['text'] for x in det_recog_result['result']
|
||||
]
|
||||
final_result = simple_res
|
||||
else:
|
||||
final_result = det_recog_result
|
||||
if export:
|
||||
mmcv.dump(final_result, export, indent=4)
|
||||
if args.print_result:
|
||||
print(final_result, end='\n\n')
|
||||
final_results.append(final_result)
|
||||
return final_results
|
||||
|
||||
# Post processing function for separate det/recog inference
|
||||
def single_pp(self, result, model):
|
||||
for arr, output, export, res in zip(self.args.arrays, self.args.output,
|
||||
self.args.export, result):
|
||||
if export:
|
||||
mmcv.dump(res, export, indent=4)
|
||||
if output or self.args.imshow:
|
||||
res_img = model.show_result(arr, res, out_file=output)
|
||||
if self.args.imshow:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
if self.args.print_result:
|
||||
print(res, end='\n\n')
|
||||
return result
|
||||
|
||||
def generate_kie_labels(self, result, boxes, class_list):
|
||||
idx_to_cls = {}
|
||||
if class_list is not None:
|
||||
for line in list_from_file(class_list):
|
||||
class_idx, class_label = line.strip().split()
|
||||
idx_to_cls[class_idx] = class_label
|
||||
|
||||
max_value, max_idx = torch.max(result['nodes'].detach().cpu(), -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
labels = []
|
||||
for i in range(len(boxes)):
|
||||
pred_label = str(node_pred_label[i])
|
||||
if pred_label in idx_to_cls:
|
||||
pred_label = idx_to_cls[pred_label]
|
||||
pred_score = node_pred_score[i]
|
||||
labels.append((pred_label, pred_score))
|
||||
return labels
|
||||
|
||||
def visualize_kie_output(self,
|
||||
model,
|
||||
data,
|
||||
result,
|
||||
out_file=None,
|
||||
show=False):
|
||||
"""Visualizes KIE output."""
|
||||
img_tensor = data['img'].data
|
||||
img_meta = data['img_metas'].data
|
||||
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
|
||||
img = tensor2imgs(img_tensor.unsqueeze(0),
|
||||
**img_meta['img_norm_cfg'])[0]
|
||||
h, w, _ = img_meta['img_shape']
|
||||
img_show = img[:h, :w, :]
|
||||
model.show_result(
|
||||
img_show, result, gt_bboxes, show=show, out_file=out_file)
|
||||
|
||||
# End2end ocr inference pipeline
|
||||
def det_recog_kie_inference(self, det_model, recog_model, kie_model=None):
|
||||
end2end_res = []
|
||||
# Find bounding boxes in the images (text detection)
|
||||
det_result = self.single_inference(det_model, self.args.arrays,
|
||||
self.args.batch_mode,
|
||||
self.args.det_batch_size)
|
||||
bboxes_list = [res['boundary_result'] for res in det_result]
|
||||
|
||||
if kie_model:
|
||||
kie_dataset = KIEDataset(
|
||||
dict_file=kie_model.cfg.data.test.dict_file)
|
||||
|
||||
# For each bounding box, the image is cropped and
|
||||
# sent to the recognition model either one by one
|
||||
# or all together depending on the batch_mode
|
||||
for filename, arr, bboxes, out_file in zip(self.args.filenames,
|
||||
self.args.arrays,
|
||||
bboxes_list,
|
||||
self.args.output):
|
||||
img_e2e_res = {}
|
||||
img_e2e_res['filename'] = filename
|
||||
img_e2e_res['result'] = []
|
||||
box_imgs = []
|
||||
for bbox in bboxes:
|
||||
box_res = {}
|
||||
box_res['box'] = [round(x) for x in bbox[:-1]]
|
||||
box_res['box_score'] = float(bbox[-1])
|
||||
box = bbox[:8]
|
||||
if len(bbox) > 9:
|
||||
min_x = min(bbox[0:-1:2])
|
||||
min_y = min(bbox[1:-1:2])
|
||||
max_x = max(bbox[0:-1:2])
|
||||
max_y = max(bbox[1:-1:2])
|
||||
box = [
|
||||
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
|
||||
]
|
||||
box_img = crop_img(arr, box)
|
||||
if self.args.batch_mode:
|
||||
box_imgs.append(box_img)
|
||||
else:
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
box_res['text'] = text
|
||||
box_res['text_score'] = text_score
|
||||
img_e2e_res['result'].append(box_res)
|
||||
|
||||
if self.args.batch_mode:
|
||||
recog_results = self.single_inference(
|
||||
recog_model, box_imgs, True, self.args.recog_batch_size)
|
||||
for i, recog_result in enumerate(recog_results):
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, (list, tuple)):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
img_e2e_res['result'][i]['text'] = text
|
||||
img_e2e_res['result'][i]['text_score'] = text_score
|
||||
|
||||
if self.args.merge:
|
||||
img_e2e_res['result'] = stitch_boxes_into_lines(
|
||||
img_e2e_res['result'], self.args.merge_xdist, 0.5)
|
||||
|
||||
if kie_model:
|
||||
annotations = copy.deepcopy(img_e2e_res['result'])
|
||||
# Customized for kie_dataset, which
|
||||
# assumes that boxes are represented by only 4 points
|
||||
for i, ann in enumerate(annotations):
|
||||
min_x = min(ann['box'][::2])
|
||||
min_y = min(ann['box'][1::2])
|
||||
max_x = max(ann['box'][::2])
|
||||
max_y = max(ann['box'][1::2])
|
||||
annotations[i]['box'] = [
|
||||
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y
|
||||
]
|
||||
ann_info = kie_dataset._parse_anno_info(annotations)
|
||||
kie_result, data = model_inference(
|
||||
kie_model,
|
||||
arr,
|
||||
ann=ann_info,
|
||||
return_data=True,
|
||||
batch_mode=self.args.batch_mode)
|
||||
# visualize KIE results
|
||||
self.visualize_kie_output(
|
||||
kie_model,
|
||||
data,
|
||||
kie_result,
|
||||
out_file=out_file,
|
||||
show=self.args.imshow)
|
||||
gt_bboxes = data['gt_bboxes'].data.numpy().tolist()
|
||||
labels = self.generate_kie_labels(kie_result, gt_bboxes,
|
||||
kie_model.class_list)
|
||||
for i in range(len(gt_bboxes)):
|
||||
img_e2e_res['result'][i]['label'] = labels[i][0]
|
||||
img_e2e_res['result'][i]['label_score'] = labels[i][1]
|
||||
|
||||
end2end_res.append(img_e2e_res)
|
||||
return end2end_res
|
||||
|
||||
# Separate det/recog inference pipeline
|
||||
def single_inference(self, model, arrays, batch_mode, batch_size):
|
||||
result = []
|
||||
if batch_mode:
|
||||
if batch_size == 0:
|
||||
result = model_inference(model, arrays, batch_mode=True)
|
||||
else:
|
||||
n = batch_size
|
||||
arr_chunks = [
|
||||
arrays[i:i + n] for i in range(0, len(arrays), n)
|
||||
]
|
||||
for chunk in arr_chunks:
|
||||
result.extend(
|
||||
model_inference(model, chunk, batch_mode=True))
|
||||
else:
|
||||
for arr in arrays:
|
||||
result.append(model_inference(model, arr, batch_mode=False))
|
||||
return result
|
||||
|
||||
# Arguments pre-processing function
|
||||
def _args_processing(self, args):
|
||||
# Check if the input is a list/tuple that
|
||||
# contains only np arrays or strings
|
||||
if isinstance(args.img, (list, tuple)):
|
||||
img_list = args.img
|
||||
if not all([isinstance(x, (np.ndarray, str)) for x in args.img]):
|
||||
raise AssertionError('Images must be strings or numpy arrays')
|
||||
|
||||
# Create a list of the images
|
||||
if isinstance(args.img, str):
|
||||
img_path = Path(args.img)
|
||||
if img_path.is_dir():
|
||||
img_list = [str(x) for x in img_path.glob('*')]
|
||||
else:
|
||||
img_list = [str(img_path)]
|
||||
elif isinstance(args.img, np.ndarray):
|
||||
img_list = [args.img]
|
||||
|
||||
# Read all image(s) in advance to reduce wasted time
|
||||
# re-reading the images for vizualisation output
|
||||
args.arrays = [mmcv.imread(x) for x in img_list]
|
||||
|
||||
# Create a list of filenames (used for output images and result files)
|
||||
if isinstance(img_list[0], str):
|
||||
args.filenames = [str(Path(x).stem) for x in img_list]
|
||||
else:
|
||||
args.filenames = [str(x) for x in range(len(img_list))]
|
||||
|
||||
# If given an output argument, create a list of output image filenames
|
||||
num_res = len(img_list)
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
if output_path.is_dir():
|
||||
args.output = [
|
||||
str(output_path / f'out_{x}.png') for x in args.filenames
|
||||
]
|
||||
else:
|
||||
args.output = [str(args.output)]
|
||||
if args.batch_mode:
|
||||
raise AssertionError('Output of multiple images inference'
|
||||
' must be a directory')
|
||||
else:
|
||||
args.output = [None] * num_res
|
||||
|
||||
# If given an export argument, create a list of
|
||||
# result filenames for each image
|
||||
if args.export:
|
||||
export_path = Path(args.export)
|
||||
args.export = [
|
||||
str(export_path / f'out_{x}.{args.export_format}')
|
||||
for x in args.filenames
|
||||
]
|
||||
else:
|
||||
args.export = [None] * num_res
|
||||
|
||||
return args
|
||||
|
||||
|
||||
# Create an inference pipeline with parsed arguments
|
||||
def main():
|
||||
args = parse_args()
|
||||
ocr = MMOCR(**vars(args))
|
||||
ocr.readtext(**vars(args))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue