Hbsun/end2end demo (#105)

* add end2end demo

* fix typo

* pad box

* fix bug of crnn

* fix polygon

* update docstring

* fix bug of polygon

* updare demo api

* fix except

* rename

* fix with comments
This commit is contained in:
Hongbin Sun 2021-04-22 20:42:42 +08:00 committed by GitHub
parent ce98d23ba2
commit 43dcb32d4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 372 additions and 95 deletions

121
demo/ocr_image_demo.py Normal file
View File

@ -0,0 +1,121 @@
import json
from argparse import ArgumentParser
import mmcv
from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.pipelines.crop import crop_img
def write_json(obj, fpath):
"""Write json object to file."""
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '), ensure_ascii=False)
def det_and_recog_inference(args, det_model, recog_model):
image_path = args.img
end2end_res = {'filename': image_path}
end2end_res['result'] = []
image = mmcv.imread(image_path)
det_result = model_inference(det_model, image)
bboxes = det_result['boundary_result']
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
box_img = crop_img(image, box)
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
end2end_res['result'].append(box_res)
return end2end_res
def main():
parser = ArgumentParser()
parser.add_argument('img', type=str, help='Input Image file.')
parser.add_argument(
'out_file', type=str, help='Output file name of the visualized image.')
parser.add_argument(
'--det-config',
type=str,
default='./configs/textdet/psenet/'
'psenet_r50_fpnf_600e_icdar2015.py',
help='Text detection config file.')
parser.add_argument(
'--det-ckpt',
type=str,
default='https://download.openmmlab.com/'
'mmocr/textdet/psenet/'
'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
help='Text detection checkpint file (local or url).')
parser.add_argument(
'--recog-config',
type=str,
default='./configs/textrecog/sar/'
'sar_r31_parallel_decoder_academic.py',
help='Text recognition config file.')
parser.add_argument(
'--recog-ckpt',
type=str,
default='https://download.openmmlab.com/'
'mmocr/textrecog/sar/'
'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
help='Text recognition checkpint file (local or url).')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
parser.add_argument(
'--imshow',
action='store_true',
help='Whether show image with OpenCV.')
args = parser.parse_args()
# build detect model
detect_model = init_detector(
args.det_config, args.det_ckpt, device=args.device)
if hasattr(detect_model, 'module'):
detect_model = detect_model.module
if detect_model.cfg.data.test['type'] == 'ConcatDataset':
detect_model.cfg.data.test.pipeline = \
detect_model.cfg.data.test['datasets'][0].pipeline
# build recog model
recog_model = init_detector(
args.recog_config, args.recog_ckpt, device=args.device)
if hasattr(recog_model, 'module'):
recog_model = recog_model.module
if recog_model.cfg.data.test['type'] == 'ConcatDataset':
recog_model.cfg.data.test.pipeline = \
recog_model.cfg.data.test['datasets'][0].pipeline
det_recog_result = det_and_recog_inference(args, detect_model, recog_model)
print(f'result: {det_recog_result}')
write_json(det_recog_result, args.out_file + '.json')
img = det_recog_show_result(args.img, det_recog_result)
mmcv.imwrite(img, args.out_file)
if args.imshow:
mmcv.imshow(img, 'predicted results')
if __name__ == '__main__':
main()

View File

@ -24,6 +24,14 @@ python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/dem
The predicted result will be saved as `demo/demo_text_det_pred.jpg`. The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
To end-to-end test a single image with both text detection and recognition,
```shell
python demo/ocr_image_demo.py demo/demo_text_det.jpg demo/output.jpg
```
The predicted result will be saved as `demo/output.jpg`.
### Test Multiple Images ### Test Multiple Images
```shell ```shell

View File

@ -26,9 +26,9 @@ def model_inference(model, img):
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
cfg = cfg.copy() cfg = cfg.copy()
# set loading pipeline type # set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline)
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
@ -51,14 +51,16 @@ def model_inference(model, img):
img_metas.data[0] for img_metas in data['img_metas'] img_metas.data[0] for img_metas in data['img_metas']
] ]
else: else:
data['img_metas'] = data['img_metas'].data[0] data['img_metas'] = data['img_metas'].data
# process img # process img
if isinstance(img, np.ndarray): if isinstance(data['img'], list):
data['img'] = [img.data[0] for img in data['img']] data['img'] = [img.data[0] for img in data['img']]
for idx, img in enumerate(data['img']): for idx, img in enumerate(data['img']):
if img.dim() == 3: if img.dim() == 3:
data['img'][idx] = img.unsqueeze(0) data['img'][idx] = img.unsqueeze(0)
else:
data['img_metas'] = data['img_metas'][0]
if next(model.parameters()).is_cuda: if next(model.parameters()).is_cuda:
# scatter to specified GPU # scatter to specified GPU

