mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* add unittests * lint * modify .gitignore, remove useless files * remove emtpy.py and generate it when use * Update according to comments 1. Use tempfile 2. Delete inference test (which will be tested in each codebase) 3. Refine calibration test * update annotation * Add export_info * Reduce data scale, fix assert * update json blank line * add backend check
135 lines
3.9 KiB
Python
135 lines
3.9 KiB
Python
import os
|
|
import tempfile
|
|
|
|
import mmcv
|
|
import pytest
|
|
|
|
from mmdeploy.utils import get_onnx_config, get_task_type, load_config
|
|
from mmdeploy.utils.constants import Task
|
|
from mmdeploy.utils.export_info import dump_info
|
|
|
|
correct_model_path = 'tests/data/srgan.py'
|
|
correct_model_cfg = mmcv.Config.fromfile(correct_model_path)
|
|
correct_deploy_path = 'tests/data/super-resolution.py'
|
|
correct_deploy_cfg = mmcv.Config.fromfile(correct_deploy_path)
|
|
empty_file_path = tempfile.NamedTemporaryFile(suffix='.py').name
|
|
empty_path = './a.py'
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope='module')
|
|
def create_empty_file():
|
|
os.mknod(empty_file_path)
|
|
|
|
|
|
def test_load_config_none():
|
|
with pytest.raises(AssertionError):
|
|
load_config()
|
|
|
|
|
|
def test_load_config_type_error():
|
|
with pytest.raises(TypeError):
|
|
load_config(1)
|
|
|
|
|
|
def test_load_config_file_error():
|
|
with pytest.raises(FileNotFoundError):
|
|
load_config(empty_path)
|
|
|
|
|
|
@pytest.mark.parametrize('args', [
|
|
[empty_file_path],
|
|
[correct_model_path],
|
|
[correct_model_cfg],
|
|
(correct_model_path, correct_deploy_path),
|
|
(correct_model_path, correct_deploy_cfg),
|
|
(correct_model_cfg, correct_deploy_cfg),
|
|
])
|
|
def test_load_config(args):
|
|
configs = load_config(*args)
|
|
for v in zip(configs, args):
|
|
if isinstance(v[1], str):
|
|
cfg = mmcv.Config.fromfile(v[1])
|
|
else:
|
|
cfg = v[1]
|
|
assert v[0]._cfg_dict == cfg._cfg_dict
|
|
|
|
|
|
@pytest.mark.parametrize('deploy_cfg, default',
|
|
[(empty_file_path, None),
|
|
(empty_file_path, Task.SUPER_RESOLUTION)])
|
|
def test_get_task_type_default(deploy_cfg, default):
|
|
if default is None:
|
|
res = get_task_type(deploy_cfg)
|
|
else:
|
|
res = get_task_type(deploy_cfg, default)
|
|
assert res == default
|
|
|
|
|
|
@pytest.mark.parametrize('deploy_cfg, default',
|
|
[(correct_deploy_path, None),
|
|
(correct_deploy_path, Task.TEXT_DETECTION),
|
|
(correct_deploy_cfg, None)])
|
|
def test_get_task_type(deploy_cfg, default):
|
|
if default is None:
|
|
res = get_task_type(deploy_cfg)
|
|
else:
|
|
res = get_task_type(deploy_cfg, default)
|
|
assert res == Task.SUPER_RESOLUTION
|
|
|
|
|
|
def test_get_onnx_config_error():
|
|
with pytest.raises(Exception):
|
|
get_onnx_config(empty_file_path)
|
|
|
|
|
|
@pytest.mark.parametrize('deploy_cfg',
|
|
[correct_deploy_path, correct_deploy_cfg])
|
|
def test_get_onnx_config(deploy_cfg):
|
|
onnx_config = dict(
|
|
dynamic_axes={
|
|
'input': {
|
|
0: 'batch',
|
|
2: 'height',
|
|
3: 'width'
|
|
},
|
|
'output': {
|
|
0: 'batch',
|
|
2: 'height',
|
|
3: 'width'
|
|
}
|
|
},
|
|
type='onnx',
|
|
export_params=True,
|
|
keep_initializers_as_inputs=False,
|
|
opset_version=11,
|
|
save_file='end2end.onnx',
|
|
input_names=['input'],
|
|
output_names=['output'],
|
|
input_shape=None)
|
|
res = get_onnx_config(deploy_cfg)
|
|
assert res == onnx_config
|
|
|
|
|
|
def test_AdvancedEnum():
|
|
keys = [
|
|
Task.TEXT_DETECTION, Task.TEXT_RECOGNITION, Task.SEGMENTATION,
|
|
Task.SUPER_RESOLUTION, Task.CLASSIFICATION, Task.OBJECT_DETECTION
|
|
]
|
|
vals = [
|
|
'TextDetection', 'TextRecognition', 'Segmentation', 'SuperResolution',
|
|
'Classification', 'ObjectDetection'
|
|
]
|
|
for k, v in zip(keys, vals):
|
|
assert Task.get(v, None) == k
|
|
assert k.value == v
|
|
assert Task.get('a', Task.TEXT_DETECTION) == Task.TEXT_DETECTION
|
|
|
|
|
|
def test_export_info():
|
|
with tempfile.TemporaryDirectory() as dir:
|
|
dump_info(correct_deploy_cfg, correct_model_cfg, dir)
|
|
preprocess_json = os.path.join(dir, 'preprocess.json')
|
|
deploy_json = os.path.join(dir, 'deploy_cfg.json')
|
|
assert os.path.exists(preprocess_json)
|
|
assert os.path.exists(deploy_json)
|