[Fix] Support onnxruntime-1.13 (#1407)

* support onnxruntime-1.13

* fix lint
pull/1420/head
Li Zhang 2022-11-22 20:25:44 +08:00 committed by GitHub
parent 4dd4d4851b
commit b5b0dcfcff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 21 additions and 15 deletions

View File

@ -13,8 +13,8 @@ namespace mmdeploy {
#define MAX(a, b) (((a) < (b)) ? (b) : (a))
#define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0))
GridSampleKernel::GridSampleKernel(OrtApi api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info)
: ort_(api), info_(info) {
align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners");
interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode");
padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode");

View File

@ -7,12 +7,11 @@
namespace mmdeploy {
struct GridSampleKernel {
GridSampleKernel(OrtApi api, const OrtKernelInfo *info);
GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
@ -23,7 +22,7 @@ struct GridSampleKernel {
};
struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
return new GridSampleKernel(api, info);
};

View File

@ -109,8 +109,9 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
}
}
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info)
: api_(api), ort_(api_), info_(info) {
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api,
const OrtKernelInfo *info)
: ort_(api), info_(info) {
std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
stride_height_ = stride[0];
stride_width_ = stride[1];

View File

@ -7,12 +7,11 @@
namespace mmdeploy {
struct MMCVModulatedDeformConvKernel {
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info);
void Compute(OrtKernelContext *context);
protected:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo *info_;
Ort::AllocatorWithDefaultOptions allocator_;
@ -29,7 +28,7 @@ struct MMCVModulatedDeformConvKernel {
struct MMCVModulatedDeformConvOp
: Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> {
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
return new MMCVModulatedDeformConvKernel(api, info);
}

View File

@ -261,8 +261,8 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
return polygon_area(orderedPts, num_convex);
}
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
: api_(api), ort_(api_), info_(info) {
NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info)
: ort_(api), info_(info) {
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");

View File

@ -12,12 +12,11 @@
namespace mmdeploy {
struct NMSRotatedKernel {
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info);
NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo* info_;
Ort::AllocatorWithDefaultOptions allocator_;
@ -26,7 +25,7 @@ struct NMSRotatedKernel {
};
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
return new NMSRotatedKernel(api, info);
}
const char* GetName() const { return "NMSRotated"; }

View File

@ -74,7 +74,11 @@ Result<void> OrtNet::Init(const Value& args) {
};
for (int i = 0; i < n_inputs; ++i) {
#if ORT_API_VERSION >= 13
auto input_name = session_.GetInputNameAllocated(i, allocator).release();
#else
auto input_name = session_.GetInputName(i, allocator);
#endif
auto type_info = session_.GetInputTypeInfo(i);
auto shape = to_shape(type_info);
MMDEPLOY_DEBUG("input {}, shape = {}", i, shape);
@ -88,7 +92,11 @@ Result<void> OrtNet::Init(const Value& args) {
auto n_outputs = session_.GetOutputCount();
for (int i = 0; i < n_outputs; ++i) {
#if ORT_API_VERSION >= 13
auto output_name = session_.GetOutputNameAllocated(i, allocator).release();
#else
auto output_name = session_.GetOutputName(i, allocator);
#endif
auto type_info = session_.GetOutputTypeInfo(i);
auto shape = to_shape(type_info);
MMDEPLOY_DEBUG("output {}, shape = {}", i, shape);