[Fix ]Fix import codebase (#1132)
* fix import; * fix flake8 * wrap mmdeploy function * fix wrong import * fix ut * re-add skip import error for ut * fix info * fix lintpull/1192/head
parent
645eefae50
commit
3767b5d46e
mmdeploy/codebase
base
mmcls/deploy
mmdet/deploy
mmdet3d/deploy
mmedit/deploy
mmocr/deploy
mmpose/deploy
mmseg/deploy
|
@ -4,6 +4,11 @@ from typing import List
|
|||
from mmdeploy.utils import Codebase
|
||||
from .base import BaseTask, MMCodebase, get_codebase_class
|
||||
|
||||
extra_dependent_library = {
|
||||
Codebase.MMOCR: ['mmdet'],
|
||||
Codebase.MMROTATE: ['mmdet']
|
||||
}
|
||||
|
||||
|
||||
def import_codebase(codebase_type: Codebase, custom_module_list: List = []):
|
||||
"""Import a codebase package in `mmdeploy.codebase`
|
||||
|
@ -17,11 +22,16 @@ def import_codebase(codebase_type: Codebase, custom_module_list: List = []):
|
|||
codebase (Codebase): The codebase to import.
|
||||
"""
|
||||
import importlib
|
||||
codebase_name = codebase_type.value
|
||||
dependent_library = [codebase_name] + \
|
||||
extra_dependent_library.get(codebase_type, [])
|
||||
for lib in dependent_library + custom_module_list:
|
||||
if not importlib.util.find_spec(lib):
|
||||
raise ImportError(f'{lib} has not been installed. '
|
||||
f'Import {lib} failed.')
|
||||
if len(custom_module_list) > 0:
|
||||
for custom_module in custom_module_list:
|
||||
importlib.import_module(f'{custom_module}')
|
||||
else:
|
||||
importlib.import_module(f'mmdeploy.codebase.{codebase_type.value}')
|
||||
codebase = get_codebase_class(codebase_type)
|
||||
codebase.register_all_modules()
|
||||
|
||||
|
|
|
@ -68,4 +68,13 @@ def get_codebase_class(codebase: Codebase) -> MMCodebase:
|
|||
Returns:
|
||||
type: The codebase class
|
||||
"""
|
||||
import importlib
|
||||
try:
|
||||
importlib.import_module(f'mmdeploy.codebase.{codebase.value}.deploy')
|
||||
except ImportError as e:
|
||||
from mmdeploy.utils import get_root_logger
|
||||
logger = get_root_logger()
|
||||
logger.warn(f'Import mmdeploy.codebase.{codebase.value}.deploy failed'
|
||||
'Please check whether the module is the custom module.'
|
||||
f'{e}')
|
||||
return CODEBASE.build({'type': codebase.value})
|
||||
|
|
|
@ -23,9 +23,15 @@ class MMClassification(MMCodebase):
|
|||
|
||||
task_registry = MMCLS_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmcls.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmcls.utils.setup_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
|
|
|
@ -22,9 +22,17 @@ class MMDetection(MMCodebase):
|
|||
|
||||
task_registry = MMDET_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmdet.models # noqa: F401
|
||||
import mmdeploy.codebase.mmdet.ops
|
||||
import mmdeploy.codebase.mmdet.structures # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmdet.utils.setup_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
|
|
|
@ -44,9 +44,15 @@ class MMDetection3d(MMCodebase):
|
|||
|
||||
return MMDET3D_TASK.build(model_cfg, deploy_cfg, device)
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmdet3d.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmdet3d.utils.set_env import register_all_modules
|
||||
from mmdet3d.utils.setup_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
register_all_modules(True)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -13,7 +13,13 @@ class MMEditing(MMCodebase):
|
|||
|
||||
task_registry = MMEDIT_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmedit.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmedit.utils.setup_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
register_all_modules(True)
|
||||
|
|
|
@ -13,11 +13,19 @@ class MMOCR(MMCodebase):
|
|||
|
||||
task_registry = MMOCR_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmocr.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmdet.utils.setup_env import \
|
||||
register_all_modules as register_all_modules_mmdet
|
||||
from mmocr.utils.setup_env import \
|
||||
register_all_modules as register_all_modules_mmocr
|
||||
|
||||
from mmdeploy.codebase.mmdet.deploy.object_detection import MMDetection
|
||||
cls.register_deploy_modules()
|
||||
MMDetection.register_deploy_modules()
|
||||
register_all_modules_mmocr(True)
|
||||
register_all_modules_mmdet(False)
|
||||
|
|
|
@ -117,9 +117,15 @@ class MMPose(MMCodebase):
|
|||
"""mmpose codebase class."""
|
||||
task_registry = MMPOSE_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmpose.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmpose.utils.setup_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
|
|
|
@ -104,9 +104,15 @@ class MMSegmentation(MMCodebase):
|
|||
"""mmsegmentation codebase class."""
|
||||
task_registry = MMSEG_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
import mmdeploy.codebase.mmseg.models # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
from mmseg.utils.set_env import register_all_modules
|
||||
|
||||
cls.register_deploy_modules()
|
||||
register_all_modules(True)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue