[Fix] Fix onnx multiple domain registry (#270)

* Fix onnx multiple domain registry

* recover test args

* remove wrong status

* replace map with unordered_map

* add symbolic rewriter
pull/12/head
q.yao 2021-12-09 17:35:28 +08:00 committed by GitHub
parent fd4297a2a3
commit 0897139744
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 55 additions and 20 deletions

View File

@ -3,8 +3,8 @@
namespace mmdeploy {
std::vector<OrtCustomOp*>& get_mmdeploy_custom_ops() {
static std::vector<OrtCustomOp*> _custom_ops;
CustomOpsTable& get_mmdeploy_custom_ops() {
static CustomOpsTable _custom_ops;
return _custom_ops;
}
} // namespace mmdeploy

View File

@ -3,10 +3,13 @@
#define ORT_MMCV_UTILS_H
#include <onnxruntime_cxx_api.h>
#include <unordered_map>
#include <vector>
namespace mmdeploy {
typedef std::unordered_map<std::string, std::vector<OrtCustomOp*>> CustomOpsTable;
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
@ -15,19 +18,20 @@ struct OrtTensorDimensions : std::vector<int64_t> {
}
};
std::vector<OrtCustomOp*>& get_mmdeploy_custom_ops();
CustomOpsTable& get_mmdeploy_custom_ops();
template <typename T>
class OrtOpsRegistrar {
template <char const* domain, typename T>
class OrtOpsRegistry {
public:
OrtOpsRegistrar() { get_mmdeploy_custom_ops().push_back(&instance); }
OrtOpsRegistry() { get_mmdeploy_custom_ops()[domain].push_back(&instance); }
private:
T instance{};
};
#define REGISTER_ONNXRUNTIME_OPS(name) \
static OrtOpsRegistrar<name> OrtOpsRegistrar##name {}
#define REGISTER_ONNXRUNTIME_OPS(domain, name) \
static char __domain_##domain##name[] = #domain; \
static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {}
} // namespace mmdeploy
#endif // ORT_MMCV_UTILS_H

View File

@ -290,5 +290,5 @@ void GridSampleKernel::Compute(OrtKernelContext *context) {
}
}
REGISTER_ONNXRUNTIME_OPS(GridSampleOp);
REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp);
} // namespace mmdeploy

View File

@ -266,5 +266,6 @@ void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) {
kernel_width, stride_height, stride_width, padding_height,
padding_width, dilation_height, dilation_width, columns, out_ptr);
}
REGISTER_ONNXRUNTIME_OPS(MMCVModulatedDeformConvOp);
REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp);
REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp);
} // namespace mmdeploy

View File

@ -6,18 +6,23 @@
const char *c_MMDeployOpDomain = "mmdeploy";
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) {
OrtCustomOpDomain *domain = nullptr;
const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION);
if (auto status = kOrtApi->CreateCustomOpDomain(c_MMDeployOpDomain, &domain)) {
return status;
}
for (auto _op : mmdeploy::get_mmdeploy_custom_ops()) {
if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) {
OrtStatus *status = nullptr;
for (auto &_op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) {
OrtCustomOpDomain *domain = nullptr;
if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) {
return status;
}
auto &_op_list = _op_list_pair.second;
for (auto &_op : _op_list) {
if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) {
return status;
}
}
// TODO: figure out what will return if failed.
status = kOrtApi->AddCustomOpDomain(options, domain);
}
return kOrtApi->AddCustomOpDomain(options, domain);
return status;
}

View File

@ -251,5 +251,5 @@ void MMCVRoiAlignKernel::Compute(OrtKernelContext *context) {
if (argmax_y) delete argmax_y;
}
REGISTER_ONNXRUNTIME_OPS(MMCVRoiAlignCustomOp);
REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoiAlignCustomOp);
} // namespace mmdeploy

View File

@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deform_conv import deform_conv_openvino
from .modulated_deform_conv import modulated_deform_conv_default
from .nms import * # noqa: F401,F403
from .roi_align import roi_align_default
__all__ = ['roi_align_default', 'deform_conv_openvino']
__all__ = [
'roi_align_default', 'modulated_deform_conv_default',
'deform_conv_openvino'
]

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction')
def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias,
stride, padding, dilation, groups,
deform_groups):
"""Rewrite mdcn symbolic function for all backend."""
input_tensors = [input, offset, mask, weight]
if bias is not None:
input_tensors.append(bias)
return g.op(
'mmdeploy::MMCVModulatedDeformConv2d',
*input_tensors,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
groups_i=groups,
deform_groups_i=deform_groups)