# Copyright (c) OpenMMLab. All rights reserved. import os import tempfile import mmcv import pytest import mmdeploy.utils as util from mmdeploy.utils.constants import Backend, Codebase, Task from mmdeploy.utils.export_info import dump_info from mmdeploy.utils.test import get_random_name correct_model_path = 'tests/data/srgan.py' correct_model_cfg = mmcv.Config.fromfile(correct_model_path) correct_deploy_path = 'tests/data/super-resolution.py' correct_deploy_cfg = mmcv.Config.fromfile(correct_deploy_path) empty_file_path = tempfile.NamedTemporaryFile(suffix='.py').name empty_path = './a.py' @pytest.fixture(autouse=True, scope='module') def create_empty_file(): os.mknod(empty_file_path) class TestLoadConfigError: def test_load_config_none(self): with pytest.raises(AssertionError): util.load_config() def test_load_config_type_error(self): with pytest.raises(TypeError): util.load_config(1) def test_load_config_file_error(self): with pytest.raises(FileNotFoundError): util.load_config(empty_path) class TestLoadConfig: @pytest.mark.parametrize('args', [ [empty_file_path], [correct_model_path], [correct_model_cfg], (correct_model_path, correct_deploy_path), (correct_model_path, correct_deploy_cfg), (correct_model_cfg, correct_deploy_cfg), ]) def test_load_config(self, args): configs = util.load_config(*args) for v in zip(configs, args): if isinstance(v[1], str): cfg = mmcv.Config.fromfile(v[1]) else: cfg = v[1] assert v[0]._cfg_dict == cfg._cfg_dict class TestGetCodebaseConfig: def test_get_codebase_config_empty(self): assert util.get_codebase_config(mmcv.Config(dict())) == {} def test_get_codebase_config(self): codebase_config = util.get_codebase_config(correct_deploy_path) assert isinstance(codebase_config, dict) and len(codebase_config) > 1 class TestGetTaskType: def test_get_task_type_none(self): with pytest.raises(AssertionError): util.get_task_type(mmcv.Config(dict())) def test_get_task_type(self): assert util.get_task_type(correct_deploy_path) == Task.SUPER_RESOLUTION class TestGetCodebase: def test_get_codebase_none(self): with pytest.raises(AssertionError): util.get_codebase(mmcv.Config(dict())) def test_get_codebase(self): assert util.get_codebase(correct_deploy_path) == Codebase.MMEDIT class TestGetBackendConfig: def test_get_backend_config_empty(self): assert util.get_backend_config(mmcv.Config(dict())) == {} def test_get_backend_config(self): backend_config = util.get_backend_config(correct_deploy_path) assert isinstance(backend_config, dict) and len(backend_config) == 1 class TestGetBackend: def test_get_backend_none(self): with pytest.raises(AssertionError): util.get_backend(mmcv.Config(dict())) def test_get_backend(self): assert util.get_backend(correct_deploy_path) == Backend.ONNXRUNTIME class TestGetOnnxConfig: def test_get_onnx_config_empty(self): assert util.get_onnx_config(mmcv.Config(dict())) == {} def test_get_onnx_config(self): onnx_config = dict( dynamic_axes={ 'input': { 0: 'batch', 2: 'height', 3: 'width' }, 'output': { 0: 'batch', 2: 'height', 3: 'width' } }, type='onnx', export_params=True, keep_initializers_as_inputs=False, opset_version=11, save_file='end2end.onnx', input_names=['input'], output_names=['output'], input_shape=None) assert util.get_onnx_config(correct_deploy_path) == onnx_config class TestIsDynamic: config_with_onnx_config = mmcv.Config(dict(onnx_config=dict())) config_with_dynamic_axes = mmcv.Config( dict( onnx_config=dict( type='onnx', dynamic_axes={'input': { 0: 'batch', 2: 'height', 3: 'width' }}))) config_with_dynamic_axes_and_input_names = mmcv.Config( dict( onnx_config=dict( type='onnx', input_names=['image'], dynamic_axes={'image': { 0: 'batch', 2: 'height', 3: 'width' }}))) config_with_dynamic_axes_list = mmcv.Config( dict( onnx_config=dict( type='onnx', input_names=['image'], dynamic_axes=[[0, 2, 3]]))) def test_is_dynamic_batch_none(self): assert util.is_dynamic_batch( TestIsDynamic.config_with_onnx_config) is False def test_is_dynamic_batch_error_name(self): assert util.is_dynamic_batch(TestIsDynamic.config_with_dynamic_axes, 'output') is False def test_is_dynamic_batch(self): assert util.is_dynamic_batch( TestIsDynamic.config_with_dynamic_axes) is True def test_is_dynamic_batch_axes_list(self): assert util.is_dynamic_batch( TestIsDynamic.config_with_dynamic_axes_list) is True def test_is_dynamic_shape_none(self): assert util.is_dynamic_shape( TestIsDynamic.config_with_onnx_config) is False def test_is_dynamic_shape_error_name(self): assert util.is_dynamic_shape(TestIsDynamic.config_with_dynamic_axes, 'output') is False def test_is_dynamic_shape(self): assert util.is_dynamic_shape( TestIsDynamic.config_with_dynamic_axes) is True def test_is_dynamic_shape_input_names(self): assert util.is_dynamic_shape( TestIsDynamic.config_with_dynamic_axes_and_input_names) is True def test_is_dynamic_shape_different_names(self): config_with_different_names = \ TestIsDynamic.config_with_dynamic_axes_and_input_names util.get_ir_config( config_with_different_names).input_names = 'another_name' assert util.is_dynamic_shape(config_with_different_names) is False def test_is_dynamic_shape_axes_list(self): assert util.is_dynamic_shape( TestIsDynamic.config_with_dynamic_axes_list) is True class TestGetInputShape: config_without_input_shape = mmcv.Config( dict(onnx_config=dict(input_shape=None))) config_with_input_shape = mmcv.Config( dict(onnx_config=dict(input_shape=[1, 1]))) config_with_error_shape = mmcv.Config( dict(onnx_config=dict(input_shape=[1, 1, 1]))) def test_get_input_shape_none(self): assert util.get_input_shape( TestGetInputShape.config_without_input_shape) is None def test_get_input_shape_error(self): with pytest.raises(Exception): util.get_input_shape(TestGetInputShape.config_with_error_shape) def test_get_input_shape(self): assert util.get_input_shape( TestGetInputShape.config_with_input_shape) == [1, 1] class TestCfgApplyMark: config_with_mask = mmcv.Config( dict(partition_config=dict(apply_marks=True))) def test_cfg_apply_marks_none(self): assert util.cfg_apply_marks(mmcv.Config(dict())) is None def test_cfg_apply_marks(self): assert util.cfg_apply_marks(TestCfgApplyMark.config_with_mask) is True class TestGetPartitionConfig: config_with_mask = mmcv.Config( dict(partition_config=dict(apply_marks=True))) config_without_mask = mmcv.Config( dict(partition_config=dict(apply_marks=False))) def test_get_partition_config_none(self): assert util.get_partition_config(mmcv.Config(dict())) is None def test_get_partition_config_without_mask(self): assert util.get_partition_config( TestGetPartitionConfig.config_without_mask) is None def test_get_partition_config(self): assert util.get_partition_config( TestGetPartitionConfig.config_with_mask) == dict(apply_marks=True) class TestGetCalib: config_with_calib = mmcv.Config( dict(calib_config=dict(create_calib=True, calib_file='calib_data.h5'))) config_without_calib = mmcv.Config( dict( calib_config=dict(create_calib=False, calib_file='calib_data.h5'))) def test_get_calib_config(self): assert util.get_calib_config(TestGetCalib.config_with_calib) == dict( create_calib=True, calib_file='calib_data.h5') def test_get_calib_filename_none(self): assert util.get_calib_filename(mmcv.Config(dict())) is None def test_get_calib_filename_false(self): assert util.get_calib_filename( TestGetCalib.config_without_calib) is None def test_get_calib_filename(self): assert util.get_calib_filename( TestGetCalib.config_with_calib) == 'calib_data.h5' class TestGetCommonConfig: config_with_common_config = mmcv.Config( dict( backend_config=dict( type='tensorrt', common_config=dict(fp16_mode=False)))) def test_get_common_config(self): assert util.get_common_config( TestGetCommonConfig.config_with_common_config) == dict( fp16_mode=False) class TestGetModelInputs: config_with_model_inputs = mmcv.Config( dict(backend_config=dict(model_inputs=[dict(input_shapes=None)]))) def test_model_inputs(self): assert util.get_model_inputs( TestGetModelInputs.config_with_model_inputs) == [ dict(input_shapes=None) ] class TestGetDynamicAxes: input_name = get_random_name() def test_with_empty_cfg(self): deploy_cfg = mmcv.Config() with pytest.raises(KeyError): util.get_dynamic_axes(deploy_cfg) def test_can_get_axes_from_dict(self): expected_dynamic_axes = { self.input_name: { 0: 'batch', 2: 'height', 3: 'width' } } deploy_cfg = mmcv.Config( dict(onnx_config=dict(dynamic_axes=expected_dynamic_axes))) dynamic_axes = util.get_dynamic_axes(deploy_cfg) assert expected_dynamic_axes == dynamic_axes def test_can_not_get_axes_from_list_without_names(self): axes = [[0, 2, 3]] deploy_cfg = mmcv.Config(dict(onnx_config=dict(dynamic_axes=axes))) with pytest.raises(KeyError): util.get_dynamic_axes(deploy_cfg) def test_can_get_axes_from_list_with_args(self): axes = [[0, 2, 3]] expected_dynamic_axes = {self.input_name: axes[0]} axes_names = [self.input_name] deploy_cfg = mmcv.Config(dict(onnx_config=dict(dynamic_axes=axes))) dynamic_axes = util.get_dynamic_axes(deploy_cfg, axes_names) assert expected_dynamic_axes == dynamic_axes def test_can_get_axes_from_list_with_cfg(self): output_name = get_random_name() axes = [[0, 2, 3], [0]] expected_dynamic_axes = { self.input_name: axes[0], output_name: axes[1] } deploy_cfg = mmcv.Config( dict( onnx_config=dict( input_names=[self.input_name], output_names=[output_name], dynamic_axes=axes))) dynamic_axes = util.get_dynamic_axes(deploy_cfg) assert expected_dynamic_axes == dynamic_axes def test_AdvancedEnum(): keys = [ Task.TEXT_DETECTION, Task.TEXT_RECOGNITION, Task.SEGMENTATION, Task.SUPER_RESOLUTION, Task.CLASSIFICATION, Task.OBJECT_DETECTION ] vals = [ 'TextDetection', 'TextRecognition', 'Segmentation', 'SuperResolution', 'Classification', 'ObjectDetection' ] for k, v in zip(keys, vals): assert Task.get(v) == k assert k.value == v def test_export_info(): with tempfile.TemporaryDirectory() as dir: dump_info(correct_deploy_cfg, correct_model_cfg, dir, '') deploy_json = os.path.join(dir, 'deploy.json') pipeline_json = os.path.join(dir, 'pipeline.json') detail_json = os.path.join(dir, 'detail.json') assert os.path.exists(pipeline_json) assert os.path.exists(detail_json) assert os.path.exists(deploy_json) def test_get_root_logger(): from mmdeploy.utils import get_root_logger logger = get_root_logger() logger.info('This is a test message')