From 97e2f27017c829dc719e48cbda7da40d03349d05 Mon Sep 17 00:00:00 2001 From: Hongbin Sun Date: Thu, 13 May 2021 15:18:00 +0800 Subject: [PATCH] fix #173: support aug test (#178) * fix #173: support aug test * fix pytest * support batch inference for both text det and recog * update unittest * use one img for batch infer test --- demo/batch_image_demo.py | 6 +- demo/ocr_image_demo.py | 47 +++++++++++--- mmocr/apis/inference.py | 48 +++++++++++++-- mmocr/models/textrecog/recognizer/base.py | 7 ++- tests/test_apis/test_model_inference.py | 74 +++++++++++++---------- 5 files changed, 132 insertions(+), 50 deletions(-) diff --git a/demo/batch_image_demo.py b/demo/batch_image_demo.py index 0a6c9bc7..245405fc 100644 --- a/demo/batch_image_demo.py +++ b/demo/batch_image_demo.py @@ -18,9 +18,7 @@ def main(): '--images', nargs='+', help='Image files to be predicted with batch mode, ' - 'separated by space, like "image_1.jpg image2.jpg". ' - 'If algorithm use augmentation test, only one ' - 'image file can be given.') + 'separated by space, like "image_1.jpg image2.jpg".') parser.add_argument( '--device', default='cuda:0', help='Device used for inference.') parser.add_argument( @@ -36,7 +34,7 @@ def main(): 0].pipeline # test multiple images - results = model_inference(model, args.images) + results = model_inference(model, args.images, batch_mode=True) print(f'results: {results}') save_path = Path(args.save_path) diff --git a/demo/ocr_image_demo.py b/demo/ocr_image_demo.py index 494869d2..27b83d39 100644 --- a/demo/ocr_image_demo.py +++ b/demo/ocr_image_demo.py @@ -17,6 +17,7 @@ def det_and_recog_inference(args, det_model, recog_model): det_result = model_inference(det_model, image) bboxes = det_result['boundary_result'] + box_imgs = [] for bbox in bboxes: box_res = {} box_res['box'] = [round(x) for x in bbox[:-1]] @@ -29,17 +30,37 @@ def det_and_recog_inference(args, det_model, recog_model): max_y = max(bbox[1:-1:2]) box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y] box_img = crop_img(image, box) + if args.batch_mode: + box_imgs.append(box_img) + else: + recog_result = model_inference(recog_model, box_img) + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, list): + text_score = sum(text_score) / max(1, len(text)) + box_res['text'] = text + box_res['text_score'] = text_score - recog_result = model_inference(recog_model, box_img) - - text = recog_result['text'] - text_score = recog_result['score'] - if isinstance(text_score, list): - text_score = sum(text_score) / max(1, len(text)) - box_res['text'] = text - box_res['text_score'] = text_score end2end_res['result'].append(box_res) + if args.batch_mode: + batch_size = args.batch_size + for chunk_idx in range(len(box_imgs) // batch_size + 1): + start_idx = chunk_idx * batch_size + end_idx = (chunk_idx + 1) * batch_size + chunk_box_imgs = box_imgs[start_idx:end_idx] + if len(chunk_box_imgs) == 0: + continue + recog_results = model_inference( + recog_model, chunk_box_imgs, batch_mode=True) + for i, recog_result in enumerate(recog_results): + text = recog_result['text'] + text_score = recog_result['score'] + if isinstance(text_score, list): + text_score = sum(text_score) / max(1, len(text)) + end2end_res['result'][start_idx + i]['text'] = text + end2end_res['result'][start_idx + i]['text_score'] = text_score + return end2end_res @@ -74,6 +95,16 @@ def main(): 'mmocr/textrecog/sar/' 'sar_r31_parallel_decoder_academic-dba3a4a3.pth', help='Text recognition checkpint file (local or url).') + parser.add_argument( + '--batch-mode', + action='store_true', + help='Whether use batch mode for text recognition.') + parser.add_argument( + '--batch-size', + type=int, + default=4, + help='Batch size for text recognition inference ' + 'if batch_mode is True above.') parser.add_argument( '--device', default='cuda:0', help='Device used for inference.') parser.add_argument( diff --git a/mmocr/apis/inference.py b/mmocr/apis/inference.py index be955156..670ab1cf 100644 --- a/mmocr/apis/inference.py +++ b/mmocr/apis/inference.py @@ -7,16 +7,33 @@ from mmdet.datasets import replace_ImageToTensor from mmdet.datasets.pipelines import Compose -def model_inference(model, imgs): +def disable_text_recog_aug_test(cfg): + """Remove aug_test from test pipeline of text recognition. + Args: + cfg (mmcv.Config): Input config. + + Returns: + cfg (mmcv.Config): Output config removing + `MultiRotateAugOCR` in test pipeline. + """ + if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR': + cfg.data.test.pipeline = [ + cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms + ] + + return cfg + + +def model_inference(model, imgs, batch_mode=False): """Inference image(s) with the detector. Args: model (nn.Module): The loaded detector. imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): Either image files or loaded images. - + batch_mode (bool): If True, use batch mode for inference. Returns: - result (dict): Detection results. + result (dict): Predicted results. """ if isinstance(imgs, (list, tuple)): @@ -31,7 +48,18 @@ def model_inference(model, imgs): raise AssertionError('imgs must be strings or numpy arrays') is_ndarray = isinstance(imgs[0], np.ndarray) + cfg = model.cfg + + if batch_mode: + if cfg.data.test.pipeline[1].type == 'ResizeOCR': + if cfg.data.test.pipeline[1].max_width is None: + raise Exception('Free resize do not support batch mode ' + 'since the image width is not fixed, ' + 'for resize keeping aspect ratio and ' + 'max_width is not give.') + cfg = disable_text_recog_aug_test(cfg) + device = next(model.parameters()).device # model device if is_ndarray: @@ -54,8 +82,18 @@ def model_inference(model, imgs): # build the data pipeline data = test_pipeline(data) + # get tensor from list to stack for batch mode (text detection) + if batch_mode: + if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug': + for key, value in data.items(): + data[key] = value[0] datas.append(data) + if isinstance(datas[0]['img'], list) and len(datas) > 1: + raise Exception('aug test does not support ' + f'inference with batch size ' + f'{len(datas)}') + data = collate(datas, samples_per_gpu=len(imgs)) # process img_metas @@ -67,7 +105,9 @@ def model_inference(model, imgs): data['img_metas'] = data['img_metas'].data if isinstance(data['img'], list): - data['img'] = [img for img in data['img'][0].data] + data['img'] = [img.data for img in data['img']] + if isinstance(data['img'][0], list): + data['img'] = [img[0] for img in data['img']] else: data['img'] = data['img'].data diff --git a/mmocr/models/textrecog/recognizer/base.py b/mmocr/models/textrecog/recognizer/base.py index 8e39feba..3870210b 100644 --- a/mmocr/models/textrecog/recognizer/base.py +++ b/mmocr/models/textrecog/recognizer/base.py @@ -75,10 +75,11 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta): The outer list indicates images in a batch. """ if isinstance(imgs, list): - assert len(imgs) == len(img_metas), ('aug test does not support ' - f'inference with batch size ' - f'{len(imgs)}') assert len(imgs) > 0 + assert imgs[0].size(0) == 1, ('aug test does not support ' + f'inference with batch size ' + f'{imgs[0].size(0)}') + assert len(imgs) == len(img_metas) return self.aug_test(imgs, img_metas, **kwargs) return self.simple_test(imgs, img_metas, **kwargs) diff --git a/tests/test_apis/test_model_inference.py b/tests/test_apis/test_model_inference.py index dbf7c363..d964b5ff 100644 --- a/tests/test_apis/test_model_inference.py +++ b/tests/test_apis/test_model_inference.py @@ -20,20 +20,9 @@ def build_model(config_file): return model -def disable_aug_test(model): - model.cfg.data.test.pipeline = [ - model.cfg.data.test.pipeline[0], - *model.cfg.data.test.pipeline[1].transforms - ] - - return model - - @pytest.mark.parametrize('cfg_file', [ '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py', '../configs/textrecog/crnn/crnn_academic_dataset.py', - '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py', - '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py', '../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py', '../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py' ]) @@ -53,45 +42,42 @@ def test_model_inference(cfg_file): model_inference(model, img) -@pytest.mark.parametrize('cfg_file', [ - '../configs/textrecog/crnn/crnn_academic_dataset.py', - '../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py' -]) -def test_model_batch_inference(cfg_file): +@pytest.mark.parametrize( + 'cfg_file', + ['../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py']) +def test_model_batch_inference_det(cfg_file): tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) config_file = os.path.join(tmp_dir, cfg_file) model = build_model(config_file) sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg') - results = model_inference(model, [sample_img_path, sample_img_path]) + results = model_inference(model, [sample_img_path], batch_mode=True) - assert len(results) == 2 + assert len(results) == 1 # numpy inference img = imread(sample_img_path) - results = model_inference(model, [img, img]) + results = model_inference(model, [img], batch_mode=True) - assert len(results) == 2 + assert len(results) == 1 @pytest.mark.parametrize('cfg_file', [ '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py', - '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py', - '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py', ]) -def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file): +def test_model_batch_inference_raises_exception_error_aug_test_recog(cfg_file): tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) config_file = os.path.join(tmp_dir, cfg_file) model = build_model(config_file) with pytest.raises( - AssertionError, + Exception, match='aug test does not support inference with batch size'): sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg') model_inference(model, [sample_img_path, sample_img_path]) with pytest.raises( - AssertionError, + Exception, match='aug test does not support inference with batch size'): img = imread(sample_img_path) model_inference(model, [img, img]) @@ -99,22 +85,48 @@ def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file): @pytest.mark.parametrize('cfg_file', [ '../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py', - '../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py', - '../configs/textrecog/robust_scanner/robustscanner_r31_academic.py', ]) def test_model_batch_inference_recog(cfg_file): tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) config_file = os.path.join(tmp_dir, cfg_file) model = build_model(config_file) - model = disable_aug_test(model) - sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg') - results = model_inference(model, [sample_img_path, sample_img_path]) + sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg') + results = model_inference( + model, [sample_img_path, sample_img_path], batch_mode=True) assert len(results) == 2 # numpy inference img = imread(sample_img_path) - results = model_inference(model, [img, img]) + results = model_inference(model, [img, img], batch_mode=True) assert len(results) == 2 + + +@pytest.mark.parametrize( + 'cfg_file', ['../configs/textrecog/crnn/crnn_academic_dataset.py']) +def test_model_batch_inference_raises_exception_error_free_resize_recog( + cfg_file): + tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + config_file = os.path.join(tmp_dir, cfg_file) + model = build_model(config_file) + + with pytest.raises( + Exception, + match='Free resize do not support batch mode ' + 'since the image width is not fixed, ' + 'for resize keeping aspect ratio and ' + 'max_width is not give.'): + sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg') + model_inference( + model, [sample_img_path, sample_img_path], batch_mode=True) + + with pytest.raises( + Exception, + match='Free resize do not support batch mode ' + 'since the image width is not fixed, ' + 'for resize keeping aspect ratio and ' + 'max_width is not give.'): + img = imread(sample_img_path) + model_inference(model, [img, img], batch_mode=True)