hanrui1sensetime d7cbb395da
[Enhancement 2.0] mmdeploy for mmyolo (#1088)
* support for external codebase like mmyolo

* support for external export

* fix missing flake8

* fix comments

* add aenum

* add missing files

* fix condition

* refactor import_codebase

* fix mmyolo support

* fix lint

* add base codebase

* fix a strange clang-format

* fix import_codebase

* fix dependent codebase register

* wrap custom_model

* fix comment

* add ut
2022-09-28 16:30:29 +08:00

375 lines
13 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import mmcv
import numpy as np
import torch
from mmengine import Config
from mmengine.model import BaseDataPreprocessor
from torch.utils.data import DataLoader, Dataset
from mmdeploy.utils import (get_backend_config, get_codebase,
get_codebase_config, get_root_logger)
from mmdeploy.utils.config_utils import get_codebase_external_module
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
class BaseTask(metaclass=ABCMeta):
"""Wrap the processing functions of a Computer Vision task.
Args:
model_cfg (str | Config): Model config file.
deploy_cfg (str | Config): Deployment config file.
device (str): A string specifying device type.
"""
def __init__(self,
model_cfg: Config,
deploy_cfg: Config,
device: str,
experiment_name: str = 'BaseTask'):
self.model_cfg = model_cfg
self.deploy_cfg = deploy_cfg
self.device = device
self.codebase = get_codebase(deploy_cfg)
self.experiment_name = experiment_name
# init scope
from .. import import_codebase
custom_module_list = get_codebase_external_module(deploy_cfg)
import_codebase(self.codebase, custom_module_list)
from mmengine.registry import DefaultScope
if not DefaultScope.check_instance_created(self.experiment_name):
self.scope = DefaultScope.get_instance(
self.experiment_name,
scope_name=self.model_cfg.get('default_scope'))
else:
self.scope = DefaultScope.get_instance(self.experiment_name)
# lazy build visualizer
self.visualizer = self.model_cfg.get('visualizer', None)
@abstractmethod
def build_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:
"""Initialize backend model.
Args:
model_files (Sequence[str]): Input model files.
Returns:
nn.Module: An initialized backend model.
"""
pass
def build_data_preprocessor(self):
model = deepcopy(self.model_cfg.model)
preprocess_cfg = model['data_preprocessor']
from mmengine.registry import MODELS
data_preprocessor = MODELS.build(preprocess_cfg)
return data_preprocessor
def build_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
"""Initialize torch model.
Args:
model_checkpoint (str): The checkpoint file of torch model,
defaults to `None`.
cfg_options (dict): Optional config key-pair parameters.
Returns:
nn.Module: An initialized torch model generated by other OpenMMLab
codebases.
"""
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import MODELS
model = deepcopy(self.model_cfg.model)
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
preprocess_cfg.update(
deepcopy(self.model_cfg.get('data_preprocessor', {})))
model.setdefault('data_preprocessor', preprocess_cfg)
model = MODELS.build(model)
if model_checkpoint is not None:
from mmengine.runner.checkpoint import load_checkpoint
load_checkpoint(model, model_checkpoint)
model = revert_sync_batchnorm(model)
model = model.to(self.device)
model.eval()
return model
def build_dataset(self,
dataset_cfg: Union[str, Config],
is_sort_dataset: bool = True,
**kwargs) -> Dataset:
"""Build dataset for different codebase.
Args:
dataset_cfg (str | Config): Dataset config file or Config
object.
is_sort_dataset (bool): When 'True', the dataset will be sorted
by image shape in ascending order if 'dataset_cfg'
contains information about height and width.
Returns:
Dataset: The built dataset.
"""
backend_cfg = get_backend_config(self.deploy_cfg)
from mmdeploy.utils import load_config
dataset_cfg = load_config(dataset_cfg)[0]
if 'pipeline' in backend_cfg:
dataset_cfg.pipeline = backend_cfg.pipeline
from mmengine.registry import DATASETS
dataset = DATASETS.build(dataset_cfg)
logger = get_root_logger()
if is_sort_dataset:
if is_can_sort_dataset(dataset):
sort_dataset(dataset)
else:
logger.info('Sorting the dataset by \'height\' and \'width\' '
'is not possible.')
return dataset
@staticmethod
def build_dataloader(dataloader: Union[DataLoader, Dict],
seed: Optional[int] = None) -> DataLoader:
"""Build PyTorch dataloader. A wrap of Runner.build_dataloader.
Args:
dataloader (DataLoader or dict): A Dataloader object or a dict to
build Dataloader object. If ``dataloader`` is a Dataloader
object, just returns itself.
seed (int, optional): Random seed. Defaults to None.
Returns:
Dataloader: DataLoader build from ``dataloader_cfg``.
"""
from mmengine.runner import Runner
return Runner.build_dataloader(dataloader, seed)
def build_test_runner(self,
model: torch.nn.Module,
work_dir: str,
log_file: Optional[str] = None,
show: bool = False,
show_dir: Optional[str] = None,
wait_time: int = 0,
interval: int = 1,
dataloader: Optional[Union[DataLoader,
Dict]] = None):
def _merge_cfg(cfg):
"""Merge CLI arguments to config."""
# -------------------- visualization --------------------
if show or (show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
'VisualizationHook is not set in the `default_hooks`'\
' field of config. Please set '\
'`visualization=dict(type="VisualizationHook")`'
cfg.default_hooks.visualization.enable = True
cfg.default_hooks.visualization.show = show
cfg.default_hooks.visualization.wait_time = wait_time
cfg.default_hooks.visualization.out_dir = show_dir
cfg.default_hooks.visualization.interval = interval
return cfg
model_cfg = deepcopy(self.model_cfg)
if dataloader is None:
dataloader = model_cfg.test_dataloader
if not isinstance(dataloader, DataLoader):
if type(dataloader) == list:
dataloader = [self.build_dataloader(dl) for dl in dataloader]
else:
dataloader = self.build_dataloader(dataloader)
model_cfg = _merge_cfg(model_cfg)
visualizer = self.get_visualizer(work_dir, work_dir)
from .runner import DeployTestRunner
runner = DeployTestRunner(
model=model,
work_dir=work_dir,
log_file=log_file,
device=self.device,
visualizer=visualizer,
default_hooks=model_cfg.default_hooks,
test_dataloader=dataloader,
test_cfg=model_cfg.test_cfg,
test_evaluator=model_cfg.test_evaluator,
default_scope=model_cfg.default_scope)
return runner
@abstractmethod
def create_input(
self,
imgs: Union[str, np.ndarray],
input_shape: Optional[Sequence[int]] = None,
data_preprocessor: Optional[BaseDataPreprocessor] = None
) -> Tuple[Dict, torch.Tensor]:
"""Create input for model.
Args:
imgs (str | np.ndarray): Input image(s), accepted data types are
`str`, `np.ndarray`.
input_shape (list[int]): Input shape of image in (width, height)
format, defaults to `None`.
Returns:
tuple: (data, img), meta information for the input image and input
image tensor.
"""
pass
def get_visualizer(self, name: str, save_dir: str):
"""Get the visualizer instance.
Args:
name (str): The name of the visualizer.
save_dir (str): The save directory of visualizer.
"""
cfg = deepcopy(self.visualizer)
cfg.name = name
cfg.save_dir = save_dir
from mmengine.registry import VISUALIZERS, DefaultScope
with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)):
# get the global default scope
default_scope = DefaultScope.get_current_instance()
if default_scope is not None:
scope_name = default_scope.scope_name
root = VISUALIZERS._get_root_registry()
registry = root._search_child(scope_name)
if registry is None:
registry = VISUALIZERS
else:
registry = VISUALIZERS
VisualizerClass = registry.get(cfg.type)
if VisualizerClass.check_instance_created(cfg.name):
return VisualizerClass.get_instance(cfg.name)
else:
return registry.build_func(cfg, registry=registry)
def visualize(self,
image: Union[str, np.ndarray],
result: list,
output_file: str,
window_name: str = '',
show_result: bool = False,
draw_gt: bool = False,
**kwargs):
"""Visualize predictions of a model.
Args:
model (nn.Module): Input model.
image (str | np.ndarray): Input image to draw predictions on.
result (list): A list of predictions.
output_file (str): Output file to save drawn image.
window_name (str): The name of visualization window. Defaults to
an empty string.
show_result (bool): Whether to show result in windows, defaults
to `False`.
draw_gt (bool): Whether to show ground truth in windows, defaults
to `False`.
"""
save_dir, save_name = osp.split(output_file)
visualizer = self.get_visualizer(window_name, save_dir)
name = osp.splitext(save_name)[0]
if isinstance(image, str):
image = mmcv.imread(image, channel_order='rgb')
assert isinstance(image, np.ndarray)
visualizer.add_datasample(
name,
image,
data_sample=result,
draw_gt=draw_gt,
show=show_result,
out_file=output_file)
@staticmethod
@abstractmethod
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
"""Get a certain partition config.
Args:
partition_type (str): A string specifying partition type.
Returns:
dict: A dictionary of partition config.
"""
pass
@staticmethod
def get_tensor_from_input(input_data: Dict[str, Any],
**kwargs) -> torch.Tensor:
"""Get input tensor from input data.
Args:
input_data (dict): Input data containing meta info and image
tensor.
Returns:
torch.Tensor: An image in `Tensor`.
"""
return input_data['inputs']
@abstractmethod
def get_preprocess(self, *args, **kwargs) -> Dict:
"""Get the preprocess information for SDK.
Return:
dict: Composed of the preprocess information.
"""
pass
@abstractmethod
def get_postprocess(self, *args, **kwargs) -> Dict:
"""Get the postprocess information for SDK.
Return:
dict: Composed of the postprocess information.
"""
pass
@abstractmethod
def get_model_name(self, *args, **kwargs) -> str:
"""Get the model name.
Return:
str: the name of the model.
"""
pass
@property
def from_mmrazor(self) -> bool:
"""Whether the codebase from mmrazor.
Returns:
bool: From mmrazor or not.
Raises:
TypeError: An error when type of `from_mmrazor` is not boolean.
"""
codebase_config = get_codebase_config(self.deploy_cfg)
from_mmrazor = codebase_config.get('from_mmrazor', False)
if not isinstance(from_mmrazor, bool):
raise TypeError('`from_mmrazor` attribute must be boolean type! '
f'but got: {from_mmrazor}')
return from_mmrazor