[Enhancement 2.0] mmdeploy for mmyolo (#1088)
* support for external codebase like mmyolo * support for external export * fix missing flake8 * fix comments * add aenum * add missing files * fix condition * refactor import_codebase * fix mmyolo support * fix lint * add base codebase * fix a strange clang-format * fix import_codebase * fix dependent codebase register * wrap custom_model * fix comment * add utpull/1091/head
parent
b8c19b35d2
commit
7f70d7fe56
csrc/mmdeploy/codebase/mmocr
mmdeploy
apis/utils
codebase
mmcls/deploy
mmdet/deploy
mmdet3d/deploy
mmedit/deploy
mmocr/deploy
mmpose/deploy
mmseg/deploy
utils
requirements
tests/test_utils
|
@ -78,9 +78,8 @@ class ShortScaleAspectJitterImpl : public Module {
|
|||
auto dst_width = static_cast<int>(std::round(scale * img_shape[2]));
|
||||
dst_height = static_cast<int>(std::ceil(1.0 * dst_height / scale_divisor) * scale_divisor);
|
||||
dst_width = static_cast<int>(std::ceil(1.0 * dst_width / scale_divisor) * scale_divisor);
|
||||
|
||||
std::vector<float> scale_factor = {(float)(1.0 * dst_width / img_shape[2]),
|
||||
(float)(1.0 * dst_height / img_shape[1])};
|
||||
std::vector<float> scale_factor = {(float)1.0 * dst_width / img_shape[2],
|
||||
(float)1.0 * dst_height / img_shape[1]};
|
||||
|
||||
img_resize = ResizeImage(img, dst_height, dst_width);
|
||||
Value output = input;
|
||||
|
|
|
@ -4,6 +4,7 @@ import mmengine
|
|||
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
|
||||
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
|
||||
parse_device_id)
|
||||
from mmdeploy.utils.config_utils import get_codebase_external_module
|
||||
|
||||
|
||||
def check_backend_device(deploy_cfg: mmengine.Config, device: str):
|
||||
|
@ -37,7 +38,8 @@ def build_task_processor(model_cfg: mmengine.Config,
|
|||
"""
|
||||
check_backend_device(deploy_cfg=deploy_cfg, device=device)
|
||||
codebase_type = get_codebase(deploy_cfg)
|
||||
import_codebase(codebase_type)
|
||||
custom_module_list = get_codebase_external_module(deploy_cfg)
|
||||
import_codebase(codebase_type, custom_module_list)
|
||||
codebase = get_codebase_class(codebase_type)
|
||||
return codebase.build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
|
@ -58,7 +60,8 @@ def get_predefined_partition_cfg(deploy_cfg: mmengine.Config,
|
|||
dict: A dictionary of partition config.
|
||||
"""
|
||||
codebase_type = get_codebase(deploy_cfg)
|
||||
import_codebase(codebase_type)
|
||||
custom_module_list = get_codebase_external_module(deploy_cfg)
|
||||
import_codebase(codebase_type, custom_module_list)
|
||||
task = get_task_type(deploy_cfg)
|
||||
codebase = get_codebase_class(codebase_type)
|
||||
task_processor_class = codebase.get_task_class(task)
|
||||
|
|
|
@ -1,16 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import importlib
|
||||
from typing import List
|
||||
|
||||
from mmdeploy.utils import Codebase
|
||||
from .base import BaseTask, MMCodebase, get_codebase_class
|
||||
|
||||
extra_dependent_library = {
|
||||
Codebase.MMOCR: ['mmdet'],
|
||||
Codebase.MMROTATE: ['mmdet']
|
||||
}
|
||||
|
||||
|
||||
def import_codebase(codebase: Codebase):
|
||||
def import_codebase(codebase_type: Codebase, custom_module_list: List = []):
|
||||
"""Import a codebase package in `mmdeploy.codebase`
|
||||
|
||||
The function will check if all dependent libraries are installed.
|
||||
|
@ -21,21 +16,14 @@ def import_codebase(codebase: Codebase):
|
|||
Args:
|
||||
codebase (Codebase): The codebase to import.
|
||||
"""
|
||||
codebase_name = codebase.value
|
||||
dependent_library = [codebase_name] + \
|
||||
extra_dependent_library.get(codebase, [])
|
||||
|
||||
for lib in dependent_library:
|
||||
if not importlib.util.find_spec(lib):
|
||||
raise ImportError(
|
||||
f'{lib} has not been installed. '
|
||||
f'Import mmdeploy.codebase.{codebase_name} failed.')
|
||||
importlib.import_module(f'mmdeploy.codebase.{lib}')
|
||||
importlib.import_module(f'{lib}.models')
|
||||
importlib.import_module(f'{lib}.datasets')
|
||||
importlib.import_module(f'{lib}.structures')
|
||||
importlib.import_module(f'{lib}.visualization')
|
||||
importlib.import_module(f'{lib}.engine')
|
||||
import importlib
|
||||
if len(custom_module_list) > 0:
|
||||
for custom_module in custom_module_list:
|
||||
importlib.import_module(f'{custom_module}')
|
||||
else:
|
||||
importlib.import_module(f'mmdeploy.codebase.{codebase_type.value}')
|
||||
codebase = get_codebase_class(codebase_type)
|
||||
codebase.register_all_modules()
|
||||
|
||||
|
||||
__all__ = ['MMCodebase', 'BaseTask', 'get_codebase_class']
|
||||
|
|
|
@ -49,6 +49,10 @@ class MMCodebase(metaclass=ABCMeta):
|
|||
deploy_cfg=deploy_cfg,
|
||||
device=device))
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
pass
|
||||
|
||||
|
||||
# Note that the build function returns the class instead of its instance.
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset
|
|||
|
||||
from mmdeploy.utils import (get_backend_config, get_codebase,
|
||||
get_codebase_config, get_root_logger)
|
||||
from mmdeploy.utils.config_utils import get_codebase_external_module
|
||||
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
||||
|
||||
|
||||
|
@ -40,7 +41,8 @@ class BaseTask(metaclass=ABCMeta):
|
|||
|
||||
# init scope
|
||||
from .. import import_codebase
|
||||
import_codebase(self.codebase)
|
||||
custom_module_list = get_codebase_external_module(deploy_cfg)
|
||||
import_codebase(self.codebase, custom_module_list)
|
||||
|
||||
from mmengine.registry import DefaultScope
|
||||
if not DefaultScope.check_instance_created(self.experiment_name):
|
||||
|
|
|
@ -23,6 +23,11 @@ class MMClassification(MMCodebase):
|
|||
|
||||
task_registry = MMCLS_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmcls.utils.setup_env import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
def process_model_config(model_cfg: Config,
|
||||
imgs: Union[str, np.ndarray],
|
||||
|
|
|
@ -22,6 +22,11 @@ class MMDetection(MMCodebase):
|
|||
|
||||
task_registry = MMDET_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmdet.utils.setup_env import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
def process_model_config(model_cfg: Config,
|
||||
imgs: Union[Sequence[str], Sequence[np.ndarray]],
|
||||
|
|
|
@ -41,8 +41,14 @@ class MMDetection3d(MMCodebase):
|
|||
Returns:
|
||||
BaseTask: A task processor.
|
||||
"""
|
||||
|
||||
return MMDET3D_TASK.build(model_cfg, deploy_cfg, device)
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmdet3d.utils.set_env import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
||||
@staticmethod
|
||||
def build_dataset(dataset_cfg: Union[str, mmengine.Config], *args,
|
||||
**kwargs) -> Dataset:
|
||||
|
|
|
@ -12,3 +12,8 @@ class MMEditing(MMCodebase):
|
|||
"""mmediting codebase class."""
|
||||
|
||||
task_registry = MMEDIT_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmedit.utils.setup_env import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
|
|
@ -12,3 +12,12 @@ class MMOCR(MMCodebase):
|
|||
"""MMOCR codebase class."""
|
||||
|
||||
task_registry = MMOCR_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmdet.utils.setup_env import \
|
||||
register_all_modules as register_all_modules_mmdet
|
||||
from mmocr.utils.setup_env import \
|
||||
register_all_modules as register_all_modules_mmocr
|
||||
register_all_modules_mmocr(False)
|
||||
register_all_modules_mmdet(True)
|
||||
|
|
|
@ -117,6 +117,11 @@ class MMPose(MMCodebase):
|
|||
"""mmpose codebase class."""
|
||||
task_registry = MMPOSE_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmpose.utils.setup_env import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
@MMPOSE_TASK.register_module(Task.POSE_DETECTION.value)
|
||||
class PoseDetection(BaseTask):
|
||||
|
|
|
@ -104,6 +104,11 @@ class MMSegmentation(MMCodebase):
|
|||
"""mmsegmentation codebase class."""
|
||||
task_registry = MMSEG_TASK
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmseg.utils.set_env import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
@MMSEG_TASK.register_module(Task.SEGMENTATION.value)
|
||||
class Segmentation(BaseTask):
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
|||
import mmengine
|
||||
|
||||
from .constants import Backend, Codebase, Task
|
||||
from .utils import deprecate
|
||||
from .utils import deprecate, get_root_logger
|
||||
|
||||
|
||||
def load_config(*args) -> List[mmengine.Config]:
|
||||
|
@ -62,6 +62,27 @@ def get_task_type(deploy_cfg: Union[str, mmengine.Config]) -> Task:
|
|||
return Task.get(task)
|
||||
|
||||
|
||||
def register_codebase(codebase: str) -> Codebase:
|
||||
"""Register a new codebase which is not included in Codebase.
|
||||
|
||||
Args:
|
||||
codebase (str): The codebase name.
|
||||
|
||||
Returns:
|
||||
Codebase : An enumeration denotes the codebase type.
|
||||
"""
|
||||
from aenum import extend_enum
|
||||
try:
|
||||
Codebase.get(codebase)
|
||||
except Exception as e:
|
||||
logger = get_root_logger()
|
||||
extend_enum(Codebase, codebase.upper(), codebase)
|
||||
logger.warn(f'Failed to get codebase, got: {e}. Then export '
|
||||
f'a new codebase in Codebase {codebase.upper()}: '
|
||||
f'{codebase}')
|
||||
return Codebase.get(codebase)
|
||||
|
||||
|
||||
def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
|
||||
"""Get the codebase from the config.
|
||||
|
||||
|
@ -71,12 +92,11 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
|
|||
Returns:
|
||||
Codebase : An enumeration denotes the codebase type.
|
||||
"""
|
||||
|
||||
codebase_config = get_codebase_config(deploy_cfg)
|
||||
assert 'type' in codebase_config, 'The codebase config of deploy config'\
|
||||
'requires a "type" field'
|
||||
codebase = codebase_config['type']
|
||||
return Codebase.get(codebase)
|
||||
return register_codebase(codebase)
|
||||
|
||||
|
||||
def get_backend_config(deploy_cfg: Union[str, mmengine.Config]) -> Dict:
|
||||
|
@ -417,3 +437,8 @@ def get_precision(deploy_cfg: Union[str, mmengine.Config]) -> str:
|
|||
if backend == Backend.NCNN and 'precision' in deploy_cfg['backend_config']:
|
||||
precision = deploy_cfg['backend_config']['precision']
|
||||
return precision
|
||||
|
||||
|
||||
def get_codebase_external_module(
|
||||
deploy_cfg: Union[str, mmengine.Config]) -> List:
|
||||
return get_codebase_config(deploy_cfg).get('module', [])
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
aenum
|
||||
grpcio
|
||||
h5py
|
||||
matplotlib
|
||||
|
|
|
@ -12,6 +12,7 @@ from mmengine import Config
|
|||
import mmdeploy.utils as util
|
||||
from mmdeploy.backend.sdk.export_info import export2SDK
|
||||
from mmdeploy.utils import target_wrapper
|
||||
from mmdeploy.utils.config_utils import get_codebase_external_module
|
||||
from mmdeploy.utils.constants import Backend, Codebase, Task
|
||||
from mmdeploy.utils.test import get_random_name
|
||||
|
||||
|
@ -104,6 +105,21 @@ class TestGetBackendConfig:
|
|||
assert isinstance(backend_config, dict) and len(backend_config) == 1
|
||||
|
||||
|
||||
class TestGetCodebaseExternalModule:
|
||||
|
||||
def test_get_codebase_external_module_empty(self):
|
||||
assert get_codebase_external_module(Config(dict())) == []
|
||||
|
||||
def test_get_codebase_external_module(self):
|
||||
external_deploy_cfg = dict(
|
||||
onnx_config=dict(),
|
||||
codebase_config=dict(module=['mmyolo.deploy.mmyolo']),
|
||||
backend_config=dict(type='onnxruntime'))
|
||||
custom_module_list = get_codebase_external_module(external_deploy_cfg)
|
||||
assert isinstance(custom_module_list, list) \
|
||||
and len(custom_module_list) == 1
|
||||
|
||||
|
||||
class TestGetBackend:
|
||||
|
||||
def test_get_backend_none(self):
|
||||
|
|
Loading…
Reference in New Issue