mmocr/tests/test_apis/test_model_inference.py

65 lines
1.8 KiB
Python
Raw Normal View History

2021-04-06 16:04:14 +08:00
import os
import shutil
import urllib
import pytest
from mmcv.image import imread
2021-04-06 16:04:14 +08:00
from mmdet.apis import init_detector
from mmocr.apis.inference import model_inference
@pytest.fixture
def project_dir():
return os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
2021-04-06 16:04:14 +08:00
@pytest.fixture
def sample_img_path(project_dir):
return os.path.join(project_dir, '../demo/demo_text_recog.jpg')
@pytest.fixture
def sarnet_model(project_dir):
2021-04-06 16:04:14 +08:00
print(project_dir)
config_file = os.path.join(
project_dir,
'../configs/textrecog/sar/sar_r31_parallel_decoder_academic.py')
checkpoint_file = os.path.join(
project_dir,
'../checkpoints/sar_r31_parallel_decoder_academic-dba3a4a3.pth')
if not os.path.exists(checkpoint_file):
url = ('https://download.openmmlab.com/mmocr'
'/textrecog/sar/'
'sar_r31_parallel_decoder_academic-dba3a4a3.pth')
print(f'Downloading {url} ...')
local_filename, _ = urllib.request.urlretrieve(url)
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
shutil.move(local_filename, checkpoint_file)
print(f'Saved as {checkpoint_file}')
else:
print(f'Using existing checkpoint {checkpoint_file}')
device = 'cpu'
2021-04-06 16:04:14 +08:00
model = init_detector(
config_file, checkpoint=checkpoint_file, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
2021-04-06 16:04:14 +08:00
return model
def test_model_inference_image_path(sample_img_path, sarnet_model):
2021-04-06 16:04:14 +08:00
with pytest.raises(AssertionError):
model_inference(sarnet_model, 1)
model_inference(sarnet_model, sample_img_path)
2021-04-06 16:04:14 +08:00
def test_model_inference_numpy_ndarray(sample_img_path, sarnet_model):
img = imread(sample_img_path)
model_inference(sarnet_model, img)