[Feature]: Add custom operators support for onnxruntime in mmcv (#612)

* add onnx support to roi_align and roi_pool

* add softnms ort support

* fix for lint

* format cpp code with clang-format:google

* add new empty line to the end of head files in onnxruntime

* update to pytorch1.7

* add test of softnms to onnxruntime

* fix for lint

* remote print in ops/info.py

* change import order, fix for flake8

* fix include

* add assert torch>=1.7.0

* [doc]: add document for onnxruntime custom operator

* update onnxruntime version to v1.5.1 for softnms

* remove doc menu

* Resolve lint for markdown

* resolve naming style in onnxruntime_op.md

* Use old cpp apis, optimize test_onnx.py

* Fixing strings in tests/test_ops/test_onnx.py

* code format with yapf

* fix soft_nms parrot

* add import in onnxruntime setup, avoid conflict

* fix doc and add assert

* change cpp guard

Co-authored-by: maningsheng <maningsheng@sensetime.com>
pull/754/head
q.yao 2020-12-23 11:03:55 +08:00 committed by GitHub
parent 8b4e5de43d
commit 94810f2297
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 607 additions and 13 deletions

View File

@ -168,6 +168,10 @@ Another way is to compile locally by running
pip install mmcv-full
```
c. Install full version with custom operators for onnxruntime
- Check [here](docs/onnxruntime_op.md) for detailed instruction.
Note that the local compiling may take up to 10 mins.
If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/build.html).

View File

@ -0,0 +1,119 @@
# Custom operators for ONNX Runtime in MMCV
## Introduction of ONNX Runtime
**ONNX Runtime** is a cross-platform inferencing and training accelerator compatible with many popular ML/DNN frameworks. Check its [github](https://github.com/microsoft/onnxruntime) for more information.
## Introduction of ONNX
**ONNX** stands for **Open Neural Network Exchange**, which acts as *Intermediate Representation(IR)* for ML/DNN models from many frameworks. Check its [github](https://github.com/onnx/onnx) for more information.
## Why include custom operators for ONNX Runtime in MMCV
- To verify the correctness of exported ONNX models in ONNX Runtime.
- To ease the deployment of ONNX models with custom operators from `mmcv.ops` in ONNX Runtime.
## List of operators for ONNX Runtime supported in MMCV
| Operator | CPU | GPU | Note |
| :------: | :---: | :---: | :---: |
| SoftNMS | Y | N | None |
## How to build custom operators for ONNX Runtime
*Please be noted that only **onnxruntime>=1.5.1** of CPU version on Linux platform is tested by now.*
### Prerequisite
- Clone repository
```bash
git clone https://github.com/open-mmlab/mmcv.git
```
- Download `onnxruntime-linux-x64-1.5.1.tgz` from ONNX Runtime [releases](https://github.com/microsoft/onnxruntime/releases/tag/v1.5.1), extract it, expose `ONNXRUNTIME_DIR` and finally add the lib path to `LD_LIBRARY_PATH` as below:
```bash
wget https://github.com/microsoft/onnxruntime/releases/download/v1.5.1/onnxruntime-linux-x64-1.5.1.tgz
tar -zxvf onnxruntime-linux-x64-1.5.1.tgz
cd onnxruntime-linux-x64-1.5.1
export ONNXRUNTIME_DIR=$(pwd)
export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH
```
### Build on Linux
```bash
cd mmcv # to MMCV root directory
MMCV_WITH_OPS=1 MMCV_WITH_ORT=1 pip install -e .
```
## How to do inference using exported ONNX models with custom operators in ONNX Runtime in python
Install ONNX Runtime with `pip`
```bash
pip install onnxruntime==1.5.1
```
Inference Demo
```python
import os
import numpy as np
import onnxruntime as ort
from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
assert os.path.exists(ort_custom_op_path)
session_options = ort.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
# exported ONNX model with custom operators
onnx_file = 'sample.onnx'
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
sess = ort.InferenceSession(onnx_file, session_options)
onnx_results = sess.run(None, {'input' : input_data})
```
## How to add a new custom operator for ONNX Runtime in MMCV
### Reminder
- The custom operator is not included in [supported operator list](https://github.com/microsoft/onnxruntime/blob/master/docs/OperatorKernels.md) in ONNX Runtime.
- The custom operator should be able to be exported to ONNX.
### Main procedures
Take custom operator `soft_nms` for example.
1. Add header `soft_nms.h` to ONNX Runtime include directory `mmcv/ops/csrc/onnxruntime/`
2. Add source `soft_nms.cpp` to ONNX Runtime source directory `mmcv/ops/csrc/onnxruntime/cpu/`
3. Register `soft_nms` operator in [onnxruntime_register.cpp](../mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp)
```c++
#include "soft_nms.h"
SoftNmsOp c_SoftNmsOp;
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_SoftNmsOp)) {
return status;
}
```
4. Add unit test into `tests/test_ops/test_onnx.py`
Check [here](../tests/test_ops/test_onnx.py) for examples.
**Finally, welcome to send us PR of adding custom operators for ONNX Runtime in MMCV.** :nerd_face:
## Known Issues
- None
## References
- [How to export Pytorch model with custom op to ONNX and run it in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md)
- [How to add a custom operator/kernel in ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/master/docs/AddingCustomOp.md)

View File

@ -12,7 +12,8 @@ from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .info import get_compiler_version, get_compiling_cuda_version
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
@ -33,8 +34,9 @@ __all__ = [
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d',
'masked_conv2d', 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'get_compiler_version', 'get_compiling_cuda_version',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',

View File

@ -0,0 +1,23 @@
#include "onnxruntime_register.h"
#include "ort_mmcv_utils.h"
#include "soft_nms.h"
const char *c_MMCVOpDomain = "mmcv";
SoftNmsOp c_SoftNmsOp;
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) {
OrtCustomOpDomain *domain = nullptr;
const OrtApi *ortApi = api->GetApi(ORT_API_VERSION);
if (auto status = ortApi->CreateCustomOpDomain(c_MMCVOpDomain, &domain)) {
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_SoftNmsOp)) {
return status;
}
return ortApi->AddCustomOpDomain(options, domain);
}

View File

@ -0,0 +1,155 @@
#include "soft_nms.h"
#include <assert.h>
#include <algorithm>
#include <cmath>
#include "../ort_mmcv_utils.h"
SoftNmsKernel::SoftNmsKernel(OrtApi api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
sigma_ = ort_.KernelInfoGetAttribute<float>(info, "sigma");
min_score_ = ort_.KernelInfoGetAttribute<float>(info, "min_score");
method_ = ort_.KernelInfoGetAttribute<int64_t>(info, "method");
offset_ = ort_.KernelInfoGetAttribute<int64_t>(info, "offset");
// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}
void SoftNmsKernel::Compute(OrtKernelContext *context) {
typedef float T;
const T iou_threshold = T(iou_threshold_);
const T sigma = T(sigma_);
const T min_score = T(min_score_);
const int method = int(method_);
const T offset = T(offset_);
const OrtValue *boxes = ort_.KernelContext_GetInput(context, 0);
const T *boxes_data =
reinterpret_cast<const float *>(ort_.GetTensorData<T>(boxes));
const OrtValue *scores = ort_.KernelContext_GetInput(context, 1);
const T *scores_data =
reinterpret_cast<const float *>(ort_.GetTensorData<T>(scores));
OrtTensorDimensions boxes_dim(ort_, boxes);
OrtTensorDimensions scores_dim(ort_, scores);
int64_t nboxes = boxes_dim[0];
assert(boxes_dim[1] == 4);
// allocate tmp memory
T *tmp_boxes = (T *)allocator_.Alloc(sizeof(T) * nboxes * 4);
T *x1 = tmp_boxes;
T *y1 = tmp_boxes + 1;
T *x2 = tmp_boxes + 2;
T *y2 = tmp_boxes + 3;
T *sc = (T *)allocator_.Alloc(sizeof(T) * nboxes);
T *areas = (T *)allocator_.Alloc(sizeof(T) * nboxes);
T *de = (T *)allocator_.Alloc(sizeof(T) * nboxes * 5);
int64_t *inds = (int64_t *)allocator_.Alloc(sizeof(int64_t) * nboxes);
memcpy(tmp_boxes, boxes_data, sizeof(T) * nboxes * 4);
memcpy(sc, scores_data, sizeof(T) * nboxes);
// init inds as arange(nboxes)
std::generate(inds, inds + nboxes, [n = 0]() mutable { return n++; });
// area = (x2-x1+offset)*(y2-y1+offset)
for (int64_t i = 0; i < nboxes; i++) {
areas[i] =
(x2[i * 4] - x1[i * 4] + offset) * (y2[i * 4] - y1[i * 4] + offset);
}
int64_t pos = 0;
for (int64_t i = 0; i < nboxes; i++) {
auto max_score = sc[i];
auto max_pos = i;
pos = i + 1;
// get max box
while (pos < nboxes) {
if (max_score < sc[pos]) {
max_score = sc[pos];
max_pos = pos;
}
pos = pos + 1;
}
// swap
auto ix1 = de[i * 5 + 0] = x1[max_pos * 4];
auto iy1 = de[i * 5 + 1] = y1[max_pos * 4];
auto ix2 = de[i * 5 + 2] = x2[max_pos * 4];
auto iy2 = de[i * 5 + 3] = y2[max_pos * 4];
auto iscore = de[i * 5 + 4] = sc[max_pos];
auto iarea = areas[max_pos];
auto iind = inds[max_pos];
x1[max_pos * 4] = x1[i * 4];
y1[max_pos * 4] = y1[i * 4];
x2[max_pos * 4] = x2[i * 4];
y2[max_pos * 4] = y2[i * 4];
sc[max_pos] = sc[i];
areas[max_pos] = areas[i];
inds[max_pos] = inds[i];
x1[i * 4] = ix1;
y1[i * 4] = iy1;
x2[i * 4] = ix2;
y2[i * 4] = iy2;
sc[i] = iscore;
areas[i] = iarea;
inds[i] = iind;
pos = i + 1;
while (pos < nboxes) {
auto xx1 = std::max(ix1, x1[pos * 4]);
auto yy1 = std::max(iy1, y1[pos * 4]);
auto xx2 = std::min(ix2, x2[pos * 4]);
auto yy2 = std::min(iy2, y2[pos * 4]);
auto w = std::max(0.f, xx2 - xx1 + offset);
auto h = std::max(0.f, yy2 - yy1 + offset);
auto inter = w * h;
auto ovr = inter / (iarea + areas[pos] - inter);
float weight = 1.;
if (method == 0) {
if (ovr >= iou_threshold) weight = 0;
} else if (method == 1) {
if (ovr >= iou_threshold) weight = 1 - ovr;
} else if (method == 2) {
weight = std::exp(-(ovr * ovr) / sigma);
}
sc[pos] *= weight;
// if box score falls below threshold, discard the box by
// swapping with last box update N
if (sc[pos] < min_score) {
x1[pos * 4] = x1[(nboxes - 1) * 4];
y1[pos * 4] = y1[(nboxes - 1) * 4];
x2[pos * 4] = x2[(nboxes - 1) * 4];
y2[pos * 4] = y2[(nboxes - 1) * 4];
sc[pos] = sc[nboxes - 1];
areas[pos] = areas[nboxes - 1];
inds[pos] = inds[nboxes - 1];
nboxes = nboxes - 1;
pos = pos - 1;
}
pos = pos + 1;
}
}
std::vector<int64_t> dets_dim({nboxes, 5});
OrtValue *dets = ort_.KernelContext_GetOutput(context, 0, dets_dim.data(),
dets_dim.size());
T *dets_data = ort_.GetTensorMutableData<T>(dets);
std::vector<int64_t> inds_dim({nboxes});
OrtValue *inds_ov = ort_.KernelContext_GetOutput(context, 1, inds_dim.data(),
inds_dim.size());
int64_t *inds_data = ort_.GetTensorMutableData<int64_t>(inds_ov);
memcpy(dets_data, de, sizeof(T) * nboxes * 5);
memcpy(inds_data, inds, sizeof(int64_t) * nboxes);
}

View File

@ -0,0 +1,15 @@
#ifndef ONNXRUNTIME_REGISTER_H
#define ONNXRUNTIME_REGISTER_H
#include <onnxruntime_c_api.h>
#ifdef __cplusplus
extern "C" {
#endif
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api);
#ifdef __cplusplus
}
#endif
#endif // ONNXRUNTIME_REGISTER_H

View File

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H
#define ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H
/*
* This file defines SessionOptions Config Keys and format of the Config Values.
*
* The Naming Convention for a SessionOptions Config Key,
* "[Area][.[SubArea1].[SubArea2]...].[Keyname]"
* Such as "ep.cuda.use_arena"
* The Config Key cannot be empty
* The maximum length of the Config Key is 128
*
* The string format of a SessionOptions Config Value is defined individually
* for each Config. The maximum length of the Config Value is 1024
*/
// Key for disable PrePacking,
// If the config value is set to "1" then the prepacking is disabled, otherwise
// prepacking is enabled (default value)
static const char* const kOrtSessionOptionsConfigDisablePrepacking =
"session.disable_prepacking";
// A value of "1" means allocators registered in the env will be used. "0" means
// the allocators created in the session will be used. Use this to override the
// usage of env allocators on a per session level.
static const char* const kOrtSessionOptionsConfigUseEnvAllocators =
"session.use_env_allocators";
// Set to 'ORT' (case sensitive) to load an ORT format model.
// If unset, model type will default to ONNX unless inferred from filename
// ('.ort' == ORT format) or bytes to be ORT
static const char* const kOrtSessionOptionsConfigLoadModelFormat =
"session.load_model_format";
// Set to 'ORT' (case sensitive) to save optimized model in ORT format when
// SessionOptions.optimized_model_path is set. If unset, format will default to
// ONNX unless optimized_model_filepath ends in '.ort'.
static const char* const kOrtSessionOptionsConfigSaveModelFormat =
"session.save_model_format";
#endif // ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H

View File

@ -0,0 +1,14 @@
#ifndef ORT_MMCV_UTILS_H
#define ORT_MMCV_UTILS_H
#include <onnxruntime_cxx_api.h>
#include <vector>
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
ort.ReleaseTensorTypeAndShapeInfo(info);
}
};
#endif // ORT_MMCV_UTILS_H

View File

@ -0,0 +1,48 @@
#ifndef ONNXRUNTIME_SOFT_NMS_H
#define ONNXRUNTIME_SOFT_NMS_H
#include <onnxruntime_cxx_api.h>
struct SoftNmsKernel {
SoftNmsKernel(OrtApi api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
float iou_threshold_;
float sigma_;
float min_score_;
int64_t method_;
int64_t offset_;
};
struct SoftNmsOp : Ort::CustomOpBase<SoftNmsOp, SoftNmsKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) {
return new SoftNmsKernel(api, info);
};
const char *GetName() const { return "SoftNonMaxSuppression"; };
size_t GetInputTypeCount() const { return 2; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
size_t GetOutputTypeCount() const { return 2; };
ONNXTensorElementDataType GetOutputType(size_t index) const {
if (index == 1) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
// force cpu
const char *GetExecutionProviderType() const {
return "CPUExecutionProvider";
};
};
#endif // ONNXRUNTIME_SOFT_NMS_H

View File

@ -1,3 +1,6 @@
import glob
import os
import torch
if torch.__version__ == 'parrots':
@ -18,3 +21,15 @@ else:
def get_compiling_cuda_version():
return ext_module.get_compiling_cuda_version()
def get_onnxruntime_op_path():
wildcard = os.path.join(
os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
'_ext_ort.*.so')
paths = glob.glob(wildcard)
if len(paths) > 0:
return paths[0]
else:
return ''

View File

@ -39,6 +39,41 @@ class NMSop(torch.autograd.Function):
1)
class SoftNMSop(torch.autograd.Function):
@staticmethod
def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method,
offset):
dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
inds = ext_module.softnms(
boxes.cpu(),
scores.cpu(),
dets.cpu(),
iou_threshold=float(iou_threshold),
sigma=float(sigma),
min_score=float(min_score),
method=int(method),
offset=int(offset))
return dets, inds
@staticmethod
def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
offset):
from packaging import version
assert version.parse(torch.__version__) >= version.parse('1.7.0')
nms_out = g.op(
'mmcv::SoftNonMaxSuppression',
boxes,
scores,
iou_threshold_f=float(iou_threshold),
sigma_f=float(sigma),
min_score_f=float(min_score),
method_i=int(method),
offset_i=int(offset),
outputs=2)
return nms_out
@deprecated_api_warning({'iou_thr': 'iou_threshold'})
def nms(boxes, scores, iou_threshold, offset=0):
"""Dispatch to either CPU or GPU NMS implementations.
@ -191,17 +226,12 @@ def soft_nms(boxes,
dets, inds, num_out = ext_module.softnms(*indata_list, **indata_dict)
inds = inds[:num_out]
else:
dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
inds = ext_module.softnms(
boxes.cpu(),
scores.cpu(),
dets.cpu(),
iou_threshold=float(iou_threshold),
sigma=float(sigma),
min_score=float(min_score),
method=method_dict[method],
offset=int(offset))
dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(),
float(iou_threshold), float(sigma),
float(min_score), method_dict[method],
int(offset))
dets = dets[:inds.size(0)]
if is_numpy:
dets = dets.cpu().numpy()
inds = inds.cpu().numpy()

View File

@ -182,6 +182,52 @@ def get_extensions():
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)
if EXT_TYPE == 'pytorch' and os.getenv('MMCV_WITH_ORT', '0') != '0':
ext_name = 'mmcv._ext_ort'
from torch.utils.cpp_extension import library_paths, include_paths
import onnxruntime
library_dirs = []
libraries = []
include_dirs = []
ort_path = os.getenv('ONNXRUNTIME_DIR', '0')
library_dirs += [os.path.join(ort_path, 'lib')]
libraries.append('onnxruntime')
kwargs = {}
define_macros = []
extra_compile_args = {'cxx': []}
include_path = os.path.abspath('./mmcv/ops/csrc/onnxruntime')
include_dirs.append(include_path)
include_dirs.append(os.path.join(ort_path, 'include'))
include_dirs += include_paths(cuda=True)
op_files = glob.glob('./mmcv/ops/csrc/onnxruntime/cpu/*')
if onnxruntime.get_device() == 'GPU' or os.getenv('FORCE_CUDA',
'0') == '1':
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files += glob.glob('./mmcv/ops/csrc/onnxruntime/gpu/*')
library_dirs += library_paths(cuda=True)
else:
library_dirs += library_paths(cuda=False)
kwargs['library_dirs'] = library_dirs
kwargs['libraries'] = libraries
from setuptools import Extension
ext_ops = Extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
language='c++',
library_dirs=library_dirs,
libraries=libraries)
extensions.append(ext_ops)
return extensions

View File

@ -1,9 +1,11 @@
import os
import warnings
from functools import partial
import numpy as np
import onnx
import onnxruntime as rt
import pytest
import torch
import torch.nn as nn
@ -58,6 +60,83 @@ def test_nms():
assert np.allclose(pytorch_score, onnx_score, atol=1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason='CUDA is unavailable for test_softnms')
def test_softnms():
from mmcv.ops import get_onnxruntime_op_path, soft_nms
from packaging import version
# only support pytorch >= 1.7.0
if version.parse(torch.__version__) < version.parse('1.7.0'):
warnings.warn('test_softnms should be ran with pytorch >= 1.7.0')
return
# only support onnxruntime >= 1.5.1
assert version.parse(rt.__version__) >= version.parse(
'1.5.1'), 'test_softnms should be ran with onnxruntime >= 1.5.1'
ort_custom_op_path = get_onnxruntime_op_path()
assert os.path.exists(ort_custom_op_path)
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
configs = [[0.3, 0.5, 0.01, 'linear'], [0.3, 0.5, 0.01, 'gaussian'],
[0.3, 0.5, 0.01, 'naive']]
session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
for _iou_threshold, _sigma, _min_score, _method in configs:
pytorch_dets, pytorch_inds = soft_nms(
boxes,
scores,
iou_threshold=_iou_threshold,
sigma=_sigma,
min_score=_min_score,
method=_method)
nms = partial(
soft_nms,
iou_threshold=_iou_threshold,
sigma=_sigma,
min_score=_min_score,
method=_method)
wrapped_model = WrapFunction(nms)
wrapped_model.cpu().eval()
with torch.no_grad():
torch.onnx.export(
wrapped_model, (boxes, scores),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['boxes', 'scores'],
opset_version=11)
onnx_model = onnx.load(onnx_file)
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 2)
sess = rt.InferenceSession(onnx_file, session_options)
onnx_dets, onnx_inds = sess.run(None, {
'scores': scores.detach().numpy(),
'boxes': boxes.detach().numpy()
})
os.remove(onnx_file)
assert np.allclose(pytorch_dets, onnx_dets, atol=1e-3)
assert np.allclose(onnx_inds, onnx_inds, atol=1e-3)
def test_roialign():
from mmcv.ops import roi_align