mmdeploy/docs/zh_cn/get_started.md
RunningLeon 4d8ea40f55
Sync v0.7.0 to dev-1.x (#907)
* make -install -> make install (#621)

change `make -install` to `make install`

https://github.com/open-mmlab/mmdeploy/issues/618

* [Fix] fix csharp api detector release result (#620)

* fix csharp api detector release result

* fix wrong count arg of xxx_release_result in c# api

* [Enhancement] Support two-stage rotated detector TensorRT. (#530)

* upload

* add fake_multiclass_nms_rotated

* delete unused code

* align with pytorch

* Update delta_midpointoffset_rbbox_coder.py

* add trt rotated roi align

* add index feature in nms

* not good

* fix index

* add ut

* add benchmark

* move to csrc/mmdeploy

* update unit test

Co-authored-by: zytx121 <592267829@qq.com>

* Reduce mmcls version dependency (#635)

* fix shufflenetv2 with trt (#645)

* fix shufflenetv2 and pspnet

* fix ci

* remove print

* ' -> " (#654)

If there is a variable in the string, single quotes will ignored it, while double quotes will bring the variable into the string after parsing

* ' -> " (#655)

same with https://github.com/open-mmlab/mmdeploy/pull/654

* Support deployment of Segmenter (#587)

* support segmentor with ncnn

* update regression yml

* replace chunk with split to support ts

* update regression yml

* update docs

* fix segmenter ncnn inference failure brought by #477

* add test

* fix test for ncnn and trt

* fix lint

* export nn.linear to Gemm op in onnx for ncnn

* fix ci

* simplify `Expand` (#617)

* Fix typo (#625)

* Add make install in en docs

* Add make install in zh docs

* Fix typo

* Merge and add windows build

Co-authored-by: tripleMu <865626@163.com>

* [Enhancement] Fix ncnn unittest (#626)

* optmize-csp-darknet

* replace floordiv to torch.div

* update csp_darknet default implement

* fix test

* [Enhancement] TensorRT Anchor generator plugin (#646)

* custom trt anchor generator

* add ut

* add docstring, update doc

* Add partition doc and sample code (#599)

* update torch2onnx tool to support onnx partition

* add model partition of yolov3

* add cn doc

* update torch2onnx tool to support onnx partition

* add model partition of yolov3

* add cn doc

* add to index.rst

* resolve comment

* resolve comments

* fix lint

* change caption level in docs

* update docs (#624)

* Add java apis and demos (#563)

* add java classifier detector

* add segmentor

* fix lint

* add ImageRestorer java apis and demo

* remove useless count parameter for Segmentor and Restorer, add PoseDetector

* add RotatedDetection java api and demo

* add Ocr java demo and apis

* remove mmrotate ncnn java api and demo

* fix lint

* sync java api folder after rebase to master

* fix include

* remove record

* fix java apis dir path in cmake

* add java demo readme

* fix lint mdformat

* add test javaapi ci

* fix lint

* fix flake8

* fix test javaapi ci

* refactor readme.md

* fix install opencv for ci

* fix install opencv : add permission

* add all codebases and mmcv install

* add torch

* install mmdeploy

* fix image path

* fix picture path

* fix import ncnn

* fix import ncnn

* add submodule of pybind

* fix pybind submodule

* change download to git clone for submodule

* fix ncnn dir

* fix README error

* simplify the github ci

* fix ci

* fix yapf

* add JNI as required

* fix Capitalize

* fix Capitalize

* fix copyright

* ignore .class changed

* add OpenJDK installation docs

* install target of javaapi

* simplify ci

* add jar

* fix ci

* fix ci

* fix test java command

* debugging what failed

* debugging what failed

* debugging what failed

* add java version info

* install openjdk

* add java env var

* fix export

* fix export

* fix export

* fix export

* fix picture path

* fix picture path

* fix file name

* fix file name

* fix README

* remove java_api strategy

* fix python version

* format task name

* move args position

* extract common utils code

* show image class result

* add detector result

* segmentation result format

* add ImageRestorer result

* add PoseDetection java result format

* fix ci

* stage ocr

* add visualize

* move utils

* fix lint

* fix ocr bugs

* fix ci demo

* fix java classpath for ci

* fix popd

* fix ocr demo text garbled

* fix ci

* fix ci

* fix ci

* fix path of utils ci

* update the circleci config file by adding workflows both for linux, windows and linux-gpu (#368)

* update circleci by adding more workflows

* fix test workflow failure on windows platform

* fix docker exec command for SDK unittests

* Fixed tensorrt plugin not found in Windows (#672)

* update introduction.png (#674)

* [Enhancement] Add fuse select assign pass (#589)

* Add fuse select assign pass

* move code to csrc

* add config flag

* remove bool cast

* fix export sdk info of input shape (#667)

* Update get_started.md (#675)

Fix backend model assignment

* Update get_started.md (#676)

Fix backend model assignment

* [Fix] fix clang build (#677)

* fix clang build

* fix ndk build

* fix ndk build

* switch to `std::filesystem` for clang-7 and later

* Deploy the Swin Transformer on TensorRT. (#652)

* resolve conflicts

* update ut and docs

* fix ut

* refine docstring

* add comments and refine UT

* resolve comments

* resolve comments

* update doc

* add roll export

* check backend

* update regression test

* bump version to 0.6.0 (#680)

* bump vertion to 0.6.0

* update version

* pass img_metas while exporting to onnx (#681)

* pass img_metas while exporting to onnx

* remove try-catch in tools for beter debugging

* use get

* fix typo

* [Fix] fix ssd ncnn ut (#692)

* fix ssd ncnn ut

* fix yapf

* fix passing img_metas to pytorch2onnx for mmedit (#700)

* fix passing img_metas for mmdet3d (#707)

* [Fix] Fix android build (#698)

* fix android build

* fix cmake

* fix url link

* fix wrong exit code in pipeline_manager (#715)

* fix exit

* change to general exit errorcode=1

* fix passing wrong backend type (#719)

* Rename onnx2ncnn to mmdeploy_onnx2ncnn (#694)

* improvement(tools/onnx2ncnn.py): rename to mmdeploy_onnx2ncnn

* format(tools/deploy.py): clean code

* fix(init_plugins.py): improve if condition

* fix(CI): update target

* fix(test_onnx2ncnn.py): update desc

* Update init_plugins.py

* [Fix] Fix mmdet ort static shape bug (#687)

* fix shape

* add device

* fix yapf

* fix rewriter for transforms

* reverse image shape

* fix ut of distance2bbox

* fix rewriter name

* fix c4 for torchscript (#724)

* [Enhancement] Standardize C API (#634)

* unify C API naming

* fix demo and move apis/c/* -> apis/c/mmdeploy/*

* fix lint

* fix C# project

* fix Java API

* [Enhancement] Support Slide Vertex TRT (#650)

* reorgnize mmrotate

* fix

* add hbb2obb

* add ut

* fix rotated nms

* update docs

* update benchmark

* update test

* remove ort regression test, remove comment

* Fix get-started rendering issues in readthedocs (#740)

* fix mermaid markdown rendering issue in readthedocs

* fix error in C++ example

* fix error in c++ example in zh_cn get_started doc

* [Fix] set default topk for dump info (#702)

* set default topk for dump info

* remove redundant docstrings

* add ci densenet

* fix classification warnings

* fix mmcls version

* fix logger.warnings

* add version control (#754)

* fix satrn for ORT (#753)

* fix satrn for ORT

* move rewrite into pytorch

* Add inference latency test tool (#665)

* add profile tool

* remove print envs in profile tool

* set cudnn_benchmark to True

* add doc

* update tests

* fix typo

* support test with images from a directory

* update doc

* resolve comments

* [Enhancement] Add CSE ONNX pass (#647)

* Add fuse select assign pass

* move code to csrc

* add config flag

* Add fuse select assign pass

* Add CSE for ONNX

* remove useless code

* Test robot

Just test robot

* Update README.md

Revert

* [Fix] fix yolox point_generator (#758)

* fix yolox point_generator

* add a UT

* resolve comments

* fix comment lines

* limit markdown version (#773)

* [Enhancement] Better index put ONNX export. (#704)

* Add rewriter for tensor setitem

* add version check

* Upgrade Dockerfile to use TensorRT==8.2.4.2 (#706)

* Upgrade TensorRT to 8.2.4.2

* upgrade pytorch&mmcv in CPU Dockerfile

* Delete redundant port example in Docker

* change 160x160-608x608 to 64x64-608x608 for yolov3

* [Fix] reduce log verbosity & improve error reporting (#755)

* reduce log verbosity & improve error reporting

* improve error reporting

* [Enhancement] Support latest ppl.nn & ppl.cv (#564)

* support latest ppl.nn

* fix pplnn for model convertor

* fix lint

* update memory policy

* import algo from buffer

* update ppl.cv

* use `ppl.cv==0.7.0`

* document supported ppl.nn version

* skip pplnn dependency when building shared libs

* [Fix][P0] Fix for torch1.12 (#751)

* fix for torch1.12

* add comment

* fix check env (#785)

* [Fix] fix cascade mask rcnn (#787)

* fix cascade mask rcnn

* fix lint

* add regression

* [Feature] Support RoITransRoIHead (#713)

* [Feature] Support RoITransRoIHead

* Add docs

* Add mmrotate models regression test

* Add a draft for test code

* change the argument name

* fix test code

* fix minor change for not class agnostic case

* fix sample for test code

* fix sample for test code

* Add mmrotate in requirements

* Revert "Add mmrotate in requirements"

This reverts commit 043490075e6dbe4a8fb98e94b2b583b91fc5038d.

* [Fix] fix triu (#792)

* fix triu

* triu -> triu_default

* [Enhancement] Install Optimizer by setuptools (#690)

* Add fuse select assign pass

* move code to csrc

* add config flag

* Add fuse select assign pass

* Add CSE for ONNX

* remove useless code

* Install optimizer by setup tools

* fix comment

* [Feature] support MMRotate model with le135 (#788)

* support MMRotate model with le135

* cse before fuse select assign

* remove unused import

* [Fix] Support macOS build (#762)

* fix macOS build

* fix missing

* add option to build & install examples (#822)

* [Fix] Fix setup on non-linux-x64 (#811)

* fix setup

* replace long to int64_t

* [Feature] support build single sdk library (#806)

* build single lib for c api

* update csharp doc & project

* update test build

* fix test build

* fix

* update document for building android sdk (#817)

Co-authored-by: dwSun <dwsunny@icloud.com>

* [Enhancement] support kwargs in SDK python bindings (#794)

* support-kwargs

* make '__call__' as single image inference and add 'batch' API to deal with batch images inference

* fix linting error and typo

* fix lint

* improvement(sdk): add sdk code coverage (#808)

* feat(doc): add CI

* CI(sdk): add sdk coverage

* style(test): code format

* fix(CI): update coverage.info path

* improvement(CI): use internal image

* improvement(CI): push coverage info once

* [Feature] Add C++ API for SDK (#831)

* add C++ API

* unify result type & add examples

* minor fix

* install cxx API headers

* fix Mat, add more examples

* fix monolithic build & fix lint

* install examples correctly

* fix lint

* feat(tools/deploy.py): support snpe (#789)

* fix(tools/deploy.py): support snpe

* improvement(backend/snpe): review advices

* docs(backend/snpe): update build

* docs(backend/snpe): server support specify port

* docs(backend/snpe): update path

* fix(backend/snpe): time counter missing argument

* docs(backend/snpe): add missing argument

* docs(backend/snpe): update download and using

* improvement(snpe_net.cpp): load model with modeldata

* Support setup on environment with no PyTorch (#843)

* support test with multi batch (#829)

* support test with multi batch

* resolve comment

* import algorithm from buffer (#793)

* [Enhancement] build sdk python api in standard-alone manner (#810)

* build sdk python api in standard-alone manner

* enable MMDEPLOY_BUILD_SDK_MONOLITHIC and MMDEPLOY_BUILD_EXAMPLES in prebuild config

* link mmdeploy to python target when monolithic option is on

* checkin README to describe precompiled package build procedure

* use packaging.version.parse(python_version) instead of list(python_version)

* fix according to review results

* rebase master

* rollback cmake.in and apis/python/CMakeLists.txt

* reorganize files in install/example

* let cmake detect visual studio instead of specifying 2019

* rename whl name of precompiled package

* fix according to review results

* Fix SDK backend (#844)

* fix mmpose python api (#852)

* add prebuild package usage docs on windows (#816)

* add prebuild package usage docs on windows

* fix lint

* update

* try fix lint

* add en docs

* update

* update

* udpate faq

* fix typo (#862)

* [Enhancement] Improve get_started documents and bump version to 0.7.0 (#813)

* simplify commands in get_started

* add installation commands for Windows

* fix typo

* limit markdown and sphinx_markdown_tables version

* adopt html <details open> tag

* bump mmdeploy version

* bump mmdeploy version

* update get_started

* update get_started

* use python3.8 instead of python3.7

* remove duplicate section

* resolve issue #856

* update according to review results

* add reference to prebuilt_package_windows.md

* fix error when build sdk demos

* fix mmcls

Co-authored-by: Ryan_Huang <44900829+DrRyanHuang@users.noreply.github.com>
Co-authored-by: Chen Xin <xinchen.tju@gmail.com>
Co-authored-by: q.yao <yaoqian@sensetime.com>
Co-authored-by: zytx121 <592267829@qq.com>
Co-authored-by: Li Zhang <lzhang329@gmail.com>
Co-authored-by: tripleMu <gpu@163.com>
Co-authored-by: tripleMu <865626@163.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
Co-authored-by: lvhan028 <lvhan_028@163.com>
Co-authored-by: Bryan Glen Suello <11388006+bgsuello@users.noreply.github.com>
Co-authored-by: zambranohally <63218980+zambranohally@users.noreply.github.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: Hakjin Lee <nijkah@gmail.com>
Co-authored-by: 孙德伟 <5899962+dwSun@users.noreply.github.com>
Co-authored-by: dwSun <dwsunny@icloud.com>
Co-authored-by: Chen Xin <irexyc@gmail.com>
2022-08-19 09:30:13 +08:00

13 KiB
Raw Blame History

操作概述

MMDeploy 提供了一系列工具,帮助您更轻松的将 OpenMMLab 下的算法部署到各种设备与平台上。

您可以使用我们设计的流程一“部”到位,也可以定制您自己的转换流程。

流程简介

MMDeploy 定义的模型部署流程,如下图所示: deploy-pipeline

模型转换Model Converter

模型转换的主要功能是把输入的模型格式,转换为目标设备的推理引擎所要求的模型格式。

目前MMDeploy 可以把 PyTorch 模型转换为 ONNX、TorchScript 等和设备无关的 IR 模型。也可以将 ONNX 模型转换为推理后端模型。两者相结合,可实现端到端的模型转换,也就是从训练端到生产端的一键式部署。

MMDeploy 模型MMDeploy Model

也称 SDK Model。它是模型转换结果的集合。不仅包括后端模型还包括模型的元信息。这些信息将用于推理 SDK 中。

推理 SDKInference SDK

封装了模型的前处理、网络推理和后处理过程。对外提供多语言的模型推理接口。

准备工作

对于端到端的模型转换和推理MMDeploy 依赖 Python 3.6+ 以及 PyTorch 1.8+。

第一步:从官网下载并安装 Miniconda

第二步:创建并激活 conda 环境

conda create --name mmdeploy python=3.8 -y
conda activate mmdeploy

第三步: 参考官方文档并安装 PyTorch

在 GPU 环境下:

conda install pytorch=={pytorch_version} torchvision=={torchvision_version} cudatoolkit={cudatoolkit_version} -c pytorch -c conda-forge

在 CPU 环境下:

conda install pytorch=={pytorch_version} torchvision=={torchvision_version} cpuonly -c pytorch
在 GPU 环境下,请务必保证 {cudatoolkit_version} 和主机的 CUDA Toolkit 版本一致,避免在使用 TensorRT 时,可能引起的版本冲突问题。

安装 MMDeploy

第一步:通过 MIM 安装 MMCV

pip install -U openmim
mim install mmcv-full

第二步: 安装 MMDeploy 和 推理引擎

我们推荐用户使用预编译包安装和体验 MMDeploy 功能。请根据目标软硬件平台,从这里 选择最新版本下载并安装。

目前MMDeploy 的预编译包支持的平台和设备矩阵如下:

OS-Arch Device ONNX Runtime TensorRT
Linux-x86_64 CPU Y N/A
CUDA N Y
Windows-x86_64 CPU Y N/A
CUDA N Y

注:对于不在上述表格中的软硬件平台,请参考源码安装文档,正确安装和配置 MMDeploy。

以最新的预编译包为例,你可以参考以下命令安装:

Linux-x86_64, CPU, ONNX Runtime 1.8.1
# 安装 MMDeploy ONNX Runtime 自定义算子库和推理 SDK
wget https://github.com/open-mmlab/mmdeploy/releases/download/v0.7.0/mmdeploy-0.7.0-linux-x86_64-onnxruntime1.8.1.tar.gz
tar -zxvf mmdeploy-0.7.0-linux-x86_64-onnxruntime1.8.1.tar.gz
cd mmdeploy-0.7.0-linux-x86_64-onnxruntime1.8.1
pip install dist/mmdeploy-0.7.0-py3-none-linux_x86_64.whl
pip install sdk/python/mmdeploy_python-0.7.0-cp38-none-linux_x86_64.whl
cd ..
# 安装推理引擎 ONNX Runtime
pip install onnxruntime==1.8.1
wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz
tar -zxvf onnxruntime-linux-x64-1.8.1.tgz
export ONNXRUNTIME_DIR=$(pwd)/onnxruntime-linux-x64-1.8.1
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
Linux-x86_64, CUDA 11.x, TensorRT 8.2.3.0
# 安装 MMDeploy TensorRT 自定义算子库和推理 SDK
wget https://github.com/open-mmlab/mmdeploy/releases/download/v0.7.0/mmdeploy-0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0.tar.gz
tar -zxvf mmdeploy-v0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0.tar.gz
cd mmdeploy-0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0
pip install dist/mmdeploy-0.7.0-py3-none-linux_x86_64.whl
pip install sdk/python/mmdeploy_python-0.7.0-cp38-none-linux_x86_64.whl
cd ..
# 安装推理引擎 TensorRT
# !!! 从 NVIDIA 官网下载 TensorRT-8.2.3.0 CUDA 11.x 安装包并解压到当前目录
pip install TensorRT-8.2.3.0/python/tensorrt-8.2.3.0-cp38-none-linux_x86_64.whl
pip install pycuda
export TENSORRT_DIR=$(pwd)/TensorRT-8.2.3.0
export LD_LIBRARY_PATH=${TENSORRT_DIR}/lib:$LD_LIBRARY_PATH
# !!! 从 NVIDIA 官网下载 cuDNN 8.2.1 CUDA 11.x 安装包并解压到当前目录
export CUDNN_DIR=$(pwd)/cuda
export LD_LIBRARY_PATH=$CUDNN_DIR/lib64:$LD_LIBRARY_PATH
Windows-x86_64

请阅读 这里,了解 MMDeploy 预编译包在 Windows 平台下的使用方法。

模型转换

在准备工作就绪后,我们可以使用 MMDeploy 中的工具 tools/deploy.py,将 OpenMMLab 的 PyTorch 模型转换成推理后端支持的格式。 对于tools/deploy.py 的使用细节,请参考 如何转换模型

MMDetection 中的 Faster R-CNN 为例,我们可以使用如下命令,将 PyTorch 模型转换为 TenorRT 模型,从而部署到 NVIDIA GPU 上.

# 克隆 mmdeploy 仓库。转换时,需要使用 mmdeploy 仓库中的配置文件,建立转换流水线
git clone --recursive https://github.com/open-mmlab/mmdeploy.git

# 安装 mmdetection。转换时需要使用 mmdetection 仓库中的模型配置文件,构建 PyTorch nn module
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -v -e .
cd ..

# 下载 Faster R-CNN 模型权重
wget -P checkpoints https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth

# 执行转换命令,实现端到端的转换
python mmdeploy/tools/deploy.py \
    mmdeploy/configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py \
    mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
    checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
    mmdetection/demo/demo.jpg \
    --work-dir mmdeploy_model/faster-rcnn \
    --device cuda \
    --dump-info

转换结果被保存在 --work-dir 指向的文件夹中。该文件夹中不仅包含推理后端模型,还包括推理元信息。这些内容的整体被定义为 SDK Model。推理 SDK 将用它进行模型推理。

在安装了 MMDeploy-ONNXRuntime 预编译包后把上述转换命令中的detection_tensorrt_dynamic-320x320-1344x1344.py 换成 detection_onnxruntime_dynamic.py并修改 --device 为 cpu
即可以转出 onnx 模型,并用 ONNXRuntime 进行推理

模型推理

在转换完成后,你既可以使用 Model Converter 进行推理,也可以使用 Inference SDK。

使用 Model Converter 的推理 API

Model Converter 屏蔽了推理后端接口的差异,对其推理 API 进行了统一封装,接口名称为 inference_model

以上文中 Faster R-CNN 的 TensorRT 模型为例,你可以使用如下方式进行模型推理工作:

from mmdeploy.apis import inference_model
result = inference_model(
  model_cfg='mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
  deploy_cfg='mmdeploy/configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py',
  backend_files=['mmdeploy_model/faster-rcnn/end2end.engine'],
  img='mmdetection/demo/demo.jpg',
  device='cuda:0')
接口中的 model_path 指的是推理引擎文件的路径比如例子当中end2end.engine文件的路径。路径必须放在 list 中,因为有的推理引擎模型结构和权重是分开存储的。

使用推理 SDK

你可以直接运行预编译包中的 demo 程序,输入 SDK Model 和图像,进行推理,并查看推理结果。

cd mmdeploy-0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0
# 运行 python demo
python sdk/example/python/object_detection.py cuda ../mmdeploy_model/faster-rcnn ../mmdetection/demo/demo.jpg
# 运行 C/C++ demo
export LD_LIBRARY_PATH=$(pwd)/sdk/lib:$LD_LIBRARY_PATH
./sdk/bin/object_detection cuda ../mmdeploy_model/faster-rcnn ../mmdetection/demo/demo.jpg
以上述命令中,输入模型是 SDK Model 的路径(也就是 Model Converter 中 --work-dir 参数),而不是推理引擎文件的路径。
因为 SDK 不仅要获取推理引擎文件还需要推理元信息deploy.json, pipeline.json。它们合在一起构成 SDK Model存储在 --work-dir 下

除了 demo 程序,预编译包还提供了 SDK 多语言接口。你可以根据自己的项目需求,选择合适的语言接口, 把 MMDeploy SDK 集成到自己的项目中,进行二次开发。

Python API

对于检测功能,你也可以参考如下代码,集成 MMDeploy SDK Python API 到自己的项目中:

from mmdeploy_python import Detector
import cv2

# 读取图片
img = cv2.imread('mmdetection/demo/demo.jpg')

# 创建检测器
detector = Detector(model_path='mmdeploy_models/faster-rcnn', device_name='cuda', device_id=0)
# 执行推理
bboxes, labels, _ = detector(img)
# 使用阈值过滤推理结果,并绘制到原图中
indices = [i for i in range(len(bboxes))]
for index, bbox, label_id in zip(indices, bboxes, labels):
  [left, top, right, bottom], score = bbox[0:4].astype(int),  bbox[4]
  if score < 0.3:
      continue
  cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0))

cv2.imwrite('output_detection.png', img)

更多示例,请查阅这里

C++ API

使用 C++ API 进行模型推理的流程符合下面的模式: image

以下是这个流程的具体应用过程:

#include <cstdlib>
#include <opencv2/opencv.hpp>
#include "mmdeploy/detector.hpp"

int main() {
  const char* device_name = "cuda";
  int device_id = 0;

  // mmdeploy SDK model以上文中转出的 faster r-cnn 模型为例
  std::string model_path = "mmdeploy_model/faster-rcnn";
  std::string image_path = "mmdetection/demo/demo.jpg";

  // 1. 读取模型
  mmdeploy::Model model(model_path);
  // 2. 创建预测器
  mmdeploy::Detector detector(model, mmdeploy::Device{device_name, device_id});
  // 3. 读取图像
  cv::Mat img = cv::imread(image_path);
  // 4. 应用预测器推理
  auto dets = detector.Apply(img);
  // 5. 处理推理结果: 此处我们选择可视化推理结果
  for (int i = 0; i < dets.size(); ++i) {
    const auto& box = dets[i].bbox;
    fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n",
            i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score);
    if (bboxes[i].score < 0.3) {
      continue;
    }
    cv::rectangle(img, cv::Point{(int)box.left, (int)box.top},
                  cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0});
  }
  cv::imwrite("output_detection.png", img);
  return 0;
}

在您的项目CMakeLists中增加

find_package(MMDeploy REQUIRED)
target_link_libraries(${name} PRIVATE mmdeploy ${OpenCV_LIBS})

编译时,使用 -DMMDeploy_DIR传入MMDeloyConfig.cmake所在的路径。它在预编译包中的sdk/lib/cmake/MMDeloy下。 更多示例,请查阅此处

对于 C API、C# API、Java API 的使用方法,请分别阅读代码C demos C# demosJava demos。 我们将在后续版本中详细讲述它们的用法。

模型精度评估

为了测试部署模型的精度,推理效率,我们提供了 tools/test.py 来帮助完成相关工作。以上文中的部署模型为例:

python mmdeploy/tools/test.py \
    mmdeploy/configs/detection/detection_tensorrt_dynamic-320x320-1344x1344.py \
    mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
    --model mmdeploy_model/faster-rcnn/end2end.engine \
    --metrics ${METRICS} \
    --device cuda:0
关于 --model 选项,当使用 Model Converter 进行推理时,它代表转换后的推理后端模型的文件路径。而当使用 SDK 测试模型精度时,该选项表示 MMDeploy Model 的路径.

请阅读 如何进行模型评估 了解关于 tools/test.py 的使用细节。