mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
support mmrazor (#1701)
* support mmrazor * add make divisible * update * Pruning -> ModelCompress and add docstring --------- Co-authored-by: liukai <your_email@abc.example>
This commit is contained in:
parent
637958a910
commit
847a906e6f
@ -41,7 +41,7 @@ def build_task_processor(model_cfg: mmengine.Config,
|
||||
BaseTask: A task processor.
|
||||
"""
|
||||
check_backend_device(deploy_cfg=deploy_cfg, device=device)
|
||||
codebase_type = get_codebase(deploy_cfg)
|
||||
codebase_type = get_codebase(deploy_cfg, model_cfg=model_cfg)
|
||||
custom_module_list = get_codebase_external_module(deploy_cfg)
|
||||
import_codebase(codebase_type, custom_module_list)
|
||||
codebase = get_codebase_class(codebase_type)
|
||||
|
4
mmdeploy/codebase/mmrazor/deploy/__init__.py
Normal file
4
mmdeploy/codebase/mmrazor/deploy/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .mmrazor import MMCodebase, MMRazor
|
||||
|
||||
__all__ = ['MMRazor', 'MMCodebase']
|
135
mmdeploy/codebase/mmrazor/deploy/mmrazor.py
Normal file
135
mmdeploy/codebase/mmrazor/deploy/mmrazor.py
Normal file
@ -0,0 +1,135 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.model import BaseDataPreprocessor
|
||||
from mmengine.registry import Registry
|
||||
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
|
||||
from mmdeploy.utils import Codebase, Task
|
||||
|
||||
MMRAZOR_TASK = Registry('mmrazor_tasks')
|
||||
|
||||
|
||||
@CODEBASE.register_module(Codebase.MMRAZOR.value)
|
||||
class MMRazor(MMCodebase):
|
||||
"""MMRazor codebase class."""
|
||||
task_registry = MMRAZOR_TASK
|
||||
|
||||
@classmethod
|
||||
def register_deploy_modules(cls):
|
||||
"""Register all rewriters for mmrazor."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def register_all_modules(cls):
|
||||
"""Register all related modules and rewriters for mmrazor."""
|
||||
from mmrazor.utils import register_all_modules
|
||||
register_all_modules(True)
|
||||
|
||||
@classmethod
|
||||
def build_task_processor(cls, model_cfg: Config, deploy_cfg: Config,
|
||||
device: str):
|
||||
"""Build task processor for mmrazor.
|
||||
|
||||
Now we use ModelCompress by default.
|
||||
"""
|
||||
return ModelCompress(
|
||||
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
|
||||
|
||||
|
||||
@MMRAZOR_TASK.register_module(Task.ModelCompress.value)
|
||||
class ModelCompress(BaseTask):
|
||||
"""General model compress task for mmrazor.
|
||||
|
||||
Args:
|
||||
model_cfg (Config): Original PyTorch model config file
|
||||
deploy_cfg (Config): Deployment config file or loaded Config
|
||||
object.
|
||||
device (str): A string represents device type.
|
||||
experiment_name (str, optional): Name of current experiment.
|
||||
If not specified, timestamp will be used as
|
||||
``experiment_name``. Defaults to ``None``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_cfg: Config,
|
||||
deploy_cfg: Config,
|
||||
device: str,
|
||||
experiment_name: str = 'BaseTask'):
|
||||
|
||||
super().__init__(model_cfg, deploy_cfg, device, experiment_name)
|
||||
self.origin_model_cfg = self.revert_model_cfg(model_cfg)
|
||||
self.base_task = build_task_processor(self.origin_model_cfg,
|
||||
deploy_cfg, device)
|
||||
|
||||
def revert_model_cfg(self, model_cfg: Config):
|
||||
"""Restore the original model config from the model config of the
|
||||
compressed model."""
|
||||
origin_model_cfg = copy.deepcopy(model_cfg)
|
||||
model = model_cfg['model']
|
||||
if 'architecture' in model:
|
||||
origin_model = model['architecture']
|
||||
elif 'algorithm' in model:
|
||||
origin_model = model['algorithm']['architecture']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
origin_model_cfg['model'] = origin_model
|
||||
if 'data_preprocessor' in origin_model:
|
||||
origin_model_cfg['data_preprocessor'] = origin_model[
|
||||
'data_preprocessor']
|
||||
return origin_model_cfg
|
||||
|
||||
# abstract method
|
||||
|
||||
def build_backend_model(self,
|
||||
model_files=None,
|
||||
data_preprocessor_updater=None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Build backend model for using base task."""
|
||||
return self.base_task.build_backend_model(model_files,
|
||||
data_preprocessor_updater,
|
||||
**kwargs)
|
||||
|
||||
def create_input(self,
|
||||
imgs: Union[str, np.ndarray],
|
||||
input_shape=None,
|
||||
data_preprocessor: Optional[BaseDataPreprocessor] = None,
|
||||
**kwargs) -> Tuple[Dict, torch.Tensor]:
|
||||
"""Create input using base task."""
|
||||
return self.base_task.create_input(imgs, input_shape,
|
||||
data_preprocessor, **kwargs)
|
||||
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
"""Get model name using base task."""
|
||||
return self.base_task.get_model_name(*args, **kwargs)
|
||||
|
||||
def get_preprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get data preprocess name using base task."""
|
||||
return self.base_task.get_preprocess(*args, **kwargs)
|
||||
|
||||
def get_postprocess(self, *args, **kwargs) -> Dict:
|
||||
"""Get data poseprocess name using base task."""
|
||||
return self.base_task.get_postprocess(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
|
||||
"""Get a certain partition config."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Build PyTorch model for mmrazor and execute post process for
|
||||
mmdeploy."""
|
||||
model = super().build_pytorch_model(model_checkpoint, cfg_options,
|
||||
**kwargs)
|
||||
if hasattr(model, 'post_process_for_mmdeploy'):
|
||||
model.post_process_for_mmdeploy()
|
||||
|
||||
return model
|
@ -83,7 +83,8 @@ def register_codebase(codebase: str) -> Codebase:
|
||||
return Codebase.get(codebase)
|
||||
|
||||
|
||||
def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
|
||||
def get_codebase(deploy_cfg: Union[str, mmengine.Config],
|
||||
model_cfg=None) -> Codebase:
|
||||
"""Get the codebase from the config.
|
||||
|
||||
Args:
|
||||
@ -92,6 +93,12 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
|
||||
Returns:
|
||||
Codebase : An enumeration denotes the codebase type.
|
||||
"""
|
||||
if model_cfg is not None:
|
||||
# using mmrazor codebase if the model is a mmrazor model.
|
||||
model_cfg: dict = model_cfg['model']
|
||||
if model_cfg.get('_scope_', None) == 'mmrazor'\
|
||||
or model_cfg['type'].startswith('mmrazor.'):
|
||||
return register_codebase('mmrazor')
|
||||
codebase_config = get_codebase_config(deploy_cfg)
|
||||
assert 'type' in codebase_config, 'The codebase config of deploy config'\
|
||||
'requires a "type" field'
|
||||
|
@ -28,6 +28,7 @@ class Task(AdvancedEnum):
|
||||
POSE_DETECTION = 'PoseDetection'
|
||||
ROTATED_DETECTION = 'RotatedDetection'
|
||||
VIDEO_RECOGNITION = 'VideoRecognition'
|
||||
ModelCompress = 'ModelCompress'
|
||||
|
||||
|
||||
class Codebase(AdvancedEnum):
|
||||
@ -41,6 +42,7 @@ class Codebase(AdvancedEnum):
|
||||
MMPOSE = 'mmpose'
|
||||
MMROTATE = 'mmrotate'
|
||||
MMACTION = 'mmaction'
|
||||
MMRAZOR = 'mmrazor'
|
||||
|
||||
|
||||
class IR(AdvancedEnum):
|
||||
|
Loading…
x
Reference in New Issue
Block a user