mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* support cascade (mask) rcnn * fix docstring * support SwinTransformer * move dense_head support to this branch * fix function names * fix part of uts of mmdet * fix for mmdet ut * fix det model cfg for ut * fix test_object_detection.py * fix mmdet object_detection_model.py * fix mmdet yolov3 ort ut * fix part of uts * fix cascade bbox head ut * fix cascade bbox head ut * remove useless ssd ncnn test * fix ncnn wrapper * fix openvino ut for reppoint head * fix openvino cascade mask rcnn * sync codes * support roll * remove unused pad * fix yolox * fix isort * fix lint * fix flake8 * reply for comments and fix failed ut * fix sdk_export in dump_info * fix temp hidden xlsx bugs * fix mmdet regression test * fix lint * fix timer * fix timecount side-effect * adapt profile.py for mmdet 2.0 * hardcode report.txt for T4 benchmark test: temp version * fix no-visualizer case * fix backend_model * fix android build * adapt new mmdet 2.0 0825 * fix new 2.0 * fix test_mmdet_structures * fix test_object_detection * fix codebase import * fix ut * fix all mmdet uts * fix det * fix mmdet trt * fix ncnn onnx optimize
120 lines
5.1 KiB
Python
120 lines
5.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from abc import ABCMeta
|
|
from typing import Optional, Sequence, Union
|
|
|
|
import mmengine
|
|
from mmengine.model import BaseModel
|
|
from torch import nn
|
|
|
|
from mmdeploy.utils import (SDK_TASK_MAP, Backend, get_backend_config,
|
|
get_ir_config, get_task_type)
|
|
|
|
|
|
class BaseBackendModel(BaseModel, metaclass=ABCMeta):
|
|
"""A backend model wraps the details to initialize and run a backend
|
|
engine."""
|
|
|
|
def __init__(self,
|
|
deploy_cfg: Optional[Union[str, mmengine.Config]] = None,
|
|
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
|
|
*args,
|
|
**kwargs):
|
|
"""The default for building the base class.
|
|
|
|
Args:
|
|
deploy_cfg (str | mmengine.Config | None): The deploy config.
|
|
"""
|
|
input_names = output_names = None
|
|
if deploy_cfg is not None:
|
|
ir_config = get_ir_config(deploy_cfg)
|
|
output_names = ir_config.get('output_names', None)
|
|
input_names = ir_config.get('input_names', None)
|
|
# TODO use input_names instead in the future for multiple inputs
|
|
self.input_name = input_names[0] if input_names else 'input'
|
|
self.output_names = output_names if output_names else ['output']
|
|
super().__init__(data_preprocessor=data_preprocessor)
|
|
|
|
@staticmethod
|
|
def _build_wrapper(backend: Backend,
|
|
backend_files: Sequence[str],
|
|
device: str,
|
|
input_names: Optional[Sequence[str]] = None,
|
|
output_names: Optional[Sequence[str]] = None,
|
|
deploy_cfg: Optional[mmengine.Config] = None,
|
|
*args,
|
|
**kwargs):
|
|
"""The default methods to build backend wrappers.
|
|
|
|
Args:
|
|
backend (Backend): The backend enum type.
|
|
beckend_files (Sequence[str]): Paths to all required backend files(
|
|
e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
|
|
device (str): A string specifying device type.
|
|
input_names (Sequence[str] | None): Names of model inputs in
|
|
order. Defaults to `None`.
|
|
output_names (Sequence[str] | None): Names of model outputs in
|
|
order. Defaults to `None` and the wrapper will load the output
|
|
names from the model.
|
|
deploy_cfg: Deployment config file.
|
|
"""
|
|
if backend == Backend.ONNXRUNTIME:
|
|
from mmdeploy.backend.onnxruntime import ORTWrapper
|
|
return ORTWrapper(
|
|
onnx_file=backend_files[0],
|
|
device=device,
|
|
output_names=output_names)
|
|
elif backend == Backend.TENSORRT:
|
|
from mmdeploy.backend.tensorrt import TRTWrapper
|
|
return TRTWrapper(
|
|
engine=backend_files[0], output_names=output_names)
|
|
elif backend == Backend.PPLNN:
|
|
from mmdeploy.backend.pplnn import PPLNNWrapper
|
|
return PPLNNWrapper(
|
|
onnx_file=backend_files[0],
|
|
algo_file=backend_files[1] if len(backend_files) > 1 else None,
|
|
device=device,
|
|
output_names=output_names)
|
|
elif backend == Backend.NCNN:
|
|
from mmdeploy.backend.ncnn import NCNNWrapper
|
|
|
|
# For unittest deploy_config will not pass into _build_wrapper
|
|
# function.
|
|
if deploy_cfg:
|
|
backend_config = get_backend_config(deploy_cfg)
|
|
use_vulkan = backend_config.get('use_vulkan', False)
|
|
else:
|
|
use_vulkan = False
|
|
return NCNNWrapper(
|
|
param_file=backend_files[0],
|
|
bin_file=backend_files[1],
|
|
output_names=output_names,
|
|
use_vulkan=use_vulkan)
|
|
elif backend == Backend.OPENVINO:
|
|
from mmdeploy.backend.openvino import OpenVINOWrapper
|
|
return OpenVINOWrapper(
|
|
ir_model_file=backend_files[0], output_names=output_names)
|
|
elif backend == Backend.SDK:
|
|
assert deploy_cfg is not None, \
|
|
'Building SDKWrapper requires deploy_cfg'
|
|
from mmdeploy.backend.sdk.wrapper import SDKWrapper
|
|
task_name = SDK_TASK_MAP[get_task_type(deploy_cfg)]['cls_name']
|
|
return SDKWrapper(
|
|
model_file=backend_files[0],
|
|
task_name=task_name,
|
|
device=device)
|
|
elif backend == Backend.TORCHSCRIPT:
|
|
from mmdeploy.backend.torchscript import TorchscriptWrapper
|
|
return TorchscriptWrapper(
|
|
model=backend_files[0],
|
|
input_names=input_names,
|
|
output_names=output_names)
|
|
elif backend == Backend.SNPE:
|
|
from mmdeploy.backend.snpe import SNPEWrapper
|
|
uri = None
|
|
if 'uri' in kwargs:
|
|
uri = kwargs['uri']
|
|
return SNPEWrapper(
|
|
dlc_file=backend_files[0], uri=uri, output_names=output_names)
|
|
else:
|
|
raise NotImplementedError(f'Unknown backend type: {backend.value}')
|