Add Deformable Conv CustomOp for onnxruntime (#1343)

* add onnx dcn

* replace gemm with torch C++ api

* small update

* fix cpp clang format

* pefer generic GEMM than torch ATen library

* addressing comments

* add ops path check
pull/1401/head
Eugene Liu 2021-10-13 13:28:24 +01:00 committed by GitHub
parent 48d5b095a2
commit 599163e645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 457 additions and 6 deletions

View File

@ -51,6 +51,12 @@
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
- [MMCVDeformConv2d](#mmcvdeformconv2d)
- [Description](#description-8)
- [Parameters](#parameters-8)
- [Inputs](#inputs-8)
- [Outputs](#outputs-8)
- [Type Constraints](#type-constraints-8)
<!-- TOC -->
@ -331,3 +337,42 @@ Perform Modulated Deformable Convolution on input feature, read [Deformable Conv
#### Type Constraints
- T:tensor(float32, Linear)
## MMCVDeformConv2d
### Description
Perform Deformable Convolution on input feature, read [Deformable Convolutional Network](https://arxiv.org/abs/1703.06211) for detail.
### Parameters
| Type | Parameter | Description |
| -------------- | ------------------ | --------------------------------------------------------------------------------------------------------------------------------- |
| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
| `int` | `deformable_group` | Groups of deformable offset. |
| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. |
| `int` | `im2col_step` | DeformableConv2d use im2col to compute convolution. im2col_step is used to split input and offset, reduce memory usage of column. |
### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the numbers of channels, inH and inW are the height and width of the data.</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).</dd>
</dl>
### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Output feature; 4-D tensor of shape (N, output_channel, outH, outW).</dd>
</dl>
### Type Constraints
- T:tensor(float32, Linear)

View File

@ -0,0 +1,263 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "deform_conv.h"
#include <cmath>
#include <vector>
#include "../ort_mmcv_utils.h"
void gemm_ref_fp32_deform(const float *A, const float *B, const float *V,
const float *H, const int32_t trans_A,
const int32_t trans_B, const int32_t M,
const int32_t N, const int32_t K, const float alpha,
const float beta, float *Y) {
if (!trans_A && !trans_B) { // MK, KN; NN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && !trans_B) { // KM, KN; TN
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[k * N + n];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (trans_A && trans_B) { // KM, NK; TT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[k * M + m] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
if (!trans_A && trans_B) { // MK, NK; NT
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float y = 0.0f;
for (int64_t k = 0; k < K; ++k) {
y += A[m * K + k] * B[n * K + k];
}
y *= alpha;
if (V) y += beta * V[n];
if (H) y += beta * H[m * N + n];
Y[m * N + n] = y;
}
}
}
}
float bilinear_interpolate(const float *src, const int64_t src_h,
const int64_t src_w, const float h, const float w) {
if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) {
return 0;
}
int64_t h_low = floor(h);
int64_t w_low = floor(w);
int64_t h_high = h_low + 1;
int64_t w_high = w_low + 1;
float lh = h - h_low;
float lw = w - w_low;
float hh = 1 - lh;
float hw = 1 - lw;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low];
float v2 = 0;
if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high];
float v3 = 0;
if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low];
float v4 = 0;
if (h_high <= src_h - 1 && w_high <= src_w - 1)
v4 = src[h_high * src_w + w_high];
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
void deformable_im2col(const float *input, const float *offset,
const int64_t src_h, const int64_t src_w,
const int64_t kernel_h, const int64_t kernel_w,
const int64_t pad_h, const int64_t pad_w,
const int64_t stride_h, const int64_t stride_w,
const int64_t dilation_h, const int64_t dilation_w,
const int64_t channels, const int64_t offset_groups,
const int64_t dst_h, const int64_t dst_w,
float *columns) {
const int64_t indices = channels * dst_h * dst_w;
for (int64_t index = 0; index != indices; ++index) {
const int64_t w_col = index % dst_w;
const int64_t h_col = (index / dst_w) % dst_h;
const int64_t c_im = index / (dst_w * dst_h);
const int64_t c_col = c_im * kernel_h * kernel_w;
int64_t c_per_offset_grp = channels / offset_groups;
const int64_t grp_idx = c_im / c_per_offset_grp;
auto columns_ptr =
columns + (c_col * (dst_h * dst_w) + h_col * dst_w + w_col);
auto input_ptr = input + c_im * (src_h * src_w);
auto offset_ptr =
offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w;
for (int64_t kh = 0; kh < kernel_h; ++kh) {
for (int64_t kw = 0; kw < kernel_w; ++kw) {
const int data_offset_h_ptr =
((2 * (kh * kernel_w + kw)) * dst_h + h_col) * dst_w + w_col;
const int data_offset_w_ptr =
((2 * (kh * kernel_w + kw) + 1) * dst_h + h_col) * dst_w + w_col;
const float offset_h = offset_ptr[data_offset_h_ptr];
const float offset_w = offset_ptr[data_offset_w_ptr];
const float ih =
(h_col * stride_h - pad_h) + kh * dilation_h + offset_h;
const float iw =
(w_col * stride_w - pad_w) + kw * dilation_w + offset_w;
*columns_ptr = bilinear_interpolate(input_ptr, src_h, src_w, ih, iw);
columns_ptr += dst_h * dst_w;
}
}
}
}
void deformable_conv_forward(
const float *src, const float *offset, const float *filter,
const int64_t batch, const int64_t src_c, const int64_t src_h,
const int64_t src_w, const int64_t dst_c, const int64_t dst_h,
const int64_t dst_w, const int64_t group, const int64_t offset_group,
const int64_t channels, const int64_t num_output, const int64_t kernel_h,
const int64_t kernel_w, const int64_t stride_h, const int64_t stride_w,
const int64_t pad_h, const int64_t pad_w, const int64_t dilation_h,
const int64_t dilation_w, float *columns, float *dst) {
const int64_t ic_per_gp = channels / group;
const int64_t oc_per_gp = num_output / group;
for (int64_t b = 0; b < batch; ++b) {
for (int64_t g = 0; g < group; ++g) {
deformable_im2col(
src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w,
offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w,
src_h, src_w, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, ic_per_gp, offset_group, dst_h, dst_w,
columns);
float *dst_ptr =
dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w;
memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w);
gemm_ref_fp32_deform(
filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns,
nullptr, dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w,
ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr);
}
}
}
MMCVDeformConvKernel::MMCVDeformConvKernel(OrtApi api,
const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
std::vector<int64_t> stride =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
stride_height_ = stride[0];
stride_width_ = stride[1];
std::vector<int64_t> padding =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "padding");
padding_height_ = padding[0];
padding_width_ = padding[1];
std::vector<int64_t> dilation =
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "dilation");
dilation_height_ = dilation[0];
dilation_width_ = dilation[1];
deformable_group_ =
ort_.KernelInfoGetAttribute<int64_t>(info, "deform_groups");
group_ = ort_.KernelInfoGetAttribute<int64_t>(info, "groups");
// create allocator
allocator_ = Ort::AllocatorWithDefaultOptions();
}
void MMCVDeformConvKernel::Compute(OrtKernelContext *context) {
const int64_t stride_height = stride_height_;
const int64_t stride_width = stride_width_;
const int64_t padding_height = padding_height_;
const int64_t padding_width = padding_width_;
const int64_t dilation_height = dilation_height_;
const int64_t dilation_width = dilation_width_;
const int64_t deformable_group = deformable_group_;
const int64_t group = group_;
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
const float *input_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input));
const OrtValue *offset = ort_.KernelContext_GetInput(context, 1);
const float *offset_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(offset));
const OrtValue *filter = ort_.KernelContext_GetInput(context, 2);
const float *filter_data =
reinterpret_cast<const float *>(ort_.GetTensorData<float>(filter));
OrtTensorDimensions input_dims(ort_, input);
OrtTensorDimensions filter_dims(ort_, filter);
int64_t batch_size = input_dims[0];
int64_t in_channels = input_dims[1];
int64_t in_height = input_dims[2];
int64_t in_width = input_dims[3];
int64_t out_channels = filter_dims[0];
int64_t kernel_height = filter_dims[2];
int64_t kernel_width = filter_dims[3];
// get output memory
int64_t out_height = floor((in_height + 2 * padding_height -
dilation_height * (kernel_height - 1) - 1) /
stride_height +
1);
int64_t out_width = floor(
(in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) /
stride_width +
1);
std::vector<int64_t> output_dims = {batch_size, out_channels, out_height,
out_width};
OrtValue *output = ort_.KernelContext_GetOutput(
context, 0, output_dims.data(), output_dims.size());
float *out_ptr = ort_.GetTensorMutableData<float>(output);
// allocate tmp memory
int64_t column_len = (in_channels / group) * kernel_height * kernel_width *
out_height * out_width;
float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len);
deformable_conv_forward(
input_data, offset_data, filter_data, batch_size, in_channels, in_height,
in_width, out_channels, out_height, out_width, group, deformable_group,
in_channels, out_channels, kernel_height, kernel_width, stride_height,
stride_width, padding_height, padding_width, dilation_height,
dilation_width, columns, out_ptr);
}

