[Enhancement 2.0] mmdeploy for mmyolo ()

* support for external codebase like mmyolo

* support for external export

* fix missing flake8

* fix comments

* add aenum

* add missing files

* fix condition

* refactor import_codebase

* fix mmyolo support

* fix lint

* add base codebase

* fix a strange clang-format

* fix import_codebase

* fix dependent codebase register

* wrap custom_model

* fix comment

* add ut
pull/1091/head
hanrui1sensetime 2022-09-28 16:30:29 +08:00 committed by RunningLeon
parent b8c19b35d2
commit 7f70d7fe56
15 changed files with 109 additions and 31 deletions
csrc/mmdeploy/codebase/mmocr
mmdeploy
apis/utils
codebase
mmcls/deploy
mmdet3d/deploy
mmedit/deploy
mmocr/deploy
mmpose/deploy
mmseg/deploy
requirements
tests/test_utils

View File

@ -78,9 +78,8 @@ class ShortScaleAspectJitterImpl : public Module {
auto dst_width = static_cast<int>(std::round(scale * img_shape[2]));
dst_height = static_cast<int>(std::ceil(1.0 * dst_height / scale_divisor) * scale_divisor);
dst_width = static_cast<int>(std::ceil(1.0 * dst_width / scale_divisor) * scale_divisor);
std::vector<float> scale_factor = {(float)(1.0 * dst_width / img_shape[2]),
(float)(1.0 * dst_height / img_shape[1])};
std::vector<float> scale_factor = {(float)1.0 * dst_width / img_shape[2],
(float)1.0 * dst_height / img_shape[1]};
img_resize = ResizeImage(img, dst_height, dst_width);
Value output = input;

View File

@ -4,6 +4,7 @@ import mmengine
from mmdeploy.codebase import BaseTask, get_codebase_class, import_codebase
from mmdeploy.utils import (get_backend, get_codebase, get_task_type,
parse_device_id)
from mmdeploy.utils.config_utils import get_codebase_external_module
def check_backend_device(deploy_cfg: mmengine.Config, device: str):
@ -37,7 +38,8 @@ def build_task_processor(model_cfg: mmengine.Config,
"""
check_backend_device(deploy_cfg=deploy_cfg, device=device)
codebase_type = get_codebase(deploy_cfg)
import_codebase(codebase_type)
custom_module_list = get_codebase_external_module(deploy_cfg)
import_codebase(codebase_type, custom_module_list)
codebase = get_codebase_class(codebase_type)
return codebase.build_task_processor(model_cfg, deploy_cfg, device)
@ -58,7 +60,8 @@ def get_predefined_partition_cfg(deploy_cfg: mmengine.Config,
dict: A dictionary of partition config.
"""
codebase_type = get_codebase(deploy_cfg)
import_codebase(codebase_type)
custom_module_list = get_codebase_external_module(deploy_cfg)
import_codebase(codebase_type, custom_module_list)
task = get_task_type(deploy_cfg)
codebase = get_codebase_class(codebase_type)
task_processor_class = codebase.get_task_class(task)

View File

@ -1,16 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import importlib
from typing import List
from mmdeploy.utils import Codebase
from .base import BaseTask, MMCodebase, get_codebase_class
extra_dependent_library = {
Codebase.MMOCR: ['mmdet'],
Codebase.MMROTATE: ['mmdet']
}
def import_codebase(codebase: Codebase):
def import_codebase(codebase_type: Codebase, custom_module_list: List = []):
"""Import a codebase package in `mmdeploy.codebase`
The function will check if all dependent libraries are installed.
@ -21,21 +16,14 @@ def import_codebase(codebase: Codebase):
Args:
codebase (Codebase): The codebase to import.
"""
codebase_name = codebase.value
dependent_library = [codebase_name] + \
extra_dependent_library.get(codebase, [])
for lib in dependent_library:
if not importlib.util.find_spec(lib):
raise ImportError(
f'{lib} has not been installed. '
f'Import mmdeploy.codebase.{codebase_name} failed.')
importlib.import_module(f'mmdeploy.codebase.{lib}')
importlib.import_module(f'{lib}.models')
importlib.import_module(f'{lib}.datasets')
importlib.import_module(f'{lib}.structures')
importlib.import_module(f'{lib}.visualization')
importlib.import_module(f'{lib}.engine')
import importlib
if len(custom_module_list) > 0:
for custom_module in custom_module_list:
importlib.import_module(f'{custom_module}')
else:
importlib.import_module(f'mmdeploy.codebase.{codebase_type.value}')
codebase = get_codebase_class(codebase_type)
codebase.register_all_modules()
__all__ = ['MMCodebase', 'BaseTask', 'get_codebase_class']

View File

@ -49,6 +49,10 @@ class MMCodebase(metaclass=ABCMeta):
deploy_cfg=deploy_cfg,
device=device))
@classmethod
def register_all_modules(cls):
pass
# Note that the build function returns the class instead of its instance.

View File

@ -13,6 +13,7 @@ from torch.utils.data import DataLoader, Dataset
from mmdeploy.utils import (get_backend_config, get_codebase,
get_codebase_config, get_root_logger)
from mmdeploy.utils.config_utils import get_codebase_external_module
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
@ -40,7 +41,8 @@ class BaseTask(metaclass=ABCMeta):
# init scope
from .. import import_codebase
import_codebase(self.codebase)
custom_module_list = get_codebase_external_module(deploy_cfg)
import_codebase(self.codebase, custom_module_list)
from mmengine.registry import DefaultScope
if not DefaultScope.check_instance_created(self.experiment_name):

View File

