From 43dcb32d4ff43adb46360e6cd7b9ade87aba6cd3 Mon Sep 17 00:00:00 2001 From: Hongbin Sun Date: Thu, 22 Apr 2021 20:42:42 +0800 Subject: [PATCH] Hbsun/end2end demo (#105) * add end2end demo * fix typo * pad box * fix bug of crnn * fix polygon * update docstring * fix bug of polygon * updare demo api * fix except * rename * fix with comments --- demo/ocr_image_demo.py | 121 ++++++++++++++++++++++++ docs/getting_started.md | 8 ++ mmocr/apis/inference.py | 16 ++-- mmocr/core/visualize.py | 110 +++++++++++++++++++++ mmocr/datasets/pipelines/__init__.py | 4 +- mmocr/datasets/pipelines/crop.py | 35 +++++-- mmocr/datasets/pipelines/loading.py | 40 +++++++- tests/test_apis/test_model_inference.py | 90 ++++-------------- tests/test_core/test_end2end_vis.py | 20 ++++ tests/test_dataset/test_crop.py | 11 ++- tests/test_dataset/test_loading.py | 12 ++- 11 files changed, 372 insertions(+), 95 deletions(-) create mode 100644 demo/ocr_image_demo.py create mode 100644 tests/test_core/test_end2end_vis.py diff --git a/demo/ocr_image_demo.py b/demo/ocr_image_demo.py new file mode 100644 index 00000000..4cbc3886 --- /dev/null +++ b/demo/ocr_image_demo.py @@ -0,0 +1,121 @@ +import json +from argparse import ArgumentParser + +import mmcv + +from mmdet.apis import init_detector +from mmocr.apis.inference import model_inference +from mmocr.core.visualize import det_recog_show_result +from mmocr.datasets.pipelines.crop import crop_img + + +def write_json(obj, fpath): + """Write json object to file.""" + with open(fpath, 'w') as f: + json.dump(obj, f, indent=4, separators=(',', ': '), ensure_ascii=False) + + +def det_and_recog_inference(args, det_model, recog_model): + image_path = args.img + end2end_res = {'filename': image_path} + end2end_res['result'] = [] + + image = mmcv.imread(image_path) + det_result = model_inference(det_model, image) + bboxes = det_result['boundary_result'] + + for bbox in bboxes: + box_res = {} + box_res['box'] = [round(x) for x in bbox[:-1]] + box_res['box_score'] = float(bbox[-1]) + box = bbox[:8] + if len(bbox) > 9: + min_x = min(bbox[0:-1:2]) + min_y = min(bbox[1:-1:2]) + max_x = max(bbox[0:-1:2]) + max_y = max(bbox[1:-1:2]) + box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y] + box_img = crop_img(image, box) + + recog_result = model_inference(recog_model, box_img) + + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, list): + text_score = sum(text_score) / max(1, len(text)) + box_res['text'] = text + box_res['text_score'] = text_score + end2end_res['result'].append(box_res) + + return end2end_res + + +def main(): + parser = ArgumentParser() + parser.add_argument('img', type=str, help='Input Image file.') + parser.add_argument( + 'out_file', type=str, help='Output file name of the visualized image.') + parser.add_argument( + '--det-config', + type=str, + default='./configs/textdet/psenet/' + 'psenet_r50_fpnf_600e_icdar2015.py', + help='Text detection config file.') + parser.add_argument( + '--det-ckpt', + type=str, + default='https://download.openmmlab.com/' + 'mmocr/textdet/psenet/' + 'psenet_r50_fpnf_600e_icdar2015_pretrain-eefd8fe6.pth', + help='Text detection checkpint file (local or url).') + parser.add_argument( + '--recog-config', + type=str, + default='./configs/textrecog/sar/' + 'sar_r31_parallel_decoder_academic.py', + help='Text recognition config file.') + parser.add_argument( + '--recog-ckpt', + type=str, + default='https://download.openmmlab.com/' + 'mmocr/textrecog/sar/' + 'sar_r31_parallel_decoder_academic-dba3a4a3.pth', + help='Text recognition checkpint file (local or url).') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference.') + parser.add_argument( + '--imshow', + action='store_true', + help='Whether show image with OpenCV.') + args = parser.parse_args() + + # build detect model + detect_model = init_detector( + args.det_config, args.det_ckpt, device=args.device) + if hasattr(detect_model, 'module'): + detect_model = detect_model.module + if detect_model.cfg.data.test['type'] == 'ConcatDataset': + detect_model.cfg.data.test.pipeline = \ + detect_model.cfg.data.test['datasets'][0].pipeline + + # build recog model + recog_model = init_detector( + args.recog_config, args.recog_ckpt, device=args.device) + if hasattr(recog_model, 'module'): + recog_model = recog_model.module + if recog_model.cfg.data.test['type'] == 'ConcatDataset': + recog_model.cfg.data.test.pipeline = \ + recog_model.cfg.data.test['datasets'][0].pipeline + + det_recog_result = det_and_recog_inference(args, detect_model, recog_model) + print(f'result: {det_recog_result}') + write_json(det_recog_result, args.out_file + '.json') + + img = det_recog_show_result(args.img, det_recog_result) + mmcv.imwrite(img, args.out_file) + if args.imshow: + mmcv.imshow(img, 'predicted results') + + +if __name__ == '__main__': + main() diff --git a/docs/getting_started.md b/docs/getting_started.md index c281a2db..6c93b3f4 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -24,6 +24,14 @@ python demo/image_demo.py demo/demo_text_det.jpg configs/xxx.py xxx.pth demo/dem The predicted result will be saved as `demo/demo_text_det_pred.jpg`. +To end-to-end test a single image with both text detection and recognition, + +```shell +python demo/ocr_image_demo.py demo/demo_text_det.jpg demo/output.jpg +``` + +The predicted result will be saved as `demo/output.jpg`. + ### Test Multiple Images ```shell diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index 458281f1..34bf28b3 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -26,9 +26,9 @@ def model_inference(model, img): if isinstance(img, np.ndarray): cfg = cfg.copy() # set loading pipeline type - cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' - cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray' + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) if isinstance(img, np.ndarray): @@ -51,14 +51,16 @@ def model_inference(model, img): img_metas.data[0] for img_metas in data['img_metas'] ] else: - data['img_metas'] = data['img_metas'].data[0] + data['img_metas'] = data['img_metas'].data # process img - if isinstance(img, np.ndarray): + if isinstance(data['img'], list): data['img'] = [img.data[0] for img in data['img']] - for idx, img in enumerate(data['img']): - if img.dim() == 3: - data['img'][idx] = img.unsqueeze(0) + for idx, img in enumerate(data['img']): + if img.dim() == 3: + data['img'][idx] = img.unsqueeze(0) + else: + data['img_metas'] = data['img_metas'][0] if next(model.parameters()).is_cuda: # scatter to specified GPU diff --git a/mmocr/core/visualize.py b/mmocr/core/visualize.py index f1965970..4f607d8b 100644 --- a/mmocr/core/visualize.py +++ b/mmocr/core/visualize.py @@ -417,3 +417,113 @@ def imshow_edge_node(img, mmcv.imwrite(vis_img, out_file) return vis_img + + +def gen_color(): + """Generate BGR color schemes.""" + color_list = [(101, 67, 254), (154, 157, 252), (173, 205, 249), + (123, 151, 138), (187, 200, 178), (148, 137, 69), + (169, 200, 200), (155, 175, 131), (154, 194, 182), + (178, 190, 137), (140, 211, 222), (83, 156, 222)] + return color_list + + +def draw_polygons(img, polys): + """Draw polygons on image. + + Args: + img (np.ndarray): The original image. + polys (list[list[float]]): Detected polygons. + Return: + out_img (np.ndarray): Visualized image. + """ + dst_img = img.copy() + color_list = gen_color() + out_img = dst_img + for idx, poly in enumerate(polys): + poly = np.array(poly).reshape((-1, 1, 2)).astype(np.int32) + cv2.drawContours( + img, + np.array([poly]), + -1, + color_list[idx % len(color_list)], + thickness=cv2.FILLED) + out_img = cv2.addWeighted(dst_img, 0.5, img, 0.5, 0) + return out_img + + +def get_optimal_font_scale(text, width): + """Get optimal font scale for cv2.putText. + + Args: + text (str): Text in one box. + width (int): The box width. + """ + for scale in reversed(range(0, 60, 1)): + textSize = cv2.getTextSize( + text, + fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=scale / 10, + thickness=1) + new_width = textSize[0][0] + if (new_width <= width): + return scale / 10 + return 1 + + +def draw_texts(img, boxes, texts): + """Draw boxes and texts on empty img. + + Args: + img (np.ndarray): The original image. + boxes (list[list[float]]): Detected bounding boxes. + texts (list[str]): Recognized texts. + Return: + out_img (np.ndarray): Visualized image. + """ + color_list = gen_color() + h, w = img.shape[:2] + out_img = np.ones((h, w, 3), dtype=np.uint8) * 255 + for idx, (box, text) in enumerate(zip(boxes, texts)): + new_box = [[x, y] for x, y in zip(box[0::2], box[1::2])] + Pts = np.array([new_box], np.int32) + cv2.polylines( + out_img, [Pts.reshape((-1, 1, 2))], + True, + color=color_list[idx % len(color_list)], + thickness=1) + min_x = int(min(box[0::2])) + max_y = int( + np.mean(np.array(box[1::2])) + 0.2 * + (max(box[1::2]) - min(box[1::2]))) + font_scale = get_optimal_font_scale( + text, int(max(box[0::2]) - min(box[0::2]))) + cv2.putText(out_img, text, (min_x, max_y), cv2.FONT_HERSHEY_SIMPLEX, + font_scale, (0, 0, 0), 1) + + return out_img + + +def det_recog_show_result(img, end2end_res): + """Draw `result`(boxes and texts) on `img`. + Args: + img (str or np.ndarray): The image to be displayed. + end2end_res (dict): Text detect and recognize results. + + Return: + out_img (np.ndarray): Visualized image. + """ + img = mmcv.imread(img) + boxes, texts = [], [] + for res in end2end_res['result']: + boxes.append(res['box']) + texts.append(res['text']) + box_vis_img = draw_polygons(img, boxes) + text_vis_img = draw_texts(img, boxes, texts) + + h, w = img.shape[:2] + out_img = np.ones((h, w * 2, 3), dtype=np.uint8) + out_img[:, :w, :] = box_vis_img + out_img[:, w:, :] = text_vis_img + + return out_img diff --git a/mmocr/datasets/pipelines/__init__.py b/mmocr/datasets/pipelines/__init__.py index 4eca60a3..9ebe9336 100644 --- a/mmocr/datasets/pipelines/__init__.py +++ b/mmocr/datasets/pipelines/__init__.py @@ -2,7 +2,7 @@ from .box_utils import sort_vertex from .custom_format_bundle import CustomFormatBundle from .dbnet_transforms import EastRandomCrop, ImgAug from .kie_transforms import KIEFormatBundle -from .loading import LoadTextAnnotations +from .loading import LoadImageFromNdarray, LoadTextAnnotations from .ocr_seg_targets import OCRSegTargets from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR, OpencvToPil, PilToOpencv, RandomPaddingOCR, @@ -22,5 +22,5 @@ __all__ = [ 'RandomCropPolyInstances', 'RandomRotatePolyInstances', 'RandomPaddingOCR', 'ImgAug', 'EastRandomCrop', 'RandomRotateImageBox', 'OpencvToPil', 'PilToOpencv', 'KIEFormatBundle', 'SquareResizePad', 'TextSnakeTargets', - 'sort_vertex' + 'sort_vertex', 'LoadImageFromNdarray' ] diff --git a/mmocr/datasets/pipelines/crop.py b/mmocr/datasets/pipelines/crop.py index eea0ffbb..90127578 100644 --- a/mmocr/datasets/pipelines/crop.py +++ b/mmocr/datasets/pipelines/crop.py @@ -83,24 +83,41 @@ def warp_img(src_img, return dst_img -def crop_img(src_img, box): - """Crop box area to rectangle. +def crop_img(src_img, box, long_edge_pad_ratio=0.4, short_edge_pad_ratio=0.2): + """Crop text region with their bounding box. Args: - src_img (np.array): Image before crop. + src_img (np.array): The original image. box (list[float | int]): Points of quadrangle. + long_edge_pad_ratio (float): Box pad ratio for long edge + corresponding to font size. + short_edge_pad_ratio (float): Box pad ratio for short edge + corresponding to font size. """ assert utils.is_type_list(box, float) or utils.is_type_list(box, int) assert len(box) == 8 + assert 0. <= long_edge_pad_ratio < 1.0 + assert 0. <= short_edge_pad_ratio < 1.0 h, w = src_img.shape[:2] - points_x = [min(max(x, 0), w) for x in box[0:8:2]] - points_y = [min(max(y, 0), h) for y in box[1:9:2]] + points_x = np.clip(np.array(box[0::2]), 0, w) + points_y = np.clip(np.array(box[1::2]), 0, h) - left = int(min(points_x)) - top = int(min(points_y)) - right = int(max(points_x)) - bottom = int(max(points_y)) + box_width = np.max(points_x) - np.min(points_x) + box_height = np.max(points_y) - np.min(points_y) + font_size = min(box_height, box_width) + + if box_height < box_width: + horizontal_pad = long_edge_pad_ratio * font_size + vertical_pad = short_edge_pad_ratio * font_size + else: + horizontal_pad = short_edge_pad_ratio * font_size + vertical_pad = long_edge_pad_ratio * font_size + + left = np.clip(int(np.min(points_x) - horizontal_pad), 0, w) + top = np.clip(int(np.min(points_y) - vertical_pad), 0, h) + right = np.clip(int(np.max(points_x) + horizontal_pad), 0, w) + bottom = np.clip(int(np.max(points_y) + vertical_pad), 0, h) dst_img = src_img[top:bottom, left:right] diff --git a/mmocr/datasets/pipelines/loading.py b/mmocr/datasets/pipelines/loading.py index 5c3cda6e..d931edd1 100644 --- a/mmocr/datasets/pipelines/loading.py +++ b/mmocr/datasets/pipelines/loading.py @@ -1,8 +1,9 @@ +import mmcv import numpy as np from mmdet.core import BitmapMasks, PolygonMasks from mmdet.datasets.builder import PIPELINES -from mmdet.datasets.pipelines.loading import LoadAnnotations +from mmdet.datasets.pipelines.loading import LoadAnnotations, LoadImageFromFile @PIPELINES.register_module() @@ -66,3 +67,40 @@ class LoadTextAnnotations(LoadAnnotations): results['gt_masks'] = gt_masks results['mask_fields'].append('gt_masks') return results + + +@PIPELINES.register_module() +class LoadImageFromNdarray(LoadImageFromFile): + """Load an image from np.ndarray. + + Similar with :obj:`LoadImageFromFile`, but the image read from + ``results['img']``, which is np.ndarray. + """ + + def __call__(self, results): + """Call functions to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + assert results['img'].dtype == 'uint8' + + img = results['img'] + if self.color_type == 'grayscale' and img.shape[2] == 3: + img = mmcv.bgr2gray(img, keepdim=True) + if self.color_type == 'color' and img.shape[2] == 1: + img = mmcv.gray2bgr(img) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = None + results['ori_filename'] = None + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + results['img_fields'] = ['img'] + return results diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py index a96fff56..a08541e2 100644 --- a/tests/test_apis/test_model_inference.py +++ b/tests/test_apis/test_model_inference.py @@ -1,6 +1,4 @@ import os -import shutil -import urllib import pytest from mmcv.image import imread @@ -9,85 +7,29 @@ from mmdet.apis import init_detector from mmocr.apis.inference import model_inference -@pytest.fixture -def project_dir(): - return os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - - -@pytest.fixture -def sample_img_path(project_dir): - return os.path.join(project_dir, '../demo/demo_text_recog.jpg') - - -@pytest.fixture -def sample_det_img_path(project_dir): - return os.path.join(project_dir, '../demo/demo_text_det.jpg') - - -@pytest.fixture -def sarnet_model(project_dir): - print(project_dir) - config_file = os.path.join( - project_dir, - '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py') - checkpoint_file = os.path.join( - project_dir, - '../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth') - - if not os.path.exists(checkpoint_file): - url = ('https://download.openmmlab.com/mmocr' - '/textrecog/sar/' - 'sar_r31_parallel_decoder_academic-dba3a4a3.pth') - print(f'Downloading {url} ...') - local_filename, _ = urllib.request.urlretrieve(url) - os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True) - shutil.move(local_filename, checkpoint_file) - print(f'Saved as {checkpoint_file}') - else: - print(f'Using existing checkpoint {checkpoint_file}') - - device = 'cpu' - model = init_detector( - config_file, checkpoint=checkpoint_file, device=device) - if model.cfg.data.test['type'] == 'ConcatDataset': - model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ - 0].pipeline - - return model - - -@pytest.fixture -def psenet_model(project_dir): - config_file = os.path.join( - project_dir, - '../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py') - +@pytest.mark.parametrize('cfg_file', [ + '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py', + '../configs/textrecog/crnn/crnn_academic_dataset.py', + '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py', + '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py', + '../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py', + '../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py' +]) +def test_model_inference(cfg_file): + tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + config_file = os.path.join(tmp_dir, cfg_file) device = 'cpu' model = init_detector(config_file, checkpoint=None, device=device) if model.cfg.data.test['type'] == 'ConcatDataset': model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][ 0].pipeline - - return model - - -def test_model_inference_image_path(sample_img_path, sarnet_model): - with pytest.raises(AssertionError): - model_inference(sarnet_model, 1) + model_inference(model, 1) - model_inference(sarnet_model, sample_img_path) + sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg') + model_inference(model, sample_img_path) - -def test_model_inference_image_path_det(sample_det_img_path, psenet_model): - model_inference(psenet_model, sample_det_img_path) - - -def test_model_inference_numpy_ndarray(sample_img_path, sarnet_model): + # numpy inference img = imread(sample_img_path) - model_inference(sarnet_model, img) - -def test_model_inference_numpy_ndarray_det(sample_det_img_path, psenet_model): - det_img = imread(sample_det_img_path) - model_inference(psenet_model, det_img) + model_inference(model, img) diff --git a/tests/test_core/test_end2end_vis.py b/tests/test_core/test_end2end_vis.py new file mode 100644 index 00000000..7d22af3e --- /dev/null +++ b/tests/test_core/test_end2end_vis.py @@ -0,0 +1,20 @@ +import numpy as np + +from mmocr.core import det_recog_show_result + + +def test_det_recog_show_result(): + img = np.ones((100, 100, 3), dtype=np.uint8) * 255 + det_recog_res = { + 'result': [{ + 'box': [51, 88, 51, 62, 85, 62, 85, 88], + 'box_score': 0.9417, + 'text': 'hell', + 'text_score': 0.8834 + }] + } + + vis_img = det_recog_show_result(img, det_recog_res) + assert vis_img.shape[0] == 100 + assert vis_img.shape[1] == 200 + assert vis_img.shape[2] == 3 diff --git a/tests/test_dataset/test_crop.py b/tests/test_dataset/test_crop.py index a50ea54a..1db731a1 100644 --- a/tests/test_dataset/test_crop.py +++ b/tests/test_dataset/test_crop.py @@ -85,12 +85,21 @@ def test_min_rect_crop(): dummy_img = np.ones((600, 600, 3), dtype=np.uint8) dummy_box = [20, 20, 120, 20, 120, 40, 20, 40] - cropped_img = crop_img(dummy_img, dummy_box) + cropped_img = crop_img( + dummy_img, + dummy_box, + 0., + 0., + ) with pytest.raises(AssertionError): crop_img(dummy_img, []) with pytest.raises(AssertionError): crop_img(dummy_img, [20, 40, 40, 20]) + with pytest.raises(AssertionError): + crop_img(dummy_img, dummy_box, 4, 0.2) + with pytest.raises(AssertionError): + crop_img(dummy_img, dummy_box, 0.4, 1.2) assert math.isclose(cropped_img.shape[0], 20) assert math.isclose(cropped_img.shape[1], 100) diff --git a/tests/test_dataset/test_loading.py b/tests/test_dataset/test_loading.py index 04e593ab..4b45de06 100644 --- a/tests/test_dataset/test_loading.py +++ b/tests/test_dataset/test_loading.py @@ -1,6 +1,6 @@ import numpy as np -from mmocr.datasets.pipelines import LoadTextAnnotations +from mmocr.datasets.pipelines import LoadImageFromNdarray, LoadTextAnnotations def _create_dummy_ann(): @@ -36,3 +36,13 @@ def test_loadtextannotation(): assert len(output['gt_masks_ignore']) == 4 assert np.allclose(output['gt_masks_ignore'].masks[0], [[499, 94, 531, 94, 531, 124, 499, 124]]) + + +def test_load_img_from_numpy(): + result = {'img': np.ones((32, 100, 3), dtype=np.uint8)} + + load = LoadImageFromNdarray(color_type='color') + output = load(result) + + assert output['img'].shape[2] == 3 + assert len(output['img'].shape) == 3