[Fix] Fix onnx multiple domain registry (#270)
* Fix onnx multiple domain registry * recover test args * remove wrong status * replace map with unordered_map * add symbolic rewriterpull/12/head
parent
fd4297a2a3
commit
0897139744
|
@ -3,8 +3,8 @@
|
|||
|
||||
namespace mmdeploy {
|
||||
|
||||
std::vector<OrtCustomOp*>& get_mmdeploy_custom_ops() {
|
||||
static std::vector<OrtCustomOp*> _custom_ops;
|
||||
CustomOpsTable& get_mmdeploy_custom_ops() {
|
||||
static CustomOpsTable _custom_ops;
|
||||
return _custom_ops;
|
||||
}
|
||||
} // namespace mmdeploy
|
||||
|
|
|
@ -3,10 +3,13 @@
|
|||
#define ORT_MMCV_UTILS_H
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
typedef std::unordered_map<std::string, std::vector<OrtCustomOp*>> CustomOpsTable;
|
||||
|
||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||
|
@ -15,19 +18,20 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||
}
|
||||
};
|
||||
|
||||
std::vector<OrtCustomOp*>& get_mmdeploy_custom_ops();
|
||||
CustomOpsTable& get_mmdeploy_custom_ops();
|
||||
|
||||
template <typename T>
|
||||
class OrtOpsRegistrar {
|
||||
template <char const* domain, typename T>
|
||||
class OrtOpsRegistry {
|
||||
public:
|
||||
OrtOpsRegistrar() { get_mmdeploy_custom_ops().push_back(&instance); }
|
||||
OrtOpsRegistry() { get_mmdeploy_custom_ops()[domain].push_back(&instance); }
|
||||
|
||||
private:
|
||||
T instance{};
|
||||
};
|
||||
|
||||
#define REGISTER_ONNXRUNTIME_OPS(name) \
|
||||
static OrtOpsRegistrar<name> OrtOpsRegistrar##name {}
|
||||
#define REGISTER_ONNXRUNTIME_OPS(domain, name) \
|
||||
static char __domain_##domain##name[] = #domain; \
|
||||
static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {}
|
||||
|
||||
} // namespace mmdeploy
|
||||
#endif // ORT_MMCV_UTILS_H
|
||||
|
|
|
@ -290,5 +290,5 @@ void GridSampleKernel::Compute(OrtKernelContext *context) {
|
|||
}
|
||||
}
|
||||
|
||||
REGISTER_ONNXRUNTIME_OPS(GridSampleOp);
|
||||
REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp);
|
||||
} // namespace mmdeploy
|
||||
|
|
|
@ -266,5 +266,6 @@ void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) {
|
|||
kernel_width, stride_height, stride_width, padding_height,
|
||||
padding_width, dilation_height, dilation_width, columns, out_ptr);
|
||||
}
|
||||
REGISTER_ONNXRUNTIME_OPS(MMCVModulatedDeformConvOp);
|
||||
REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp);
|
||||
REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp);
|
||||
} // namespace mmdeploy
|
||||
|
|
|
@ -6,18 +6,23 @@
|
|||
const char *c_MMDeployOpDomain = "mmdeploy";
|
||||
|
||||
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) {
|
||||
OrtCustomOpDomain *domain = nullptr;
|
||||
const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION);
|
||||
|
||||
if (auto status = kOrtApi->CreateCustomOpDomain(c_MMDeployOpDomain, &domain)) {
|
||||
return status;
|
||||
}
|
||||
|
||||
for (auto _op : mmdeploy::get_mmdeploy_custom_ops()) {
|
||||
if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) {
|
||||
OrtStatus *status = nullptr;
|
||||
for (auto &_op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) {
|
||||
OrtCustomOpDomain *domain = nullptr;
|
||||
if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) {
|
||||
return status;
|
||||
}
|
||||
auto &_op_list = _op_list_pair.second;
|
||||
for (auto &_op : _op_list) {
|
||||
if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
// TODO: figure out what will return if failed.
|
||||
status = kOrtApi->AddCustomOpDomain(options, domain);
|
||||
}
|
||||
|
||||
return kOrtApi->AddCustomOpDomain(options, domain);
|
||||
return status;
|
||||
}
|
||||
|
|
|
@ -251,5 +251,5 @@ void MMCVRoiAlignKernel::Compute(OrtKernelContext *context) {
|
|||
if (argmax_y) delete argmax_y;
|
||||
}
|
||||
|
||||
REGISTER_ONNXRUNTIME_OPS(MMCVRoiAlignCustomOp);
|
||||
REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoiAlignCustomOp);
|
||||
} // namespace mmdeploy
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .deform_conv import deform_conv_openvino
|
||||
from .modulated_deform_conv import modulated_deform_conv_default
|
||||
from .nms import * # noqa: F401,F403
|
||||
from .roi_align import roi_align_default
|
||||
|
||||
__all__ = ['roi_align_default', 'deform_conv_openvino']
|
||||
__all__ = [
|
||||
'roi_align_default', 'modulated_deform_conv_default',
|
||||
'deform_conv_openvino'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction')
|
||||
def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias,
|
||||
stride, padding, dilation, groups,
|
||||
deform_groups):
|
||||
"""Rewrite mdcn symbolic function for all backend."""
|
||||
input_tensors = [input, offset, mask, weight]
|
||||
if bias is not None:
|
||||
input_tensors.append(bias)
|
||||
return g.op(
|
||||
'mmdeploy::MMCVModulatedDeformConv2d',
|
||||
*input_tensors,
|
||||
stride_i=stride,
|
||||
padding_i=padding,
|
||||
dilation_i=dilation,
|
||||
groups_i=groups,
|
||||
deform_groups_i=deform_groups)
|
Loading…
Reference in New Issue