Hbsun/end2end demo (#105)

* add end2end demo

* fix typo

* pad box

* fix bug of crnn

* fix polygon

* update docstring

* fix bug of polygon

* updare demo api

* fix except

* rename

* fix with comments
pull/120/head
Hongbin Sun 2021-04-22 20:42:42 +08:00 committed by GitHub
parent ce98d23ba2
commit 43dcb32d4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 372 additions and 95 deletions

View File

@ -0,0 +1,121 @@
import json
from argparse import ArgumentParser
import mmcv
from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference
from mmocr.core.visualize import det_recog_show_result
from mmocr.datasets.pipelines.crop import crop_img
def write_json(obj, fpath):
"""Write json object to file."""
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '), ensure_ascii=False)
def det_and_recog_inference(args, det_model, recog_model):
image_path = args.img
end2end_res = {'filename': image_path}
end2end_res['result'] = []
image = mmcv.imread(image_path)
det_result = model_inference(det_model, image)
bboxes = det_result['boundary_result']
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
box_res['box_score'] = float(bbox[-1])
box = bbox[:8]
if len(bbox) > 9:
min_x = min(bbox[0:-1:2])
min_y = min(bbox[1:-1:2])
max_x = max(bbox[0:-1:2])
max_y = max(bbox[1:-1:2])
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
box_img = crop_img(image, box)
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
end2end_res['result'].append(box_res)
return end2end_res
def main():
parser = ArgumentParser()
parser.add_argument('img', type=str, help='Input Image file.')
parser.add_argument(
'out_file', type=str, help='Output file name of the visualized image.')
parser.add_argument(
'--det-config',
type=str,
default='./configs/textdet/psenet/'
'psenet_r50_fpnf_600e_icdar2015.py',
help='Text detection config file.')
parser.add_argument(
'--det-ckpt',
type=str,
default='https://download.openmmlab.com/'
'mmocr/textdet/psenet/'
'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth',
help='Text detection checkpint file (local or url).')
parser.add_argument(
'--recog-config',
type=str,
default='./configs/textrecog/sar/'
'sar_r31_parallel_decoder_academic.py',
help='Text recognition config file.')
parser.add_argument(
'--recog-ckpt',
type=str,
default='https://download.openmmlab.com/'
'mmocr/textrecog/sar/'
'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
help='Text recognition checkpint file (local or url).')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
parser.add_argument(
'--imshow',
action='store_true',
help='Whether show image with OpenCV.')
args = parser.parse_args()
# build detect model
detect_model = init_detector(
args.det_config, args.det_ckpt, device=args.device)
if hasattr(detect_model, 'module'):
detect_model = detect_model.module
if detect_model.cfg.data.test['type'] == 'ConcatDataset':
detect_model.cfg.data.test.pipeline = \
detect_model.cfg.data.test['datasets'][0].pipeline
# build recog model
recog_model = init_detector(
args.recog_config, args.recog_ckpt, device=args.device)
if hasattr(recog_model, 'module'):
recog_model = recog_model.module
if recog_model.cfg.data.test['type'] == 'ConcatDataset':
recog_model.cfg.data.test.pipeline = \
recog_model.cfg.data.test['datasets'][0].pipeline
det_recog_result = det_and_recog_inference(args, detect_model, recog_model)
print(f'result: {det_recog_result}')
write_json(det_recog_result, args.out_file + '.json')
img = det_recog_show_result(args.img, det_recog_result)
mmcv.imwrite(img, args.out_file)
if args.imshow:
mmcv.imshow(img, 'predicted results')
if __name__ == '__main__':
main()

View File

@ -24,6 +24,14 @@ python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/dem
The predicted result will be saved as `demo/demo_text_det_pred.jpg`.
To end-to-end test a single image with both text detection and recognition,
```shell
python demo/ocr_image_demo.py demo/demo_text_det.jpg demo/output.jpg
```
The predicted result will be saved as `demo/output.jpg`.
### Test Multiple Images
```shell

View File

@ -26,9 +26,9 @@ def model_inference(model, img):
if isinstance(img, np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
if isinstance(img, np.ndarray):
@ -51,14 +51,16 @@ def model_inference(model, img):
img_metas.data[0] for img_metas in data['img_metas']
]
else:
data['img_metas'] = data['img_metas'].data[0]
data['img_metas'] = data['img_metas'].data
# process img
if isinstance(img, np.ndarray):
if isinstance(data['img'], list):
data['img'] = [img.data[0] for img in data['img']]
for idx, img in enumerate(data['img']):
if img.dim() == 3:
data['img'][idx] = img.unsqueeze(0)
for idx, img in enumerate(data['img']):
if img.dim() == 3:
data['img'][idx] = img.unsqueeze(0)
else:
data['img_metas'] = data['img_metas'][0]
if next(model.parameters()).is_cuda:
# scatter to specified GPU

View File

@ -417,3 +417,113 @@ def imshow_edge_node(img,
mmcv.imwrite(vis_img, out_file)
return vis_img
def gen_color():
"""Generate BGR color schemes."""
color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249),
(123, 151, 138), (187, 200, 178), (148, 137, 69),
(169, 200, 200), (155, 175, 131), (154, 194, 182),
(178, 190, 137), (140, 211, 222), (83, 156, 222)]
return color_list
def draw_polygons(img, polys):
"""Draw polygons on image.
Args:
img (np.ndarray): The original image.
polys (list[list[float]]): Detected polygons.
Return:
out_img (np.ndarray): Visualized image.
"""
dst_img = img.copy()
color_list = gen_color()
out_img = dst_img
for idx, poly in enumerate(polys):
poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32)
cv2.drawContours(
img,
np.array([poly]),
-1,
color_list[idx % len(color_list)],
thickness=cv2.FILLED)
out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0)
return out_img
def get_optimal_font_scale(text, width):
"""Get optimal font scale for cv2.putText.
Args:
text (str): Text in one box.
width (int): The box width.
"""
for scale in reversed(range(0, 60, 1)):
textSize = cv2.getTextSize(
text,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=scale / 10,
thickness=1)
new_width = textSize[0][0]
if (new_width <= width):
return scale / 10
return 1
def draw_texts(img, boxes, texts):
"""Draw boxes and texts on empty img.
Args:
img (np.ndarray): The original image.
boxes (list[list[float]]): Detected bounding boxes.
texts (list[str]): Recognized texts.
Return:
out_img (np.ndarray): Visualized image.
"""
color_list = gen_color()
h, w = img.shape[:2]
out_img = np.ones((h, w, 3), dtype=np.uint8) * 255
for idx, (box, text) in enumerate(zip(boxes, texts)):
new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])]
Pts = np.array([new_box], np.int32)
cv2.polylines(
out_img, [Pts.reshape((-1, 1, 2))],
True,
color=color_list[idx % len(color_list)],
thickness=1)
min_x = int(min(box[0::2]))
max_y = int(
np.mean(np.array(box[1::2])) + 0.2 *
(max(box[1::2]) - min(box[1::2])))
font_scale = get_optimal_font_scale(
text, int(max(box[0::2]) - min(box[0::2])))
cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX,
font_scale, (0, 0, 0), 1)
return out_img
def det_recog_show_result(img, end2end_res):
"""Draw `result`(boxes and texts) on `img`.
Args:
img (str or np.ndarray): The image to be displayed.
end2end_res (dict): Text detect and recognize results.
Return:
out_img (np.ndarray): Visualized image.
"""
img = mmcv.imread(img)
boxes, texts = [], []
for res in end2end_res['result']:
boxes.append(res['box'])
texts.append(res['text'])
box_vis_img = draw_polygons(img, boxes)
text_vis_img = draw_texts(img, boxes, texts)
h, w = img.shape[:2]
out_img = np.ones((h, w * 2, 3), dtype=np.uint8)
out_img[:, :w, :] = box_vis_img
out_img[:, w:, :] = text_vis_img
return out_img

View File

@ -2,7 +2,7 @@ from .box_utils import sort_vertex
from .custom_format_bundle import CustomFormatBundle
from .dbnet_transforms import EastRandomCrop, ImgAug
from .kie_transforms import KIEFormatBundle
from .loading import LoadTextAnnotations
from .loading import LoadImageFromNdarray, LoadTextAnnotations
from .ocr_seg_targets import OCRSegTargets
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR,
@ -22,5 +22,5 @@ __all__ = [
'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR',
'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil',
'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets',
'sort_vertex'
'sort_vertex', 'LoadImageFromNdarray'
]

View File

@ -83,24 +83,41 @@ def warp_img(src_img,
return dst_img
def crop_img(src_img, box):
"""Crop box area to rectangle.
def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2):
"""Crop text region with their bounding box.
Args:
src_img (np.array): Image before crop.
src_img (np.array): The original image.
box (list[float | int]): Points of quadrangle.
long_edge_pad_ratio (float): Box pad ratio for long edge
corresponding to font size.
short_edge_pad_ratio (float): Box pad ratio for short edge
corresponding to font size.
"""
assert utils.is_type_list(box, float) or utils.is_type_list(box, int)
assert len(box) == 8
assert 0. <= long_edge_pad_ratio < 1.0
assert 0. <= short_edge_pad_ratio < 1.0
h, w = src_img.shape[:2]
points_x = [min(max(x, 0), w) for x in box[0:8:2]]
points_y = [min(max(y, 0), h) for y in box[1:9:2]]
points_x = np.clip(np.array(box[0::2]), 0, w)
points_y = np.clip(np.array(box[1::2]), 0, h)
left = int(min(points_x))
top = int(min(points_y))
right = int(max(points_x))
bottom = int(max(points_y))
box_width = np.max(points_x) - np.min(points_x)
box_height = np.max(points_y) - np.min(points_y)
font_size = min(box_height, box_width)
if box_height < box_width:
horizontal_pad = long_edge_pad_ratio * font_size
vertical_pad = short_edge_pad_ratio * font_size
else:
horizontal_pad = short_edge_pad_ratio * font_size
vertical_pad = long_edge_pad_ratio * font_size
left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w)
top = np.clip(int(np.min(points_y) - vertical_pad), 0, h)
right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w)
bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h)
dst_img = src_img[top:bottom, left:right]

View File

@ -1,8 +1,9 @@
import mmcv
import numpy as np
from mmdet.core import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines.loading import LoadAnnotations
from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile
@PIPELINES.register_module()
@ -66,3 +67,40 @@ class LoadTextAnnotations(LoadAnnotations):
results['gt_masks'] = gt_masks
results['mask_fields'].append('gt_masks')
return results
@PIPELINES.register_module()
class LoadImageFromNdarray(LoadImageFromFile):
"""Load an image from np.ndarray.
Similar with :obj:`LoadImageFromFile`, but the image read from
``results['img']``, which is np.ndarray.
"""
def __call__(self, results):
"""Call functions to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
assert results['img'].dtype == 'uint8'
img = results['img']
if self.color_type == 'grayscale' and img.shape[2] == 3:
img = mmcv.bgr2gray(img, keepdim=True)
if self.color_type == 'color' and img.shape[2] == 1:
img = mmcv.gray2bgr(img)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = None
results['ori_filename'] = None
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results

View File

@ -1,6 +1,4 @@
import os
import shutil
import urllib
import pytest
from mmcv.image import imread
@ -9,85 +7,29 @@ from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference
@pytest.fixture
def project_dir():
return os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
@pytest.fixture
def sample_img_path(project_dir):
return os.path.join(project_dir, '../demo/demo_text_recog.jpg')
@pytest.fixture
def sample_det_img_path(project_dir):
return os.path.join(project_dir, '../demo/demo_text_det.jpg')
@pytest.fixture
def sarnet_model(project_dir):
print(project_dir)
config_file = os.path.join(
project_dir,
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py')
checkpoint_file = os.path.join(
project_dir,
'../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth')
if not os.path.exists(checkpoint_file):
url = ('https://download.openmmlab.com/mmocr'
'/textrecog/sar/'
'sar_r31_parallel_decoder_academic-dba3a4a3.pth')
print(f'Downloading {url} ...')
local_filename, _ = urllib.request.urlretrieve(url)
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
shutil.move(local_filename, checkpoint_file)
print(f'Saved as {checkpoint_file}')
else:
print(f'Using existing checkpoint {checkpoint_file}')
device = 'cpu'
model = init_detector(
config_file, checkpoint=checkpoint_file, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
return model
@pytest.fixture
def psenet_model(project_dir):
config_file = os.path.join(
project_dir,
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py')
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
'../configs/textrecog/crnn/crnn_academic_dataset.py',
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
])
def test_model_inference(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
device = 'cpu'
model = init_detector(config_file, checkpoint=None, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
return model
def test_model_inference_image_path(sample_img_path, sarnet_model):
with pytest.raises(AssertionError):
model_inference(sarnet_model, 1)
model_inference(model, 1)
model_inference(sarnet_model, sample_img_path)
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
model_inference(model, sample_img_path)
def test_model_inference_image_path_det(sample_det_img_path, psenet_model):
model_inference(psenet_model, sample_det_img_path)
def test_model_inference_numpy_ndarray(sample_img_path, sarnet_model):
# numpy inference
img = imread(sample_img_path)
model_inference(sarnet_model, img)
def test_model_inference_numpy_ndarray_det(sample_det_img_path, psenet_model):
det_img = imread(sample_det_img_path)
model_inference(psenet_model, det_img)
model_inference(model, img)

View File

@ -0,0 +1,20 @@
import numpy as np
from mmocr.core import det_recog_show_result
def test_det_recog_show_result():
img = np.ones((100, 100, 3), dtype=np.uint8) * 255
det_recog_res = {
'result': [{
'box': [51, 88, 51, 62, 85, 62, 85, 88],
'box_score': 0.9417,
'text': 'hell',
'text_score': 0.8834
}]
}
vis_img = det_recog_show_result(img, det_recog_res)
assert vis_img.shape[0] == 100
assert vis_img.shape[1] == 200
assert vis_img.shape[2] == 3

View File

@ -85,12 +85,21 @@ def test_min_rect_crop():
dummy_img = np.ones((600, 600, 3), dtype=np.uint8)
dummy_box = [20, 20, 120, 20, 120, 40, 20, 40]
cropped_img = crop_img(dummy_img, dummy_box)
cropped_img = crop_img(
dummy_img,
dummy_box,
0.,
0.,
)
with pytest.raises(AssertionError):
crop_img(dummy_img, [])
with pytest.raises(AssertionError):
crop_img(dummy_img, [20, 40, 40, 20])
with pytest.raises(AssertionError):
crop_img(dummy_img, dummy_box, 4, 0.2)
with pytest.raises(AssertionError):
crop_img(dummy_img, dummy_box, 0.4, 1.2)
assert math.isclose(cropped_img.shape[0], 20)
assert math.isclose(cropped_img.shape[1], 100)

View File

@ -1,6 +1,6 @@
import numpy as np
from mmocr.datasets.pipelines import LoadTextAnnotations
from mmocr.datasets.pipelines import LoadImageFromNdarray, LoadTextAnnotations
def _create_dummy_ann():
@ -36,3 +36,13 @@ def test_loadtextannotation():
assert len(output['gt_masks_ignore']) == 4
assert np.allclose(output['gt_masks_ignore'].masks[0],
[[499, 94, 531, 94, 531, 124, 499, 124]])
def test_load_img_from_numpy():
result = {'img': np.ones((32, 100, 3), dtype=np.uint8)}
load = LoadImageFromNdarray(color_type='color')
output = load(result)
assert output['img'].shape[2] == 3
assert len(output['img'].shape) == 3