From 599163e6453f6d246cb74582a40440968b9a1734 Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Wed, 13 Oct 2021 13:28:24 +0100 Subject: [PATCH] Add Deformable Conv CustomOp for onnxruntime (#1343) * add onnx dcn * replace gemm with torch C++ api * small update * fix cpp clang format * pefer generic GEMM than torch ATen library * addressing comments * add ops path check --- docs/deployment/onnxruntime_custom_ops.md | 45 +++ mmcv/ops/csrc/onnxruntime/cpu/deform_conv.cpp | 263 ++++++++++++++++++ .../onnxruntime/cpu/onnxruntime_register.cpp | 6 + mmcv/ops/csrc/onnxruntime/deform_conv.h | 57 ++++ tests/test_ops/test_onnx.py | 92 +++++- 5 files changed, 457 insertions(+), 6 deletions(-) create mode 100644 mmcv/ops/csrc/onnxruntime/cpu/deform_conv.cpp create mode 100644 mmcv/ops/csrc/onnxruntime/deform_conv.h diff --git a/docs/deployment/onnxruntime_custom_ops.md b/docs/deployment/onnxruntime_custom_ops.md index ea1124cad..baaa576f6 100644 --- a/docs/deployment/onnxruntime_custom_ops.md +++ b/docs/deployment/onnxruntime_custom_ops.md @@ -51,6 +51,12 @@ - [Inputs](#inputs-7) - [Outputs](#outputs-7) - [Type Constraints](#type-constraints-7) + - [MMCVDeformConv2d](#mmcvdeformconv2d) + - [Description](#description-8) + - [Parameters](#parameters-8) + - [Inputs](#inputs-8) + - [Outputs](#outputs-8) + - [Type Constraints](#type-constraints-8) @@ -331,3 +337,42 @@ Perform Modulated Deformable Convolution on input feature, read [Deformable Conv #### Type Constraints - T:tensor(float32, Linear) + +## MMCVDeformConv2d + +### Description + +Perform Deformable Convolution on input feature, read [Deformable Convolutional Network](https://arxiv.org/abs/1703.06211) for detail. + +### Parameters + +| Type | Parameter | Description | +| -------------- | ------------------ | --------------------------------------------------------------------------------------------------------------------------------- | +| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) | +| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) | +| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) | +| `int` | `deformable_group` | Groups of deformable offset. | +| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. | +| `int` | `im2col_step` | DeformableConv2d use im2col to compute convolution. im2col_step is used to split input and offset, reduce memory usage of column. | + +### Inputs + +
+
inputs[0]: T
+
Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the numbers of channels, inH and inW are the height and width of the data.
+
inputs[1]: T
+
Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.
+
inputs[2]: T
+
Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).
+
+ +### Outputs + +
+
outputs[0]: T
+
Output feature; 4-D tensor of shape (N, output_channel, outH, outW).
+
+ +### Type Constraints + +- T:tensor(float32, Linear) diff --git a/mmcv/ops/csrc/onnxruntime/cpu/deform_conv.cpp b/mmcv/ops/csrc/onnxruntime/cpu/deform_conv.cpp new file mode 100644 index 000000000..db1f08b51 --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/cpu/deform_conv.cpp @@ -0,0 +1,263 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "deform_conv.h" + +#include +#include + +#include "../ort_mmcv_utils.h" + +void gemm_ref_fp32_deform(const float *A, const float *B, const float *V, + const float *H, const int32_t trans_A, + const int32_t trans_B, const int32_t M, + const int32_t N, const int32_t K, const float alpha, + const float beta, float *Y) { + if (!trans_A && !trans_B) { // MK, KN; NN + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float y = 0.0f; + for (int64_t k = 0; k < K; ++k) { + y += A[m * K + k] * B[k * N + n]; + } + y *= alpha; + if (V) y += beta * V[n]; + if (H) y += beta * H[m * N + n]; + Y[m * N + n] = y; + } + } + } + if (trans_A && !trans_B) { // KM, KN; TN + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float y = 0.0f; + for (int64_t k = 0; k < K; ++k) { + y += A[k * M + m] * B[k * N + n]; + } + y *= alpha; + if (V) y += beta * V[n]; + if (H) y += beta * H[m * N + n]; + Y[m * N + n] = y; + } + } + } + if (trans_A && trans_B) { // KM, NK; TT + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float y = 0.0f; + for (int64_t k = 0; k < K; ++k) { + y += A[k * M + m] * B[n * K + k]; + } + y *= alpha; + if (V) y += beta * V[n]; + if (H) y += beta * H[m * N + n]; + Y[m * N + n] = y; + } + } + } + if (!trans_A && trans_B) { // MK, NK; NT + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float y = 0.0f; + for (int64_t k = 0; k < K; ++k) { + y += A[m * K + k] * B[n * K + k]; + } + y *= alpha; + if (V) y += beta * V[n]; + if (H) y += beta * H[m * N + n]; + Y[m * N + n] = y; + } + } + } +} + +float bilinear_interpolate(const float *src, const int64_t src_h, + const int64_t src_w, const float h, const float w) { + if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) { + return 0; + } + + int64_t h_low = floor(h); + int64_t w_low = floor(w); + int64_t h_high = h_low + 1; + int64_t w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh; + float hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low]; + float v2 = 0; + if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high]; + float v3 = 0; + if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low]; + float v4 = 0; + if (h_high <= src_h - 1 && w_high <= src_w - 1) + v4 = src[h_high * src_w + w_high]; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +void deformable_im2col(const float *input, const float *offset, + const int64_t src_h, const int64_t src_w, + const int64_t kernel_h, const int64_t kernel_w, + const int64_t pad_h, const int64_t pad_w, + const int64_t stride_h, const int64_t stride_w, + const int64_t dilation_h, const int64_t dilation_w, + const int64_t channels, const int64_t offset_groups, + const int64_t dst_h, const int64_t dst_w, + float *columns) { + const int64_t indices = channels * dst_h * dst_w; + for (int64_t index = 0; index != indices; ++index) { + const int64_t w_col = index % dst_w; + const int64_t h_col = (index / dst_w) % dst_h; + const int64_t c_im = index / (dst_w * dst_h); + const int64_t c_col = c_im * kernel_h * kernel_w; + + int64_t c_per_offset_grp = channels / offset_groups; + const int64_t grp_idx = c_im / c_per_offset_grp; + auto columns_ptr = + columns + (c_col * (dst_h * dst_w) + h_col * dst_w + w_col); + auto input_ptr = input + c_im * (src_h * src_w); + auto offset_ptr = + offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w; + + for (int64_t kh = 0; kh < kernel_h; ++kh) { + for (int64_t kw = 0; kw < kernel_w; ++kw) { + const int data_offset_h_ptr = + ((2 * (kh * kernel_w + kw)) * dst_h + h_col) * dst_w + w_col; + const int data_offset_w_ptr = + ((2 * (kh * kernel_w + kw) + 1) * dst_h + h_col) * dst_w + w_col; + + const float offset_h = offset_ptr[data_offset_h_ptr]; + const float offset_w = offset_ptr[data_offset_w_ptr]; + const float ih = + (h_col * stride_h - pad_h) + kh * dilation_h + offset_h; + const float iw = + (w_col * stride_w - pad_w) + kw * dilation_w + offset_w; + *columns_ptr = bilinear_interpolate(input_ptr, src_h, src_w, ih, iw); + columns_ptr += dst_h * dst_w; + } + } + } +} + +void deformable_conv_forward( + const float *src, const float *offset, const float *filter, + const int64_t batch, const int64_t src_c, const int64_t src_h, + const int64_t src_w, const int64_t dst_c, const int64_t dst_h, + const int64_t dst_w, const int64_t group, const int64_t offset_group, + const int64_t channels, const int64_t num_output, const int64_t kernel_h, + const int64_t kernel_w, const int64_t stride_h, const int64_t stride_w, + const int64_t pad_h, const int64_t pad_w, const int64_t dilation_h, + const int64_t dilation_w, float *columns, float *dst) { + const int64_t ic_per_gp = channels / group; + const int64_t oc_per_gp = num_output / group; + for (int64_t b = 0; b < batch; ++b) { + for (int64_t g = 0; g < group; ++g) { + deformable_im2col( + src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w, + offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w, + src_h, src_w, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, ic_per_gp, offset_group, dst_h, dst_w, + columns); + float *dst_ptr = + dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w; + + memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w); + + gemm_ref_fp32_deform( + filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, + nullptr, dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w, + ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr); + } + } +} + +MMCVDeformConvKernel::MMCVDeformConvKernel(OrtApi api, + const OrtKernelInfo *info) + : api_(api), ort_(api_), info_(info) { + std::vector stride = + ort_.KernelInfoGetAttribute>(info, "stride"); + stride_height_ = stride[0]; + stride_width_ = stride[1]; + std::vector padding = + ort_.KernelInfoGetAttribute>(info, "padding"); + padding_height_ = padding[0]; + padding_width_ = padding[1]; + std::vector dilation = + ort_.KernelInfoGetAttribute>(info, "dilation"); + dilation_height_ = dilation[0]; + dilation_width_ = dilation[1]; + deformable_group_ = + ort_.KernelInfoGetAttribute(info, "deform_groups"); + group_ = ort_.KernelInfoGetAttribute(info, "groups"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); +} + +void MMCVDeformConvKernel::Compute(OrtKernelContext *context) { + const int64_t stride_height = stride_height_; + const int64_t stride_width = stride_width_; + const int64_t padding_height = padding_height_; + const int64_t padding_width = padding_width_; + const int64_t dilation_height = dilation_height_; + const int64_t dilation_width = dilation_width_; + const int64_t deformable_group = deformable_group_; + const int64_t group = group_; + + const OrtValue *input = ort_.KernelContext_GetInput(context, 0); + const float *input_data = + reinterpret_cast(ort_.GetTensorData(input)); + + const OrtValue *offset = ort_.KernelContext_GetInput(context, 1); + const float *offset_data = + reinterpret_cast(ort_.GetTensorData(offset)); + + const OrtValue *filter = ort_.KernelContext_GetInput(context, 2); + const float *filter_data = + reinterpret_cast(ort_.GetTensorData(filter)); + + OrtTensorDimensions input_dims(ort_, input); + OrtTensorDimensions filter_dims(ort_, filter); + + int64_t batch_size = input_dims[0]; + int64_t in_channels = input_dims[1]; + int64_t in_height = input_dims[2]; + int64_t in_width = input_dims[3]; + int64_t out_channels = filter_dims[0]; + int64_t kernel_height = filter_dims[2]; + int64_t kernel_width = filter_dims[3]; + + // get output memory + int64_t out_height = floor((in_height + 2 * padding_height - + dilation_height * (kernel_height - 1) - 1) / + stride_height + + 1); + int64_t out_width = floor( + (in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / + stride_width + + 1); + + std::vector output_dims = {batch_size, out_channels, out_height, + out_width}; + + OrtValue *output = ort_.KernelContext_GetOutput( + context, 0, output_dims.data(), output_dims.size()); + float *out_ptr = ort_.GetTensorMutableData(output); + + // allocate tmp memory + int64_t column_len = (in_channels / group) * kernel_height * kernel_width * + out_height * out_width; + float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len); + deformable_conv_forward( + input_data, offset_data, filter_data, batch_size, in_channels, in_height, + in_width, out_channels, out_height, out_width, group, deformable_group, + in_channels, out_channels, kernel_height, kernel_width, stride_height, + stride_width, padding_height, padding_width, dilation_height, + dilation_width, columns, out_ptr); +} diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp index b640f482f..ae7807223 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -2,6 +2,7 @@ #include "onnxruntime_register.h" #include "corner_pool.h" +#include "deform_conv.h" #include "grid_sample.h" #include "modulated_deform_conv.h" #include "nms.h" @@ -21,6 +22,7 @@ MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp; MMCVCumMinCustomOp c_MMCVCumMinCustomOp; MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp; MMCVModulatedDeformConvOp c_MMCVModulatedDeformConvOp; +MMCVDeformConvOp c_MMCVDeformConvOp; OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { @@ -71,5 +73,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, return status; } + if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVDeformConvOp)) { + return status; + } + return ortApi->AddCustomOpDomain(options, domain); } diff --git a/mmcv/ops/csrc/onnxruntime/deform_conv.h b/mmcv/ops/csrc/onnxruntime/deform_conv.h new file mode 100644 index 000000000..05f324a7d --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/deform_conv.h @@ -0,0 +1,57 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ONNXRUNTIME_DEFORM_CONV_H +#define ONNXRUNTIME_DEFORM_CONV_H + +#include + +struct MMCVDeformConvKernel { + MMCVDeformConvKernel(OrtApi api, const OrtKernelInfo *info); + + void Compute(OrtKernelContext *context); + + protected: + OrtApi api_; + Ort::CustomOpApi ort_; + const OrtKernelInfo *info_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t stride_height_; + int64_t stride_width_; + int64_t padding_height_; + int64_t padding_width_; + int64_t dilation_height_; + int64_t dilation_width_; + int64_t deformable_group_; + int64_t group_; + int64_t im2col_step_; +}; + +struct MMCVDeformConvOp + : Ort::CustomOpBase { + void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const { + return new MMCVDeformConvKernel(api, info); + } + + const char *GetName() const { return "MMCVDeformConv2d"; }; + + size_t GetInputTypeCount() const { return 3; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic( + size_t index) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + // force cpu + const char *GetExecutionProviderType() const { + return "CPUExecutionProvider"; + }; +}; +#endif diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 379f4a628..c2fc7ff3f 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -103,6 +103,7 @@ def test_grid_sample(mode, padding_mode, align_corners): @pytest.mark.parametrize('align_corners', [True, False]) def test_bilinear_grid_sample(align_corners): from mmcv.ops.point_sample import bilinear_grid_sample + # only support pytorch >= 1.5.0 if version.parse(torch.__version__) < version.parse('1.5.0'): pytest.skip('Only support PyTorch >= 1.5.0') @@ -245,8 +246,7 @@ def test_roialign(): if torch.__version__ == 'parrots': pytest.skip('onnx is not supported in parrots directly') try: - from mmcv.ops import roi_align - from mmcv.ops import get_onnxruntime_op_path + from mmcv.ops import get_onnxruntime_op_path, roi_align except (ImportError, ModuleNotFoundError): pytest.skip('roi_align op is not successfully compiled') @@ -318,8 +318,7 @@ def test_roialign_rotated(): if torch.__version__ == 'parrots': pytest.skip('onnx is not supported in parrots directly') try: - from mmcv.ops import roi_align_rotated - from mmcv.ops import get_onnxruntime_op_path + from mmcv.ops import get_onnxruntime_op_path, roi_align_rotated except (ImportError, ModuleNotFoundError): pytest.skip('roi_align_aligned op is not successfully compiled') @@ -648,8 +647,7 @@ def test_roll(shifts_dims_pair): reason='modulated_deform_conv2d only supports in GPU') def test_modulated_deform_conv2d(): try: - from mmcv.ops import ModulatedDeformConv2d - from mmcv.ops import get_onnxruntime_op_path + from mmcv.ops import ModulatedDeformConv2d, get_onnxruntime_op_path except (ImportError, ModuleNotFoundError): pytest.skip('modulated_deform_conv op is not successfully compiled') @@ -730,3 +728,85 @@ def test_modulated_deform_conv2d(): pytorch_output = model(input, offset, mask).cpu() # allclose assert np.allclose(pytorch_output, onnx_output, atol=1e-3) + + +@pytest.mark.skipif( + torch.__version__ == 'parrots', + reason='onnx is not supported in parrots directly') +def test_deform_conv2d(threshold=1e-3): + try: + from mmcv.ops import DeformConv2d, get_onnxruntime_op_path + except (ImportError, ModuleNotFoundError): + pytest.skip('deform_conv op is not successfully compiled') + + ort_custom_op_path = get_onnxruntime_op_path() + if not os.path.exists(ort_custom_op_path): + pytest.skip('custom ops for onnxruntime are not compiled.') + + # deform conv config + # modulated deform conv config + in_channels = 1 + out_channels = 64 + stride = 1 + padding = 0 + dilation = 1 + groups = 1 + deform_groups = 1 + kernel_size = 2 + input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]] + offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]], + [[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]], + [[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]], + [[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]] + offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7] + deform_weight = [[[0.4, 0.2, 0.1, 0.9]]] + + x = torch.tensor(input) + conv_offset = nn.Conv2d( + in_channels=in_channels, + out_channels=deform_groups * 2 * kernel_size * kernel_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True) + + conv_offset.weight.data = torch.nn.Parameter( + torch.Tensor(offset_weight).reshape(8, 1, 2, 2)) + conv_offset.bias.data = torch.nn.Parameter( + torch.Tensor(offset_bias).reshape(8)) + + offset = conv_offset(x) + + model = DeformConv2d(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, deform_groups) + + model.weight.data = torch.nn.Parameter( + torch.Tensor(deform_weight).reshape(1, 1, 2, 2)) + + with torch.no_grad(): + torch.onnx.export( + model, (x, offset), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['input', 'offset'], + opset_version=11) + + session_options = rt.SessionOptions() + if os.path.exists(ort_custom_op_path): + session_options.register_custom_ops_library(ort_custom_op_path) + + # compute onnx_output + sess = rt.InferenceSession(onnx_file, session_options) + onnx_output = sess.run( + None, { + 'input': x.cpu().detach().numpy(), + 'offset': offset.cpu().detach().numpy(), + })[0] + + # compute pytorch_output + with torch.no_grad(): + pytorch_output = model(x, offset).cpu() + # allclose + assert np.allclose(pytorch_output, onnx_output, atol=1e-3)