mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Support onnxruntime-1.13 (#1407)
* support onnxruntime-1.13 * fix lint (cherry picked from commit b5b0dcfcffad426dff9dcd548cdb3a9950e1146e)
This commit is contained in:
parent
070036f964
commit
a1ca93ea1c
@ -13,8 +13,8 @@ namespace mmdeploy {
|
|||||||
#define MAX(a, b) (((a) < (b)) ? (b) : (a))
|
#define MAX(a, b) (((a) < (b)) ? (b) : (a))
|
||||||
#define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0))
|
#define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0))
|
||||||
|
|
||||||
GridSampleKernel::GridSampleKernel(OrtApi api, const OrtKernelInfo *info)
|
GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info)
|
||||||
: api_(api), ort_(api_), info_(info) {
|
: ort_(api), info_(info) {
|
||||||
align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners");
|
align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners");
|
||||||
interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode");
|
interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode");
|
||||||
padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode");
|
padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode");
|
||||||
|
@ -7,12 +7,11 @@
|
|||||||
namespace mmdeploy {
|
namespace mmdeploy {
|
||||||
|
|
||||||
struct GridSampleKernel {
|
struct GridSampleKernel {
|
||||||
GridSampleKernel(OrtApi api, const OrtKernelInfo *info);
|
GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info);
|
||||||
|
|
||||||
void Compute(OrtKernelContext *context);
|
void Compute(OrtKernelContext *context);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
OrtApi api_;
|
|
||||||
Ort::CustomOpApi ort_;
|
Ort::CustomOpApi ort_;
|
||||||
const OrtKernelInfo *info_;
|
const OrtKernelInfo *info_;
|
||||||
Ort::AllocatorWithDefaultOptions allocator_;
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
@ -23,7 +22,7 @@ struct GridSampleKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> {
|
struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> {
|
||||||
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
|
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
|
||||||
return new GridSampleKernel(api, info);
|
return new GridSampleKernel(api, info);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -109,8 +109,9 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info)
|
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api,
|
||||||
: api_(api), ort_(api_), info_(info) {
|
const OrtKernelInfo *info)
|
||||||
|
: ort_(api), info_(info) {
|
||||||
std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
|
std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
|
||||||
stride_height_ = stride[0];
|
stride_height_ = stride[0];
|
||||||
stride_width_ = stride[1];
|
stride_width_ = stride[1];
|
||||||
|
@ -7,12 +7,11 @@
|
|||||||
namespace mmdeploy {
|
namespace mmdeploy {
|
||||||
|
|
||||||
struct MMCVModulatedDeformConvKernel {
|
struct MMCVModulatedDeformConvKernel {
|
||||||
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
|
MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info);
|
||||||
|
|
||||||
void Compute(OrtKernelContext *context);
|
void Compute(OrtKernelContext *context);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
OrtApi api_;
|
|
||||||
Ort::CustomOpApi ort_;
|
Ort::CustomOpApi ort_;
|
||||||
const OrtKernelInfo *info_;
|
const OrtKernelInfo *info_;
|
||||||
Ort::AllocatorWithDefaultOptions allocator_;
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
@ -29,7 +28,7 @@ struct MMCVModulatedDeformConvKernel {
|
|||||||
|
|
||||||
struct MMCVModulatedDeformConvOp
|
struct MMCVModulatedDeformConvOp
|
||||||
: Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> {
|
: Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> {
|
||||||
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
|
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
|
||||||
return new MMCVModulatedDeformConvKernel(api, info);
|
return new MMCVModulatedDeformConvKernel(api, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,8 +261,8 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
|
|||||||
return polygon_area(orderedPts, num_convex);
|
return polygon_area(orderedPts, num_convex);
|
||||||
}
|
}
|
||||||
|
|
||||||
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
|
NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info)
|
||||||
: api_(api), ort_(api_), info_(info) {
|
: ort_(api), info_(info) {
|
||||||
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
|
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
|
||||||
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");
|
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");
|
||||||
|
|
||||||
|
@ -12,12 +12,11 @@
|
|||||||
|
|
||||||
namespace mmdeploy {
|
namespace mmdeploy {
|
||||||
struct NMSRotatedKernel {
|
struct NMSRotatedKernel {
|
||||||
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info);
|
NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info);
|
||||||
|
|
||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
OrtApi api_;
|
|
||||||
Ort::CustomOpApi ort_;
|
Ort::CustomOpApi ort_;
|
||||||
const OrtKernelInfo* info_;
|
const OrtKernelInfo* info_;
|
||||||
Ort::AllocatorWithDefaultOptions allocator_;
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
@ -26,7 +25,7 @@ struct NMSRotatedKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
|
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
|
||||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||||
return new NMSRotatedKernel(api, info);
|
return new NMSRotatedKernel(api, info);
|
||||||
}
|
}
|
||||||
const char* GetName() const { return "NMSRotated"; }
|
const char* GetName() const { return "NMSRotated"; }
|
||||||
|
@ -74,7 +74,11 @@ Result<void> OrtNet::Init(const Value& args) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
for (int i = 0; i < n_inputs; ++i) {
|
for (int i = 0; i < n_inputs; ++i) {
|
||||||
|
#if ORT_API_VERSION >= 13
|
||||||
|
auto input_name = session_.GetInputNameAllocated(i, allocator).release();
|
||||||
|
#else
|
||||||
auto input_name = session_.GetInputName(i, allocator);
|
auto input_name = session_.GetInputName(i, allocator);
|
||||||
|
#endif
|
||||||
auto type_info = session_.GetInputTypeInfo(i);
|
auto type_info = session_.GetInputTypeInfo(i);
|
||||||
auto shape = to_shape(type_info);
|
auto shape = to_shape(type_info);
|
||||||
MMDEPLOY_DEBUG("input {}, shape = {}", i, shape);
|
MMDEPLOY_DEBUG("input {}, shape = {}", i, shape);
|
||||||
@ -88,7 +92,11 @@ Result<void> OrtNet::Init(const Value& args) {
|
|||||||
auto n_outputs = session_.GetOutputCount();
|
auto n_outputs = session_.GetOutputCount();
|
||||||
|
|
||||||
for (int i = 0; i < n_outputs; ++i) {
|
for (int i = 0; i < n_outputs; ++i) {
|
||||||
|
#if ORT_API_VERSION >= 13
|
||||||
|
auto output_name = session_.GetOutputNameAllocated(i, allocator).release();
|
||||||
|
#else
|
||||||
auto output_name = session_.GetOutputName(i, allocator);
|
auto output_name = session_.GetOutputName(i, allocator);
|
||||||
|
#endif
|
||||||
auto type_info = session_.GetOutputTypeInfo(i);
|
auto type_info = session_.GetOutputTypeInfo(i);
|
||||||
auto shape = to_shape(type_info);
|
auto shape = to_shape(type_info);
|
||||||
MMDEPLOY_DEBUG("output {}, shape = {}", i, shape);
|
MMDEPLOY_DEBUG("output {}, shape = {}", i, shape);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user