mirror of https://github.com/open-mmlab/mmcv.git
[Feature]: Support corner_pool related custom operators for onnxruntime in mmcv (#997)
* supports for onnxruntime custom op `mmcv::MMCVTopPool` * supports for onnxruntime custom op `mmcv::MMCVCornerPool`, involving TopPool, BottomPool, LeftPool and RightPool * add unittest for corner_pool * supports mmcv::CornerPool without memcpy * add docs for mmcv::CornerPool * re-add docs for mmcv::CornerPool * fix output dtype doc * reformat * format with pre-commit * format * fix lint error, by using google clang-format style for c/c++pull/914/head
parent
3f8e985b02
commit
db6b0542c7
|
@ -27,6 +27,12 @@
|
|||
- [Inputs](#inputs-3)
|
||||
- [Outputs](#outputs-3)
|
||||
- [Type Constraints](#type-constraints-3)
|
||||
- [CornerPool](#cornerpool)
|
||||
- [Description](#description-4)
|
||||
- [Parameters](#parameters-4)
|
||||
- [Inputs](#inputs-4)
|
||||
- [Outputs](#outputs-4)
|
||||
- [Type Constraints](#type-constraints-4)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -171,3 +177,33 @@ Perform sample from `input` with pixel locations from `grid`.
|
|||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
## CornerPool
|
||||
|
||||
### Description
|
||||
|
||||
Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| ------- | --------------- | ---------------------------------------------------------------- |
|
||||
| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>input</tt>: T</dt>
|
||||
<dd>Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt>: T</dt>
|
||||
<dd>Output the pooled features. 4-D tensor of shape (N, C, H, W).</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32)
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
|
||||
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
|
||||
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
|
||||
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
|
||||
|
||||
## How to build custom operators for ONNX Runtime
|
||||
|
||||
|
|
|
@ -10,9 +10,17 @@ ext_module = ext_loader.load_ext('_ext', [
|
|||
'right_pool_forward', 'right_pool_backward'
|
||||
])
|
||||
|
||||
_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
|
||||
|
||||
|
||||
class TopPoolFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input):
|
||||
output = g.op(
|
||||
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.top_pool_forward(input)
|
||||
|
@ -28,6 +36,12 @@ class TopPoolFunction(Function):
|
|||
|
||||
class BottomPoolFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input):
|
||||
output = g.op(
|
||||
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.bottom_pool_forward(input)
|
||||
|
@ -43,6 +57,12 @@ class BottomPoolFunction(Function):
|
|||
|
||||
class LeftPoolFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input):
|
||||
output = g.op(
|
||||
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.left_pool_forward(input)
|
||||
|
@ -58,6 +78,12 @@ class LeftPoolFunction(Function):
|
|||
|
||||
class RightPoolFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input):
|
||||
output = g.op(
|
||||
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = ext_module.right_pool_forward(input)
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
#ifndef ONNXRUNTIME_CORNER_POOL_H
|
||||
#define ONNXRUNTIME_CORNER_POOL_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
struct MMCVCornerPoolKernel {
|
||||
public:
|
||||
MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
|
||||
: ort_(ort) {
|
||||
mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "mode");
|
||||
}
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
Ort::CustomOpApi ort_;
|
||||
|
||||
int64_t mode_;
|
||||
};
|
||||
|
||||
struct MMCVCornerPoolCustomOp
|
||||
: Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> {
|
||||
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
|
||||
return new MMCVCornerPoolKernel(api, info);
|
||||
}
|
||||
|
||||
const char* GetName() const { return "MMCVCornerPool"; }
|
||||
|
||||
size_t GetInputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetInputType(size_t) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; }
|
||||
ONNXTensorElementDataType GetOutputType(size_t) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
// force cpu
|
||||
const char* GetExecutionProviderType() const {
|
||||
return "CPUExecutionProvider";
|
||||
}
|
||||
};
|
||||
#endif // ONNXRUNTIME_CORNER_POOL_H
|
|
@ -0,0 +1,122 @@
|
|||
#include "corner_pool.h"
|
||||
|
||||
#include "../ort_mmcv_utils.h"
|
||||
|
||||
void TopPoolForwardCPU(const float *input, float *output, const int batch_size,
|
||||
const int channels, const int height, const int width) {
|
||||
for (int n = 0; n < batch_size; n++) {
|
||||
int index_n = n * channels * width * height;
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int index_n_c = index_n + c * width * height;
|
||||
for (int w = 0; w < width; w++) {
|
||||
// directly copy the most bottom value from input to output
|
||||
output[index_n_c + (height - 1) * width + w] =
|
||||
input[index_n_c + (height - 1) * width + w];
|
||||
// do top_pool
|
||||
for (int h = height - 2; h >= 0; h--) {
|
||||
output[index_n_c + h * width + w] =
|
||||
std::max(output[index_n_c + (h + 1) * width + w],
|
||||
input[index_n_c + h * width + w]);
|
||||
} // for h
|
||||
} // for w
|
||||
} // for c
|
||||
} // for n
|
||||
}
|
||||
|
||||
void BottomPoolForwardCPU(const float *input, float *output,
|
||||
const int batch_size, const int channels,
|
||||
const int height, const int width) {
|
||||
for (int n = 0; n < batch_size; n++) {
|
||||
int index_n = n * channels * width * height;
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int index_n_c = index_n + c * width * height;
|
||||
for (int w = 0; w < width; w++) {
|
||||
// directly copy the most top value from input to output
|
||||
output[index_n_c + w] = input[index_n_c + w];
|
||||
// do top_pool
|
||||
for (int h = 1; h < height; h++) {
|
||||
output[index_n_c + h * width + w] =
|
||||
std::max(output[index_n_c + (h - 1) * width + w],
|
||||
input[index_n_c + h * width + w]);
|
||||
} // for h
|
||||
} // for w
|
||||
} // for c
|
||||
} // for n
|
||||
}
|
||||
|
||||
void LeftPoolForwardCPU(const float *input, float *output, const int batch_size,
|
||||
const int channels, const int height, const int width) {
|
||||
for (int n = 0; n < batch_size; n++) {
|
||||
int index_n = n * channels * width * height;
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int index_n_c = index_n + c * width * height;
|
||||
for (int h = 0; h < height; h++) {
|
||||
// directly copy the most right value from input to output
|
||||
output[index_n_c + h * width + width - 1] =
|
||||
input[index_n_c + h * width + width - 1];
|
||||
// do left_pool
|
||||
for (int w = width - 2; w >= 0; w--) {
|
||||
output[index_n_c + h * width + w] =
|
||||
std::max(output[index_n_c + h * width + w + 1],
|
||||
input[index_n_c + h * width + w]);
|
||||
} // for w
|
||||
} // for h
|
||||
} // for c
|
||||
} // for n
|
||||
}
|
||||
|
||||
void RightPoolForwardCPU(const float *input, float *output,
|
||||
const int batch_size, const int channels,
|
||||
const int height, const int width) {
|
||||
for (int n = 0; n < batch_size; n++) {
|
||||
int index_n = n * channels * width * height;
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int index_n_c = index_n + c * width * height;
|
||||
for (int h = 0; h < height; h++) {
|
||||
// directly copy the most left value from input to output
|
||||
output[index_n_c + h * width] = input[index_n_c + h * width];
|
||||
// do right_pool
|
||||
for (int w = 1; w < width; w++) {
|
||||
output[index_n_c + h * width + w] =
|
||||
std::max(output[index_n_c + h * width + w - 1],
|
||||
input[index_n_c + h * width + w]);
|
||||
} // for w
|
||||
} // for h
|
||||
} // for c
|
||||
} // for n
|
||||
}
|
||||
|
||||
void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) {
|
||||
const int mode = int(mode_);
|
||||
typedef float T;
|
||||
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
|
||||
const T *input_data =
|
||||
reinterpret_cast<const float *>(ort_.GetTensorData<T>(input));
|
||||
|
||||
// get output memory
|
||||
OrtTensorDimensions out_dimensions(ort_, input);
|
||||
OrtValue *output = ort_.KernelContext_GetOutput(
|
||||
context, 0, out_dimensions.data(), out_dimensions.size());
|
||||
T *output_data = ort_.GetTensorMutableData<T>(output);
|
||||
|
||||
// 'top': 0, 'bottom': 1, 'left': 2, 'right':3
|
||||
assert(mode == 0 || mode == 1 || mode == 2 || mode == 3);
|
||||
|
||||
// do corner_pool
|
||||
int batch_size = out_dimensions.data()[0];
|
||||
int input_channels = out_dimensions.data()[1];
|
||||
int input_height = out_dimensions.data()[2];
|
||||
int input_width = out_dimensions.data()[3];
|
||||
if (mode == 0)
|
||||
TopPoolForwardCPU(input_data, output_data, batch_size, input_channels,
|
||||
input_height, input_width);
|
||||
else if (mode == 1)
|
||||
BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels,
|
||||
input_height, input_width);
|
||||
else if (mode == 2)
|
||||
LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels,
|
||||
input_height, input_width);
|
||||
else
|
||||
RightPoolForwardCPU(input_data, output_data, batch_size, input_channels,
|
||||
input_height, input_width);
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
#include "onnxruntime_register.h"
|
||||
|
||||
#include "corner_pool.h"
|
||||
#include "grid_sample.h"
|
||||
#include "nms.h"
|
||||
#include "ort_mmcv_utils.h"
|
||||
|
@ -13,6 +14,7 @@ NmsOp c_NmsOp;
|
|||
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
|
||||
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
|
||||
GridSampleOp c_GridSampleOp;
|
||||
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;
|
||||
|
||||
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
|
||||
const OrtApiBase *api) {
|
||||
|
@ -45,5 +47,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
|
|||
return status;
|
||||
}
|
||||
|
||||
if (auto status =
|
||||
ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return ortApi->AddCustomOpDomain(options, domain);
|
||||
}
|
||||
|
|
|
@ -448,3 +448,49 @@ def test_interpolate():
|
|||
if os.path.exists(onnx_file):
|
||||
os.remove(onnx_file)
|
||||
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
|
||||
def test_corner_pool(mode, opset=11):
|
||||
if torch.__version__ == 'parrots':
|
||||
pytest.skip('onnx is not supported in parrots directly')
|
||||
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
if not os.path.exists(ort_custom_op_path):
|
||||
pytest.skip('custom ops for onnxruntime are not compiled.')
|
||||
|
||||
from mmcv.ops.corner_pool import CornerPool
|
||||
|
||||
def corner_pool_func(input):
|
||||
corner_pool_module = CornerPool(mode)
|
||||
return corner_pool_module.corner_pool.apply(input)
|
||||
|
||||
wrapped_model = WrapFunction(corner_pool_func).eval()
|
||||
|
||||
input = torch.rand((2, 3, 9, 12)) # (n,c,h,w)
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapped_model,
|
||||
input,
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
opset_version=opset)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
input_all = [node.name for node in onnx_model.graph.input]
|
||||
input_initializer = [node.name for node in onnx_model.graph.initializer]
|
||||
net_feed_input = list(set(input_all) - set(input_initializer))
|
||||
assert (len(net_feed_input) == 1)
|
||||
|
||||
session_options = rt.SessionOptions()
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = rt.InferenceSession(onnx_file, session_options)
|
||||
ort_result = sess.run(None, {'input': input.detach().numpy()})
|
||||
pytorch_results = wrapped_model(input.clone())
|
||||
os.remove(onnx_file)
|
||||
assert np.allclose(pytorch_results, ort_result, atol=1e-5)
|
||||
|
|
Loading…
Reference in New Issue