mirror of https://github.com/open-mmlab/mmcv.git
Add Deformable Conv CustomOp for onnxruntime (#1343)
* add onnx dcn * replace gemm with torch C++ api * small update * fix cpp clang format * pefer generic GEMM than torch ATen library * addressing comments * add ops path checkpull/1401/head
parent
48d5b095a2
commit
599163e645
|
@ -51,6 +51,12 @@
|
|||
- [Inputs](#inputs-7)
|
||||
- [Outputs](#outputs-7)
|
||||
- [Type Constraints](#type-constraints-7)
|
||||
- [MMCVDeformConv2d](#mmcvdeformconv2d)
|
||||
- [Description](#description-8)
|
||||
- [Parameters](#parameters-8)
|
||||
- [Inputs](#inputs-8)
|
||||
- [Outputs](#outputs-8)
|
||||
- [Type Constraints](#type-constraints-8)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
|
@ -331,3 +337,42 @@ Perform Modulated Deformable Convolution on input feature, read [Deformable Conv
|
|||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
## MMCVDeformConv2d
|
||||
|
||||
### Description
|
||||
|
||||
Perform Deformable Convolution on input feature, read [Deformable Convolutional Network](https://arxiv.org/abs/1703.06211) for detail.
|
||||
|
||||
### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| -------------- | ------------------ | --------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
|
||||
| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
|
||||
| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
|
||||
| `int` | `deformable_group` | Groups of deformable offset. |
|
||||
| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. |
|
||||
| `int` | `im2col_step` | DeformableConv2d use im2col to compute convolution. im2col_step is used to split input and offset, reduce memory usage of column. |
|
||||
|
||||
### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the numbers of channels, inH and inW are the height and width of the data.</dd>
|
||||
<dt><tt>inputs[1]</tt>: T</dt>
|
||||
<dd>Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
|
||||
<dt><tt>inputs[2]</tt>: T</dt>
|
||||
<dd>Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).</dd>
|
||||
</dl>
|
||||
|
||||
### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>Output feature; 4-D tensor of shape (N, output_channel, outH, outW).</dd>
|
||||
</dl>
|
||||
|
||||
### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "deform_conv.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "../ort_mmcv_utils.h"
|
||||
|
||||
void gemm_ref_fp32_deform(const float *A, const float *B, const float *V,
|
||||
const float *H, const int32_t trans_A,
|
||||
const int32_t trans_B, const int32_t M,
|
||||
const int32_t N, const int32_t K, const float alpha,
|
||||
const float beta, float *Y) {
|
||||
if (!trans_A && !trans_B) { // MK, KN; NN
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
for (int64_t n = 0; n < N; ++n) {
|
||||
float y = 0.0f;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
y += A[m * K + k] * B[k * N + n];
|
||||
}
|
||||
y *= alpha;
|
||||
if (V) y += beta * V[n];
|
||||
if (H) y += beta * H[m * N + n];
|
||||
Y[m * N + n] = y;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (trans_A && !trans_B) { // KM, KN; TN
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
for (int64_t n = 0; n < N; ++n) {
|
||||
float y = 0.0f;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
y += A[k * M + m] * B[k * N + n];
|
||||
}
|
||||
y *= alpha;
|
||||
if (V) y += beta * V[n];
|
||||
if (H) y += beta * H[m * N + n];
|
||||
Y[m * N + n] = y;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (trans_A && trans_B) { // KM, NK; TT
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
for (int64_t n = 0; n < N; ++n) {
|
||||
float y = 0.0f;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
y += A[k * M + m] * B[n * K + k];
|
||||
}
|
||||
y *= alpha;
|
||||
if (V) y += beta * V[n];
|
||||
if (H) y += beta * H[m * N + n];
|
||||
Y[m * N + n] = y;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!trans_A && trans_B) { // MK, NK; NT
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
for (int64_t n = 0; n < N; ++n) {
|
||||
float y = 0.0f;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
y += A[m * K + k] * B[n * K + k];
|
||||
}
|
||||
y *= alpha;
|
||||
if (V) y += beta * V[n];
|
||||
if (H) y += beta * H[m * N + n];
|
||||
Y[m * N + n] = y;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float bilinear_interpolate(const float *src, const int64_t src_h,
|
||||
const int64_t src_w, const float h, const float w) {
|
||||
if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t h_low = floor(h);
|
||||
int64_t w_low = floor(w);
|
||||
int64_t h_high = h_low + 1;
|
||||
int64_t w_high = w_low + 1;
|
||||
|
||||
float lh = h - h_low;
|
||||
float lw = w - w_low;
|
||||
float hh = 1 - lh;
|
||||
float hw = 1 - lw;
|
||||
|
||||
float v1 = 0;
|
||||
if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low];
|
||||
float v2 = 0;
|
||||
if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high];
|
||||
float v3 = 0;
|
||||
if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low];
|
||||
float v4 = 0;
|
||||
if (h_high <= src_h - 1 && w_high <= src_w - 1)
|
||||
v4 = src[h_high * src_w + w_high];
|
||||
|
||||
float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
||||
|
||||
float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
||||
return val;
|
||||
}
|
||||
|
||||
void deformable_im2col(const float *input, const float *offset,
|
||||
const int64_t src_h, const int64_t src_w,
|
||||
const int64_t kernel_h, const int64_t kernel_w,
|
||||
const int64_t pad_h, const int64_t pad_w,
|
||||
const int64_t stride_h, const int64_t stride_w,
|
||||
const int64_t dilation_h, const int64_t dilation_w,
|
||||
const int64_t channels, const int64_t offset_groups,
|
||||
const int64_t dst_h, const int64_t dst_w,
|
||||
float *columns) {
|
||||
const int64_t indices = channels * dst_h * dst_w;
|
||||
for (int64_t index = 0; index != indices; ++index) {
|
||||
const int64_t w_col = index % dst_w;
|
||||
const int64_t h_col = (index / dst_w) % dst_h;
|
||||
const int64_t c_im = index / (dst_w * dst_h);
|
||||
const int64_t c_col = c_im * kernel_h * kernel_w;
|
||||
|
||||
int64_t c_per_offset_grp = channels / offset_groups;
|
||||
const int64_t grp_idx = c_im / c_per_offset_grp;
|
||||
auto columns_ptr =
|
||||
columns + (c_col * (dst_h * dst_w) + h_col * dst_w + w_col);
|
||||
auto input_ptr = input + c_im * (src_h * src_w);
|
||||
auto offset_ptr =
|
||||
offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w;
|
||||
|
||||
for (int64_t kh = 0; kh < kernel_h; ++kh) {
|
||||
for (int64_t kw = 0; kw < kernel_w; ++kw) {
|
||||
const int data_offset_h_ptr =
|
||||
((2 * (kh * kernel_w + kw)) * dst_h + h_col) * dst_w + w_col;
|
||||
const int data_offset_w_ptr =
|
||||
((2 * (kh * kernel_w + kw) + 1) * dst_h + h_col) * dst_w + w_col;
|
||||
|
||||
const float offset_h = offset_ptr[data_offset_h_ptr];
|
||||
const float offset_w = offset_ptr[data_offset_w_ptr];
|
||||
const float ih =
|
||||
(h_col * stride_h - pad_h) + kh * dilation_h + offset_h;
|
||||
const float iw =
|
||||
(w_col * stride_w - pad_w) + kw * dilation_w + offset_w;
|
||||
*columns_ptr = bilinear_interpolate(input_ptr, src_h, src_w, ih, iw);
|
||||
columns_ptr += dst_h * dst_w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deformable_conv_forward(
|
||||
const float *src, const float *offset, const float *filter,
|
||||
const int64_t batch, const int64_t src_c, const int64_t src_h,
|
||||
const int64_t src_w, const int64_t dst_c, const int64_t dst_h,
|
||||
const int64_t dst_w, const int64_t group, const int64_t offset_group,
|
||||
const int64_t channels, const int64_t num_output, const int64_t kernel_h,
|
||||
const int64_t kernel_w, const int64_t stride_h, const int64_t stride_w,
|
||||
const int64_t pad_h, const int64_t pad_w, const int64_t dilation_h,
|
||||
const int64_t dilation_w, float *columns, float *dst) {
|
||||
const int64_t ic_per_gp = channels / group;
|
||||
const int64_t oc_per_gp = num_output / group;
|
||||
for (int64_t b = 0; b < batch; ++b) {
|
||||
for (int64_t g = 0; g < group; ++g) {
|
||||
deformable_im2col(
|
||||
src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w,
|
||||
offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w,
|
||||
src_h, src_w, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, ic_per_gp, offset_group, dst_h, dst_w,
|
||||
columns);
|
||||
float *dst_ptr =
|
||||
dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w;
|
||||
|
||||
memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w);
|
||||
|
||||
gemm_ref_fp32_deform(
|
||||
filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns,
|
||||
nullptr, dst_ptr, 0, 0, oc_per_gp, dst_h * dst_w,
|
||||
ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MMCVDeformConvKernel::MMCVDeformConvKernel(OrtApi api,
|
||||
const OrtKernelInfo *info)
|
||||
: api_(api), ort_(api_), info_(info) {
|
||||
std::vector<int64_t> stride =
|
||||
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
|
||||
stride_height_ = stride[0];
|
||||
stride_width_ = stride[1];
|
||||
std::vector<int64_t> padding =
|
||||
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "padding");
|
||||
padding_height_ = padding[0];
|
||||
padding_width_ = padding[1];
|
||||
std::vector<int64_t> dilation =
|
||||
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "dilation");
|
||||
dilation_height_ = dilation[0];
|
||||
dilation_width_ = dilation[1];
|
||||
deformable_group_ =
|
||||
ort_.KernelInfoGetAttribute<int64_t>(info, "deform_groups");
|
||||
group_ = ort_.KernelInfoGetAttribute<int64_t>(info, "groups");
|
||||
|
||||
// create allocator
|
||||
allocator_ = Ort::AllocatorWithDefaultOptions();
|
||||
}
|
||||
|
||||
void MMCVDeformConvKernel::Compute(OrtKernelContext *context) {
|
||||
const int64_t stride_height = stride_height_;
|
||||
const int64_t stride_width = stride_width_;
|
||||
const int64_t padding_height = padding_height_;
|
||||
const int64_t padding_width = padding_width_;
|
||||
const int64_t dilation_height = dilation_height_;
|
||||
const int64_t dilation_width = dilation_width_;
|
||||
const int64_t deformable_group = deformable_group_;
|
||||
const int64_t group = group_;
|
||||
|
||||
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
|
||||
const float *input_data =
|
||||
reinterpret_cast<const float *>(ort_.GetTensorData<float>(input));
|
||||
|
||||
const OrtValue *offset = ort_.KernelContext_GetInput(context, 1);
|
||||
const float *offset_data =
|
||||
reinterpret_cast<const float *>(ort_.GetTensorData<float>(offset));
|
||||
|
||||
const OrtValue *filter = ort_.KernelContext_GetInput(context, 2);
|
||||
const float *filter_data =
|
||||
reinterpret_cast<const float *>(ort_.GetTensorData<float>(filter));
|
||||
|
||||
OrtTensorDimensions input_dims(ort_, input);
|
||||
OrtTensorDimensions filter_dims(ort_, filter);
|
||||
|
||||
int64_t batch_size = input_dims[0];
|
||||
int64_t in_channels = input_dims[1];
|
||||
int64_t in_height = input_dims[2];
|
||||
int64_t in_width = input_dims[3];
|
||||
int64_t out_channels = filter_dims[0];
|
||||
int64_t kernel_height = filter_dims[2];
|
||||
int64_t kernel_width = filter_dims[3];
|
||||
|
||||
// get output memory
|
||||
int64_t out_height = floor((in_height + 2 * padding_height -
|
||||
dilation_height * (kernel_height - 1) - 1) /
|
||||
stride_height +
|
||||
1);
|
||||
int64_t out_width = floor(
|
||||
(in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) /
|
||||
stride_width +
|
||||
1);
|
||||
|
||||
std::vector<int64_t> output_dims = {batch_size, out_channels, out_height,
|
||||
out_width};
|
||||
|
||||
OrtValue *output = ort_.KernelContext_GetOutput(
|
||||
context, 0, output_dims.data(), output_dims.size());
|
||||
float *out_ptr = ort_.GetTensorMutableData<float>(output);
|
||||
|
||||
// allocate tmp memory
|
||||
int64_t column_len = (in_channels / group) * kernel_height * kernel_width *
|
||||
out_height * out_width;
|
||||
float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len);
|
||||
deformable_conv_forward(
|
||||
input_data, offset_data, filter_data, batch_size, in_channels, in_height,
|
||||
in_width, out_channels, out_height, out_width, group, deformable_group,
|
||||
in_channels, out_channels, kernel_height, kernel_width, stride_height,
|
||||
stride_width, padding_height, padding_width, dilation_height,
|
||||
dilation_width, columns, out_ptr);
|
||||
}
|
|
@ -2,6 +2,7 @@
|
|||
#include "onnxruntime_register.h"
|
||||
|
||||
#include "corner_pool.h"
|
||||
#include "deform_conv.h"
|
||||
#include "grid_sample.h"
|
||||
#include "modulated_deform_conv.h"
|
||||
#include "nms.h"
|
||||
|
@ -21,6 +22,7 @@ MMCVCumMaxCustomOp c_MMCVCumMaxCustomOp;
|
|||
MMCVCumMinCustomOp c_MMCVCumMinCustomOp;
|
||||
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;
|
||||
MMCVModulatedDeformConvOp c_MMCVModulatedDeformConvOp;
|
||||
MMCVDeformConvOp c_MMCVDeformConvOp;
|
||||
|
||||
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
|
||||
const OrtApiBase *api) {
|
||||
|
@ -71,5 +73,9 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
|
|||
return status;
|
||||
}
|
||||
|
||||
if (auto status = ortApi->CustomOpDomain_Add(domain, &c_MMCVDeformConvOp)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return ortApi->AddCustomOpDomain(options, domain);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef ONNXRUNTIME_DEFORM_CONV_H
|
||||
#define ONNXRUNTIME_DEFORM_CONV_H
|
||||
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
struct MMCVDeformConvKernel {
|
||||
MMCVDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
|
||||
|
||||
void Compute(OrtKernelContext *context);
|
||||
|
||||
protected:
|
||||
OrtApi api_;
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo *info_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
int64_t stride_height_;
|
||||
int64_t stride_width_;
|
||||
int64_t padding_height_;
|
||||
int64_t padding_width_;
|
||||
int64_t dilation_height_;
|
||||
int64_t dilation_width_;
|
||||
int64_t deformable_group_;
|
||||
int64_t group_;
|
||||
int64_t im2col_step_;
|
||||
};
|
||||
|
||||
struct MMCVDeformConvOp
|
||||
: Ort::CustomOpBase<MMCVDeformConvOp, MMCVDeformConvKernel> {
|
||||
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
|
||||
return new MMCVDeformConvKernel(api, info);
|
||||
}
|
||||
|
||||
const char *GetName() const { return "MMCVDeformConv2d"; };
|
||||
|
||||
size_t GetInputTypeCount() const { return 3; };
|
||||
ONNXTensorElementDataType GetInputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(
|
||||
size_t index) const {
|
||||
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
|
||||
}
|
||||
|
||||
size_t GetOutputTypeCount() const { return 1; };
|
||||
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const {
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
};
|
||||
|
||||
// force cpu
|
||||
const char *GetExecutionProviderType() const {
|
||||
return "CPUExecutionProvider";
|
||||
};
|
||||
};
|
||||
#endif
|
|
@ -103,6 +103,7 @@ def test_grid_sample(mode, padding_mode, align_corners):
|
|||
@pytest.mark.parametrize('align_corners', [True, False])
|
||||
def test_bilinear_grid_sample(align_corners):
|
||||
from mmcv.ops.point_sample import bilinear_grid_sample
|
||||
|
||||
# only support pytorch >= 1.5.0
|
||||
if version.parse(torch.__version__) < version.parse('1.5.0'):
|
||||
pytest.skip('Only support PyTorch >= 1.5.0')
|
||||
|
@ -245,8 +246,7 @@ def test_roialign():
|
|||
if torch.__version__ == 'parrots':
|
||||
pytest.skip('onnx is not supported in parrots directly')
|
||||
try:
|
||||
from mmcv.ops import roi_align
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
from mmcv.ops import get_onnxruntime_op_path, roi_align
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('roi_align op is not successfully compiled')
|
||||
|
||||
|
@ -318,8 +318,7 @@ def test_roialign_rotated():
|
|||
if torch.__version__ == 'parrots':
|
||||
pytest.skip('onnx is not supported in parrots directly')
|
||||
try:
|
||||
from mmcv.ops import roi_align_rotated
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
from mmcv.ops import get_onnxruntime_op_path, roi_align_rotated
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('roi_align_aligned op is not successfully compiled')
|
||||
|
||||
|
@ -648,8 +647,7 @@ def test_roll(shifts_dims_pair):
|
|||
reason='modulated_deform_conv2d only supports in GPU')
|
||||
def test_modulated_deform_conv2d():
|
||||
try:
|
||||
from mmcv.ops import ModulatedDeformConv2d
|
||||
from mmcv.ops import get_onnxruntime_op_path
|
||||
from mmcv.ops import ModulatedDeformConv2d, get_onnxruntime_op_path
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('modulated_deform_conv op is not successfully compiled')
|
||||
|
||||
|
@ -730,3 +728,85 @@ def test_modulated_deform_conv2d():
|
|||
pytorch_output = model(input, offset, mask).cpu()
|
||||
# allclose
|
||||
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.__version__ == 'parrots',
|
||||
reason='onnx is not supported in parrots directly')
|
||||
def test_deform_conv2d(threshold=1e-3):
|
||||
try:
|
||||
from mmcv.ops import DeformConv2d, get_onnxruntime_op_path
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('deform_conv op is not successfully compiled')
|
||||
|
||||
ort_custom_op_path = get_onnxruntime_op_path()
|
||||
if not os.path.exists(ort_custom_op_path):
|
||||
pytest.skip('custom ops for onnxruntime are not compiled.')
|
||||
|
||||
# deform conv config
|
||||
# modulated deform conv config
|
||||
in_channels = 1
|
||||
out_channels = 64
|
||||
stride = 1
|
||||
padding = 0
|
||||
dilation = 1
|
||||
groups = 1
|
||||
deform_groups = 1
|
||||
kernel_size = 2
|
||||
input = [[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]]
|
||||
offset_weight = [[[0.1, 0.4, 0.6, 0.1]], [[0.3, 0.2, 0.1, 0.3]],
|
||||
[[0.5, 0.5, 0.2, 0.8]], [[0.8, 0.3, 0.9, 0.1]],
|
||||
[[0.3, 0.1, 0.2, 0.5]], [[0.3, 0.7, 0.5, 0.3]],
|
||||
[[0.6, 0.2, 0.5, 0.3]], [[0.4, 0.1, 0.8, 0.4]]]
|
||||
offset_bias = [0.7, 0.1, 0.8, 0.5, 0.6, 0.5, 0.4, 0.7]
|
||||
deform_weight = [[[0.4, 0.2, 0.1, 0.9]]]
|
||||
|
||||
x = torch.tensor(input)
|
||||
conv_offset = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=deform_groups * 2 * kernel_size * kernel_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=True)
|
||||
|
||||
conv_offset.weight.data = torch.nn.Parameter(
|
||||
torch.Tensor(offset_weight).reshape(8, 1, 2, 2))
|
||||
conv_offset.bias.data = torch.nn.Parameter(
|
||||
torch.Tensor(offset_bias).reshape(8))
|
||||
|
||||
offset = conv_offset(x)
|
||||
|
||||
model = DeformConv2d(in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, deform_groups)
|
||||
|
||||
model.weight.data = torch.nn.Parameter(
|
||||
torch.Tensor(deform_weight).reshape(1, 1, 2, 2))
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
model, (x, offset),
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=['input', 'offset'],
|
||||
opset_version=11)
|
||||
|
||||
session_options = rt.SessionOptions()
|
||||
if os.path.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
|
||||
# compute onnx_output
|
||||
sess = rt.InferenceSession(onnx_file, session_options)
|
||||
onnx_output = sess.run(
|
||||
None, {
|
||||
'input': x.cpu().detach().numpy(),
|
||||
'offset': offset.cpu().detach().numpy(),
|
||||
})[0]
|
||||
|
||||
# compute pytorch_output
|
||||
with torch.no_grad():
|
||||
pytorch_output = model(x, offset).cpu()
|
||||
# allclose
|
||||
assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
|
||||
|
|
Loading…
Reference in New Issue