diff --git a/README.md b/README.md index 0a4bede26..f4ab70e19 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,10 @@ Another way is to compile locally by running pip install mmcv-full ``` +c. Install full version with custom operators for onnxruntime + +- Check [here](docs/onnxruntime_op.md) for detailed instruction. + Note that the local compiling may take up to 10 mins. If you would like to build MMCV from source, please refer to the [guide](https://mmcv.readthedocs.io/en/latest/build.html). diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md new file mode 100644 index 000000000..854b1cd37 --- /dev/null +++ b/docs/onnxruntime_op.md @@ -0,0 +1,119 @@ +# Custom operators for ONNX Runtime in MMCV + +## Introduction of ONNX Runtime + +**ONNX Runtime** is a cross-platform inferencing and training accelerator compatible with many popular ML/DNN frameworks. Check its [github](https://github.com/microsoft/onnxruntime) for more information. + +## Introduction of ONNX + +**ONNX** stands for **Open Neural Network Exchange**, which acts as *Intermediate Representation(IR)* for ML/DNN models from many frameworks. Check its [github](https://github.com/onnx/onnx) for more information. + +## Why include custom operators for ONNX Runtime in MMCV + +- To verify the correctness of exported ONNX models in ONNX Runtime. +- To ease the deployment of ONNX models with custom operators from `mmcv.ops` in ONNX Runtime. + +## List of operators for ONNX Runtime supported in MMCV + +| Operator | CPU | GPU | Note | +| :------: | :---: | :---: | :---: | +| SoftNMS | Y | N | None | + +## How to build custom operators for ONNX Runtime + +*Please be noted that only **onnxruntime>=1.5.1** of CPU version on Linux platform is tested by now.* + +### Prerequisite + +- Clone repository + +```bash +git clone https://github.com/open-mmlab/mmcv.git +``` + +- Download `onnxruntime-linux-x64-1.5.1.tgz` from ONNX Runtime [releases](https://github.com/microsoft/onnxruntime/releases/tag/v1.5.1), extract it, expose `ONNXRUNTIME_DIR` and finally add the lib path to `LD_LIBRARY_PATH` as below: + +```bash + +wget https://github.com/microsoft/onnxruntime/releases/download/v1.5.1/onnxruntime-linux-x64-1.5.1.tgz + +tar -zxvf onnxruntime-linux-x64-1.5.1.tgz +cd onnxruntime-linux-x64-1.5.1 +export ONNXRUNTIME_DIR=$(pwd) +export LD_LIBRARY_PATH=$ONNXRUNTIME_DIR/lib:$LD_LIBRARY_PATH +``` + +### Build on Linux + +```bash +cd mmcv # to MMCV root directory +MMCV_WITH_OPS=1 MMCV_WITH_ORT=1 pip install -e . +``` + +## How to do inference using exported ONNX models with custom operators in ONNX Runtime in python + +Install ONNX Runtime with `pip` + +```bash +pip install onnxruntime==1.5.1 +``` + +Inference Demo + +```python +import os + +import numpy as np +import onnxruntime as ort + +from mmcv.ops import get_onnxruntime_op_path + +ort_custom_op_path = get_onnxruntime_op_path() +assert os.path.exists(ort_custom_op_path) +session_options = ort.SessionOptions() +session_options.register_custom_ops_library(ort_custom_op_path) +# exported ONNX model with custom operators +onnx_file = 'sample.onnx' +input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) +sess = ort.InferenceSession(onnx_file, session_options) +onnx_results = sess.run(None, {'input' : input_data}) +``` + +## How to add a new custom operator for ONNX Runtime in MMCV + +### Reminder + +- The custom operator is not included in [supported operator list](https://github.com/microsoft/onnxruntime/blob/master/docs/OperatorKernels.md) in ONNX Runtime. +- The custom operator should be able to be exported to ONNX. + +### Main procedures + +Take custom operator `soft_nms` for example. + +1. Add header `soft_nms.h` to ONNX Runtime include directory `mmcv/ops/csrc/onnxruntime/` +2. Add source `soft_nms.cpp` to ONNX Runtime source directory `mmcv/ops/csrc/onnxruntime/cpu/` +3. Register `soft_nms` operator in [onnxruntime_register.cpp](../mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp) + + ```c++ + #include "soft_nms.h" + + SoftNmsOp c_SoftNmsOp; + + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_SoftNmsOp)) { + return status; + } + ``` + +4. Add unit test into `tests/test_ops/test_onnx.py` + Check [here](../tests/test_ops/test_onnx.py) for examples. + +**Finally, welcome to send us PR of adding custom operators for ONNX Runtime in MMCV.** :nerd_face: + +## Known Issues + +- None + +## References + +- [How to export Pytorch model with custom op to ONNX and run it in ONNX Runtime](https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md) +- [How to add a custom operator/kernel in ONNX Runtime](https://github.com/microsoft/onnxruntime/blob/master/docs/AddingCustomOp.md) diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 12030d856..b36315611 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -12,7 +12,8 @@ from .deprecated_wrappers import Linear_deprecated as Linear from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, sigmoid_focal_loss, softmax_focal_loss) -from .info import get_compiler_version, get_compiling_cuda_version +from .info import (get_compiler_version, get_compiling_cuda_version, + get_onnxruntime_op_path) from .masked_conv import MaskedConv2d, masked_conv2d from .modulated_deform_conv import (ModulatedDeformConv2d, ModulatedDeformConv2dPack, @@ -33,8 +34,9 @@ __all__ = [ 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack', 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss', 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss', - 'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d', - 'masked_conv2d', 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', + 'get_compiler_version', 'get_compiling_cuda_version', + 'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d', + 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack', 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp new file mode 100644 index 000000000..a01be6995 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -0,0 +1,23 @@ +#include "onnxruntime_register.h" + +#include "ort_mmcv_utils.h" +#include "soft_nms.h" + +const char *c_MMCVOpDomain = "mmcv"; +SoftNmsOp c_SoftNmsOp; + +OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, + const OrtApiBase *api) { + OrtCustomOpDomain *domain = nullptr; + const OrtApi *ortApi = api->GetApi(ORT_API_VERSION); + + if (auto status = ortApi->CreateCustomOpDomain(c_MMCVOpDomain, &domain)) { + return status; + } + + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_SoftNmsOp)) { + return status; + } + + return ortApi->AddCustomOpDomain(options, domain); +} diff --git a/mmcv/ops/csrc/onnxruntime/cpu/soft_nms.cpp b/mmcv/ops/csrc/onnxruntime/cpu/soft_nms.cpp new file mode 100644 index 000000000..efab504f5 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/cpu/soft_nms.cpp @@ -0,0 +1,155 @@ +#include "soft_nms.h" + +#include + +#include +#include + +#include "../ort_mmcv_utils.h" + +SoftNmsKernel::SoftNmsKernel(OrtApi api, const OrtKernelInfo *info) + : api_(api), ort_(api_), info_(info) { + iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); + sigma_ = ort_.KernelInfoGetAttribute(info, "sigma"); + min_score_ = ort_.KernelInfoGetAttribute(info, "min_score"); + method_ = ort_.KernelInfoGetAttribute(info, "method"); + offset_ = ort_.KernelInfoGetAttribute(info, "offset"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); +} + +void SoftNmsKernel::Compute(OrtKernelContext *context) { + typedef float T; + + const T iou_threshold = T(iou_threshold_); + const T sigma = T(sigma_); + const T min_score = T(min_score_); + const int method = int(method_); + const T offset = T(offset_); + + const OrtValue *boxes = ort_.KernelContext_GetInput(context, 0); + const T *boxes_data = + reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue *scores = ort_.KernelContext_GetInput(context, 1); + const T *scores_data = + reinterpret_cast(ort_.GetTensorData(scores)); + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + + int64_t nboxes = boxes_dim[0]; + assert(boxes_dim[1] == 4); + + // allocate tmp memory + T *tmp_boxes = (T *)allocator_.Alloc(sizeof(T) * nboxes * 4); + T *x1 = tmp_boxes; + T *y1 = tmp_boxes + 1; + T *x2 = tmp_boxes + 2; + T *y2 = tmp_boxes + 3; + T *sc = (T *)allocator_.Alloc(sizeof(T) * nboxes); + T *areas = (T *)allocator_.Alloc(sizeof(T) * nboxes); + T *de = (T *)allocator_.Alloc(sizeof(T) * nboxes * 5); + int64_t *inds = (int64_t *)allocator_.Alloc(sizeof(int64_t) * nboxes); + + memcpy(tmp_boxes, boxes_data, sizeof(T) * nboxes * 4); + memcpy(sc, scores_data, sizeof(T) * nboxes); + + // init inds as arange(nboxes) + std::generate(inds, inds + nboxes, [n = 0]() mutable { return n++; }); + + // area = (x2-x1+offset)*(y2-y1+offset) + for (int64_t i = 0; i < nboxes; i++) { + areas[i] = + (x2[i * 4] - x1[i * 4] + offset) * (y2[i * 4] - y1[i * 4] + offset); + } + + int64_t pos = 0; + + for (int64_t i = 0; i < nboxes; i++) { + auto max_score = sc[i]; + auto max_pos = i; + + pos = i + 1; + // get max box + while (pos < nboxes) { + if (max_score < sc[pos]) { + max_score = sc[pos]; + max_pos = pos; + } + pos = pos + 1; + } + // swap + auto ix1 = de[i * 5 + 0] = x1[max_pos * 4]; + auto iy1 = de[i * 5 + 1] = y1[max_pos * 4]; + auto ix2 = de[i * 5 + 2] = x2[max_pos * 4]; + auto iy2 = de[i * 5 + 3] = y2[max_pos * 4]; + auto iscore = de[i * 5 + 4] = sc[max_pos]; + auto iarea = areas[max_pos]; + auto iind = inds[max_pos]; + x1[max_pos * 4] = x1[i * 4]; + y1[max_pos * 4] = y1[i * 4]; + x2[max_pos * 4] = x2[i * 4]; + y2[max_pos * 4] = y2[i * 4]; + sc[max_pos] = sc[i]; + areas[max_pos] = areas[i]; + inds[max_pos] = inds[i]; + x1[i * 4] = ix1; + y1[i * 4] = iy1; + x2[i * 4] = ix2; + y2[i * 4] = iy2; + sc[i] = iscore; + areas[i] = iarea; + inds[i] = iind; + + pos = i + 1; + while (pos < nboxes) { + auto xx1 = std::max(ix1, x1[pos * 4]); + auto yy1 = std::max(iy1, y1[pos * 4]); + auto xx2 = std::min(ix2, x2[pos * 4]); + auto yy2 = std::min(iy2, y2[pos * 4]); + + auto w = std::max(0.f, xx2 - xx1 + offset); + auto h = std::max(0.f, yy2 - yy1 + offset); + auto inter = w * h; + auto ovr = inter / (iarea + areas[pos] - inter); + + float weight = 1.; + if (method == 0) { + if (ovr >= iou_threshold) weight = 0; + } else if (method == 1) { + if (ovr >= iou_threshold) weight = 1 - ovr; + } else if (method == 2) { + weight = std::exp(-(ovr * ovr) / sigma); + } + sc[pos] *= weight; + // if box score falls below threshold, discard the box by + // swapping with last box update N + if (sc[pos] < min_score) { + x1[pos * 4] = x1[(nboxes - 1) * 4]; + y1[pos * 4] = y1[(nboxes - 1) * 4]; + x2[pos * 4] = x2[(nboxes - 1) * 4]; + y2[pos * 4] = y2[(nboxes - 1) * 4]; + sc[pos] = sc[nboxes - 1]; + areas[pos] = areas[nboxes - 1]; + inds[pos] = inds[nboxes - 1]; + nboxes = nboxes - 1; + pos = pos - 1; + } + pos = pos + 1; + } + } + + std::vector dets_dim({nboxes, 5}); + OrtValue *dets = ort_.KernelContext_GetOutput(context, 0, dets_dim.data(), + dets_dim.size()); + T *dets_data = ort_.GetTensorMutableData(dets); + + std::vector inds_dim({nboxes}); + OrtValue *inds_ov = ort_.KernelContext_GetOutput(context, 1, inds_dim.data(), + inds_dim.size()); + int64_t *inds_data = ort_.GetTensorMutableData(inds_ov); + + memcpy(dets_data, de, sizeof(T) * nboxes * 5); + memcpy(inds_data, inds, sizeof(int64_t) * nboxes); +} diff --git a/mmcv/ops/csrc/onnxruntime/onnxruntime_register.h b/mmcv/ops/csrc/onnxruntime/onnxruntime_register.h new file mode 100644 index 000000000..175071a71 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/onnxruntime_register.h @@ -0,0 +1,15 @@ +#ifndef ONNXRUNTIME_REGISTER_H +#define ONNXRUNTIME_REGISTER_H +#include + +#ifdef __cplusplus +extern "C" { +#endif + +OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, + const OrtApiBase *api); + +#ifdef __cplusplus +} +#endif +#endif // ONNXRUNTIME_REGISTER_H diff --git a/mmcv/ops/csrc/onnxruntime/onnxruntime_session_options_config_keys.h b/mmcv/ops/csrc/onnxruntime/onnxruntime_session_options_config_keys.h new file mode 100644 index 000000000..8e8dbf4bd --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/onnxruntime_session_options_config_keys.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H +#define ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H + +/* + * This file defines SessionOptions Config Keys and format of the Config Values. + * + * The Naming Convention for a SessionOptions Config Key, + * "[Area][.[SubArea1].[SubArea2]...].[Keyname]" + * Such as "ep.cuda.use_arena" + * The Config Key cannot be empty + * The maximum length of the Config Key is 128 + * + * The string format of a SessionOptions Config Value is defined individually + * for each Config. The maximum length of the Config Value is 1024 + */ + +// Key for disable PrePacking, +// If the config value is set to "1" then the prepacking is disabled, otherwise +// prepacking is enabled (default value) +static const char* const kOrtSessionOptionsConfigDisablePrepacking = + "session.disable_prepacking"; + +// A value of "1" means allocators registered in the env will be used. "0" means +// the allocators created in the session will be used. Use this to override the +// usage of env allocators on a per session level. +static const char* const kOrtSessionOptionsConfigUseEnvAllocators = + "session.use_env_allocators"; + +// Set to 'ORT' (case sensitive) to load an ORT format model. +// If unset, model type will default to ONNX unless inferred from filename +// ('.ort' == ORT format) or bytes to be ORT +static const char* const kOrtSessionOptionsConfigLoadModelFormat = + "session.load_model_format"; + +// Set to 'ORT' (case sensitive) to save optimized model in ORT format when +// SessionOptions.optimized_model_path is set. If unset, format will default to +// ONNX unless optimized_model_filepath ends in '.ort'. +static const char* const kOrtSessionOptionsConfigSaveModelFormat = + "session.save_model_format"; + +#endif // ONNXRUNTIME_SESSION_OPTIONS_CONFIG_KEYS_H diff --git a/mmcv/ops/csrc/onnxruntime/ort_mmcv_utils.h b/mmcv/ops/csrc/onnxruntime/ort_mmcv_utils.h new file mode 100644 index 000000000..3bab9d637 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/ort_mmcv_utils.h @@ -0,0 +1,14 @@ +#ifndef ORT_MMCV_UTILS_H +#define ORT_MMCV_UTILS_H +#include + +#include + +struct OrtTensorDimensions : std::vector { + OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { + OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); + std::vector::operator=(ort.GetTensorShape(info)); + ort.ReleaseTensorTypeAndShapeInfo(info); + } +}; +#endif // ORT_MMCV_UTILS_H diff --git a/mmcv/ops/csrc/onnxruntime/soft_nms.h b/mmcv/ops/csrc/onnxruntime/soft_nms.h new file mode 100644 index 000000000..3b1b3c02a --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/soft_nms.h @@ -0,0 +1,48 @@ +#ifndef ONNXRUNTIME_SOFT_NMS_H +#define ONNXRUNTIME_SOFT_NMS_H +#include + +struct SoftNmsKernel { + SoftNmsKernel(OrtApi api, const OrtKernelInfo *info); + + void Compute(OrtKernelContext *context); + + protected: + OrtApi api_; + Ort::CustomOpApi ort_; + const OrtKernelInfo *info_; + Ort::AllocatorWithDefaultOptions allocator_; + + float iou_threshold_; + float sigma_; + float min_score_; + int64_t method_; + int64_t offset_; +}; + +struct SoftNmsOp : Ort::CustomOpBase { + void *CreateKernel(OrtApi api, const OrtKernelInfo *info) { + return new SoftNmsKernel(api, info); + }; + + const char *GetName() const { return "SoftNonMaxSuppression"; }; + + size_t GetInputTypeCount() const { return 2; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const { return 2; }; + ONNXTensorElementDataType GetOutputType(size_t index) const { + if (index == 1) { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + // force cpu + const char *GetExecutionProviderType() const { + return "CPUExecutionProvider"; + }; +}; +#endif // ONNXRUNTIME_SOFT_NMS_H diff --git a/mmcv/ops/info.py b/mmcv/ops/info.py index 912acf46e..01d9a6fdd 100644 --- a/mmcv/ops/info.py +++ b/mmcv/ops/info.py @@ -1,3 +1,6 @@ +import glob +import os + import torch if torch.__version__ == 'parrots': @@ -18,3 +21,15 @@ else: def get_compiling_cuda_version(): return ext_module.get_compiling_cuda_version() + + +def get_onnxruntime_op_path(): + wildcard = os.path.join( + os.path.abspath(os.path.dirname(os.path.dirname(__file__))), + '_ext_ort.*.so') + + paths = glob.glob(wildcard) + if len(paths) > 0: + return paths[0] + else: + return '' diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 1b4bd2bef..69da580eb 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -39,6 +39,41 @@ class NMSop(torch.autograd.Function): 1) +class SoftNMSop(torch.autograd.Function): + + @staticmethod + def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method, + offset): + dets = boxes.new_empty((boxes.size(0), 5), device='cpu') + inds = ext_module.softnms( + boxes.cpu(), + scores.cpu(), + dets.cpu(), + iou_threshold=float(iou_threshold), + sigma=float(sigma), + min_score=float(min_score), + method=int(method), + offset=int(offset)) + return dets, inds + + @staticmethod + def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method, + offset): + from packaging import version + assert version.parse(torch.__version__) >= version.parse('1.7.0') + nms_out = g.op( + 'mmcv::SoftNonMaxSuppression', + boxes, + scores, + iou_threshold_f=float(iou_threshold), + sigma_f=float(sigma), + min_score_f=float(min_score), + method_i=int(method), + offset_i=int(offset), + outputs=2) + return nms_out + + @deprecated_api_warning({'iou_thr': 'iou_threshold'}) def nms(boxes, scores, iou_threshold, offset=0): """Dispatch to either CPU or GPU NMS implementations. @@ -191,17 +226,12 @@ def soft_nms(boxes, dets, inds, num_out = ext_module.softnms(*indata_list, **indata_dict) inds = inds[:num_out] else: - dets = boxes.new_empty((boxes.size(0), 5), device='cpu') - inds = ext_module.softnms( - boxes.cpu(), - scores.cpu(), - dets.cpu(), - iou_threshold=float(iou_threshold), - sigma=float(sigma), - min_score=float(min_score), - method=method_dict[method], - offset=int(offset)) + dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(), + float(iou_threshold), float(sigma), + float(min_score), method_dict[method], + int(offset)) dets = dets[:inds.size(0)] + if is_numpy: dets = dets.cpu().numpy() inds = inds.cpu().numpy() diff --git a/setup.py b/setup.py index 957c902ec..30e8472df 100644 --- a/setup.py +++ b/setup.py @@ -182,6 +182,52 @@ def get_extensions(): define_macros=define_macros, extra_compile_args=extra_compile_args) extensions.append(ext_ops) + + if EXT_TYPE == 'pytorch' and os.getenv('MMCV_WITH_ORT', '0') != '0': + ext_name = 'mmcv._ext_ort' + from torch.utils.cpp_extension import library_paths, include_paths + import onnxruntime + library_dirs = [] + libraries = [] + include_dirs = [] + ort_path = os.getenv('ONNXRUNTIME_DIR', '0') + library_dirs += [os.path.join(ort_path, 'lib')] + libraries.append('onnxruntime') + kwargs = {} + define_macros = [] + extra_compile_args = {'cxx': []} + + include_path = os.path.abspath('./mmcv/ops/csrc/onnxruntime') + include_dirs.append(include_path) + include_dirs.append(os.path.join(ort_path, 'include')) + include_dirs += include_paths(cuda=True) + + op_files = glob.glob('./mmcv/ops/csrc/onnxruntime/cpu/*') + if onnxruntime.get_device() == 'GPU' or os.getenv('FORCE_CUDA', + '0') == '1': + define_macros += [('MMCV_WITH_CUDA', None)] + cuda_args = os.getenv('MMCV_CUDA_ARGS') + extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] + op_files += glob.glob('./mmcv/ops/csrc/onnxruntime/gpu/*') + library_dirs += library_paths(cuda=True) + else: + library_dirs += library_paths(cuda=False) + + kwargs['library_dirs'] = library_dirs + kwargs['libraries'] = libraries + + from setuptools import Extension + ext_ops = Extension( + name=ext_name, + sources=op_files, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + language='c++', + library_dirs=library_dirs, + libraries=libraries) + extensions.append(ext_ops) + return extensions diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 7f43a06c8..7e250106b 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -1,9 +1,11 @@ import os +import warnings from functools import partial import numpy as np import onnx import onnxruntime as rt +import pytest import torch import torch.nn as nn @@ -58,6 +60,83 @@ def test_nms(): assert np.allclose(pytorch_score, onnx_score, atol=1e-3) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason='CUDA is unavailable for test_softnms') +def test_softnms(): + from mmcv.ops import get_onnxruntime_op_path, soft_nms + from packaging import version + + # only support pytorch >= 1.7.0 + if version.parse(torch.__version__) < version.parse('1.7.0'): + warnings.warn('test_softnms should be ran with pytorch >= 1.7.0') + return + + # only support onnxruntime >= 1.5.1 + assert version.parse(rt.__version__) >= version.parse( + '1.5.1'), 'test_softnms should be ran with onnxruntime >= 1.5.1' + + ort_custom_op_path = get_onnxruntime_op_path() + assert os.path.exists(ort_custom_op_path) + + np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0], + [3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]], + dtype=np.float32) + np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32) + + boxes = torch.from_numpy(np_boxes) + scores = torch.from_numpy(np_scores) + + configs = [[0.3, 0.5, 0.01, 'linear'], [0.3, 0.5, 0.01, 'gaussian'], + [0.3, 0.5, 0.01, 'naive']] + + session_options = rt.SessionOptions() + session_options.register_custom_ops_library(ort_custom_op_path) + + for _iou_threshold, _sigma, _min_score, _method in configs: + pytorch_dets, pytorch_inds = soft_nms( + boxes, + scores, + iou_threshold=_iou_threshold, + sigma=_sigma, + min_score=_min_score, + method=_method) + nms = partial( + soft_nms, + iou_threshold=_iou_threshold, + sigma=_sigma, + min_score=_min_score, + method=_method) + + wrapped_model = WrapFunction(nms) + wrapped_model.cpu().eval() + with torch.no_grad(): + torch.onnx.export( + wrapped_model, (boxes, scores), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['boxes', 'scores'], + opset_version=11) + onnx_model = onnx.load(onnx_file) + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 2) + sess = rt.InferenceSession(onnx_file, session_options) + onnx_dets, onnx_inds = sess.run(None, { + 'scores': scores.detach().numpy(), + 'boxes': boxes.detach().numpy() + }) + os.remove(onnx_file) + assert np.allclose(pytorch_dets, onnx_dets, atol=1e-3) + assert np.allclose(onnx_inds, onnx_inds, atol=1e-3) + + def test_roialign(): from mmcv.ops import roi_align