View File

@ -2,6 +2,7 @@
#include "onnxruntime_register.h"
#include "corner_pool.h"
#include "deform_conv.h"
#include "grid_sample.h"
#include "modulated_deform_conv.h"
#include "nms.h"
@ -21,6 +22,7 @@ MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp;
MMCVCumMinCustomOp c_MMCVCumMinCustomOp;
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;
MMCVModulatedDeformConvOp c_MMCVModulatedDeformConvOp;
MMCVDeformConvOp c_MMCVDeformConvOp;
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) {
@ -71,5 +73,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
}
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVDeformConvOp)) {
return status;
}
return ortApi->AddCustomOpDomain(options, domain);
}

View File

@ -0,0 +1,57 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef ONNXRUNTIME_DEFORM_CONV_H
#define ONNXRUNTIME_DEFORM_CONV_H
#include <onnxruntime_cxx_api.h>
struct MMCVDeformConvKernel {
MMCVDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
int64_t stride_height_;
int64_t stride_width_;
int64_t padding_height_;
int64_t padding_width_;
int64_t dilation_height_;
int64_t dilation_width_;
int64_t deformable_group_;
int64_t group_;
int64_t im2col_step_;
};
struct MMCVDeformConvOp
: Ort::CustomOpBase<MMCVDeformConvOp, MMCVDeformConvKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
return new MMCVDeformConvKernel(api, info);
}
const char *GetName() const { return "MMCVDeformConv2d"; };
size_t GetInputTypeCount() const { return 3; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(
size_t index) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}
size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
};
// force cpu
const char *GetExecutionProviderType() const {
return "CPUExecutionProvider";
};
};
#endif

