[Fix ]Fix import codebase ()

* fix import;

* fix flake8

* wrap mmdeploy function

* fix wrong import

* fix ut

* re-add skip import error for ut

* fix info

* fix lint
pull/1192/head
hanrui1sensetime 2022-10-13 10:24:42 +08:00 committed by GitHub
parent 645eefae50
commit 3767b5d46e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 68 additions and 3 deletions
mmdeploy/codebase
mmcls/deploy
mmdet3d/deploy
mmedit/deploy
mmocr/deploy
mmpose/deploy
mmseg/deploy

View File

@ -4,6 +4,11 @@ from typing import List
from mmdeploy.utils import Codebase
from .base import BaseTask, MMCodebase, get_codebase_class
extra_dependent_library = {
Codebase.MMOCR: ['mmdet'],
Codebase.MMROTATE: ['mmdet']
}
def import_codebase(codebase_type: Codebase, custom_module_list: List = []):
"""Import a codebase package in `mmdeploy.codebase`
@ -17,11 +22,16 @@ def import_codebase(codebase_type: Codebase, custom_module_list: List = []):
codebase (Codebase): The codebase to import.
"""
import importlib
codebase_name = codebase_type.value
dependent_library = [codebase_name] + \
extra_dependent_library.get(codebase_type, [])
for lib in dependent_library + custom_module_list:
if not importlib.util.find_spec(lib):
raise ImportError(f'{lib} has not been installed. '
f'Import {lib} failed.')
if len(custom_module_list) > 0:
for custom_module in custom_module_list:
importlib.import_module(f'{custom_module}')
else:
importlib.import_module(f'mmdeploy.codebase.{codebase_type.value}')
codebase = get_codebase_class(codebase_type)
codebase.register_all_modules()

View File

@ -68,4 +68,13 @@ def get_codebase_class(codebase: Codebase) -> MMCodebase:
Returns:
type: The codebase class
"""
import importlib
try:
importlib.import_module(f'mmdeploy.codebase.{codebase.value}.deploy')
except ImportError as e:
from mmdeploy.utils import get_root_logger
logger = get_root_logger()
logger.warn(f'Import mmdeploy.codebase.{codebase.value}.deploy failed'
'Please check whether the module is the custom module.'
f'{e}')
return CODEBASE.build({'type': codebase.value})

View File

@ -23,9 +23,15 @@ class MMClassification(MMCodebase):
task_registry = MMCLS_TASK
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmcls.models # noqa: F401
@classmethod
def register_all_modules(cls):
from mmcls.utils.setup_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)

View File

@ -22,9 +22,17 @@ class MMDetection(MMCodebase):
task_registry = MMDET_TASK
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmdet.models # noqa: F401
import mmdeploy.codebase.mmdet.ops
import mmdeploy.codebase.mmdet.structures # noqa: F401
@classmethod
def register_all_modules(cls):
from mmdet.utils.setup_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)

View File

@ -44,9 +44,15 @@ class MMDetection3d(MMCodebase):
return MMDET3D_TASK.build(model_cfg, deploy_cfg, device)
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmdet3d.models # noqa: F401
@classmethod
def register_all_modules(cls):
from mmdet3d.utils.set_env import register_all_modules
from mmdet3d.utils.setup_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)
@staticmethod

View File

@ -13,7 +13,13 @@ class MMEditing(MMCodebase):
task_registry = MMEDIT_TASK
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmedit.models # noqa: F401
@classmethod
def register_all_modules(cls):
from mmedit.utils.setup_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)

View File

@ -13,11 +13,19 @@ class MMOCR(MMCodebase):
task_registry = MMOCR_TASK
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmocr.models # noqa: F401
@classmethod
def register_all_modules(cls):
from mmdet.utils.setup_env import \
register_all_modules as register_all_modules_mmdet
from mmocr.utils.setup_env import \
register_all_modules as register_all_modules_mmocr
from mmdeploy.codebase.mmdet.deploy.object_detection import MMDetection
cls.register_deploy_modules()
MMDetection.register_deploy_modules()
register_all_modules_mmocr(True)
register_all_modules_mmdet(False)

View File

@ -117,9 +117,15 @@ class MMPose(MMCodebase):
"""mmpose codebase class."""
task_registry = MMPOSE_TASK
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmpose.models # noqa: F401
@classmethod
def register_all_modules(cls):
from mmpose.utils.setup_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)

View File

@ -104,9 +104,15 @@ class MMSegmentation(MMCodebase):
"""mmsegmentation codebase class."""
task_registry = MMSEG_TASK
@classmethod
def register_deploy_modules(cls):
import mmdeploy.codebase.mmseg.models # noqa: F401
@classmethod
def register_all_modules(cls):
from mmseg.utils.set_env import register_all_modules
cls.register_deploy_modules()
register_all_modules(True)