mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* support for external codebase like mmyolo * support for external export * fix missing flake8 * fix comments * add aenum * add missing files * fix condition * refactor import_codebase * fix mmyolo support * fix lint * add base codebase * fix a strange clang-format * fix import_codebase * fix dependent codebase register * wrap custom_model * fix comment * add ut
375 lines
13 KiB
Python
375 lines
13 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
from abc import ABCMeta, abstractmethod
|
|
from copy import deepcopy
|
|
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmengine import Config
|
|
from mmengine.model import BaseDataPreprocessor
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from mmdeploy.utils import (get_backend_config, get_codebase,
|
|
get_codebase_config, get_root_logger)
|
|
from mmdeploy.utils.config_utils import get_codebase_external_module
|
|
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset
|
|
|
|
|
|
class BaseTask(metaclass=ABCMeta):
|
|
"""Wrap the processing functions of a Computer Vision task.
|
|
|
|
Args:
|
|
model_cfg (str | Config): Model config file.
|
|
deploy_cfg (str | Config): Deployment config file.
|
|
device (str): A string specifying device type.
|
|
"""
|
|
|
|
def __init__(self,
|
|
model_cfg: Config,
|
|
deploy_cfg: Config,
|
|
device: str,
|
|
experiment_name: str = 'BaseTask'):
|
|
|
|
self.model_cfg = model_cfg
|
|
self.deploy_cfg = deploy_cfg
|
|
self.device = device
|
|
|
|
self.codebase = get_codebase(deploy_cfg)
|
|
self.experiment_name = experiment_name
|
|
|
|
# init scope
|
|
from .. import import_codebase
|
|
custom_module_list = get_codebase_external_module(deploy_cfg)
|
|
import_codebase(self.codebase, custom_module_list)
|
|
|
|
from mmengine.registry import DefaultScope
|
|
if not DefaultScope.check_instance_created(self.experiment_name):
|
|
self.scope = DefaultScope.get_instance(
|
|
self.experiment_name,
|
|
scope_name=self.model_cfg.get('default_scope'))
|
|
else:
|
|
self.scope = DefaultScope.get_instance(self.experiment_name)
|
|
|
|
# lazy build visualizer
|
|
self.visualizer = self.model_cfg.get('visualizer', None)
|
|
|
|
@abstractmethod
|
|
def build_backend_model(self,
|
|
model_files: Sequence[str] = None,
|
|
**kwargs) -> torch.nn.Module:
|
|
"""Initialize backend model.
|
|
|
|
Args:
|
|
model_files (Sequence[str]): Input model files.
|
|
|
|
Returns:
|
|
nn.Module: An initialized backend model.
|
|
"""
|
|
pass
|
|
|
|
def build_data_preprocessor(self):
|
|
model = deepcopy(self.model_cfg.model)
|
|
preprocess_cfg = model['data_preprocessor']
|
|
|
|
from mmengine.registry import MODELS
|
|
data_preprocessor = MODELS.build(preprocess_cfg)
|
|
|
|
return data_preprocessor
|
|
|
|
def build_pytorch_model(self,
|
|
model_checkpoint: Optional[str] = None,
|
|
cfg_options: Optional[Dict] = None,
|
|
**kwargs) -> torch.nn.Module:
|
|
"""Initialize torch model.
|
|
|
|
Args:
|
|
model_checkpoint (str): The checkpoint file of torch model,
|
|
defaults to `None`.
|
|
cfg_options (dict): Optional config key-pair parameters.
|
|
|
|
Returns:
|
|
nn.Module: An initialized torch model generated by other OpenMMLab
|
|
codebases.
|
|
"""
|
|
from mmengine.model import revert_sync_batchnorm
|
|
from mmengine.registry import MODELS
|
|
|
|
model = deepcopy(self.model_cfg.model)
|
|
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
|
|
preprocess_cfg.update(
|
|
deepcopy(self.model_cfg.get('data_preprocessor', {})))
|
|
model.setdefault('data_preprocessor', preprocess_cfg)
|
|
model = MODELS.build(model)
|
|
if model_checkpoint is not None:
|
|
from mmengine.runner.checkpoint import load_checkpoint
|
|
load_checkpoint(model, model_checkpoint)
|
|
|
|
model = revert_sync_batchnorm(model)
|
|
model = model.to(self.device)
|
|
model.eval()
|
|
return model
|
|
|
|
def build_dataset(self,
|
|
dataset_cfg: Union[str, Config],
|
|
is_sort_dataset: bool = True,
|
|
**kwargs) -> Dataset:
|
|
"""Build dataset for different codebase.
|
|
|
|
Args:
|
|
dataset_cfg (str | Config): Dataset config file or Config
|
|
object.
|
|
is_sort_dataset (bool): When 'True', the dataset will be sorted
|
|
by image shape in ascending order if 'dataset_cfg'
|
|
contains information about height and width.
|
|
|
|
Returns:
|
|
Dataset: The built dataset.
|
|
"""
|
|
backend_cfg = get_backend_config(self.deploy_cfg)
|
|
from mmdeploy.utils import load_config
|
|
dataset_cfg = load_config(dataset_cfg)[0]
|
|
if 'pipeline' in backend_cfg:
|
|
dataset_cfg.pipeline = backend_cfg.pipeline
|
|
from mmengine.registry import DATASETS
|
|
dataset = DATASETS.build(dataset_cfg)
|
|
logger = get_root_logger()
|
|
if is_sort_dataset:
|
|
if is_can_sort_dataset(dataset):
|
|
sort_dataset(dataset)
|
|
else:
|
|
logger.info('Sorting the dataset by \'height\' and \'width\' '
|
|
'is not possible.')
|
|
return dataset
|
|
|
|
@staticmethod
|
|
def build_dataloader(dataloader: Union[DataLoader, Dict],
|
|
seed: Optional[int] = None) -> DataLoader:
|
|
"""Build PyTorch dataloader. A wrap of Runner.build_dataloader.
|
|
|
|
Args:
|
|
dataloader (DataLoader or dict): A Dataloader object or a dict to
|
|
build Dataloader object. If ``dataloader`` is a Dataloader
|
|
object, just returns itself.
|
|
seed (int, optional): Random seed. Defaults to None.
|
|
|
|
Returns:
|
|
Dataloader: DataLoader build from ``dataloader_cfg``.
|
|
"""
|
|
from mmengine.runner import Runner
|
|
return Runner.build_dataloader(dataloader, seed)
|
|
|
|
def build_test_runner(self,
|
|
model: torch.nn.Module,
|
|
work_dir: str,
|
|
log_file: Optional[str] = None,
|
|
show: bool = False,
|
|
show_dir: Optional[str] = None,
|
|
wait_time: int = 0,
|
|
interval: int = 1,
|
|
dataloader: Optional[Union[DataLoader,
|
|
Dict]] = None):
|
|
|
|
def _merge_cfg(cfg):
|
|
"""Merge CLI arguments to config."""
|
|
# -------------------- visualization --------------------
|
|
if show or (show_dir is not None):
|
|
assert 'visualization' in cfg.default_hooks, \
|
|
'VisualizationHook is not set in the `default_hooks`'\
|
|
' field of config. Please set '\
|
|
'`visualization=dict(type="VisualizationHook")`'
|
|
|
|
cfg.default_hooks.visualization.enable = True
|
|
cfg.default_hooks.visualization.show = show
|
|
cfg.default_hooks.visualization.wait_time = wait_time
|
|
cfg.default_hooks.visualization.out_dir = show_dir
|
|
cfg.default_hooks.visualization.interval = interval
|
|
|
|
return cfg
|
|
|
|
model_cfg = deepcopy(self.model_cfg)
|
|
if dataloader is None:
|
|
dataloader = model_cfg.test_dataloader
|
|
if not isinstance(dataloader, DataLoader):
|
|
if type(dataloader) == list:
|
|
dataloader = [self.build_dataloader(dl) for dl in dataloader]
|
|
else:
|
|
dataloader = self.build_dataloader(dataloader)
|
|
|
|
model_cfg = _merge_cfg(model_cfg)
|
|
|
|
visualizer = self.get_visualizer(work_dir, work_dir)
|
|
|
|
from .runner import DeployTestRunner
|
|
runner = DeployTestRunner(
|
|
model=model,
|
|
work_dir=work_dir,
|
|
log_file=log_file,
|
|
device=self.device,
|
|
visualizer=visualizer,
|
|
default_hooks=model_cfg.default_hooks,
|
|
test_dataloader=dataloader,
|
|
test_cfg=model_cfg.test_cfg,
|
|
test_evaluator=model_cfg.test_evaluator,
|
|
default_scope=model_cfg.default_scope)
|
|
return runner
|
|
|
|
@abstractmethod
|
|
def create_input(
|
|
self,
|
|
imgs: Union[str, np.ndarray],
|
|
input_shape: Optional[Sequence[int]] = None,
|
|
data_preprocessor: Optional[BaseDataPreprocessor] = None
|
|
) -> Tuple[Dict, torch.Tensor]:
|
|
"""Create input for model.
|
|
|
|
Args:
|
|
imgs (str | np.ndarray): Input image(s), accepted data types are
|
|
`str`, `np.ndarray`.
|
|
input_shape (list[int]): Input shape of image in (width, height)
|
|
format, defaults to `None`.
|
|
|
|
Returns:
|
|
tuple: (data, img), meta information for the input image and input
|
|
image tensor.
|
|
"""
|
|
pass
|
|
|
|
def get_visualizer(self, name: str, save_dir: str):
|
|
"""Get the visualizer instance.
|
|
|
|
Args:
|
|
name (str): The name of the visualizer.
|
|
save_dir (str): The save directory of visualizer.
|
|
"""
|
|
cfg = deepcopy(self.visualizer)
|
|
cfg.name = name
|
|
cfg.save_dir = save_dir
|
|
from mmengine.registry import VISUALIZERS, DefaultScope
|
|
with DefaultScope.overwrite_default_scope(cfg.pop('_scope_', None)):
|
|
# get the global default scope
|
|
default_scope = DefaultScope.get_current_instance()
|
|
if default_scope is not None:
|
|
scope_name = default_scope.scope_name
|
|
root = VISUALIZERS._get_root_registry()
|
|
registry = root._search_child(scope_name)
|
|
if registry is None:
|
|
registry = VISUALIZERS
|
|
else:
|
|
registry = VISUALIZERS
|
|
VisualizerClass = registry.get(cfg.type)
|
|
if VisualizerClass.check_instance_created(cfg.name):
|
|
return VisualizerClass.get_instance(cfg.name)
|
|
else:
|
|
return registry.build_func(cfg, registry=registry)
|
|
|
|
def visualize(self,
|
|
image: Union[str, np.ndarray],
|
|
result: list,
|
|
output_file: str,
|
|
window_name: str = '',
|
|
show_result: bool = False,
|
|
draw_gt: bool = False,
|
|
**kwargs):
|
|
"""Visualize predictions of a model.
|
|
|
|
Args:
|
|
model (nn.Module): Input model.
|
|
image (str | np.ndarray): Input image to draw predictions on.
|
|
result (list): A list of predictions.
|
|
output_file (str): Output file to save drawn image.
|
|
window_name (str): The name of visualization window. Defaults to
|
|
an empty string.
|
|
show_result (bool): Whether to show result in windows, defaults
|
|
to `False`.
|
|
draw_gt (bool): Whether to show ground truth in windows, defaults
|
|
to `False`.
|
|
"""
|
|
save_dir, save_name = osp.split(output_file)
|
|
visualizer = self.get_visualizer(window_name, save_dir)
|
|
|
|
name = osp.splitext(save_name)[0]
|
|
if isinstance(image, str):
|
|
image = mmcv.imread(image, channel_order='rgb')
|
|
assert isinstance(image, np.ndarray)
|
|
|
|
visualizer.add_datasample(
|
|
name,
|
|
image,
|
|
data_sample=result,
|
|
draw_gt=draw_gt,
|
|
show=show_result,
|
|
out_file=output_file)
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
|
|
"""Get a certain partition config.
|
|
|
|
Args:
|
|
partition_type (str): A string specifying partition type.
|
|
|
|
Returns:
|
|
dict: A dictionary of partition config.
|
|
"""
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_tensor_from_input(input_data: Dict[str, Any],
|
|
**kwargs) -> torch.Tensor:
|
|
"""Get input tensor from input data.
|
|
|
|
Args:
|
|
input_data (dict): Input data containing meta info and image
|
|
tensor.
|
|
Returns:
|
|
torch.Tensor: An image in `Tensor`.
|
|
"""
|
|
return input_data['inputs']
|
|
|
|
@abstractmethod
|
|
def get_preprocess(self, *args, **kwargs) -> Dict:
|
|
"""Get the preprocess information for SDK.
|
|
|
|
Return:
|
|
dict: Composed of the preprocess information.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_postprocess(self, *args, **kwargs) -> Dict:
|
|
"""Get the postprocess information for SDK.
|
|
|
|
Return:
|
|
dict: Composed of the postprocess information.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_model_name(self, *args, **kwargs) -> str:
|
|
"""Get the model name.
|
|
|
|
Return:
|
|
str: the name of the model.
|
|
"""
|
|
pass
|
|
|
|
@property
|
|
def from_mmrazor(self) -> bool:
|
|
"""Whether the codebase from mmrazor.
|
|
|
|
Returns:
|
|
bool: From mmrazor or not.
|
|
|
|
Raises:
|
|
TypeError: An error when type of `from_mmrazor` is not boolean.
|
|
"""
|
|
codebase_config = get_codebase_config(self.deploy_cfg)
|
|
from_mmrazor = codebase_config.get('from_mmrazor', False)
|
|
if not isinstance(from_mmrazor, bool):
|
|
raise TypeError('`from_mmrazor` attribute must be boolean type! '
|
|
f'but got: {from_mmrazor}')
|
|
|
|
return from_mmrazor
|