diff --git a/csrc/mmdeploy/codebase/mmocr/short_scale_aspect_jitter.cpp b/csrc/mmdeploy/codebase/mmocr/short_scale_aspect_jitter.cpp index 866e2239e..53e22337c 100644 --- a/csrc/mmdeploy/codebase/mmocr/short_scale_aspect_jitter.cpp +++ b/csrc/mmdeploy/codebase/mmocr/short_scale_aspect_jitter.cpp @@ -78,9 +78,8 @@ class ShortScaleAspectJitterImpl : public Module { auto dst_width = static_cast(std::round(scale * img_shape[2])); dst_height = static_cast(std::ceil(1.0 * dst_height / scale_divisor) * scale_divisor); dst_width = static_cast(std::ceil(1.0 * dst_width / scale_divisor) * scale_divisor); - - std::vector scale_factor = {(float) (1.0 * dst_width / img_shape[2]), - (float) (1.0 * dst_height / img_shape[1])}; + std::vector scale_factor = {(float)1.0 * dst_width / img_shape[2], + (float)1.0 * dst_height / img_shape[1]}; img_resize = ResizeImage(img, dst_height, dst_width); Value output = input; diff --git a/mmdeploy/apis/utils/utils.py b/mmdeploy/apis/utils/utils.py index ed8528b95..e3ca64690 100644 --- a/mmdeploy/apis/utils/utils.py +++ b/mmdeploy/apis/utils/utils.py @@ -3,6 +3,7 @@ import mmengine from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase from mmdeploy.utils import get_codebase, get_task_type +from mmdeploy.utils.config_utils import get_codebase_external_module def build_task_processor(model_cfg: mmengine.Config, @@ -18,7 +19,8 @@ def build_task_processor(model_cfg: mmengine.Config, BaseTask: A task processor. """ codebase_type = get_codebase(deploy_cfg) - import_codebase(codebase_type) + custom_module_list = get_codebase_external_module(deploy_cfg) + import_codebase(codebase_type, custom_module_list) codebase = get_codebase_class(codebase_type) return codebase.build_task_processor(model_cfg, deploy_cfg, device) @@ -39,7 +41,8 @@ def get_predefined_partition_cfg(deploy_cfg: mmengine.Config, dict: A dictionary of partition config. """ codebase_type = get_codebase(deploy_cfg) - import_codebase(codebase_type) + custom_module_list = get_codebase_external_module(deploy_cfg) + import_codebase(codebase_type, custom_module_list) task = get_task_type(deploy_cfg) codebase = get_codebase_class(codebase_type) task_processor_class = codebase.get_task_class(task) diff --git a/mmdeploy/codebase/__init__.py b/mmdeploy/codebase/__init__.py index 6aa263c71..4f3268ce2 100644 --- a/mmdeploy/codebase/__init__.py +++ b/mmdeploy/codebase/__init__.py @@ -1,16 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import importlib +from typing import List from mmdeploy.utils import Codebase from .base import BaseTask, MMCodebase, get_codebase_class -extra_dependent_library = { - Codebase.MMOCR: ['mmdet'], - Codebase.MMROTATE: ['mmdet'] -} - -def import_codebase(codebase: Codebase): +def import_codebase(codebase_type: Codebase, custom_module_list: List = []): """Import a codebase package in `mmdeploy.codebase` The function will check if all dependent libraries are installed. @@ -21,21 +16,14 @@ def import_codebase(codebase: Codebase): Args: codebase (Codebase): The codebase to import. """ - codebase_name = codebase.value - dependent_library = [codebase_name] + \ - extra_dependent_library.get(codebase, []) - - for lib in dependent_library: - if not importlib.util.find_spec(lib): - raise ImportError( - f'{lib} has not been installed. ' - f'Import mmdeploy.codebase.{codebase_name} failed.') - importlib.import_module(f'mmdeploy.codebase.{lib}') - importlib.import_module(f'{lib}.models') - importlib.import_module(f'{lib}.datasets') - importlib.import_module(f'{lib}.structures') - importlib.import_module(f'{lib}.visualization') - importlib.import_module(f'{lib}.engine') + import importlib + if len(custom_module_list) > 0: + for custom_module in custom_module_list: + importlib.import_module(f'{custom_module}') + else: + importlib.import_module(f'mmdeploy.codebase.{codebase_type.value}') + codebase = get_codebase_class(codebase_type) + codebase.register_all_modules() __all__ = ['MMCodebase', 'BaseTask', 'get_codebase_class'] diff --git a/mmdeploy/codebase/base/mmcodebase.py b/mmdeploy/codebase/base/mmcodebase.py index e5330af97..c2d8c0098 100644 --- a/mmdeploy/codebase/base/mmcodebase.py +++ b/mmdeploy/codebase/base/mmcodebase.py @@ -49,6 +49,10 @@ class MMCodebase(metaclass=ABCMeta): deploy_cfg=deploy_cfg, device=device)) + @classmethod + def register_all_modules(cls): + pass + # Note that the build function returns the class instead of its instance. diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 8587040f3..601284380 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -13,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset from mmdeploy.utils import (get_backend_config, get_codebase, get_codebase_config, get_root_logger) +from mmdeploy.utils.config_utils import get_codebase_external_module from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset @@ -40,7 +41,8 @@ class BaseTask(metaclass=ABCMeta): # init scope from .. import import_codebase - import_codebase(self.codebase) + custom_module_list = get_codebase_external_module(deploy_cfg) + import_codebase(self.codebase, custom_module_list) from mmengine.registry import DefaultScope if not DefaultScope.check_instance_created(self.experiment_name): diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 6e3ef7038..ef1d21a8b 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -23,6 +23,11 @@ class MMClassification(MMCodebase): task_registry = MMCLS_TASK + @classmethod + def register_all_modules(cls): + from mmcls.utils.setup_env import register_all_modules + register_all_modules(True) + def process_model_config(model_cfg: Config, imgs: Union[str, np.ndarray], diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index 7c0cf2227..3c0c833f3 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -22,6 +22,11 @@ class MMDetection(MMCodebase): task_registry = MMDET_TASK + @classmethod + def register_all_modules(cls): + from mmdet.utils.setup_env import register_all_modules + register_all_modules(True) + def process_model_config(model_cfg: Config, imgs: Union[Sequence[str], Sequence[np.ndarray]], diff --git a/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py b/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py index 234325c91..8af6c647d 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py +++ b/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py @@ -41,8 +41,14 @@ class MMDetection3d(MMCodebase): Returns: BaseTask: A task processor. """ + return MMDET3D_TASK.build(model_cfg, deploy_cfg, device) + @classmethod + def register_all_modules(cls): + from mmdet3d.utils.set_env import register_all_modules + register_all_modules(True) + @staticmethod def build_dataset(dataset_cfg: Union[str, mmengine.Config], *args, **kwargs) -> Dataset: diff --git a/mmdeploy/codebase/mmedit/deploy/mmediting.py b/mmdeploy/codebase/mmedit/deploy/mmediting.py index b8bd4db75..59d54ac1d 100644 --- a/mmdeploy/codebase/mmedit/deploy/mmediting.py +++ b/mmdeploy/codebase/mmedit/deploy/mmediting.py @@ -12,3 +12,8 @@ class MMEditing(MMCodebase): """mmediting codebase class.""" task_registry = MMEDIT_TASK + + @classmethod + def register_all_modules(cls): + from mmedit.utils.setup_env import register_all_modules + register_all_modules(True) diff --git a/mmdeploy/codebase/mmocr/deploy/mmocr.py b/mmdeploy/codebase/mmocr/deploy/mmocr.py index d8855d247..23e62d44c 100644 --- a/mmdeploy/codebase/mmocr/deploy/mmocr.py +++ b/mmdeploy/codebase/mmocr/deploy/mmocr.py @@ -18,6 +18,15 @@ class MMOCR(MMCodebase): task_registry = MMOCR_TASK + @classmethod + def register_all_modules(cls): + from mmdet.utils.setup_env import \ + register_all_modules as register_all_modules_mmdet + from mmocr.utils.setup_env import \ + register_all_modules as register_all_modules_mmocr + register_all_modules_mmocr(False) + register_all_modules_mmdet(True) + @staticmethod def single_gpu_test(model: torch.nn.Module, data_loader: DataLoader, diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 45bd2b15f..12ce96d83 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -114,6 +114,11 @@ class MMPose(MMCodebase): """mmpose codebase class.""" task_registry = MMPOSE_TASK + @classmethod + def register_all_modules(cls): + from mmpose.utils.setup_env import register_all_modules + register_all_modules(True) + @MMPOSE_TASK.register_module(Task.POSE_DETECTION.value) class PoseDetection(BaseTask): diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index 8206a8df4..316d711d9 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -105,6 +105,11 @@ class MMSegmentation(MMCodebase): """mmsegmentation codebase class.""" task_registry = MMSEG_TASK + @classmethod + def register_all_modules(cls): + from mmseg.utils.set_env import register_all_modules + register_all_modules(True) + @MMSEG_TASK.register_module(Task.SEGMENTATION.value) class Segmentation(BaseTask): diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index 510e22f08..d61215dc5 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union import mmengine from .constants import Backend, Codebase, Task -from .utils import deprecate +from .utils import deprecate, get_root_logger def load_config(*args) -> List[mmengine.Config]: @@ -62,6 +62,27 @@ def get_task_type(deploy_cfg: Union[str, mmengine.Config]) -> Task: return Task.get(task) +def register_codebase(codebase: str) -> Codebase: + """Register a new codebase which is not included in Codebase. + + Args: + codebase (str): The codebase name. + + Returns: + Codebase : An enumeration denotes the codebase type. + """ + from aenum import extend_enum + try: + Codebase.get(codebase) + except Exception as e: + logger = get_root_logger() + extend_enum(Codebase, codebase.upper(), codebase) + logger.warn(f'Failed to get codebase, got: {e}. Then export ' + f'a new codebase in Codebase {codebase.upper()}: ' + f'{codebase}') + return Codebase.get(codebase) + + def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase: """Get the codebase from the config. @@ -71,12 +92,11 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase: Returns: Codebase : An enumeration denotes the codebase type. """ - codebase_config = get_codebase_config(deploy_cfg) assert 'type' in codebase_config, 'The codebase config of deploy config'\ 'requires a "type" field' codebase = codebase_config['type'] - return Codebase.get(codebase) + return register_codebase(codebase) def get_backend_config(deploy_cfg: Union[str, mmengine.Config]) -> Dict: @@ -403,3 +423,8 @@ def get_precision(deploy_cfg: Union[str, mmengine.Config]) -> str: if backend == Backend.NCNN and 'precision' in deploy_cfg['backend_config']: precision = deploy_cfg['backend_config']['precision'] return precision + + +def get_codebase_external_module( + deploy_cfg: Union[str, mmengine.Config]) -> List: + return get_codebase_config(deploy_cfg).get('module', []) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 706ce39a8..229a775ee 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,3 +1,4 @@ +aenum grpcio h5py matplotlib diff --git a/tests/test_utils/test_util.py b/tests/test_utils/test_util.py index c0bd82803..b98657e95 100644 --- a/tests/test_utils/test_util.py +++ b/tests/test_utils/test_util.py @@ -12,6 +12,7 @@ from mmengine import Config import mmdeploy.utils as util from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.utils import target_wrapper +from mmdeploy.utils.config_utils import get_codebase_external_module from mmdeploy.utils.constants import Backend, Codebase, Task from mmdeploy.utils.test import get_random_name @@ -104,6 +105,21 @@ class TestGetBackendConfig: assert isinstance(backend_config, dict) and len(backend_config) == 1 +class TestGetCodebaseExternalModule: + + def test_get_codebase_external_module_empty(self): + assert get_codebase_external_module(Config(dict())) == [] + + def test_get_codebase_external_module(self): + external_deploy_cfg = dict( + onnx_config=dict(), + codebase_config=dict(module=['mmyolo.deploy.mmyolo']), + backend_config=dict(type='onnxruntime')) + custom_module_list = get_codebase_external_module(external_deploy_cfg) + assert isinstance(custom_module_list, list) \ + and len(custom_module_list) == 1 + + class TestGetBackend: def test_get_backend_none(self):