mmdeploy/mmdeploy/codebase/base/backend_model.py
hanrui1sensetime 308e28fcb0
[Enhancement] Support Object Detection and Instance Segmentation for ort trt ncnn and openvino in mmdet 2.0 (#786)
* support cascade (mask) rcnn

* fix docstring

* support SwinTransformer

* move dense_head support to this branch

* fix function names

* fix part of uts of mmdet

* fix for mmdet ut

* fix det model cfg for ut

* fix test_object_detection.py

* fix mmdet object_detection_model.py

* fix mmdet yolov3 ort ut

* fix part of uts

* fix cascade bbox head ut

* fix cascade bbox head ut

* remove useless ssd ncnn test

* fix ncnn wrapper

* fix openvino ut for reppoint head

* fix openvino cascade mask rcnn

* sync codes

* support roll

* remove unused pad

* fix yolox

* fix isort

* fix lint

* fix flake8

* reply for comments and fix failed ut

* fix sdk_export in dump_info

* fix temp hidden xlsx bugs

* fix mmdet regression test

* fix lint

* fix timer

* fix timecount side-effect

* adapt profile.py for mmdet 2.0

* hardcode report.txt for T4 benchmark test: temp version

* fix no-visualizer case

* fix backend_model

* fix android build

* adapt new mmdet 2.0 0825

* fix new 2.0

* fix test_mmdet_structures

* fix test_object_detection

* fix codebase import

* fix ut

* fix all mmdet uts

* fix det

* fix mmdet trt

* fix ncnn onnx optimize
2022-09-01 11:35:57 +08:00

120 lines
5.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta
from typing import Optional, Sequence, Union
import mmengine
from mmengine.model import BaseModel
from torch import nn
from mmdeploy.utils import (SDK_TASK_MAP, Backend, get_backend_config,
get_ir_config, get_task_type)
class BaseBackendModel(BaseModel, metaclass=ABCMeta):
"""A backend model wraps the details to initialize and run a backend
engine."""
def __init__(self,
deploy_cfg: Optional[Union[str, mmengine.Config]] = None,
data_preprocessor: Optional[Union[dict, nn.Module]] = None,
*args,
**kwargs):
"""The default for building the base class.
Args:
deploy_cfg (str | mmengine.Config | None): The deploy config.
"""
input_names = output_names = None
if deploy_cfg is not None:
ir_config = get_ir_config(deploy_cfg)
output_names = ir_config.get('output_names', None)
input_names = ir_config.get('input_names', None)
# TODO use input_names instead in the future for multiple inputs
self.input_name = input_names[0] if input_names else 'input'
self.output_names = output_names if output_names else ['output']
super().__init__(data_preprocessor=data_preprocessor)
@staticmethod
def _build_wrapper(backend: Backend,
backend_files: Sequence[str],
device: str,
input_names: Optional[Sequence[str]] = None,
output_names: Optional[Sequence[str]] = None,
deploy_cfg: Optional[mmengine.Config] = None,
*args,
**kwargs):
"""The default methods to build backend wrappers.
Args:
backend (Backend): The backend enum type.
beckend_files (Sequence[str]): Paths to all required backend files(
e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn).
device (str): A string specifying device type.
input_names (Sequence[str] | None): Names of model inputs in
order. Defaults to `None`.
output_names (Sequence[str] | None): Names of model outputs in
order. Defaults to `None` and the wrapper will load the output
names from the model.
deploy_cfg: Deployment config file.
"""
if backend == Backend.ONNXRUNTIME:
from mmdeploy.backend.onnxruntime import ORTWrapper
return ORTWrapper(
onnx_file=backend_files[0],
device=device,
output_names=output_names)
elif backend == Backend.TENSORRT:
from mmdeploy.backend.tensorrt import TRTWrapper
return TRTWrapper(
engine=backend_files[0], output_names=output_names)
elif backend == Backend.PPLNN:
from mmdeploy.backend.pplnn import PPLNNWrapper
return PPLNNWrapper(
onnx_file=backend_files[0],
algo_file=backend_files[1] if len(backend_files) > 1 else None,
device=device,
output_names=output_names)
elif backend == Backend.NCNN:
from mmdeploy.backend.ncnn import NCNNWrapper
# For unittest deploy_config will not pass into _build_wrapper
# function.
if deploy_cfg:
backend_config = get_backend_config(deploy_cfg)
use_vulkan = backend_config.get('use_vulkan', False)
else:
use_vulkan = False
return NCNNWrapper(
param_file=backend_files[0],
bin_file=backend_files[1],
output_names=output_names,
use_vulkan=use_vulkan)
elif backend == Backend.OPENVINO:
from mmdeploy.backend.openvino import OpenVINOWrapper
return OpenVINOWrapper(
ir_model_file=backend_files[0], output_names=output_names)
elif backend == Backend.SDK:
assert deploy_cfg is not None, \
'Building SDKWrapper requires deploy_cfg'
from mmdeploy.backend.sdk.wrapper import SDKWrapper
task_name = SDK_TASK_MAP[get_task_type(deploy_cfg)]['cls_name']
return SDKWrapper(
model_file=backend_files[0],
task_name=task_name,
device=device)
elif backend == Backend.TORCHSCRIPT:
from mmdeploy.backend.torchscript import TorchscriptWrapper
return TorchscriptWrapper(
model=backend_files[0],
input_names=input_names,
output_names=output_names)
elif backend == Backend.SNPE:
from mmdeploy.backend.snpe import SNPEWrapper
uri = None
if 'uri' in kwargs:
uri = kwargs['uri']
return SNPEWrapper(
dlc_file=backend_files[0], uri=uri, output_names=output_names)
else:
raise NotImplementedError(f'Unknown backend type: {backend.value}')