#ifndef ONNXRUNTIME_ROI_ALIGN_H #define ONNXRUNTIME_ROI_ALIGN_H #include #include #include #include #include #include struct MMCVRoiAlignKernel { public: MMCVRoiAlignKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) : ort_(ort) { aligned_ = ort_.KernelInfoGetAttribute(info, "aligned"); aligned_height_ = ort_.KernelInfoGetAttribute(info, "output_height"); aligned_width_ = ort_.KernelInfoGetAttribute(info, "output_width"); pool_mode_ = ort_.KernelInfoGetAttribute(info, "mode"); sampling_ratio_ = ort_.KernelInfoGetAttribute(info, "sampling_ratio"); spatial_scale_ = ort_.KernelInfoGetAttribute(info, "spatial_scale"); } void Compute(OrtKernelContext* context); private: Ort::CustomOpApi ort_; int aligned_height_; int aligned_width_; float spatial_scale_; int sampling_ratio_; std::string pool_mode_; int aligned_; }; struct MMCVRoiAlignCustomOp : Ort::CustomOpBase { void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { return new MMCVRoiAlignKernel(api, info); } const char* GetName() const { return "MMCVRoiAlign"; } size_t GetInputTypeCount() const { return 2; } ONNXTensorElementDataType GetInputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } size_t GetOutputTypeCount() const { return 1; } ONNXTensorElementDataType GetOutputType(size_t) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; } // force cpu const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } }; #endif // ONNXRUNTIME_ROI_ALIGN_H