Yue Zhou 42dc5bc316
Support single stage rotated detector in MMRotate (#428)
* fix lint

* fix lint

* add mmrotate part

* update

* update

* fix

* remove init_detector

* success run with bs=1

* nms_rotated support batch

* support [batch_id, class_id, box_id]

* fix

* fix

* Create test_mmrotate_core.py

* add ut

* add ut

* Update nms_rotated.py

* fix

* Revert "fix"

This reverts commit f792387fb449ba091c1d932f29d28214805fb6e3.

* add mmrotate into requirements

* add ut

* update doc

* update

* skip test because mmcv version < 1.4.6

* update

* Update rotated-detection_static.py

* Update rotated-detection_static.py

* Update rotated-detection_static.py

* fix bug of memory leak.

* Update rotated_detection_model.py
2022-05-07 16:11:43 +08:00

50 lines
1.3 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#ifndef ONNXRUNTIME_NMS_ROTATED_H
#define ONNXRUNTIME_NMS_ROTATED_H
#include <assert.h>
#include <onnxruntime_cxx_api.h>
#include <cmath>
#include <mutex>
#include <string>
#include <vector>
namespace mmdeploy {
struct NMSRotatedKernel {
NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info);
void Compute(OrtKernelContext* context);
private:
OrtApi api_;
Ort::CustomOpApi ort_;
const OrtKernelInfo* info_;
Ort::AllocatorWithDefaultOptions allocator_;
float iou_threshold_;
float score_threshold_;
};
struct NMSRotatedOp : Ort::CustomOpBase<NMSRotatedOp, NMSRotatedKernel> {
void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
return new NMSRotatedKernel(api, info);
}
const char* GetName() const { return "NMSRotated"; }
size_t GetInputTypeCount() const { return 2; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}
size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
}
// force cpu
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
};
} // namespace mmdeploy
#endif // ONNXRUNTIME_NMS_ROTATED_H