mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
* [Refactor] Use MMOCR's registry 1. Define MMOCR's registries as a child of MMDet's 2. Register all models to MMOCR's own registries 3. Modify some model configs so that some models in MMDet can be correctly located 4. Remove some outdated demo scripts * add detectors
105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
|
|
import pytest
|
|
from mmcv.image import imread
|
|
|
|
from mmocr.apis.inference import init_detector, model_inference
|
|
from mmocr.datasets import build_dataset # noqa: F401
|
|
from mmocr.models import build_detector # noqa: F401
|
|
|
|
|
|
def build_model(config_file):
|
|
device = 'cpu'
|
|
model = init_detector(config_file, checkpoint=None, device=device)
|
|
|
|
if model.cfg.data.test['type'] == 'ConcatDataset':
|
|
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
|
|
0].pipeline
|
|
|
|
return model
|
|
|
|
|
|
@pytest.mark.parametrize('cfg_file', [
|
|
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
|
'../configs/textrecog/crnn/crnn_academic_dataset.py',
|
|
'../configs/textrecog/seg/seg_r31_1by16_fpnocr_academic.py',
|
|
'../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'
|
|
])
|
|
def test_model_inference(cfg_file):
|
|
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
|
config_file = os.path.join(tmp_dir, cfg_file)
|
|
model = build_model(config_file)
|
|
with pytest.raises(AssertionError):
|
|
model_inference(model, 1)
|
|
|
|
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
|
model_inference(model, sample_img_path)
|
|
|
|
# numpy inference
|
|
img = imread(sample_img_path)
|
|
|
|
model_inference(model, img)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'cfg_file',
|
|
['../configs/textdet/psenet/psenet_r50_fpnf_600e_icdar2017.py'])
|
|
def test_model_batch_inference_det(cfg_file):
|
|
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
|
config_file = os.path.join(tmp_dir, cfg_file)
|
|
model = build_model(config_file)
|
|
|
|
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
|
results = model_inference(model, [sample_img_path], batch_mode=True)
|
|
|
|
assert len(results) == 1
|
|
|
|
# numpy inference
|
|
img = imread(sample_img_path)
|
|
results = model_inference(model, [img], batch_mode=True)
|
|
|
|
assert len(results) == 1
|
|
|
|
|
|
@pytest.mark.parametrize('cfg_file', [
|
|
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
|
])
|
|
def test_model_batch_inference_raises_exception_error_aug_test_recog(cfg_file):
|
|
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
|
config_file = os.path.join(tmp_dir, cfg_file)
|
|
model = build_model(config_file)
|
|
|
|
with pytest.raises(
|
|
Exception,
|
|
match='aug test does not support inference with batch size'):
|
|
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_det.jpg')
|
|
model_inference(model, [sample_img_path, sample_img_path])
|
|
|
|
with pytest.raises(
|
|
Exception,
|
|
match='aug test does not support inference with batch size'):
|
|
img = imread(sample_img_path)
|
|
model_inference(model, [img, img])
|
|
|
|
|
|
@pytest.mark.parametrize('cfg_file', [
|
|
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py',
|
|
])
|
|
def test_model_batch_inference_recog(cfg_file):
|
|
tmp_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
|
config_file = os.path.join(tmp_dir, cfg_file)
|
|
model = build_model(config_file)
|
|
|
|
sample_img_path = os.path.join(tmp_dir, '../demo/demo_text_recog.jpg')
|
|
results = model_inference(
|
|
model, [sample_img_path, sample_img_path], batch_mode=True)
|
|
|
|
assert len(results) == 2
|
|
|
|
# numpy inference
|
|
img = imread(sample_img_path)
|
|
results = model_inference(model, [img, img], batch_mode=True)
|
|
|
|
assert len(results) == 2
|