View File

@ -103,6 +103,7 @@ def test_grid_sample(mode, padding_mode, align_corners):
@pytest.mark.parametrize('align_corners', [True, False])
def test_bilinear_grid_sample(align_corners):
from mmcv.ops.point_sample import bilinear_grid_sample
# only support pytorch >= 1.5.0
if version.parse(torch.__version__) < version.parse('1.5.0'):
pytest.skip('Only support PyTorch >= 1.5.0')
@ -245,8 +246,7 @@ def test_roialign():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
try:
from mmcv.ops import roi_align
from mmcv.ops import get_onnxruntime_op_path
from mmcv.ops import get_onnxruntime_op_path, roi_align
except (ImportError, ModuleNotFoundError):
pytest.skip('roi_align op is not successfully compiled')
@ -318,8 +318,7 @@ def test_roialign_rotated():
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')
try:
from mmcv.ops import roi_align_rotated
from mmcv.ops import get_onnxruntime_op_path
from mmcv.ops import get_onnxruntime_op_path, roi_align_rotated
except (ImportError, ModuleNotFoundError):
pytest.skip('roi_align_aligned op is not successfully compiled')
@ -648,8 +647,7 @@ def test_roll(shifts_dims_pair):
reason='modulated_deform_conv2d only supports in GPU')
def test_modulated_deform_conv2d():
try:
from mmcv.ops import ModulatedDeformConv2d
from mmcv.ops import get_onnxruntime_op_path
from mmcv.ops import ModulatedDeformConv2d, get_onnxruntime_op_path
except (ImportError, ModuleNotFoundError):
pytest.skip('modulated_deform_conv op is not successfully compiled')
@ -730,3 +728,85 @@ def test_modulated_deform_conv2d():
pytorch_output = model(input, offset, mask).cpu()
# allclose
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
@pytest.mark.skipif(
torch.__version__ == 'parrots',
reason='onnx is not supported in parrots directly')
def test_deform_conv2d(threshold=1e-3):
try:
from mmcv.ops import DeformConv2d, get_onnxruntime_op_path
except (ImportError, ModuleNotFoundError):
pytest.skip('deform_conv op is not successfully compiled')
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')
# deform conv config
# modulated deform conv config
in_channels = 1
out_channels = 64
stride = 1
padding = 0
dilation = 1
groups = 1
deform_groups = 1
kernel_size = 2
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
[[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]],
[[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]],
[[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]]
offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7]
deform_weight = [[[0.4, 0.2, 0.1, 0.9]]]
x = torch.tensor(input)
conv_offset = nn.Conv2d(
in_channels=in_channels,
out_channels=deform_groups * 2 * kernel_size * kernel_size,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True)
conv_offset.weight.data = torch.nn.Parameter(
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
conv_offset.bias.data = torch.nn.Parameter(
torch.Tensor(offset_bias).reshape(8))
offset = conv_offset(x)
model = DeformConv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, deform_groups)
model.weight.data = torch.nn.Parameter(
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
with torch.no_grad():
torch.onnx.export(
model, (x, offset),
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'offset'],
opset_version=11)
session_options = rt.SessionOptions()
if os.path.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
# compute onnx_output
sess = rt.InferenceSession(onnx_file, session_options)
onnx_output = sess.run(
None, {
'input': x.cpu().detach().numpy(),
'offset': offset.cpu().detach().numpy(),
})[0]
# compute pytorch_output
with torch.no_grad():
pytorch_output = model(x, offset).cpu()
# allclose
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)