diff --git a/csrc/backend_ops/onnxruntime/common/ort_utils.cpp b/csrc/backend_ops/onnxruntime/common/ort_utils.cpp index 30b7c46de..c604e4b65 100644 --- a/csrc/backend_ops/onnxruntime/common/ort_utils.cpp +++ b/csrc/backend_ops/onnxruntime/common/ort_utils.cpp @@ -3,8 +3,8 @@ namespace mmdeploy { -std::vector& get_mmdeploy_custom_ops() { - static std::vector _custom_ops; +CustomOpsTable& get_mmdeploy_custom_ops() { + static CustomOpsTable _custom_ops; return _custom_ops; } } // namespace mmdeploy diff --git a/csrc/backend_ops/onnxruntime/common/ort_utils.h b/csrc/backend_ops/onnxruntime/common/ort_utils.h index 3bb0ecaec..e19c984f8 100644 --- a/csrc/backend_ops/onnxruntime/common/ort_utils.h +++ b/csrc/backend_ops/onnxruntime/common/ort_utils.h @@ -3,10 +3,13 @@ #define ORT_MMCV_UTILS_H #include +#include #include namespace mmdeploy { +typedef std::unordered_map> CustomOpsTable; + struct OrtTensorDimensions : std::vector { OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); @@ -15,19 +18,20 @@ struct OrtTensorDimensions : std::vector { } }; -std::vector& get_mmdeploy_custom_ops(); +CustomOpsTable& get_mmdeploy_custom_ops(); -template -class OrtOpsRegistrar { +template +class OrtOpsRegistry { public: - OrtOpsRegistrar() { get_mmdeploy_custom_ops().push_back(&instance); } + OrtOpsRegistry() { get_mmdeploy_custom_ops()[domain].push_back(&instance); } private: T instance{}; }; -#define REGISTER_ONNXRUNTIME_OPS(name) \ - static OrtOpsRegistrar OrtOpsRegistrar##name {} +#define REGISTER_ONNXRUNTIME_OPS(domain, name) \ + static char __domain_##domain##name[] = #domain; \ + static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {} } // namespace mmdeploy #endif // ORT_MMCV_UTILS_H diff --git a/csrc/backend_ops/onnxruntime/grid_sample/grid_sample.cpp b/csrc/backend_ops/onnxruntime/grid_sample/grid_sample.cpp index d8d0cbd25..8850b0539 100644 --- a/csrc/backend_ops/onnxruntime/grid_sample/grid_sample.cpp +++ b/csrc/backend_ops/onnxruntime/grid_sample/grid_sample.cpp @@ -290,5 +290,5 @@ void GridSampleKernel::Compute(OrtKernelContext *context) { } } -REGISTER_ONNXRUNTIME_OPS(GridSampleOp); +REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp); } // namespace mmdeploy diff --git a/csrc/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp b/csrc/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp index 9d6f5af8e..5561752cd 100644 --- a/csrc/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp +++ b/csrc/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp @@ -266,5 +266,6 @@ void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) { kernel_width, stride_height, stride_width, padding_height, padding_width, dilation_height, dilation_width, columns, out_ptr); } -REGISTER_ONNXRUNTIME_OPS(MMCVModulatedDeformConvOp); +REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp); +REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp); } // namespace mmdeploy diff --git a/csrc/backend_ops/onnxruntime/onnxruntime_register.cpp b/csrc/backend_ops/onnxruntime/onnxruntime_register.cpp index 9de47333f..9f2ce2cc0 100644 --- a/csrc/backend_ops/onnxruntime/onnxruntime_register.cpp +++ b/csrc/backend_ops/onnxruntime/onnxruntime_register.cpp @@ -6,18 +6,23 @@ const char *c_MMDeployOpDomain = "mmdeploy"; OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { - OrtCustomOpDomain *domain = nullptr; const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION); - if (auto status = kOrtApi->CreateCustomOpDomain(c_MMDeployOpDomain, &domain)) { - return status; - } - - for (auto _op : mmdeploy::get_mmdeploy_custom_ops()) { - if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) { + OrtStatus *status = nullptr; + for (auto &_op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) { + OrtCustomOpDomain *domain = nullptr; + if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) { return status; } + auto &_op_list = _op_list_pair.second; + for (auto &_op : _op_list) { + if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) { + return status; + } + } + // TODO: figure out what will return if failed. + status = kOrtApi->AddCustomOpDomain(options, domain); } - return kOrtApi->AddCustomOpDomain(options, domain); + return status; } diff --git a/csrc/backend_ops/onnxruntime/roi_align/roi_align.cpp b/csrc/backend_ops/onnxruntime/roi_align/roi_align.cpp index a752bf98e..78cd13c92 100644 --- a/csrc/backend_ops/onnxruntime/roi_align/roi_align.cpp +++ b/csrc/backend_ops/onnxruntime/roi_align/roi_align.cpp @@ -251,5 +251,5 @@ void MMCVRoiAlignKernel::Compute(OrtKernelContext *context) { if (argmax_y) delete argmax_y; } -REGISTER_ONNXRUNTIME_OPS(MMCVRoiAlignCustomOp); +REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoiAlignCustomOp); } // namespace mmdeploy diff --git a/mmdeploy/mmcv/ops/__init__.py b/mmdeploy/mmcv/ops/__init__.py index 509c429da..f839e64b9 100644 --- a/mmdeploy/mmcv/ops/__init__.py +++ b/mmdeploy/mmcv/ops/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deform_conv import deform_conv_openvino +from .modulated_deform_conv import modulated_deform_conv_default from .nms import * # noqa: F401,F403 from .roi_align import roi_align_default -__all__ = ['roi_align_default', 'deform_conv_openvino'] +__all__ = [ + 'roi_align_default', 'modulated_deform_conv_default', + 'deform_conv_openvino' +] diff --git a/mmdeploy/mmcv/ops/modulated_deform_conv.py b/mmdeploy/mmcv/ops/modulated_deform_conv.py new file mode 100644 index 000000000..df3c338a8 --- /dev/null +++ b/mmdeploy/mmcv/ops/modulated_deform_conv.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.core import SYMBOLIC_REWRITER + + +@SYMBOLIC_REWRITER.register_symbolic( + 'mmcv.ops.modulated_deform_conv.ModulatedDeformConv2dFunction') +def modulated_deform_conv_default(ctx, g, input, offset, mask, weight, bias, + stride, padding, dilation, groups, + deform_groups): + """Rewrite mdcn symbolic function for all backend.""" + input_tensors = [input, offset, mask, weight] + if bias is not None: + input_tensors.append(bias) + return g.op( + 'mmdeploy::MMCVModulatedDeformConv2d', + *input_tensors, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + groups_i=groups, + deform_groups_i=deform_groups)