support mmrazor (#1701)

* support mmrazor

* add make divisible

* update

* Pruning -> ModelCompress and add docstring

---------

Co-authored-by: liukai <your_email@abc.example>
This commit is contained in:
LKJacky 2023-03-17 15:46:09 +08:00 committed by GitHub
parent 637958a910
commit 847a906e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 150 additions and 2 deletions

View File

@ -41,7 +41,7 @@ def build_task_processor(model_cfg: mmengine.Config,
BaseTask: A task processor.
"""
check_backend_device(deploy_cfg=deploy_cfg, device=device)
codebase_type = get_codebase(deploy_cfg)
codebase_type = get_codebase(deploy_cfg, model_cfg=model_cfg)
custom_module_list = get_codebase_external_module(deploy_cfg)
import_codebase(codebase_type, custom_module_list)
codebase = get_codebase_class(codebase_type)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmrazor import MMCodebase, MMRazor
__all__ = ['MMRazor', 'MMCodebase']

View File

@ -0,0 +1,135 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task
MMRAZOR_TASK = Registry('mmrazor_tasks')
@CODEBASE.register_module(Codebase.MMRAZOR.value)
class MMRazor(MMCodebase):
"""MMRazor codebase class."""
task_registry = MMRAZOR_TASK
@classmethod
def register_deploy_modules(cls):
"""Register all rewriters for mmrazor."""
pass
@classmethod
def register_all_modules(cls):
"""Register all related modules and rewriters for mmrazor."""
from mmrazor.utils import register_all_modules
register_all_modules(True)
@classmethod
def build_task_processor(cls, model_cfg: Config, deploy_cfg: Config,
device: str):
"""Build task processor for mmrazor.
Now we use ModelCompress by default.
"""
return ModelCompress(
model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device)
@MMRAZOR_TASK.register_module(Task.ModelCompress.value)
class ModelCompress(BaseTask):
"""General model compress task for mmrazor.
Args:
model_cfg (Config): Original PyTorch model config file
deploy_cfg (Config): Deployment config file or loaded Config
object.
device (str): A string represents device type.
experiment_name (str, optional): Name of current experiment.
If not specified, timestamp will be used as
``experiment_name``. Defaults to ``None``.
"""
def __init__(self,
model_cfg: Config,
deploy_cfg: Config,
device: str,
experiment_name: str = 'BaseTask'):
super().__init__(model_cfg, deploy_cfg, device, experiment_name)
self.origin_model_cfg = self.revert_model_cfg(model_cfg)
self.base_task = build_task_processor(self.origin_model_cfg,
deploy_cfg, device)
def revert_model_cfg(self, model_cfg: Config):
"""Restore the original model config from the model config of the
compressed model."""
origin_model_cfg = copy.deepcopy(model_cfg)
model = model_cfg['model']
if 'architecture' in model:
origin_model = model['architecture']
elif 'algorithm' in model:
origin_model = model['algorithm']['architecture']
else:
raise NotImplementedError()
origin_model_cfg['model'] = origin_model
if 'data_preprocessor' in origin_model:
origin_model_cfg['data_preprocessor'] = origin_model[
'data_preprocessor']
return origin_model_cfg
# abstract method
def build_backend_model(self,
model_files=None,
data_preprocessor_updater=None,
**kwargs) -> torch.nn.Module:
"""Build backend model for using base task."""
return self.base_task.build_backend_model(model_files,
data_preprocessor_updater,
**kwargs)
def create_input(self,
imgs: Union[str, np.ndarray],
input_shape=None,
data_preprocessor: Optional[BaseDataPreprocessor] = None,
**kwargs) -> Tuple[Dict, torch.Tensor]:
"""Create input using base task."""
return self.base_task.create_input(imgs, input_shape,
data_preprocessor, **kwargs)
def get_model_name(self, *args, **kwargs) -> str:
"""Get model name using base task."""
return self.base_task.get_model_name(*args, **kwargs)
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get data preprocess name using base task."""
return self.base_task.get_preprocess(*args, **kwargs)
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get data poseprocess name using base task."""
return self.base_task.get_postprocess(*args, **kwargs)
@staticmethod
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
"""Get a certain partition config."""
raise NotImplementedError()
def build_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
"""Build PyTorch model for mmrazor and execute post process for
mmdeploy."""
model = super().build_pytorch_model(model_checkpoint, cfg_options,
**kwargs)
if hasattr(model, 'post_process_for_mmdeploy'):
model.post_process_for_mmdeploy()
return model

View File

@ -83,7 +83,8 @@ def register_codebase(codebase: str) -> Codebase:
return Codebase.get(codebase)
def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
def get_codebase(deploy_cfg: Union[str, mmengine.Config],
model_cfg=None) -> Codebase:
"""Get the codebase from the config.
Args:
@ -92,6 +93,12 @@ def get_codebase(deploy_cfg: Union[str, mmengine.Config]) -> Codebase:
Returns:
Codebase : An enumeration denotes the codebase type.
"""
if model_cfg is not None:
# using mmrazor codebase if the model is a mmrazor model.
model_cfg: dict = model_cfg['model']
if model_cfg.get('_scope_', None) == 'mmrazor'\
or model_cfg['type'].startswith('mmrazor.'):
return register_codebase('mmrazor')
codebase_config = get_codebase_config(deploy_cfg)
assert 'type' in codebase_config, 'The codebase config of deploy config'\
'requires a "type" field'

View File

@ -28,6 +28,7 @@ class Task(AdvancedEnum):
POSE_DETECTION = 'PoseDetection'
ROTATED_DETECTION = 'RotatedDetection'
VIDEO_RECOGNITION = 'VideoRecognition'
ModelCompress = 'ModelCompress'
class Codebase(AdvancedEnum):
@ -41,6 +42,7 @@ class Codebase(AdvancedEnum):
MMPOSE = 'mmpose'
MMROTATE = 'mmrotate'
MMACTION = 'mmaction'
MMRAZOR = 'mmrazor'
class IR(AdvancedEnum):