mmdeploy/tests/test_utils/test_util.py

391 lines
12 KiB
Python
Raw Normal View History

2021-11-30 15:00:37 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import mmcv
import pytest
import mmdeploy.utils as util
from mmdeploy.utils.constants import Backend, Codebase, Task
from mmdeploy.utils.export_info import dump_info
from mmdeploy.utils.test import get_random_name
correct_model_path = 'tests/data/srgan.py'
correct_model_cfg = mmcv.Config.fromfile(correct_model_path)
correct_deploy_path = 'tests/data/super-resolution.py'
correct_deploy_cfg = mmcv.Config.fromfile(correct_deploy_path)
empty_file_path = tempfile.NamedTemporaryFile(suffix='.py').name
empty_path = './a.py'
@pytest.fixture(autouse=True, scope='module')
def create_empty_file():
os.mknod(empty_file_path)
class TestLoadConfigError:
def test_load_config_none(self):
with pytest.raises(AssertionError):
util.load_config()
def test_load_config_type_error(self):
with pytest.raises(TypeError):
util.load_config(1)
def test_load_config_file_error(self):
with pytest.raises(FileNotFoundError):
util.load_config(empty_path)
class TestLoadConfig:
@pytest.mark.parametrize('args', [
[empty_file_path],
[correct_model_path],
[correct_model_cfg],
(correct_model_path, correct_deploy_path),
(correct_model_path, correct_deploy_cfg),
(correct_model_cfg, correct_deploy_cfg),
])
def test_load_config(self, args):
configs = util.load_config(*args)
for v in zip(configs, args):
if isinstance(v[1], str):
cfg = mmcv.Config.fromfile(v[1])
else:
cfg = v[1]
assert v[0]._cfg_dict == cfg._cfg_dict
class TestGetCodebaseConfig:
def test_get_codebase_config_empty(self):
assert util.get_codebase_config(mmcv.Config(dict())) == {}
def test_get_codebase_config(self):
codebase_config = util.get_codebase_config(correct_deploy_path)
assert isinstance(codebase_config, dict) and len(codebase_config) > 1
class TestGetTaskType:
def test_get_task_type_none(self):
with pytest.raises(AssertionError):
util.get_task_type(mmcv.Config(dict()))
def test_get_task_type(self):
assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION
class TestGetCodebase:
def test_get_codebase_none(self):
with pytest.raises(AssertionError):
util.get_codebase(mmcv.Config(dict()))
def test_get_codebase(self):
assert util.get_codebase(correct_deploy_path) == Codebase.MMEDIT
class TestGetBackendConfig:
def test_get_backend_config_empty(self):
assert util.get_backend_config(mmcv.Config(dict())) == {}
def test_get_backend_config(self):
backend_config = util.get_backend_config(correct_deploy_path)
assert isinstance(backend_config, dict) and len(backend_config) == 1
class TestGetBackend:
def test_get_backend_none(self):
with pytest.raises(AssertionError):
util.get_backend(mmcv.Config(dict()))
def test_get_backend(self):
assert util.get_backend(correct_deploy_path) == Backend.ONNXRUNTIME
class TestGetOnnxConfig:
def test_get_onnx_config_empty(self):
assert util.get_onnx_config(mmcv.Config(dict())) == {}
def test_get_onnx_config(self):
onnx_config = dict(
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch',
2: 'height',
3: 'width'
}
},
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='end2end.onnx',
input_names=['input'],
output_names=['output'],
input_shape=None)
assert util.get_onnx_config(correct_deploy_path) == onnx_config
class TestIsDynamic:
config_with_onnx_config = mmcv.Config(dict(onnx_config=dict()))
config_with_dynamic_axes = mmcv.Config(
dict(
onnx_config=dict(
type='onnx',
dynamic_axes={'input': {
0: 'batch',
2: 'height',
3: 'width'
}})))
config_with_dynamic_axes_and_input_names = mmcv.Config(
dict(
onnx_config=dict(
type='onnx',
input_names=['image'],
dynamic_axes={'image': {
0: 'batch',
2: 'height',
3: 'width'
}})))
config_with_dynamic_axes_list = mmcv.Config(
dict(
onnx_config=dict(
type='onnx', input_names=['image'], dynamic_axes=[[0, 2, 3]])))
def test_is_dynamic_batch_none(self):
assert util.is_dynamic_batch(
TestIsDynamic.config_with_onnx_config) is False
def test_is_dynamic_batch_error_name(self):
assert util.is_dynamic_batch(TestIsDynamic.config_with_dynamic_axes,
'output') is False
def test_is_dynamic_batch(self):
assert util.is_dynamic_batch(
TestIsDynamic.config_with_dynamic_axes) is True
def test_is_dynamic_batch_axes_list(self):
assert util.is_dynamic_batch(
TestIsDynamic.config_with_dynamic_axes_list) is True
def test_is_dynamic_shape_none(self):
assert util.is_dynamic_shape(
TestIsDynamic.config_with_onnx_config) is False
def test_is_dynamic_shape_error_name(self):
assert util.is_dynamic_shape(TestIsDynamic.config_with_dynamic_axes,
'output') is False
def test_is_dynamic_shape(self):
assert util.is_dynamic_shape(
TestIsDynamic.config_with_dynamic_axes) is True
def test_is_dynamic_shape_input_names(self):
assert util.is_dynamic_shape(
TestIsDynamic.config_with_dynamic_axes_and_input_names) is True
def test_is_dynamic_shape_different_names(self):
config_with_different_names = \
TestIsDynamic.config_with_dynamic_axes_and_input_names
util.get_ir_config(
config_with_different_names).input_names = 'another_name'
assert util.is_dynamic_shape(config_with_different_names) is False
def test_is_dynamic_shape_axes_list(self):
assert util.is_dynamic_shape(
TestIsDynamic.config_with_dynamic_axes_list) is True
class TestGetInputShape:
config_without_input_shape = mmcv.Config(
dict(onnx_config=dict(input_shape=None)))
config_with_input_shape = mmcv.Config(
dict(onnx_config=dict(input_shape=[1, 1])))
config_with_error_shape = mmcv.Config(
dict(onnx_config=dict(input_shape=[1, 1, 1])))
def test_get_input_shape_none(self):
assert util.get_input_shape(
TestGetInputShape.config_without_input_shape) is None
def test_get_input_shape_error(self):
with pytest.raises(Exception):
util.get_input_shape(TestGetInputShape.config_with_error_shape)
def test_get_input_shape(self):
assert util.get_input_shape(
TestGetInputShape.config_with_input_shape) == [1, 1]
class TestCfgApplyMark:
config_with_mask = mmcv.Config(
dict(partition_config=dict(apply_marks=True)))
def test_cfg_apply_marks_none(self):
assert util.cfg_apply_marks(mmcv.Config(dict())) is None
def test_cfg_apply_marks(self):
assert util.cfg_apply_marks(TestCfgApplyMark.config_with_mask) is True
class TestGetPartitionConfig:
config_with_mask = mmcv.Config(
dict(partition_config=dict(apply_marks=True)))
config_without_mask = mmcv.Config(
dict(partition_config=dict(apply_marks=False)))
def test_get_partition_config_none(self):
assert util.get_partition_config(mmcv.Config(dict())) is None
def test_get_partition_config_without_mask(self):
assert util.get_partition_config(
TestGetPartitionConfig.config_without_mask) is None
def test_get_partition_config(self):
assert util.get_partition_config(
TestGetPartitionConfig.config_with_mask) == dict(apply_marks=True)
class TestGetCalib:
config_with_calib = mmcv.Config(
dict(calib_config=dict(create_calib=True, calib_file='calib_data.h5')))
config_without_calib = mmcv.Config(
dict(
calib_config=dict(create_calib=False, calib_file='calib_data.h5')))
def test_get_calib_config(self):
assert util.get_calib_config(TestGetCalib.config_with_calib) == dict(
create_calib=True, calib_file='calib_data.h5')
def test_get_calib_filename_none(self):
assert util.get_calib_filename(mmcv.Config(dict())) is None
def test_get_calib_filename_false(self):
assert util.get_calib_filename(
TestGetCalib.config_without_calib) is None
def test_get_calib_filename(self):
assert util.get_calib_filename(
TestGetCalib.config_with_calib) == 'calib_data.h5'
class TestGetCommonConfig:
config_with_common_config = mmcv.Config(
dict(
backend_config=dict(
type='tensorrt', common_config=dict(fp16_mode=False))))
def test_get_common_config(self):
assert util.get_common_config(
TestGetCommonConfig.config_with_common_config) == dict(
fp16_mode=False)
class TestGetModelInputs:
config_with_model_inputs = mmcv.Config(
dict(backend_config=dict(model_inputs=[dict(input_shapes=None)])))
def test_model_inputs(self):
assert util.get_model_inputs(
TestGetModelInputs.config_with_model_inputs) == [
dict(input_shapes=None)
]
class TestGetDynamicAxes:
input_name = get_random_name()
def test_with_empty_cfg(self):
deploy_cfg = mmcv.Config()
with pytest.raises(KeyError):
util.get_dynamic_axes(deploy_cfg)
def test_can_get_axes_from_dict(self):
expected_dynamic_axes = {
self.input_name: {
0: 'batch',
2: 'height',
3: 'width'
}
}
deploy_cfg = mmcv.Config(
dict(onnx_config=dict(dynamic_axes=expected_dynamic_axes)))
dynamic_axes = util.get_dynamic_axes(deploy_cfg)
assert expected_dynamic_axes == dynamic_axes
def test_can_not_get_axes_from_list_without_names(self):
axes = [[0, 2, 3]]
deploy_cfg = mmcv.Config(dict(onnx_config=dict(dynamic_axes=axes)))
with pytest.raises(KeyError):
util.get_dynamic_axes(deploy_cfg)
def test_can_get_axes_from_list_with_args(self):
axes = [[0, 2, 3]]
expected_dynamic_axes = {self.input_name: axes[0]}
axes_names = [self.input_name]
deploy_cfg = mmcv.Config(dict(onnx_config=dict(dynamic_axes=axes)))
dynamic_axes = util.get_dynamic_axes(deploy_cfg, axes_names)
assert expected_dynamic_axes == dynamic_axes
def test_can_get_axes_from_list_with_cfg(self):
output_name = get_random_name()
axes = [[0, 2, 3], [0]]
expected_dynamic_axes = {
self.input_name: axes[0],
output_name: axes[1]
}
deploy_cfg = mmcv.Config(
dict(
onnx_config=dict(
input_names=[self.input_name],
output_names=[output_name],
dynamic_axes=axes)))
dynamic_axes = util.get_dynamic_axes(deploy_cfg)
assert expected_dynamic_axes == dynamic_axes
def test_AdvancedEnum():
keys = [
Task.TEXT_DETECTION, Task.TEXT_RECOGNITION, Task.SEGMENTATION,
Task.SUPER_RESOLUTION, Task.CLASSIFICATION, Task.OBJECT_DETECTION
]
vals = [
'TextDetection', 'TextRecognition', 'Segmentation', 'SuperResolution',
'Classification', 'ObjectDetection'
]
for k, v in zip(keys, vals):
assert Task.get(v) == k
assert k.value == v
def test_export_info():
with tempfile.TemporaryDirectory() as dir:
dump_info(correct_deploy_cfg, correct_model_cfg, dir, '')
deploy_json = os.path.join(dir, 'deploy.json')
pipeline_json = os.path.join(dir, 'pipeline.json')
detail_json = os.path.join(dir, 'detail.json')
assert os.path.exists(pipeline_json)
assert os.path.exists(detail_json)
assert os.path.exists(deploy_json)