* make -install -> make install (#621) change `make -install` to `make install` https://github.com/open-mmlab/mmdeploy/issues/618 * [Fix] fix csharp api detector release result (#620) * fix csharp api detector release result * fix wrong count arg of xxx_release_result in c# api * [Enhancement] Support two-stage rotated detector TensorRT. (#530) * upload * add fake_multiclass_nms_rotated * delete unused code * align with pytorch * Update delta_midpointoffset_rbbox_coder.py * add trt rotated roi align * add index feature in nms * not good * fix index * add ut * add benchmark * move to csrc/mmdeploy * update unit test Co-authored-by: zytx121 <592267829@qq.com> * Reduce mmcls version dependency (#635) * fix shufflenetv2 with trt (#645) * fix shufflenetv2 and pspnet * fix ci * remove print * ' -> " (#654) If there is a variable in the string, single quotes will ignored it, while double quotes will bring the variable into the string after parsing * ' -> " (#655) same with https://github.com/open-mmlab/mmdeploy/pull/654 * Support deployment of Segmenter (#587) * support segmentor with ncnn * update regression yml * replace chunk with split to support ts * update regression yml * update docs * fix segmenter ncnn inference failure brought by #477 * add test * fix test for ncnn and trt * fix lint * export nn.linear to Gemm op in onnx for ncnn * fix ci * simplify `Expand` (#617) * Fix typo (#625) * Add make install in en docs * Add make install in zh docs * Fix typo * Merge and add windows build Co-authored-by: tripleMu <865626@163.com> * [Enhancement] Fix ncnn unittest (#626) * optmize-csp-darknet * replace floordiv to torch.div * update csp_darknet default implement * fix test * [Enhancement] TensorRT Anchor generator plugin (#646) * custom trt anchor generator * add ut * add docstring, update doc * Add partition doc and sample code (#599) * update torch2onnx tool to support onnx partition * add model partition of yolov3 * add cn doc * update torch2onnx tool to support onnx partition * add model partition of yolov3 * add cn doc * add to index.rst * resolve comment * resolve comments * fix lint * change caption level in docs * update docs (#624) * Add java apis and demos (#563) * add java classifier detector * add segmentor * fix lint * add ImageRestorer java apis and demo * remove useless count parameter for Segmentor and Restorer, add PoseDetector * add RotatedDetection java api and demo * add Ocr java demo and apis * remove mmrotate ncnn java api and demo * fix lint * sync java api folder after rebase to master * fix include * remove record * fix java apis dir path in cmake * add java demo readme * fix lint mdformat * add test javaapi ci * fix lint * fix flake8 * fix test javaapi ci * refactor readme.md * fix install opencv for ci * fix install opencv : add permission * add all codebases and mmcv install * add torch * install mmdeploy * fix image path * fix picture path * fix import ncnn * fix import ncnn * add submodule of pybind * fix pybind submodule * change download to git clone for submodule * fix ncnn dir * fix README error * simplify the github ci * fix ci * fix yapf * add JNI as required * fix Capitalize * fix Capitalize * fix copyright * ignore .class changed * add OpenJDK installation docs * install target of javaapi * simplify ci * add jar * fix ci * fix ci * fix test java command * debugging what failed * debugging what failed * debugging what failed * add java version info * install openjdk * add java env var * fix export * fix export * fix export * fix export * fix picture path * fix picture path * fix file name * fix file name * fix README * remove java_api strategy * fix python version * format task name * move args position * extract common utils code * show image class result * add detector result * segmentation result format * add ImageRestorer result * add PoseDetection java result format * fix ci * stage ocr * add visualize * move utils * fix lint * fix ocr bugs * fix ci demo * fix java classpath for ci * fix popd * fix ocr demo text garbled * fix ci * fix ci * fix ci * fix path of utils ci * update the circleci config file by adding workflows both for linux, windows and linux-gpu (#368) * update circleci by adding more workflows * fix test workflow failure on windows platform * fix docker exec command for SDK unittests * Fixed tensorrt plugin not found in Windows (#672) * update introduction.png (#674) * [Enhancement] Add fuse select assign pass (#589) * Add fuse select assign pass * move code to csrc * add config flag * remove bool cast * fix export sdk info of input shape (#667) * Update get_started.md (#675) Fix backend model assignment * Update get_started.md (#676) Fix backend model assignment * [Fix] fix clang build (#677) * fix clang build * fix ndk build * fix ndk build * switch to `std::filesystem` for clang-7 and later * Deploy the Swin Transformer on TensorRT. (#652) * resolve conflicts * update ut and docs * fix ut * refine docstring * add comments and refine UT * resolve comments * resolve comments * update doc * add roll export * check backend * update regression test * bump version to 0.6.0 (#680) * bump vertion to 0.6.0 * update version * pass img_metas while exporting to onnx (#681) * pass img_metas while exporting to onnx * remove try-catch in tools for beter debugging * use get * fix typo * [Fix] fix ssd ncnn ut (#692) * fix ssd ncnn ut * fix yapf * fix passing img_metas to pytorch2onnx for mmedit (#700) * fix passing img_metas for mmdet3d (#707) * [Fix] Fix android build (#698) * fix android build * fix cmake * fix url link * fix wrong exit code in pipeline_manager (#715) * fix exit * change to general exit errorcode=1 * fix passing wrong backend type (#719) * Rename onnx2ncnn to mmdeploy_onnx2ncnn (#694) * improvement(tools/onnx2ncnn.py): rename to mmdeploy_onnx2ncnn * format(tools/deploy.py): clean code * fix(init_plugins.py): improve if condition * fix(CI): update target * fix(test_onnx2ncnn.py): update desc * Update init_plugins.py * [Fix] Fix mmdet ort static shape bug (#687) * fix shape * add device * fix yapf * fix rewriter for transforms * reverse image shape * fix ut of distance2bbox * fix rewriter name * fix c4 for torchscript (#724) * [Enhancement] Standardize C API (#634) * unify C API naming * fix demo and move apis/c/* -> apis/c/mmdeploy/* * fix lint * fix C# project * fix Java API * [Enhancement] Support Slide Vertex TRT (#650) * reorgnize mmrotate * fix * add hbb2obb * add ut * fix rotated nms * update docs * update benchmark * update test * remove ort regression test, remove comment * Fix get-started rendering issues in readthedocs (#740) * fix mermaid markdown rendering issue in readthedocs * fix error in C++ example * fix error in c++ example in zh_cn get_started doc * [Fix] set default topk for dump info (#702) * set default topk for dump info * remove redundant docstrings * add ci densenet * fix classification warnings * fix mmcls version * fix logger.warnings * add version control (#754) * fix satrn for ORT (#753) * fix satrn for ORT * move rewrite into pytorch * Add inference latency test tool (#665) * add profile tool * remove print envs in profile tool * set cudnn_benchmark to True * add doc * update tests * fix typo * support test with images from a directory * update doc * resolve comments * [Enhancement] Add CSE ONNX pass (#647) * Add fuse select assign pass * move code to csrc * add config flag * Add fuse select assign pass * Add CSE for ONNX * remove useless code * Test robot Just test robot * Update README.md Revert * [Fix] fix yolox point_generator (#758) * fix yolox point_generator * add a UT * resolve comments * fix comment lines * limit markdown version (#773) * [Enhancement] Better index put ONNX export. (#704) * Add rewriter for tensor setitem * add version check * Upgrade Dockerfile to use TensorRT==8.2.4.2 (#706) * Upgrade TensorRT to 8.2.4.2 * upgrade pytorch&mmcv in CPU Dockerfile * Delete redundant port example in Docker * change 160x160-608x608 to 64x64-608x608 for yolov3 * [Fix] reduce log verbosity & improve error reporting (#755) * reduce log verbosity & improve error reporting * improve error reporting * [Enhancement] Support latest ppl.nn & ppl.cv (#564) * support latest ppl.nn * fix pplnn for model convertor * fix lint * update memory policy * import algo from buffer * update ppl.cv * use `ppl.cv==0.7.0` * document supported ppl.nn version * skip pplnn dependency when building shared libs * [Fix][P0] Fix for torch1.12 (#751) * fix for torch1.12 * add comment * fix check env (#785) * [Fix] fix cascade mask rcnn (#787) * fix cascade mask rcnn * fix lint * add regression * [Feature] Support RoITransRoIHead (#713) * [Feature] Support RoITransRoIHead * Add docs * Add mmrotate models regression test * Add a draft for test code * change the argument name * fix test code * fix minor change for not class agnostic case * fix sample for test code * fix sample for test code * Add mmrotate in requirements * Revert "Add mmrotate in requirements" This reverts commit 043490075e6dbe4a8fb98e94b2b583b91fc5038d. * [Fix] fix triu (#792) * fix triu * triu -> triu_default * [Enhancement] Install Optimizer by setuptools (#690) * Add fuse select assign pass * move code to csrc * add config flag * Add fuse select assign pass * Add CSE for ONNX * remove useless code * Install optimizer by setup tools * fix comment * [Feature] support MMRotate model with le135 (#788) * support MMRotate model with le135 * cse before fuse select assign * remove unused import * [Fix] Support macOS build (#762) * fix macOS build * fix missing * add option to build & install examples (#822) * [Fix] Fix setup on non-linux-x64 (#811) * fix setup * replace long to int64_t * [Feature] support build single sdk library (#806) * build single lib for c api * update csharp doc & project * update test build * fix test build * fix * update document for building android sdk (#817) Co-authored-by: dwSun <dwsunny@icloud.com> * [Enhancement] support kwargs in SDK python bindings (#794) * support-kwargs * make '__call__' as single image inference and add 'batch' API to deal with batch images inference * fix linting error and typo * fix lint * improvement(sdk): add sdk code coverage (#808) * feat(doc): add CI * CI(sdk): add sdk coverage * style(test): code format * fix(CI): update coverage.info path * improvement(CI): use internal image * improvement(CI): push coverage info once * [Feature] Add C++ API for SDK (#831) * add C++ API * unify result type & add examples * minor fix * install cxx API headers * fix Mat, add more examples * fix monolithic build & fix lint * install examples correctly * fix lint * feat(tools/deploy.py): support snpe (#789) * fix(tools/deploy.py): support snpe * improvement(backend/snpe): review advices * docs(backend/snpe): update build * docs(backend/snpe): server support specify port * docs(backend/snpe): update path * fix(backend/snpe): time counter missing argument * docs(backend/snpe): add missing argument * docs(backend/snpe): update download and using * improvement(snpe_net.cpp): load model with modeldata * Support setup on environment with no PyTorch (#843) * support test with multi batch (#829) * support test with multi batch * resolve comment * import algorithm from buffer (#793) * [Enhancement] build sdk python api in standard-alone manner (#810) * build sdk python api in standard-alone manner * enable MMDEPLOY_BUILD_SDK_MONOLITHIC and MMDEPLOY_BUILD_EXAMPLES in prebuild config * link mmdeploy to python target when monolithic option is on * checkin README to describe precompiled package build procedure * use packaging.version.parse(python_version) instead of list(python_version) * fix according to review results * rebase master * rollback cmake.in and apis/python/CMakeLists.txt * reorganize files in install/example * let cmake detect visual studio instead of specifying 2019 * rename whl name of precompiled package * fix according to review results * Fix SDK backend (#844) * fix mmpose python api (#852) * add prebuild package usage docs on windows (#816) * add prebuild package usage docs on windows * fix lint * update * try fix lint * add en docs * update * update * udpate faq * fix typo (#862) * [Enhancement] Improve get_started documents and bump version to 0.7.0 (#813) * simplify commands in get_started * add installation commands for Windows * fix typo * limit markdown and sphinx_markdown_tables version * adopt html <details open> tag * bump mmdeploy version * bump mmdeploy version * update get_started * update get_started * use python3.8 instead of python3.7 * remove duplicate section * resolve issue #856 * update according to review results * add reference to prebuilt_package_windows.md * fix error when build sdk demos * fix mmcls Co-authored-by: Ryan_Huang <44900829+DrRyanHuang@users.noreply.github.com> Co-authored-by: Chen Xin <xinchen.tju@gmail.com> Co-authored-by: q.yao <yaoqian@sensetime.com> Co-authored-by: zytx121 <592267829@qq.com> Co-authored-by: Li Zhang <lzhang329@gmail.com> Co-authored-by: tripleMu <gpu@163.com> Co-authored-by: tripleMu <865626@163.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Co-authored-by: lvhan028 <lvhan_028@163.com> Co-authored-by: Bryan Glen Suello <11388006+bgsuello@users.noreply.github.com> Co-authored-by: zambranohally <63218980+zambranohally@users.noreply.github.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: Hakjin Lee <nijkah@gmail.com> Co-authored-by: 孙德伟 <5899962+dwSun@users.noreply.github.com> Co-authored-by: dwSun <dwsunny@icloud.com> Co-authored-by: Chen Xin <irexyc@gmail.com>
13 KiB
Get Started
MMDeploy provides useful tools for deploying OpenMMLab models to various platforms and devices.
With the help of them, you can not only do model deployment using our pre-defined pipelines but also customize your own deployment pipeline.
Introduction
In MMDeploy, the deployment pipeline can be illustrated by a sequential modules, i.e., Model Converter, MMDeploy Model and Inference SDK.
Model Converter
Model Converter aims at converting training models from OpenMMLab into backend models that can be run on target devices. It is able to transform PyTorch model into IR model, i.e., ONNX, TorchScript, as well as convert IR model to backend model. By combining them together, we can achieve one-click end-to-end model deployment.
MMDeploy Model
MMDeploy Model is the result package exported by Model Converter. Beside the backend models, it also includes the model meta info, which will be used by Inference SDK.
Inference SDK
Inference SDK is developed by C/C++, wrapping the preprocessing, model forward and postprocessing modules in model inference. It supports FFI such as C, C++, Python, C#, Java and so on.
Prerequisites
In order to do an end-to-end model deployment, MMDeploy requires Python 3.6+ and PyTorch 1.5+.
Step 0. Download and install Miniconda from the official website.
Step 1. Create a conda environment and activate it.
conda create --name mmdeploy python=3.8 -y
conda activate mmdeploy
Step 2. Install PyTorch following official instructions, e.g.
On GPU platforms:
conda install pytorch=={pytorch_version} torchvision=={torchvision_version} cudatoolkit={cudatoolkit_version} -c pytorch -c conda-forge
On CPU platforms:
conda install pytorch=={pytorch_version} torchvision=={torchvision_version} cpuonly -c pytorch
On GPU platform, please ensure that {cudatoolkit_version} matches your host CUDA toolkit version. Otherwise, it probably brings in conflicts when deploying model with TensorRT.
Installation
We recommend that users follow our best practices installing MMDeploy.
Step 0. Install MMCV.
pip install -U openmim
mim install mmcv-full
Step 1. Install MMDeploy and inference engine
We recommend using MMDeploy precompiled package as our best practice. You can download them from here according to your target platform and device.
The supported platform and device matrix is presented as following:
OS-Arch | Device | ONNX Runtime | TensorRT |
---|---|---|---|
Linux-x86_64 | CPU | Y | N/A |
CUDA | N | Y | |
Windows-x86_64 | CPU | Y | N/A |
CUDA | N | Y |
Note: if MMDeploy prebuilt package doesn't meet your target platforms or devices, please build MMDeploy from source
Take the latest precompiled package as example, you can install it as follows:
Linux-x86_64, CPU, ONNX Runtime 1.8.1
# install MMDeploy
wget https://github.com/open-mmlab/mmdeploy/releases/download/v0.7.0/mmdeploy-0.7.0-linux-x86_64-onnxruntime1.8.1.tar.gz
tar -zxvf mmdeploy-0.7.0-linux-x86_64-onnxruntime1.8.1.tar.gz
cd mmdeploy-0.7.0-linux-x86_64-onnxruntime1.8.1
pip install dist/mmdeploy-0.7.0-py3-none-linux_x86_64.whl
pip install sdk/python/mmdeploy_python-0.7.0-cp38-none-linux_x86_64.whl
cd ..
# install inference engine: ONNX Runtime
pip install onnxruntime==1.8.1
wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz
tar -zxvf onnxruntime-linux-x64-1.8.1.tgz
export ONNXRUNTIME_DIR=$(pwd)/onnxruntime-linux-x64-1.8.1
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
Linux-x86_64, CUDA 11.x, TensorRT 8.2.3.0
# install MMDeploy
wget https://github.com/open-mmlab/mmdeploy/releases/download/v0.7.0/mmdeploy-0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0.tar.gz
tar -zxvf mmdeploy-v0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0.tar.gz
cd mmdeploy-0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0
pip install dist/mmdeploy-0.7.0-py3-none-linux_x86_64.whl
pip install sdk/python/mmdeploy_python-0.7.0-cp38-none-linux_x86_64.whl
cd ..
# install inference engine: TensorRT
# !!! Download TensorRT-8.2.3.0 CUDA 11.x tar package from NVIDIA, and extract it to the current directory
pip install TensorRT-8.2.3.0/python/tensorrt-8.2.3.0-cp38-none-linux_x86_64.whl
pip install pycuda
export TENSORRT_DIR=$(pwd)/TensorRT-8.2.3.0
export LD_LIBRARY_PATH=${TENSORRT_DIR}/lib:$LD_LIBRARY_PATH
# !!! Download cuDNN 8.2.1 CUDA 11.x tar package from NVIDIA, and extract it to the current directory
export CUDNN_DIR=$(pwd)/cuda
export LD_LIBRARY_PATH=$CUDNN_DIR/lib64:$LD_LIBRARY_PATH
Windows-x86_64
Please learn its prebuilt package from this guide.
Convert Model
After the installation, you can enjoy the model deployment journey starting from converting PyTorch model to backend model by running tools/deploy.py
.
Based on the above settings, we provide an example to convert the Faster R-CNN in MMDetection to TensorRT as below:
# clone mmdeploy to get the deployment config. `--recursive` is not necessary
git clone https://github.com/open-mmlab/mmdeploy.git
# clone mmdetection repo. We have to use the config file to build PyTorch nn module
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection
pip install -v -e .
cd ..
# download Faster R-CNN checkpoint
wget -P checkpoints https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
# run the command to start model conversion
python mmdeploy/tools/deploy.py \
mmdeploy/configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py \
mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth \
mmdetection/demo/demo.jpg \
--work-dir mmdeploy_model/faster-rcnn \
--device cuda \
--dump-info
The converted model and its meta info will be found in the path specified by --work-dir
.
And they make up of MMDeploy Model that can be fed to MMDeploy SDK to do model inference.
For more details about model conversion, you can read how_to_convert_model. If you want to customize the conversion pipeline, you can edit the config file by following this tutorial.
If MMDeploy-ONNXRuntime prebuild package is installed, you can convert the above model to onnx model and perform ONNX Runtime inference
just by 'changing detection_tensorrt_dynamic-320x320-1344x1344.py' to 'detection_onnxruntime_dynamic.py' and making '--device' as 'cpu'.
Inference Model
After model conversion, we can perform inference not only by Model Converter but also by Inference SDK.
Inference by Model Converter
Model Converter provides a unified API named as inference_model
to do the job, making all inference backends API transparent to users.
Take the previous converted Faster R-CNN tensorrt model for example,
from mmdeploy.apis import inference_model
result = inference_model(
model_cfg='mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py',
deploy_cfg='mmdeploy/configs/mmdet/detection/detection_tensorrt_dynamic-320x320-1344x1344.py',
backend_files=['mmdeploy_model/faster-rcnn/end2end.engine'],
img='mmdetection/demo/demo.jpg',
device='cuda:0')
'backend_files' in this API refers to backend engine file path, which MUST be put in a list, since some inference engines like OpenVINO and ncnn separate the network structure and its weights into two files.
Inference by SDK
You can directly run MMDeploy demo programs in the precompiled package to get inference results.
cd mmdeploy-0.7.0-linux-x86_64-cuda11.1-tensorrt8.2.3.0
# run python demo
python sdk/example/python/object_detection.py cuda ../mmdeploy_model/faster-rcnn ../mmdetection/demo/demo.jpg
# run C/C++ demo
export LD_LIBRARY_PATH=$(pwd)/sdk/lib:$LD_LIBRARY_PATH
./sdk/bin/object_detection cuda ../mmdeploy_model/faster-rcnn ../mmdetection/demo/demo.jpg
In the above command, the input model is SDK Model path. It is NOT engine file path but actually the path passed to --work-dir. It not only includes engine files but also meta information like 'deploy.json' and 'pipeline.json'.
In the next section, we will provide examples of deploying the converted Faster R-CNN model talked above with SDK different FFI (Foreign Function Interface).
Python API
from mmdeploy_python import Detector
import cv2
img = cv2.imread('mmdetection/demo/demo.jpg')
# create a detector
detector = Detector(model_path='mmdeploy_models/faster-rcnn', device_name='cuda', device_id=0)
# run the inference
bboxes, labels, _ = detector(img)
# Filter the result according to threshold
indices = [i for i in range(len(bboxes))]
for index, bbox, label_id in zip(indices, bboxes, labels):
[left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
if score < 0.3:
continue
cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0))
cv2.imwrite('output_detection.png', img)
You can find more examples from here.
C++ API
Using SDK C++ API should follow next pattern,
Now let's apply this procedure on the above Faster R-CNN model.
#include <cstdlib>
#include <opencv2/opencv.hpp>
#include "mmdeploy/detector.hpp"
int main() {
const char* device_name = "cuda";
int device_id = 0;
std::string model_path = "mmdeploy_model/faster-rcnn";
std::string image_path = "mmdetection/demo/demo.jpg";
// 1. load model
mmdeploy::Model model(model_path);
// 2. create predictor
mmdeploy::Detector detector(model, mmdeploy::Device{device_name, device_id});
// 3. read image
cv::Mat img = cv::imread(image_path);
// 4. inference
auto dets = detector.Apply(img);
// 5. deal with the result. Here we choose to visualize it
for (int i = 0; i < dets.size(); ++i) {
const auto& box = dets[i].bbox;
fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n",
i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score);
if (bboxes[i].score < 0.3) {
continue;
}
cv::rectangle(img, cv::Point{(int)box.left, (int)box.top},
cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0});
}
cv::imwrite("output_detection.png", img);
return 0;
}
When you build this example, try to add MMDeploy package in your CMake project as following. Then pass -DMMDeploy_DIR
to cmake, which indicates the path where MMDeployConfig.cmake
locates. You can find it in the prebuilt package.
find_package(MMDeploy REQUIRED)
target_link_libraries(${name} PRIVATE mmdeploy ${OpenCV_LIBS})
For more SDK C++ API usages, please read these samples.
For the rest C, C# and Java API usages, please read C demos, C# demos and Java demos respectively. We'll talk about them more in our next release.
Evaluate Model
You can test the performance of deployed model using tool/test.py
. For example,
python ${MMDEPLOY_DIR}/tools/test.py \
${MMDEPLOY_DIR}/configs/detection/detection_tensorrt_dynamic-320x320-1344x1344.py \
${MMDET_DIR}/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py \
--model ${BACKEND_MODEL_FILES} \
--metrics ${METRICS} \
--device cuda:0
Regarding the --model option, it represents the converted engine files path when using Model Converter to do performance test. But when you try to test the metrics by Inference SDK, this option refers to the directory path of MMDeploy Model.
You can read how to evaluate a model for more details.