mirror of https://github.com/open-mmlab/mmocr.git
* fix #173: support aug test * fix pytest * support batch inference for both text det and recog * update unittest * use one img for batch infer testpull/181/head
parent
8e65bea9c3
commit
97e2f27017
|
@ -18,9 +18,7 @@ def main():
|
|||
'--images',
|
||||
nargs='+',
|
||||
help='Image files to be predicted with batch mode, '
|
||||
'separated by space, like "image_1.jpg image2.jpg". '
|
||||
'If algorithm use augmentation test, only one '
|
||||
'image file can be given.')
|
||||
'separated by space, like "image_1.jpg image2.jpg".')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
|
@ -36,7 +34,7 @@ def main():
|
|||
0].pipeline
|
||||
|
||||
# test multiple images
|
||||
results = model_inference(model, args.images)
|
||||
results = model_inference(model, args.images, batch_mode=True)
|
||||
print(f'results: {results}')
|
||||
|
||||
save_path = Path(args.save_path)
|
||||
|
|
|
@ -17,6 +17,7 @@ def det_and_recog_inference(args, det_model, recog_model):
|
|||
det_result = model_inference(det_model, image)
|
||||
bboxes = det_result['boundary_result']
|
||||
|
||||
box_imgs = []
|
||||
for bbox in bboxes:
|
||||
box_res = {}
|
||||
box_res['box'] = [round(x) for x in bbox[:-1]]
|
||||
|
@ -29,17 +30,37 @@ def det_and_recog_inference(args, det_model, recog_model):
|
|||
max_y = max(bbox[1:-1:2])
|
||||
box = [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]
|
||||
box_img = crop_img(image, box)
|
||||
if args.batch_mode:
|
||||
box_imgs.append(box_img)
|
||||
else:
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
box_res['text'] = text
|
||||
box_res['text_score'] = text_score
|
||||
|
||||
recog_result = model_inference(recog_model, box_img)
|
||||
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
box_res['text'] = text
|
||||
box_res['text_score'] = text_score
|
||||
end2end_res['result'].append(box_res)
|
||||
|
||||
if args.batch_mode:
|
||||
batch_size = args.batch_size
|
||||
for chunk_idx in range(len(box_imgs) // batch_size + 1):
|
||||
start_idx = chunk_idx * batch_size
|
||||
end_idx = (chunk_idx + 1) * batch_size
|
||||
chunk_box_imgs = box_imgs[start_idx:end_idx]
|
||||
if len(chunk_box_imgs) == 0:
|
||||
continue
|
||||
recog_results = model_inference(
|
||||
recog_model, chunk_box_imgs, batch_mode=True)
|
||||
for i, recog_result in enumerate(recog_results):
|
||||
text = recog_result['text']
|
||||
text_score = recog_result['score']
|
||||
if isinstance(text_score, list):
|
||||
text_score = sum(text_score) / max(1, len(text))
|
||||
end2end_res['result'][start_idx + i]['text'] = text
|
||||
end2end_res['result'][start_idx + i]['text_score'] = text_score
|
||||
|
||||
return end2end_res
|
||||
|
||||
|
||||
|
@ -74,6 +95,16 @@ def main():
|
|||
'mmocr/textrecog/sar/'
|
||||
'sar_r31_parallel_decoder_academic-dba3a4a3.pth',
|
||||
help='Text recognition checkpint file (local or url).')
|
||||
parser.add_argument(
|
||||
'--batch-mode',
|
||||
action='store_true',
|
||||
help='Whether use batch mode for text recognition.')
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Batch size for text recognition inference '
|
||||
'if batch_mode is True above.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference.')
|
||||
parser.add_argument(
|
||||
|
|
|
@ -7,16 +7,33 @@ from mmdet.datasets import replace_ImageToTensor
|
|||
from mmdet.datasets.pipelines import Compose
|
||||
|
||||
|
||||
def model_inference(model, imgs):
|
||||
def disable_text_recog_aug_test(cfg):
|
||||
"""Remove aug_test from test pipeline of text recognition.
|
||||
Args:
|
||||
cfg (mmcv.Config): Input config.
|
||||
|
||||
Returns:
|
||||
cfg (mmcv.Config): Output config removing
|
||||
`MultiRotateAugOCR` in test pipeline.
|
||||
"""
|
||||
if cfg.data.test.pipeline[1].type == 'MultiRotateAugOCR':
|
||||
cfg.data.test.pipeline = [
|
||||
cfg.data.test.pipeline[0], *cfg.data.test.pipeline[1].transforms
|
||||
]
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
def model_inference(model, imgs, batch_mode=False):
|
||||
"""Inference image(s) with the detector.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The loaded detector.
|
||||
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
||||
Either image files or loaded images.
|
||||
|
||||
batch_mode (bool): If True, use batch mode for inference.
|
||||
Returns:
|
||||
result (dict): Detection results.
|
||||
result (dict): Predicted results.
|
||||
"""
|
||||
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
|
@ -31,7 +48,18 @@ def model_inference(model, imgs):
|
|||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
|
||||
is_ndarray = isinstance(imgs[0], np.ndarray)
|
||||
|
||||
cfg = model.cfg
|
||||
|
||||
if batch_mode:
|
||||
if cfg.data.test.pipeline[1].type == 'ResizeOCR':
|
||||
if cfg.data.test.pipeline[1].max_width is None:
|
||||
raise Exception('Free resize do not support batch mode '
|
||||
'since the image width is not fixed, '
|
||||
'for resize keeping aspect ratio and '
|
||||
'max_width is not give.')
|
||||
cfg = disable_text_recog_aug_test(cfg)
|
||||
|
||||
device = next(model.parameters()).device # model device
|
||||
|
||||
if is_ndarray:
|
||||
|
@ -54,8 +82,18 @@ def model_inference(model, imgs):
|
|||
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
# get tensor from list to stack for batch mode (text detection)
|
||||
if batch_mode:
|
||||
if cfg.data.test.pipeline[1].type == 'MultiScaleFlipAug':
|
||||
for key, value in data.items():
|
||||
data[key] = value[0]
|
||||
datas.append(data)
|
||||
|
||||
if isinstance(datas[0]['img'], list) and len(datas) > 1:
|
||||
raise Exception('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(datas)}')
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
|
||||
# process img_metas
|
||||
|
@ -67,7 +105,9 @@ def model_inference(model, imgs):
|
|||
data['img_metas'] = data['img_metas'].data
|
||||
|
||||
if isinstance(data['img'], list):
|
||||
data['img'] = [img for img in data['img'][0].data]
|
||||
data['img'] = [img.data for img in data['img']]
|
||||
if isinstance(data['img'][0], list):
|
||||
data['img'] = [img[0] for img in data['img']]
|
||||
else:
|
||||
data['img'] = data['img'].data
|
||||
|
||||
|
|
|
@ -75,10 +75,11 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
|
|||
The outer list indicates images in a batch.
|
||||
"""
|
||||
if isinstance(imgs, list):
|
||||
assert len(imgs) == len(img_metas), ('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{len(imgs)}')
|
||||
assert len(imgs) > 0
|
||||
assert imgs[0].size(0) == 1, ('aug test does not support '
|
||||
f'inference with batch size '
|
||||
f'{imgs[0].size(0)}')
|
||||
assert len(imgs) == len(img_metas)
|
||||
return self.aug_test(imgs, img_metas, **kwargs)
|
||||
|
||||
return self.simple_test(imgs, img_metas, **kwargs)
|
||||
|
|
|
@ -20,20 +20,9 @@ def build_model(config_file):
|
|||
return model
|
||||
|
||||
|
||||
def disable_aug_test(model):
|
||||
model.cfg.data.test.pipeline = [
|
||||
model.cfg.data.test.pipeline[0],
|
||||
*model.cfg.data.test.pipeline[1].transforms
|
||||
]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'../configs/textrecog/crnn/crnn_academic_dataset.py',
|
||||
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
|
||||
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
|
||||
])
|
||||
|
@ -53,45 +42,42 @@ def test_model_inference(cfg_file):
|
|||
model_inference(model, img)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/crnn/crnn_academic_dataset.py',
|
||||
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py'
|
||||
])
|
||||
def test_model_batch_inference(cfg_file):
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file',
|
||||
['../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'])
|
||||
def test_model_batch_inference_det(cfg_file):
|
||||
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_file = os.path.join(tmp_dir, cfg_file)
|
||||
model = build_model(config_file)
|
||||
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
||||
results = model_inference(model, [sample_img_path, sample_img_path])
|
||||
results = model_inference(model, [sample_img_path], batch_mode=True)
|
||||
|
||||
assert len(results) == 2
|
||||
assert len(results) == 1
|
||||
|
||||
# numpy inference
|
||||
img = imread(sample_img_path)
|
||||
results = model_inference(model, [img, img])
|
||||
results = model_inference(model, [img], batch_mode=True)
|
||||
|
||||
assert len(results) == 2
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||
])
|
||||
def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file):
|
||||
def test_model_batch_inference_raises_exception_error_aug_test_recog(cfg_file):
|
||||
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_file = os.path.join(tmp_dir, cfg_file)
|
||||
model = build_model(config_file)
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
Exception,
|
||||
match='aug test does not support inference with batch size'):
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
||||
model_inference(model, [sample_img_path, sample_img_path])
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
Exception,
|
||||
match='aug test does not support inference with batch size'):
|
||||
img = imread(sample_img_path)
|
||||
model_inference(model, [img, img])
|
||||
|
@ -99,22 +85,48 @@ def test_model_batch_inference_raises_assertion_error_if_unsupported(cfg_file):
|
|||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
||||
'../configs/textrecog/nrtr/nrtr_r31_1by16_1by8_academic.py',
|
||||
'../configs/textrecog/robust_scanner/robustscanner_r31_academic.py',
|
||||
])
|
||||
def test_model_batch_inference_recog(cfg_file):
|
||||
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_file = os.path.join(tmp_dir, cfg_file)
|
||||
model = build_model(config_file)
|
||||
model = disable_aug_test(model)
|
||||
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
||||
results = model_inference(model, [sample_img_path, sample_img_path])
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
|
||||
results = model_inference(
|
||||
model, [sample_img_path, sample_img_path], batch_mode=True)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# numpy inference
|
||||
img = imread(sample_img_path)
|
||||
results = model_inference(model, [img, img])
|
||||
results = model_inference(model, [img, img], batch_mode=True)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'cfg_file', ['../configs/textrecog/crnn/crnn_academic_dataset.py'])
|
||||
def test_model_batch_inference_raises_exception_error_free_resize_recog(
|
||||
cfg_file):
|
||||
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
config_file = os.path.join(tmp_dir, cfg_file)
|
||||
model = build_model(config_file)
|
||||
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match='Free resize do not support batch mode '
|
||||
'since the image width is not fixed, '
|
||||
'for resize keeping aspect ratio and '
|
||||
'max_width is not give.'):
|
||||
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
|
||||
model_inference(
|
||||
model, [sample_img_path, sample_img_path], batch_mode=True)
|
||||
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match='Free resize do not support batch mode '
|
||||
'since the image width is not fixed, '
|
||||
'for resize keeping aspect ratio and '
|
||||
'max_width is not give.'):
|
||||
img = imread(sample_img_path)
|
||||
model_inference(model, [img, img], batch_mode=True)
|
||||
|
|
Loading…
Reference in New Issue