mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* fix lint * fix lint * add mmrotate part * update * update * fix * remove init_detector * success run with bs=1 * nms_rotated support batch * support [batch_id, class_id, box_id] * fix * fix * Create test_mmrotate_core.py * add ut * add ut * Update nms_rotated.py * fix * Revert "fix" This reverts commit f792387fb449ba091c1d932f29d28214805fb6e3. * add mmrotate into requirements * add ut * update doc * update * skip test because mmcv version < 1.4.6 * update * Update rotated-detection_static.py * Update rotated-detection_static.py * Update rotated-detection_static.py * fix bug of memory leak. * Update rotated_detection_model.py
50 lines
1.3 KiB
C++
50 lines
1.3 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
#ifndef ONNXRUNTIME_NMS_ROTATED_H
|
|
#define ONNXRUNTIME_NMS_ROTATED_H
|
|
|
|
#include <assert.h>
|
|
#include <onnxruntime_cxx_api.h>
|
|
|
|
#include <cmath>
|
|
#include <mutex>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace mmdeploy {
|
|
struct NMSRotatedKernel {
|
|
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info);
|
|
|
|
void Compute(OrtKernelContext* context);
|
|
|
|
private:
|
|
OrtApi api_;
|
|
Ort::CustomOpApi ort_;
|
|
const OrtKernelInfo* info_;
|
|
Ort::AllocatorWithDefaultOptions allocator_;
|
|
float iou_threshold_;
|
|
float score_threshold_;
|
|
};
|
|
|
|
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
|
|
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
|
|
return new NMSRotatedKernel(api, info);
|
|
}
|
|
const char* GetName() const { return "NMSRotated"; }
|
|
|
|
size_t GetInputTypeCount() const { return 2; }
|
|
ONNXTensorElementDataType GetInputType(size_t) const {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
|
}
|
|
|
|
size_t GetOutputTypeCount() const { return 1; }
|
|
ONNXTensorElementDataType GetOutputType(size_t) const {
|
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
|
}
|
|
|
|
// force cpu
|
|
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
|
|
};
|
|
} // namespace mmdeploy
|
|
|
|
#endif // ONNXRUNTIME_NMS_ROTATED_H
|