parent
4dd4d4851b
commit
b5b0dcfcff
|
@ -13,8 +13,8 @@ namespace mmdeploy {
|
|||
#define MAX(a, b) (((a) < (b)) ? (b) : (a))
|
||||
#define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0))
|
||||
|
||||
GridSampleKernel::GridSampleKernel(OrtApi api, const OrtKernelInfo *info)
|
||||
: api_(api), ort_(api_), info_(info) {
|
||||
GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info)
|
||||
: ort_(api), info_(info) {
|
||||
align_corners_ = ort_.KernelInfoGetAttribute<int64_t>(info, "align_corners");
|
||||
interpolation_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "interpolation_mode");
|
||||
padding_mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "padding_mode");
|
||||
|
|
|
@ -7,12 +7,11 @@
|
|||
namespace mmdeploy {
|
||||
|
||||
struct GridSampleKernel {
|
||||
GridSampleKernel(OrtApi api, const OrtKernelInfo *info);
|
||||
GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info);
|
||||
|
||||
void Compute(OrtKernelContext *context);
|
||||
|
||||
protected:
|
||||
OrtApi api_;
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo *info_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
@ -23,7 +22,7 @@ struct GridSampleKernel {
|
|||
};
|
||||
|
||||
struct GridSampleOp : Ort::CustomOpBase<GridSampleOp, GridSampleKernel> {
|
||||
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
|
||||
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
|
||||
return new GridSampleKernel(api, info);
|
||||
};
|
||||
|
||||
|
|
|
@ -109,8 +109,9 @@ void deformable_conv2d_ref_fp32(const float *src, const float *offset, const flo
|
|||
}
|
||||
}
|
||||
|
||||
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info)
|
||||
: api_(api), ort_(api_), info_(info) {
|
||||
MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api,
|
||||
const OrtKernelInfo *info)
|
||||
: ort_(api), info_(info) {
|
||||
std::vector<int64_t> stride = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "stride");
|
||||
stride_height_ = stride[0];
|
||||
stride_width_ = stride[1];
|
||||
|
|
|
@ -7,12 +7,11 @@
|
|||
namespace mmdeploy {
|
||||
|
||||
struct MMCVModulatedDeformConvKernel {
|
||||
MMCVModulatedDeformConvKernel(OrtApi api, const OrtKernelInfo *info);
|
||||
MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info);
|
||||
|
||||
void Compute(OrtKernelContext *context);
|
||||
|
||||
protected:
|
||||
OrtApi api_;
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo *info_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
@ -29,7 +28,7 @@ struct MMCVModulatedDeformConvKernel {
|
|||
|
||||
struct MMCVModulatedDeformConvOp
|
||||
: Ort::CustomOpBase<MMCVModulatedDeformConvOp, MMCVModulatedDeformConvKernel> {
|
||||
void *CreateKernel(OrtApi api, const OrtKernelInfo *info) const {
|
||||
void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const {
|
||||
return new MMCVModulatedDeformConvKernel(api, info);
|
||||
}
|
||||
|
||||
|
|
|
@ -261,8 +261,8 @@ float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2)
|
|||
return polygon_area(orderedPts, num_convex);
|
||||
}
|
||||
|
||||
NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info)
|
||||
: api_(api), ort_(api_), info_(info) {
|
||||
NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info)
|
||||
: ort_(api), info_(info) {
|
||||
iou_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "iou_threshold");
|
||||
score_threshold_ = ort_.KernelInfoGetAttribute<float>(info, "score_threshold");
|
||||
|
||||
|
|
|
@ -12,12 +12,11 @@
|
|||
|
||||
namespace mmdeploy {
|
||||
struct NMSRotatedKernel {
|
||||
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info);
|
||||
NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info);
|
||||
|
||||
void Compute(OrtKernelContext* context);
|
||||
|
||||
private:
|
||||
OrtApi api_;
|
||||
Ort::CustomOpApi ort_;
|
||||
const OrtKernelInfo* info_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
@ -26,7 +25,7 @@ struct NMSRotatedKernel {
|
|||
};
|
||||
|
||||
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
|
||||
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
||||
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
|
||||
return new NMSRotatedKernel(api, info);
|
||||
}
|
||||
const char* GetName() const { return "NMSRotated"; }
|
||||
|
|
|
@ -74,7 +74,11 @@ Result<void> OrtNet::Init(const Value& args) {
|
|||
};
|
||||
|
||||
for (int i = 0; i < n_inputs; ++i) {
|
||||
#if ORT_API_VERSION >= 13
|
||||
auto input_name = session_.GetInputNameAllocated(i, allocator).release();
|
||||
#else
|
||||
auto input_name = session_.GetInputName(i, allocator);
|
||||
#endif
|
||||
auto type_info = session_.GetInputTypeInfo(i);
|
||||
auto shape = to_shape(type_info);
|
||||
MMDEPLOY_DEBUG("input {}, shape = {}", i, shape);
|
||||
|
@ -88,7 +92,11 @@ Result<void> OrtNet::Init(const Value& args) {
|
|||
auto n_outputs = session_.GetOutputCount();
|
||||
|
||||
for (int i = 0; i < n_outputs; ++i) {
|
||||
#if ORT_API_VERSION >= 13
|
||||
auto output_name = session_.GetOutputNameAllocated(i, allocator).release();
|
||||
#else
|
||||
auto output_name = session_.GetOutputName(i, allocator);
|
||||
#endif
|
||||
auto type_info = session_.GetOutputTypeInfo(i);
|
||||
auto shape = to_shape(type_info);
|
||||
MMDEPLOY_DEBUG("output {}, shape = {}", i, shape);
|
||||
|
|
Loading…
Reference in New Issue