fix #173: support aug test (#178)

* fix #173: support aug test

* fix pytest

* support batch inference for both text det and recog

* update unittest

* use one img for batch infer test
pull/181/head
Hongbin Sun 2021-05-13 15:18:00 +08:00 committed by GitHub
parent 8e65bea9c3
commit 97e2f27017
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 132 additions and 50 deletions

View File

@ -18,9 +18,7 @@ def main():
'--images',
nargs='+',
help='Image files to be predicted with batch mode, '
'separated by space, like "image_1.jpg image2.jpg". '
'If algorithm use augmentation test, only one '
'image file can be given.')
'separated by space, like "image_1.jpg image2.jpg".')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
parser.add_argument(
@ -36,7 +34,7 @@ def main():
0].pipeline
# test multiple images
results = model_inference(model, args.images)
results = model_inference(model, args.images, batch_mode=True)
print(f'results: {results}')
save_path = Path(args.save_path)

View File

@ -17,6 +17,7 @@ def det_and_recog_inference(args, det_model, recog_model):
det_result = model_inference(det_model, image)
bboxes = det_result['boundary_result']
box_imgs = []
for bbox in bboxes:
box_res = {}
box_res['box'] = [round(x) for x in bbox[:-1]]
@ -29,17 +30,37 @@ def det_and_recog_inference(args, det_model, recog_model):
max_y = max(bbox[1:-1:2])
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
box_img = crop_img(image, box)
if args.batch_mode:
box_imgs.append(box_img)
else:
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
recog_result = model_inference(recog_model, box_img)
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
box_res['text'] = text
box_res['text_score'] = text_score
end2end_res['result'].append(box_res)
if args.batch_mode:
batch_size = args.batch_size
for chunk_idx in range(len(box_imgs) // batch_size + 1):
start_idx = chunk_idx * batch_size
end_idx = (chunk_idx + 1) * batch_size
chunk_box_imgs = box_imgs[start_idx:end_idx]
if len(chunk_box_imgs) == 0:
continue
recog_results = model_inference(
recog_model, chunk_box_imgs, batch_mode=True)
for i, recog_result in enumerate(recog_results):
text = recog_result['text']
text_score = recog_result['score']
if isinstance(text_score, list):
text_score = sum(text_score) / max(1, len(text))
end2end_res['result'][start_idx + i]['text'] = text
end2end_res['result'][start_idx + i]['text_score'] = text_score
return end2end_res
@ -74,6 +95,16 @@ def main():
'mmocr/textrecog/sar/'
'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
help='Text recognition checkpint file (local or url).')
parser.add_argument(
'--batch-mode',
action='store_true',
help='Whether use batch mode for text recognition.')
parser.add_argument(
'--batch-size',
type=int,
default=4,
help='Batch size for text recognition inference '
'if batch_mode is True above.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
parser.add_argument(

View File

@ -7,16 +7,33 @@ from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
def model_inference(model, imgs):
def disable_text_recog_aug_test(cfg):
"""Remove aug_test from test pipeline of text recognition.
Args:
cfg (mmcv.Config): Input config.
Returns:
cfg (mmcv.Config): Output config removing
`MultiRotateAugOCR` in test pipeline.
"""
if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR':
cfg.data.test.pipeline = [
cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms
]
return cfg
def model_inference(model, imgs, batch_mode=False):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
batch_mode (bool): If True, use batch mode for inference.
Returns:
result (dict): Detection results.
result (dict): Predicted results.
"""
if isinstance(imgs, (list, tuple)):
@ -31,7 +48,18 @@ def model_inference(model, imgs):
raise AssertionError('imgs must be strings or numpy arrays')
is_ndarray = isinstance(imgs[0], np.ndarray)
cfg = model.cfg
if batch_mode:
if cfg.data.test.pipeline[1].type == 'ResizeOCR':
if cfg.data.test.pipeline[1].max_width is None:
raise Exception('Free resize do not support batch mode '
'since the image width is not fixed, '
'for resize keeping aspect ratio and '
'max_width is not give.')
cfg = disable_text_recog_aug_test(cfg)
device = next(model.parameters()).device # model device
if is_ndarray:
@ -54,8 +82,18 @@ def model_inference(model, imgs):
# build the data pipeline
data = test_pipeline(data)
# get tensor from list to stack for batch mode (text detection)
if batch_mode:
if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug':
for key, value in data.items():
data[key] = value[0]
datas.append(data)
if isinstance(datas[0]['img'], list) and len(datas) > 1:
raise Exception('aug test does not support '
f'inference with batch size '
f'{len(datas)}')
data = collate(datas, samples_per_gpu=len(imgs))
# process img_metas
@ -67,7 +105,9 @@ def model_inference(model, imgs):
data['img_metas'] = data['img_metas'].data
if isinstance(data['img'], list):
data['img'] = [img for img in data['img'][0].data]
data['img'] = [img.data for img in data['img']]
if isinstance(data['img'][0], list):
data['img'] = [img[0] for img in data['img']]
else:
data['img'] = data['img'].data

View File

@ -75,10 +75,11 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
The outer list indicates images in a batch.
"""
if isinstance(imgs, list):
assert len(imgs) == len(img_metas), ('aug test does not support '
f'inference with batch size '
f'{len(imgs)}')
assert len(imgs) > 0
assert imgs[0].size(0) == 1, ('aug test does not support '
f'inference with batch size '
f'{imgs[0].size(0)}')
assert len(imgs) == len(img_metas)
return self.aug_test(imgs, img_metas, **kwargs)
return self.simple_test(imgs, img_metas, **kwargs)

View File

@ -20,20 +20,9 @@ def build_model(config_file):
return model
def disable_aug_test(model):
model.cfg.data.test.pipeline = [
model.cfg.data.test.pipeline[0],
*model.cfg.data.test.pipeline[1].transforms
]
return model
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
'../configs/textrecog/crnn/crnn_academic_dataset.py',
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
])
@ -53,45 +42,42 @@ def test_model_inference(cfg_file):
model_inference(model, img)
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/crnn/crnn_academic_dataset.py',
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py'
])
def test_model_batch_inference(cfg_file):
@pytest.mark.parametrize(
'cfg_file',
['../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'])
def test_model_batch_inference_det(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
model = build_model(config_file)
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
results = model_inference(model, [sample_img_path, sample_img_path])
results = model_inference(model, [sample_img_path], batch_mode=True)
assert len(results) == 2
assert len(results) == 1
# numpy inference
img = imread(sample_img_path)
results = model_inference(model, [img, img])
results = model_inference(model, [img], batch_mode=True)
assert len(results) == 2
assert len(results) == 1
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
])
def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file):
def test_model_batch_inference_raises_exception_error_aug_test_recog(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
model = build_model(config_file)
with pytest.raises(
AssertionError,
Exception,
match='aug test does not support inference with batch size'):
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
model_inference(model, [sample_img_path, sample_img_path])
with pytest.raises(
AssertionError,
Exception,
match='aug test does not support inference with batch size'):
img = imread(sample_img_path)
model_inference(model, [img, img])
@ -99,22 +85,48 @@ def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file):
@pytest.mark.parametrize('cfg_file', [
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
])
def test_model_batch_inference_recog(cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
model = build_model(config_file)
model = disable_aug_test(model)
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
results = model_inference(model, [sample_img_path, sample_img_path])
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
results = model_inference(
model, [sample_img_path, sample_img_path], batch_mode=True)
assert len(results) == 2
# numpy inference
img = imread(sample_img_path)
results = model_inference(model, [img, img])
results = model_inference(model, [img, img], batch_mode=True)
assert len(results) == 2
@pytest.mark.parametrize(
'cfg_file', ['../configs/textrecog/crnn/crnn_academic_dataset.py'])
def test_model_batch_inference_raises_exception_error_free_resize_recog(
cfg_file):
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
config_file = os.path.join(tmp_dir, cfg_file)
model = build_model(config_file)
with pytest.raises(
Exception,
match='Free resize do not support batch mode '
'since the image width is not fixed, '
'for resize keeping aspect ratio and '
'max_width is not give.'):
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
model_inference(
model, [sample_img_path, sample_img_path], batch_mode=True)
with pytest.raises(
Exception,
match='Free resize do not support batch mode '
'since the image width is not fixed, '
'for resize keeping aspect ratio and '
'max_width is not give.'):
img = imread(sample_img_path)
model_inference(model, [img, img], batch_mode=True)