@ -23,6 +23,11 @@ class MMClassification(MMCodebase):
task_registry = MMCLS_TASK
@classmethod
def register_all_modules(cls):
from mmcls.utils.setup_env import register_all_modules
register_all_modules(True)
def process_model_config(model_cfg: Config,
imgs: Union[str, np.ndarray],

View File

@ -22,6 +22,11 @@ class MMDetection(MMCodebase):
task_registry = MMDET_TASK
@classmethod
def register_all_modules(cls):
from mmdet.utils.setup_env import register_all_modules
register_all_modules(True)
def process_model_config(model_cfg: Config,
imgs: Union[Sequence[str], Sequence[np.ndarray]],

View File

@ -41,8 +41,14 @@ class MMDetection3d(MMCodebase):
Returns:
BaseTask: A task processor.
"""
return MMDET3D_TASK.build(model_cfg, deploy_cfg, device)
@classmethod
def register_all_modules(cls):
from mmdet3d.utils.set_env import register_all_modules
register_all_modules(True)
@staticmethod
def build_dataset(dataset_cfg: Union[str, mmengine.Config], *args,
**kwargs) -> Dataset:

View File

@ -12,3 +12,8 @@ class MMEditing(MMCodebase):
"""mmediting codebase class."""
task_registry = MMEDIT_TASK
@classmethod
def register_all_modules(cls):
from mmedit.utils.setup_env import register_all_modules
register_all_modules(True)

View File

@ -12,3 +12,12 @@ class MMOCR(MMCodebase):
"""MMOCR codebase class."""
task_registry = MMOCR_TASK
@classmethod
def register_all_modules(cls):
from mmdet.utils.setup_env import \
register_all_modules as register_all_modules_mmdet
from mmocr.utils.setup_env import \
register_all_modules as register_all_modules_mmocr
register_all_modules_mmocr(False)
register_all_modules_mmdet(True)

View File

@ -117,6 +117,11 @@ class MMPose(MMCodebase):
"""mmpose codebase class."""
task_registry = MMPOSE_TASK
@classmethod
def register_all_modules(cls):
from mmpose.utils.setup_env import register_all_modules
register_all_modules(True)
@MMPOSE_TASK.register_module(Task.POSE_DETECTION.value)
class PoseDetection(BaseTask):

View File

@ -104,6 +104,11 @@ class MMSegmentation(MMCodebase):
"""mmsegmentation codebase class."""
task_registry = MMSEG_TASK
@classmethod
def register_all_modules(cls):
from mmseg.utils.set_env import register_all_modules
register_all_modules(True)
@MMSEG_TASK.register_module(Task.SEGMENTATION.value)
class Segmentation(BaseTask):

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
import mmengine
from .constants import Backend, Codebase, Task
from .utils import deprecate
from .utils import deprecate, get_root_logger
def load_config(*args) -> List[mmengine.Config]:
@ -62,6 +62,27 @@ def get_task_type(deploy_cfg: Union[str, mmengine.Config]) -> Task:
return Task.get(task)
def register_codebase(codebase: str) -> Codebase:
"""Register a new codebase which is not included in Codebase.
Args:
codebase (str): The codebase name.
Returns:
Codebase : An enumeration denotes the codebase type.
"""
from aenum import extend_enum
try:
Codebase.get(codebase)
except Exception as e:
logger = get_root_logger()
extend_enum(Codebase, codebase.upper(), codebase)
logger.warn(f'Failed to get codebase, got: {e}. Then export '
f'a new codebase in Codebase {codebase.upper()}: '
f'{codebase}')
return Codebase.get(codebase)
def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
"""Get the codebase from the config.
@ -71,12 +92,11 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
Returns:
Codebase : An enumeration denotes the codebase type.
"""
codebase_config = get_codebase_config(deploy_cfg)
assert 'type' in codebase_config, 'The codebase config of deploy config'\
'requires a "type" field'
codebase = codebase_config['type']
return Codebase.get(codebase)
return register_codebase(codebase)
def get_backend_config(deploy_cfg: Union[str, mmengine.Config]) -> Dict:
@ -417,3 +437,8 @@ def get_precision(deploy_cfg: Union[str, mmengine.Config]) -> str:
if backend == Backend.NCNN and 'precision' in deploy_cfg['backend_config']:
precision = deploy_cfg['backend_config']['precision']
return precision
def get_codebase_external_module(
deploy_cfg: Union[str, mmengine.Config]) -> List:
return get_codebase_config(deploy_cfg).get('module', [])

View File

@ -1,3 +1,4 @@
aenum
grpcio
h5py
matplotlib

View File

@ -12,6 +12,7 @@ from mmengine import Config
import mmdeploy.utils as util
from mmdeploy.backend.sdk.export_info import export2SDK
from mmdeploy.utils import target_wrapper
from mmdeploy.utils.config_utils import get_codebase_external_module
from mmdeploy.utils.constants import Backend, Codebase, Task
from mmdeploy.utils.test import get_random_name
@ -104,6 +105,21 @@ class TestGetBackendConfig:
assert isinstance(backend_config, dict) and len(backend_config) == 1
class TestGetCodebaseExternalModule:
def test_get_codebase_external_module_empty(self):
assert get_codebase_external_module(Config(dict())) == []
def test_get_codebase_external_module(self):
external_deploy_cfg = dict(
onnx_config=dict(),
codebase_config=dict(module=['mmyolo.deploy.mmyolo']),
backend_config=dict(type='onnxruntime'))
custom_module_list = get_codebase_external_module(external_deploy_cfg)
assert isinstance(custom_module_list, list) \
and len(custom_module_list) == 1
class TestGetBackend:
def test_get_backend_none(self):