[Fix] Support onnxruntime-1.13 (#1407)

* support onnxruntime-1.13

* fix lint
(cherry picked from commit b5b0dcfcffad426dff9dcd548cdb3a9950e1146e)
This commit is contained in:
Li Zhang 2022-11-22 20:25:44 +08:00 committed by lvhan028
parent 070036f964
commit a1ca93ea1c
7 changed files with 21 additions and 15 deletions

View File

@ -13,8 +13,8 @@ namespace mmdeploy {
#define MAX(a, b) (((a) < (b)) ? (b) : (a)) #define MAX(a, b) (((a) < (b)) ? (b) : (a))
#define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0)) #define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0))
GridSampleKernel::GridSampleKernel(OrtApi api, const OrtKernelInfo *info) GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) { : ort_(api), info_(info) {
align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners"); align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners");
interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode"); interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode");
padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode"); padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode");

View File

@ -7,12 +7,11 @@
namespace mmdeploy { namespace mmdeploy {
struct GridSampleKernel { struct GridSampleKernel {
GridSampleKernel(OrtApi api, const OrtKernelInfo *info); GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context); void Compute(OrtKernelContext *context);
protected: protected:
OrtApi api_;
Ort::CustomOpApi ort_; Ort::CustomOpApi ort_;
const OrtKernelInfo *info_; const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_; Ort::AllocatorWithDefaultOptions allocator_;
@ -23,7 +22,7 @@ struct GridSampleKernel {
}; };
struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> { struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const { void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
return new GridSampleKernel(api, info); return new GridSampleKernel(api, info);
}; };

View File

@ -109,8 +109,9 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
} }
} }
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info) MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api,
: api_(api), ort_(api_), info_(info) { const OrtKernelInfo *info)
: ort_(api), info_(info) {
std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride"); std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
stride_height_ = stride[0]; stride_height_ = stride[0];
stride_width_ = stride[1]; stride_width_ = stride[1];

View File

@ -7,12 +7,11 @@
namespace mmdeploy { namespace mmdeploy {
struct MMCVModulatedDeformConvKernel { struct MMCVModulatedDeformConvKernel {
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info); MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context); void Compute(OrtKernelContext *context);
protected: protected:
OrtApi api_;
Ort::CustomOpApi ort_; Ort::CustomOpApi ort_;
const OrtKernelInfo *info_; const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_; Ort::AllocatorWithDefaultOptions allocator_;
@ -29,7 +28,7 @@ struct MMCVModulatedDeformConvKernel {
struct MMCVModulatedDeformConvOp struct MMCVModulatedDeformConvOp
: Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> { : Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const { void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
return new MMCVModulatedDeformConvKernel(api, info); return new MMCVModulatedDeformConvKernel(api, info);
} }

View File

@ -261,8 +261,8 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
return polygon_area(orderedPts, num_convex); return polygon_area(orderedPts, num_convex);
} }
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info) NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info)
: api_(api), ort_(api_), info_(info) { : ort_(api), info_(info) {
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold"); iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold"); score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");

View File

@ -12,12 +12,11 @@
namespace mmdeploy { namespace mmdeploy {
struct NMSRotatedKernel { struct NMSRotatedKernel {
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info); NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context); void Compute(OrtKernelContext* context);
private: private:
OrtApi api_;
Ort::CustomOpApi ort_; Ort::CustomOpApi ort_;
const OrtKernelInfo* info_; const OrtKernelInfo* info_;
Ort::AllocatorWithDefaultOptions allocator_; Ort::AllocatorWithDefaultOptions allocator_;
@ -26,7 +25,7 @@ struct NMSRotatedKernel {
}; };
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> { struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new NMSRotatedKernel(api, info); return new NMSRotatedKernel(api, info);
} }
const char* GetName() const { return "NMSRotated"; } const char* GetName() const { return "NMSRotated"; }

View File

@ -74,7 +74,11 @@ Result<void> OrtNet::Init(const Value& args) {
}; };
for (int i = 0; i < n_inputs; ++i) { for (int i = 0; i < n_inputs; ++i) {
#if ORT_API_VERSION >= 13
auto input_name = session_.GetInputNameAllocated(i, allocator).release();
#else
auto input_name = session_.GetInputName(i, allocator); auto input_name = session_.GetInputName(i, allocator);
#endif
auto type_info = session_.GetInputTypeInfo(i); auto type_info = session_.GetInputTypeInfo(i);
auto shape = to_shape(type_info); auto shape = to_shape(type_info);
MMDEPLOY_DEBUG("input {}, shape = {}", i, shape); MMDEPLOY_DEBUG("input {}, shape = {}", i, shape);
@ -88,7 +92,11 @@ Result<void> OrtNet::Init(const Value& args) {
auto n_outputs = session_.GetOutputCount(); auto n_outputs = session_.GetOutputCount();
for (int i = 0; i < n_outputs; ++i) { for (int i = 0; i < n_outputs; ++i) {
#if ORT_API_VERSION >= 13
auto output_name = session_.GetOutputNameAllocated(i, allocator).release();
#else
auto output_name = session_.GetOutputName(i, allocator); auto output_name = session_.GetOutputName(i, allocator);
#endif
auto type_info = session_.GetOutputTypeInfo(i); auto type_info = session_.GetOutputTypeInfo(i);
auto shape = to_shape(type_info); auto shape = to_shape(type_info);
MMDEPLOY_DEBUG("output {}, shape = {}", i, shape); MMDEPLOY_DEBUG("output {}, shape = {}", i, shape);