View File

@ -417,3 +417,113 @@ def imshow_edge_node(img,
mmcv.imwrite(vis_img, out_file) mmcv.imwrite(vis_img, out_file)
return vis_img return vis_img
def gen_color():
"""Generate BGR color schemes."""
color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249),
(123, 151, 138), (187, 200, 178), (148, 137, 69),
(169, 200, 200), (155, 175, 131), (154, 194, 182),
(178, 190, 137), (140, 211, 222), (83, 156, 222)]
return color_list
def draw_polygons(img, polys):
"""Draw polygons on image.
Args:
img (np.ndarray): The original image.
polys (list[list[float]]): Detected polygons.
Return:
out_img (np.ndarray): Visualized image.
"""
dst_img = img.copy()
color_list = gen_color()
out_img = dst_img
for idx, poly in enumerate(polys):
poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32)
cv2.drawContours(
img,
np.array([poly]),
-1,
color_list[idx % len(color_list)],
thickness=cv2.FILLED)
out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0)
return out_img
def get_optimal_font_scale(text, width):
"""Get optimal font scale for cv2.putText.
Args:
text (str): Text in one box.
width (int): The box width.
"""
for scale in reversed(range(0, 60, 1)):
textSize = cv2.getTextSize(
text,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=scale / 10,
thickness=1)
new_width = textSize[0][0]
if (new_width <= width):
return scale / 10
return 1
def draw_texts(img, boxes, texts):
"""Draw boxes and texts on empty img.
Args:
img (np.ndarray): The original image.
boxes (list[list[float]]): Detected bounding boxes.
texts (list[str]): Recognized texts.
Return:
out_img (np.ndarray): Visualized image.
"""
color_list = gen_color()
h, w = img.shape[:2]
out_img = np.ones((h, w, 3), dtype=np.uint8) * 255
for idx, (box, text) in enumerate(zip(boxes, texts)):
new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])]
Pts = np.array([new_box], np.int32)
cv2.polylines(
out_img, [Pts.reshape((-1, 1, 2))],
True,
color=color_list[idx % len(color_list)],
thickness=1)
min_x = int(min(box[0::2]))
max_y = int(
np.mean(np.array(box[1::2])) + 0.2 *
(max(box[1::2]) - min(box[1::2])))
font_scale = get_optimal_font_scale(
text, int(max(box[0::2]) - min(box[0::2])))
cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX,
font_scale, (0, 0, 0), 1)
return out_img
def det_recog_show_result(img, end2end_res):
"""Draw `result`(boxes and texts) on `img`.
Args:
img (str or np.ndarray): The image to be displayed.
end2end_res (dict): Text detect and recognize results.
Return:
out_img (np.ndarray): Visualized image.
"""
img = mmcv.imread(img)
boxes, texts = [], []
for res in end2end_res['result']:
boxes.append(res['box'])
texts.append(res['text'])
box_vis_img = draw_polygons(img, boxes)
text_vis_img = draw_texts(img, boxes, texts)
h, w = img.shape[:2]
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
out_img[:, :w, :] = box_vis_img
out_img[:, w:, :] = text_vis_img
return out_img

View File

@ -2,7 +2,7 @@ from .box_utils import sort_vertex
from .custom_format_bundle import CustomFormatBundle from .custom_format_bundle import CustomFormatBundle
from .dbnet_transforms import EastRandomCrop, ImgAug from .dbnet_transforms import EastRandomCrop, ImgAug
from .kie_transforms import KIEFormatBundle from .kie_transforms import KIEFormatBundle
from .loading import LoadTextAnnotations from .loading import LoadImageFromNdarray, LoadTextAnnotations
from .ocr_seg_targets import OCRSegTargets from .ocr_seg_targets import OCRSegTargets
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR,
@ -22,5 +22,5 @@ __all__ = [
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR', 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
'sort_vertex' 'sort_vertex', 'LoadImageFromNdarray'
] ]

