2021-11-30 15:00:37 +08:00
|
|
|
// Copyright (c) OpenMMLab. All rights reserved.
|
2021-06-16 10:36:22 +08:00
|
|
|
#ifndef ORT_MMCV_UTILS_H
|
|
|
|
#define ORT_MMCV_UTILS_H
|
|
|
|
#include <onnxruntime_cxx_api.h>
|
|
|
|
|
2021-12-09 17:35:28 +08:00
|
|
|
#include <unordered_map>
|
2021-06-16 10:36:22 +08:00
|
|
|
#include <vector>
|
|
|
|
|
2021-11-01 10:48:21 +08:00
|
|
|
namespace mmdeploy {
|
2021-07-23 17:16:07 +08:00
|
|
|
|
2021-12-09 17:35:28 +08:00
|
|
|
typedef std::unordered_map<std::string, std::vector<OrtCustomOp*>> CustomOpsTable;
|
|
|
|
|
2021-06-16 10:36:22 +08:00
|
|
|
struct OrtTensorDimensions : std::vector<int64_t> {
|
|
|
|
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
|
|
|
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
|
|
|
std::vector<int64_t>::operator=(ort.GetTensorShape(info));
|
|
|
|
ort.ReleaseTensorTypeAndShapeInfo(info);
|
|
|
|
}
|
|
|
|
};
|
2021-07-23 17:16:07 +08:00
|
|
|
|
2021-12-09 17:35:28 +08:00
|
|
|
CustomOpsTable& get_mmdeploy_custom_ops();
|
2021-07-23 17:16:07 +08:00
|
|
|
|
2021-12-09 17:35:28 +08:00
|
|
|
template <char const* domain, typename T>
|
|
|
|
class OrtOpsRegistry {
|
2021-07-23 17:16:07 +08:00
|
|
|
public:
|
2021-12-09 17:35:28 +08:00
|
|
|
OrtOpsRegistry() { get_mmdeploy_custom_ops()[domain].push_back(&instance); }
|
2021-07-23 17:16:07 +08:00
|
|
|
|
|
|
|
private:
|
|
|
|
T instance{};
|
|
|
|
};
|
|
|
|
|
2021-12-09 17:35:28 +08:00
|
|
|
#define REGISTER_ONNXRUNTIME_OPS(domain, name) \
|
|
|
|
static char __domain_##domain##name[] = #domain; \
|
|
|
|
static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {}
|
2021-07-23 17:16:07 +08:00
|
|
|
|
2021-11-01 10:48:21 +08:00
|
|
|
} // namespace mmdeploy
|
2021-06-16 10:36:22 +08:00
|
|
|
#endif // ORT_MMCV_UTILS_H
|