mirror of https://github.com/open-mmlab/mmocr.git
[Feature] Add Tesserocr Inference (#814)
* append tesserocr to requirements list, but may encounter build error at windows platform * simply save * 2022.3.4 * opencv-python==4.5.5 can cause cv2.error when print_result=True * append MMOCR.tesseract_det_inference() * argument check append * fix lint error * update commentary * lint fix * requirement remove opencv * handle tessdata problem * support tesseract recognition * fix some bugs * fix imshow bug * support batch mode(fake) * modify annotation * refactor BaseRecognizer for show_result * append pytest * Mock tesseract * Fix test * remove \n from Tesseract * normalize text score * update docspull/810/head^2
parent
7c9093684e
commit
c79a62487d
|
@ -3,6 +3,7 @@
|
|||
We provide an easy-to-use API for the demo and application purpose in [ocr.py](https://github.com/open-mmlab/mmocr/blob/main/mmocr/utils/ocr.py) script.
|
||||
|
||||
The API can be called through command line (CL) or by calling it from another python script.
|
||||
It exposes all the models in MMOCR to API as individual modules that can be called and chained together. [Tesseract](https://tesseract-ocr.github.io/) is integrated as a text detector and/or recognizer in the task pipeline.
|
||||
|
||||
---
|
||||
|
||||
|
@ -138,7 +139,7 @@ The API has an extensive list of arguments that you can use. The following table
|
|||
**MMOCR():**
|
||||
|
||||
| Arguments | Type | Default | Description |
|
||||
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
|
||||
| -------------- | --------------------- | ---------- | ---------------------------------------------------------------------------------------------------- |
|
||||
| `det` | see [models](#models) | PANet_IC15 | Text detection algorithm |
|
||||
| `recog` | see [models](#models) | SAR | Text recognition algorithm |
|
||||
| `kie` [1] | see [models](#models) | None | Key information extraction algorithm |
|
||||
|
@ -195,7 +196,7 @@ means that `batch_mode` and `print_result` are set to `True`)
|
|||
**Text detection:**
|
||||
|
||||
| Name | Reference | `batch_mode` inference support |
|
||||
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------: |
|
||||
| DB_r18 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
|
||||
| DB_r50 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
|
||||
| DRRG | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: |
|
||||
|
@ -208,23 +209,25 @@ means that `batch_mode` and `print_result` are set to `True`)
|
|||
| PANet_IC15 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#efficient-and-accurate-arbitrary-shaped-text-detection-with-pixel-aggregation-network) | :heavy_check_mark: |
|
||||
| PS_CTW | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#psenet) | :x: |
|
||||
| PS_IC15 | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#psenet) | :x: |
|
||||
| Tesseract | [link](https://tesseract-ocr.github.io/) | :x: |
|
||||
| TextSnake | [link](https://mmocr.readthedocs.io/en/latest/textdet_models.html#textsnake) | :heavy_check_mark: |
|
||||
|
||||
**Text recognition:**
|
||||
|
||||
| Name | Reference | `batch_mode` inference support |
|
||||
| ------------- | :--------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| ------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------: |
|
||||
| ABINet | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#read-like-humans-autonomous-bidirectional-and-iterative-language-modeling-for-scene-text-recognition) | :heavy_check_mark: |
|
||||
| CRNN | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#an-end-to-end-trainable-neural-network-for-image-based-sequence-recognition-and-its-application-to-scene-text-recognition) | :x: |
|
||||
| SAR | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| SAR_CN * | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
|
||||
| NRTR_1/16-1/8 | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#nrtr) | :heavy_check_mark: |
|
||||
| NRTR_1/8-1/4 | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#nrtr) | :heavy_check_mark: |
|
||||
| RobustScanner | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#robustscanner-dynamically-enhancing-positional-clues-for-robust-text-recognition) | :heavy_check_mark: |
|
||||
| SAR | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| SAR_CN * | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| SATRN | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#satrn) | :heavy_check_mark: |
|
||||
| SATRN_sm | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#satrn) | :heavy_check_mark: |
|
||||
| SEG | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
|
||||
| CRNN_TPS | [link](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
|
||||
| Tesseract | [link](https://tesseract-ocr.github.io/) | :x: |
|
||||
|
||||
:::{warning}
|
||||
|
||||
|
@ -236,9 +239,8 @@ a Chinese dictionary. Please download the dictionary from [here](https://mmocr.r
|
|||
**Key information extraction:**
|
||||
|
||||
| Name | Reference | `batch_mode` support |
|
||||
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| ----- | :---------------------------------------------------------------------------------------------------------------------------------: | :------------------: |
|
||||
| SDMGR | [link](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
|
||||
---
|
||||
|
||||
## Additional info
|
||||
|
||||
|
@ -247,5 +249,6 @@ a Chinese dictionary. Please download the dictionary from [here](https://mmocr.r
|
|||
- To perform only recognition set the `det` argument to `None`.
|
||||
- `details` argument only works with end2end ocr.
|
||||
- `det_batch_size` and `recog_batch_size` arguments define the number of images you want to forward to the model at the same time. For maximum speed, set this to the highest number you can. The max batch size is limited by the model complexity and the GPU VRAM size.
|
||||
- MMOCR calls Tesseract's API via [`tesserocr`](https://github.com/sirfz/tesserocr)
|
||||
|
||||
If you have any suggestions for new features, feel free to open a thread or even PR :)
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
MMOCR 为示例和应用,以 [ocr.py](https://github.com/open-mmlab/mmocr/blob/main/mmocr/utils/ocr.py) 脚本形式,提供了方便使用的 API。
|
||||
|
||||
该 API 可以通过命令行执行,也可以在 python 脚本内调用。
|
||||
该 API 可以通过命令行执行,也可以在 python 脚本内调用。在该 API 里,MMOCR 里的所有模型能以独立模块的形式被调用或串联。它还支持将 [Tesseract](https://tesseract-ocr.github.io/) 作为文字检测或识别的一个组件调用。
|
||||
|
||||
---
|
||||
|
||||
|
@ -137,7 +137,7 @@ results = ocr.readtext('demo/demo_kie.jpeg', print_result=True, imshow=True)
|
|||
**MMOCR():**
|
||||
|
||||
| 参数 | 类型 | 默认值 | 描述 |
|
||||
| -------------- | --------------------- | ------------- | ----------------------------------------------------------- |
|
||||
| -------------- | ------------------ | ---------- | ---------------------------------------------------------------------------------------- |
|
||||
| `det` | 参考 **模型** 章节 | PANet_IC15 | 文本检测算法 |
|
||||
| `recog` | 参考 **模型** 章节 | SAR | 文本识别算法 |
|
||||
| `kie` [1] | 参考 **模型** 章节 | None | 关键信息提取算法 |
|
||||
|
@ -161,7 +161,7 @@ mmocr 为了方便使用提供了预置的模型配置和对应的预训练权
|
|||
### readtext()
|
||||
|
||||
| 参数 | 类型 | 默认值 | 描述 |
|
||||
| ------------------- | ----------------------- | ------------ | ---------------------------------------------------------------------- |
|
||||
| ------------------- | ----------------------- | -------- | --------------------------------------------------------------------- |
|
||||
| `img` | str/list/tuple/np.array | **必填** | 图像,文件夹路径,np array 或 list/tuple (包含图片路径或 np arrays) |
|
||||
| `output` | str | None | 可视化输出结果 - 图片路径或文件夹路径 |
|
||||
| `batch_mode` | bool | False | 是否使用批处理模式推理 [1] |
|
||||
|
@ -192,37 +192,39 @@ mmocr 为了方便使用提供了预置的模型配置和对应的预训练权
|
|||
|
||||
**文本检测:**
|
||||
|
||||
| 名称 | `batch_mode` 推理支持 |
|
||||
| ------------- | :------------------: |
|
||||
| [DB_r18](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
|
||||
| [DB_r50](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
|
||||
| [DRRG](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: |
|
||||
| [FCE_IC15](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
|
||||
| [FCE_CTW_DCNv2](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
|
||||
| [MaskRCNN_CTW](https://mmocr.readthedocs.io/en/latest/textdet_models.html#mask-r-cnn) | :x: |
|
||||
| [MaskRCNN_IC15](https://mmocr.readthedocs.io/en/latest/textdet_models.html#mask-r-cnn) | :x: |
|
||||
| [MaskRCNN_IC17](https://mmocr.readthedocs.io/en/latest/textdet_models.html#mask-r-cnn) | :x: |
|
||||
| [PANet_CTW](https://mmocr.readthedocs.io/en/latest/textdet_models.html#efficient-and-accurate-arbitrary-shaped-text-detection-with-pixel-aggregation-network) | :heavy_check_mark: |
|
||||
| [PANet_IC15](https://mmocr.readthedocs.io/en/latest/textdet_models.html#efficient-and-accurate-arbitrary-shaped-text-detection-with-pixel-aggregation-network) | :heavy_check_mark: |
|
||||
| [PS_CTW](https://mmocr.readthedocs.io/en/latest/textdet_models.html#psenet) | :x: |
|
||||
| [PS_IC15](https://mmocr.readthedocs.io/en/latest/textdet_models.html#psenet) | :x: |
|
||||
| [TextSnake](https://mmocr.readthedocs.io/en/latest/textdet_models.html#textsnake) | :heavy_check_mark: |
|
||||
| 名称 | 引用 | `batch_mode` 推理支持 |
|
||||
| ------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------: |
|
||||
| DB_r18 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
|
||||
| DB_r50 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#real-time-scene-text-detection-with-differentiable-binarization) | :x: |
|
||||
| DRRG | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#drrg) | :x: |
|
||||
| FCE_IC15 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
|
||||
| FCE_CTW_DCNv2 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#fourier-contour-embedding-for-arbitrary-shaped-text-detection) | :x: |
|
||||
| MaskRCNN_CTW | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#mask-r-cnn) | :x: |
|
||||
| MaskRCNN_IC15 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#mask-r-cnn) | :x: |
|
||||
| MaskRCNN_IC17 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#mask-r-cnn) | :x: |
|
||||
| PANet_CTW | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#efficient-and-accurate-arbitrary-shaped-text-detection-with-pixel-aggregation-network) | :heavy_check_mark: |
|
||||
| PANet_IC15 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#efficient-and-accurate-arbitrary-shaped-text-detection-with-pixel-aggregation-network) | :heavy_check_mark: |
|
||||
| PS_CTW | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#psenet) | :x: |
|
||||
| PS_IC15 | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#psenet) | :x: |
|
||||
| Tesseract | [链接](https://tesseract-ocr.github.io/) | :heavy_check_mark: |
|
||||
| TextSnake | [链接](https://mmocr.readthedocs.io/en/latest/textdet_models.html#textsnake) | :heavy_check_mark: |
|
||||
|
||||
**文本识别:**
|
||||
|
||||
| 名称 | `batch_mode` 推理支持 |
|
||||
| ------------- |:------------------: |
|
||||
| [ABINet](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#read-like-humans-autonomous-bidirectional-and-iterative-language-modeling-for-scene-text-recognition) | :heavy_check_mark: |
|
||||
| [CRNN](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#an-end-to-end-trainable-neural-network-for-image-based-sequence-recognition-and-its-application-to-scene-text-recognition) | :x: |
|
||||
| [SAR](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| [SAR_CN](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| [NRTR_1/16-1/8](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#nrtr) | :heavy_check_mark: |
|
||||
| [NRTR_1/8-1/4](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#nrtr) | :heavy_check_mark: |
|
||||
| [RobustScanner](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#robustscanner-dynamically-enhancing-positional-clues-for-robust-text-recognition) | :heavy_check_mark: |
|
||||
| [SATRN](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#satrn) | :heavy_check_mark: |
|
||||
| [SATRN_sm](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#satrn) | :heavy_check_mark: |
|
||||
| [SEG](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
|
||||
| [CRNN_TPS](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
|
||||
| 名称 | 引用 | `batch_mode` 推理支持 |
|
||||
| ------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------: |
|
||||
| ABINet | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#read-like-humans-autonomous-bidirectional-and-iterative-language-modeling-for-scene-text-recognition) | :heavy_check_mark: |
|
||||
| CRNN | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#an-end-to-end-trainable-neural-network-for-image-based-sequence-recognition-and-its-application-to-scene-text-recognition) | :x: |
|
||||
| CRNN_TPS | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#crnn-with-tps-based-stn) | :heavy_check_mark: |
|
||||
| NRTR_1/16-1/8 | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#nrtr) | :heavy_check_mark: |
|
||||
| NRTR_1/8-1/4 | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#nrtr) | :heavy_check_mark: |
|
||||
| RobustScanner | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#robustscanner-dynamically-enhancing-positional-clues-for-robust-text-recognition) | :heavy_check_mark: |
|
||||
| SAR | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| SAR_CN * | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#show-attend-and-read-a-simple-and-strong-baseline-for-irregular-text-recognition) | :heavy_check_mark: |
|
||||
| SATRN | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#satrn) | :heavy_check_mark: |
|
||||
| SATRN_sm | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#satrn) | :heavy_check_mark: |
|
||||
| SEG | [链接](https://mmocr.readthedocs.io/en/latest/textrecog_models.html#segocr-simple-baseline) | :x: |
|
||||
| Tesseract | [链接](https://tesseract-ocr.github.io/) | :heavy_check_mark: |
|
||||
|
||||
:::{note}
|
||||
|
||||
|
@ -233,9 +235,8 @@ SAR_CN 是唯一支持中文字符识别的模型,并且它需要一个中文
|
|||
**关键信息提取:**
|
||||
|
||||
| 名称 | `batch_mode` 支持 |
|
||||
| ------------- | :------------------: |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------ | :----------------: |
|
||||
| [SDMGR](https://mmocr.readthedocs.io/en/latest/kie_models.html#spatial-dual-modality-graph-reasoning-for-key-information-extraction) | :heavy_check_mark: |
|
||||
---
|
||||
|
||||
## 其他需要注意
|
||||
|
||||
|
@ -244,5 +245,6 @@ SAR_CN 是唯一支持中文字符识别的模型,并且它需要一个中文
|
|||
- 如果只需要执行识别,则 `det` 参数设置为 `None`。
|
||||
- `details` 参数仅在端到端的 ocr 模型有效。
|
||||
- `det_batch_size` 和 `recog_batch_size` 指定了在同时间传递给模型的图片数量。为了提高推理速度,应该尽可能设置你能设置的最大值。最大的批处理值受模型复杂度和 GPU 的显存大小限制。
|
||||
- MMOCR 目前通过 [`tesserocr`](https://github.com/sirfz/tesserocr) 调用 Tesseract 的 API.
|
||||
|
||||
如果你对新特性有任何建议,请随时开一个 issue,甚至可以提一个 PR:)
|
||||
|
|
|
@ -178,8 +178,8 @@ class BaseRecognizer(BaseModule, metaclass=ABCMeta):
|
|||
|
||||
return outputs
|
||||
|
||||
def show_result(self,
|
||||
img,
|
||||
@staticmethod
|
||||
def show_result(img,
|
||||
result,
|
||||
gt_label='',
|
||||
win_name='',
|
||||
|
|
|
@ -12,6 +12,12 @@ import torch
|
|||
from mmcv.image.misc import tensor2imgs
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.utils.config import Config
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import tesserocr
|
||||
except ImportError:
|
||||
tesserocr = None
|
||||
|
||||
from mmocr.apis import init_detector
|
||||
from mmocr.apis.inference import model_inference
|
||||
|
@ -19,6 +25,9 @@ from mmocr.core.visualize import det_recog_show_result
|
|||
from mmocr.datasets.kie_dataset import KIEDataset
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.models.textdet.detectors import TextDetectorMixin
|
||||
from mmocr.models.textrecog.recognizer import BaseRecognizer
|
||||
from mmocr.utils import is_type_list
|
||||
from mmocr.utils.box_util import stitch_boxes_into_lines
|
||||
from mmocr.utils.fileio import list_from_file
|
||||
from mmocr.utils.model import revert_sync_batchnorm
|
||||
|
@ -262,7 +271,8 @@ class MMOCR:
|
|||
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500.py',
|
||||
'ckpt':
|
||||
'textsnake/textsnake_r50_fpn_unet_1200e_ctw1500-27f65b64.pth'
|
||||
}
|
||||
},
|
||||
'Tesseract': {}
|
||||
}
|
||||
|
||||
textrecog_models = {
|
||||
|
@ -313,7 +323,8 @@ class MMOCR:
|
|||
'CRNN_TPS': {
|
||||
'config': 'tps/crnn_tps_academic_dataset.py',
|
||||
'ckpt': 'tps/crnn_tps_academic_dataset_20210510-d221a905.pth'
|
||||
}
|
||||
},
|
||||
'Tesseract': {}
|
||||
}
|
||||
|
||||
kie_models = {
|
||||
|
@ -350,7 +361,13 @@ class MMOCR:
|
|||
' with text detection and recognition algorithms.')
|
||||
|
||||
self.detect_model = None
|
||||
if self.td:
|
||||
if self.td and self.td == 'Tesseract':
|
||||
if tesserocr is None:
|
||||
raise ImportError('Please install tesserocr first. '
|
||||
'Check out the installation guide at '
|
||||
'https://github.com/sirfz/tesserocr')
|
||||
self.detect_model = 'Tesseract_det'
|
||||
elif self.td:
|
||||
# Build detection model
|
||||
if not det_config:
|
||||
det_config = os.path.join(config_dir, 'textdet/',
|
||||
|
@ -364,7 +381,13 @@ class MMOCR:
|
|||
self.detect_model = revert_sync_batchnorm(self.detect_model)
|
||||
|
||||
self.recog_model = None
|
||||
if self.tr:
|
||||
if self.tr and self.tr == 'Tesseract':
|
||||
if tesserocr is None:
|
||||
raise ImportError('Please install tesserocr first. '
|
||||
'Check out the installation guide at '
|
||||
'https://github.com/sirfz/tesserocr')
|
||||
self.recog_model = 'Tesseract_recog'
|
||||
elif self.tr:
|
||||
# Build recognition model
|
||||
if not recog_config:
|
||||
recog_config = os.path.join(
|
||||
|
@ -400,6 +423,107 @@ class MMOCR:
|
|||
if hasattr(model, 'module'):
|
||||
model = model.module
|
||||
|
||||
@staticmethod
|
||||
def get_tesserocr_api():
|
||||
"""Get tesserocr api depending on different platform."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
if sys.platform == 'linux':
|
||||
api = tesserocr.PyTessBaseAPI()
|
||||
elif sys.platform == 'win32':
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
'where tesseract', stdout=subprocess.PIPE, shell=True)
|
||||
s = p.communicate()[0].decode('utf-8').split('\\')
|
||||
path = s[:-1] + ['tessdata']
|
||||
tessdata_path = '/'.join(path)
|
||||
api = tesserocr.PyTessBaseAPI(path=tessdata_path)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
'Please install tesseract first.\n Check out the'
|
||||
' installation guide at'
|
||||
' https://github.com/UB-Mannheim/tesseract/wiki')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return api
|
||||
|
||||
def tesseract_det_inference(self, imgs, **kwargs):
|
||||
"""Inference image(s) with the tesseract detector.
|
||||
|
||||
Args:
|
||||
imgs (ndarray or list[ndarray]): image(s) to inference.
|
||||
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
is_batch = True
|
||||
if isinstance(imgs, np.ndarray):
|
||||
is_batch = False
|
||||
imgs = [imgs]
|
||||
assert is_type_list(imgs, np.ndarray)
|
||||
api = self.get_tesserocr_api()
|
||||
|
||||
# Get detection result using tesseract
|
||||
results = []
|
||||
for img in imgs:
|
||||
image = Image.fromarray(img)
|
||||
api.SetImage(image)
|
||||
boxes = api.GetComponentImages(tesserocr.RIL.TEXTLINE, True)
|
||||
boundaries = []
|
||||
for _, box, _, _ in boxes:
|
||||
min_x = box['x']
|
||||
min_y = box['y']
|
||||
max_x = box['x'] + box['w']
|
||||
max_y = box['y'] + box['h']
|
||||
boundary = [
|
||||
min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y, 1.0
|
||||
]
|
||||
boundaries.append(boundary)
|
||||
results.append({'boundary_result': boundaries})
|
||||
|
||||
# close tesserocr api
|
||||
api.End()
|
||||
|
||||
if not is_batch:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
def tesseract_recog_inference(self, imgs, **kwargs):
|
||||
"""Inference image(s) with the tesseract recognizer.
|
||||
|
||||
Args:
|
||||
imgs (ndarray or list[ndarray]): image(s) to inference.
|
||||
|
||||
Returns:
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
is_batch = True
|
||||
if isinstance(imgs, np.ndarray):
|
||||
is_batch = False
|
||||
imgs = [imgs]
|
||||
assert is_type_list(imgs, np.ndarray)
|
||||
api = self.get_tesserocr_api()
|
||||
|
||||
results = []
|
||||
for img in imgs:
|
||||
image = Image.fromarray(img)
|
||||
api.SetImage(image)
|
||||
api.SetRectangle(0, 0, img.shape[1], img.shape[0])
|
||||
# Remove beginning and trailing spaces from Tesseract
|
||||
text = api.GetUTF8Text().strip()
|
||||
conf = api.MeanTextConf() / 100
|
||||
results.append({'text': text, 'score': conf})
|
||||
|
||||
# close tesserocr api
|
||||
api.End()
|
||||
|
||||
if not is_batch:
|
||||
return results[0]
|
||||
else:
|
||||
return results
|
||||
|
||||
def readtext(self,
|
||||
img,
|
||||
output=None,
|
||||
|
@ -478,6 +602,13 @@ class MMOCR:
|
|||
if export:
|
||||
mmcv.dump(res, export, indent=4)
|
||||
if output or self.args.imshow:
|
||||
if model == 'Tesseract_det':
|
||||
res_img = TextDetectorMixin(show_score=False).show_result(
|
||||
arr, res, out_file=output)
|
||||
elif model == 'Tesseract_recog':
|
||||
res_img = BaseRecognizer.show_result(
|
||||
arr, res, out_file=output)
|
||||
else:
|
||||
res_img = model.show_result(arr, res, out_file=output)
|
||||
if self.args.imshow:
|
||||
mmcv.imshow(res_img, 'inference results')
|
||||
|
@ -566,6 +697,10 @@ class MMOCR:
|
|||
box_img = crop_img(arr, box)
|
||||
if self.args.batch_mode:
|
||||
box_imgs.append(box_img)
|
||||
else:
|
||||
if recog_model == 'Tesseract_recog':
|
||||
recog_result = self.single_inference(
|
||||
recog_model, box_img, batch_mode=True)
|
||||
else:
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
text = recog_result['text']
|
||||
|
@ -633,21 +768,29 @@ class MMOCR:
|
|||
|
||||
# Separate det/recog inference pipeline
|
||||
def single_inference(self, model, arrays, batch_mode, batch_size=0):
|
||||
|
||||
def inference(m, a, **kwargs):
|
||||
if model == 'Tesseract_det':
|
||||
return self.tesseract_det_inference(a)
|
||||
elif model == 'Tesseract_recog':
|
||||
return self.tesseract_recog_inference(a)
|
||||
else:
|
||||
return model_inference(m, a, **kwargs)
|
||||
|
||||
result = []
|
||||
if batch_mode:
|
||||
if batch_size == 0:
|
||||
result = model_inference(model, arrays, batch_mode=True)
|
||||
result = inference(model, arrays, batch_mode=True)
|
||||
else:
|
||||
n = batch_size
|
||||
arr_chunks = [
|
||||
arrays[i:i + n] for i in range(0, len(arrays), n)
|
||||
]
|
||||
for chunk in arr_chunks:
|
||||
result.extend(
|
||||
model_inference(model, chunk, batch_mode=True))
|
||||
result.extend(inference(model, chunk, batch_mode=True))
|
||||
else:
|
||||
for arr in arrays:
|
||||
result.append(model_inference(model, arr, batch_mode=False))
|
||||
result.append(inference(model, arr, batch_mode=False))
|
||||
return result
|
||||
|
||||
# Arguments pre-processing function
|
||||
|
|
|
@ -369,3 +369,52 @@ def test_readtext(mock_kiedataset):
|
|||
with mock.patch('mmocr.utils.ocr.stitch_boxes_into_lines') as mock_merge:
|
||||
mmocr_det_recog.readtext(toy_imgs, merge=True)
|
||||
assert mock_merge.call_count == len(toy_imgs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('det, recog, target',
|
||||
[('Tesseract', None, {
|
||||
'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 1.0]]
|
||||
}),
|
||||
('Tesseract', 'Tesseract', {
|
||||
'result': [{
|
||||
'box': [0, 0, 1, 0, 1, 1, 0, 1],
|
||||
'box_score': 1.0,
|
||||
'text': 'text',
|
||||
'text_score': 0.5
|
||||
}]
|
||||
}),
|
||||
(None, 'Tesseract', {
|
||||
'text': 'text',
|
||||
'score': 0.5
|
||||
})])
|
||||
@mock.patch('mmocr.utils.ocr.init_detector')
|
||||
@mock.patch('mmocr.utils.ocr.tesserocr')
|
||||
def test_tesseract_wrapper(mock_tesserocr, mock_init_detector, det, recog,
|
||||
target):
|
||||
|
||||
def init_detector_skip_ckpt(config, ckpt, device):
|
||||
return init_detector(config, device=device)
|
||||
|
||||
mock_init_detector.side_effect = init_detector_skip_ckpt
|
||||
mock_tesseract = mock.Mock()
|
||||
mock_tesseract.GetComponentImages.return_value = [(None, {
|
||||
'x': 0,
|
||||
'y': 0,
|
||||
'w': 1,
|
||||
'h': 1
|
||||
}, 0, None)]
|
||||
mock_tesseract.GetUTF8Text.return_value = 'text'
|
||||
mock_tesseract.MeanTextConf.return_value = 50
|
||||
mock_tesserocr.PyTessBaseAPI.return_value = mock_tesseract
|
||||
|
||||
mmocr = MMOCR(det=det, recog=recog, device='cpu')
|
||||
|
||||
img_path = 'demo/demo_kie.jpeg'
|
||||
|
||||
# Test imshow
|
||||
with mock.patch('mmocr.utils.ocr.mmcv.imshow') as mock_imshow:
|
||||
result = mmocr.readtext(img_path, imshow=True, details=True)
|
||||
for k, v in target.items():
|
||||
assert result[0][k] == v
|
||||
mock_imshow.assert_called_once()
|
||||
mock_imshow.reset_mock()
|
||||
|
|
Loading…
Reference in New Issue