View File

@ -83,24 +83,41 @@ def warp_img(src_img,
return dst_img return dst_img
def crop_img(src_img, box): def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
"""Crop box area to rectangle. """Crop text region with their bounding box.
Args: Args:
src_img (np.array): Image before crop. src_img (np.array): The original image.
box (list[float | int]): Points of quadrangle. box (list[float | int]): Points of quadrangle.
long_edge_pad_ratio (float): Box pad ratio for long edge
corresponding to font size.
short_edge_pad_ratio (float): Box pad ratio for short edge
corresponding to font size.
""" """
assert utils.is_type_list(box, float) or utils.is_type_list(box, int) assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
assert len(box) == 8 assert len(box) == 8
assert 0. <= long_edge_pad_ratio < 1.0
assert 0. <= short_edge_pad_ratio < 1.0
h, w = src_img.shape[:2] h, w = src_img.shape[:2]
points_x = [min(max(x, 0), w) for x in box[0:8:2]] points_x = np.clip(np.array(box[0::2]), 0, w)
points_y = [min(max(y, 0), h) for y in box[1:9:2]] points_y = np.clip(np.array(box[1::2]), 0, h)
left = int(min(points_x)) box_width = np.max(points_x) - np.min(points_x)
top = int(min(points_y)) box_height = np.max(points_y) - np.min(points_y)
right = int(max(points_x)) font_size = min(box_height, box_width)
bottom = int(max(points_y))
if box_height < box_width:
horizontal_pad = long_edge_pad_ratio * font_size
vertical_pad = short_edge_pad_ratio * font_size
else:
horizontal_pad = short_edge_pad_ratio * font_size
vertical_pad = long_edge_pad_ratio * font_size
left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
dst_img = src_img[top:bottom, left:right] dst_img = src_img[top:bottom, left:right]

View File

@ -1,8 +1,9 @@
import mmcv
import numpy as np import numpy as np
from mmdet.core import BitmapMasks, PolygonMasks from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.loading import LoadAnnotations from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
@PIPELINES.register_module() @PIPELINES.register_module()
@ -66,3 +67,40 @@ class LoadTextAnnotations(LoadAnnotations):
results['gt_masks'] = gt_masks results['gt_masks'] = gt_masks
results['mask_fields'].append('gt_masks') results['mask_fields'].append('gt_masks')
return results return results
@PIPELINES.register_module()
class LoadImageFromNdarray(LoadImageFromFile):
"""Load an image from np.ndarray.
Similar with :obj:`LoadImageFromFile`, but the image read from
``results['img']``, which is np.ndarray.
"""
def __call__(self, results):
"""Call functions to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
assert results['img'].dtype == 'uint8'
img = results['img']
if self.color_type == 'grayscale' and img.shape[2] == 3:
img = mmcv.bgr2gray(img, keepdim=True)
if self.color_type == 'color' and img.shape[2] == 1:
img = mmcv.gray2bgr(img)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = None
results['ori_filename'] = None
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results

View File

@ -1,6 +1,4 @@
import os import os
import shutil
import urllib
import pytest import pytest
from mmcv.image import imread from mmcv.image import imread
@ -9,85 +7,29 @@ from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference from mmocr.apis.inference import model_inference
@pytest.fixture @pytest.mark.parametrize('cfg_file', [
def project_dir(): '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
return os.path.abspath(os.path.dirname(os.path.dirname(__file__))) '../configs/textrecog/crnn/crnn_academic_dataset.py',
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
@pytest.fixture '../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
def sample_img_path(project_dir): '../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
return os.path.join(project_dir, '../demo/demo_text_recog.jpg') ])
def test_model_inference(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
@pytest.fixture config_file = os.path.join(tmp_dir, cfg_file)
def sample_det_img_path(project_dir):
return os.path.join(project_dir, '../demo/demo_text_det.jpg')
@pytest.fixture
def sarnet_model(project_dir):
print(project_dir)
config_file = os.path.join(
project_dir,
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py')
checkpoint_file = os.path.join(
project_dir,
'../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth')
if not os.path.exists(checkpoint_file):
url = ('https://download.openmmlab.com/mmocr'
'/textrecog/sar/'
'sar_r31_parallel_decoder_academic-dba3a4a3.pth')
print(f'Downloading {url} ...')
local_filename, _ = urllib.request.urlretrieve(url)
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
shutil.move(local_filename, checkpoint_file)
print(f'Saved as {checkpoint_file}')
else:
print(f'Using existing checkpoint {checkpoint_file}')
device = 'cpu'
model = init_detector(
config_file, checkpoint=checkpoint_file, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
return model
@pytest.fixture
def psenet_model(project_dir):
config_file = os.path.join(
project_dir,
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py')
device = 'cpu' device = 'cpu'
model = init_detector(config_file, checkpoint=None, device=device) model = init_detector(config_file, checkpoint=None, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset': if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline 0].pipeline
return model
def test_model_inference_image_path(sample_img_path, sarnet_model):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
model_inference(sarnet_model, 1) model_inference(model, 1)
model_inference(sarnet_model, sample_img_path) sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
model_inference(model, sample_img_path)
# numpy inference
def test_model_inference_image_path_det(sample_det_img_path, psenet_model):
model_inference(psenet_model, sample_det_img_path)
def test_model_inference_numpy_ndarray(sample_img_path, sarnet_model):
img = imread(sample_img_path) img = imread(sample_img_path)
model_inference(sarnet_model, img)
model_inference(model, img)
def test_model_inference_numpy_ndarray_det(sample_det_img_path, psenet_model):
det_img = imread(sample_det_img_path)
model_inference(psenet_model, det_img)

View File

@ -0,0 +1,20 @@
import numpy as np
from mmocr.core import det_recog_show_result
def test_det_recog_show_result():
img = np.ones((100, 100, 3), dtype=np.uint8) * 255
det_recog_res = {
'result': [{
'box': [51, 88, 51, 62, 85, 62, 85, 88],
'box_score': 0.9417,
'text': 'hell',
'text_score': 0.8834
}]
}
vis_img = det_recog_show_result(img, det_recog_res)
assert vis_img.shape[0] == 100
assert vis_img.shape[1] == 200
assert vis_img.shape[2] == 3

View File

@ -85,12 +85,21 @@ def test_min_rect_crop():
dummy_img = np.ones((600, 600, 3), dtype=np.uint8) dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40] dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
cropped_img = crop_img(dummy_img, dummy_box) cropped_img = crop_img(
dummy_img,
dummy_box,
0.,
0.,
)
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
crop_img(dummy_img, []) crop_img(dummy_img, [])
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
crop_img(dummy_img, [20, 40, 40, 20]) crop_img(dummy_img, [20, 40, 40, 20])
with pytest.raises(AssertionError):
crop_img(dummy_img, dummy_box, 4, 0.2)
with pytest.raises(AssertionError):
crop_img(dummy_img, dummy_box, 0.4, 1.2)
assert math.isclose(cropped_img.shape[0], 20) assert math.isclose(cropped_img.shape[0], 20)
assert math.isclose(cropped_img.shape[1], 100) assert math.isclose(cropped_img.shape[1], 100)

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from mmocr.datasets.pipelines import LoadTextAnnotations from mmocr.datasets.pipelines import LoadImageFromNdarray, LoadTextAnnotations
def _create_dummy_ann(): def _create_dummy_ann():
@ -36,3 +36,13 @@ def test_loadtextannotation():
assert len(output['gt_masks_ignore']) == 4 assert len(output['gt_masks_ignore']) == 4
assert np.allclose(output['gt_masks_ignore'].masks[0], assert np.allclose(output['gt_masks_ignore'].masks[0],
[[499, 94, 531, 94, 531, 124, 499, 124]]) [[499, 94, 531, 94, 531, 124, 499, 124]])
def test_load_img_from_numpy():
result = {'img': np.ones((32, 100, 3), dtype=np.uint8)}
load = LoadImageFromNdarray(color_type='color')
output = load(result)
assert output['img'].shape[2] == 3
assert len(output['img'].shape) == 3