mmrotate sdk module (#450)

* support mmrotate

* fix name

* windows default link to cudart_static.lib, which is not compatible with static build && python_api

* python api

* fix ci

* fix type & remove unused meta info

* fix doxygen, add [out] to @param

* fix mmrotate-c-api

* refactor naming

* refactor naming

* fix lint

* fix lint

* move replace_RResize -> get_preprocess

* Update cuda.cmake

On windows, make static lib and python api build success.

* fix ptr

* Use unique ptr to prevent memory leaks

* move unique_ptr

* remove deleter

Co-authored-by: chenxin2 <chenxin2@sensetime.com>
Co-authored-by: cx <cx@ubuntu20.04>
pull/503/head
Chen Xin 2022-05-17 23:37:32 +08:00 committed by GitHub
parent 1a8d7aceaf
commit 0ce7c83c63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 631 additions and 6 deletions

View File

@ -6,6 +6,11 @@ if (${CMAKE_VERSION} VERSION_GREATER_EQUAL "3.18.0")
cmake_policy(SET CMP0104 OLD)
endif ()
if (MSVC)
# use shared, on windows, python api can't build with static lib.
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
endif ()
# nvcc compiler settings
find_package(CUDA REQUIRED)
#message(STATUS "CUDA VERSION: ${CUDA_VERSION_STRING}")

View File

@ -0,0 +1,8 @@
_base_ = ['./rotated-detection_static.py', '../_base_/backends/sdk.py']
codebase_config = dict(model_type='sdk')
backend_config = dict(pipeline=[
dict(type='LoadImageFromFile'),
dict(type='Collect', keys=['img'], meta_keys=['filename', 'ori_shape'])
])

View File

@ -5,7 +5,8 @@ project(capis)
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)
if ("all" IN_LIST MMDEPLOY_CODEBASES)
set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;pose_detector;restorer;model")
set(TASK_LIST "classifier;detector;segmentor;text_detector;text_recognizer;"
"pose_detector;restorer;model;rotated_detector")
else ()
set(TASK_LIST "model")
if ("mmcls" IN_LIST MMDEPLOY_CODEBASES)
@ -27,6 +28,9 @@ else ()
if ("mmpose" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "pose_detector")
endif ()
if ("mmrotate" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND TASK_LIST "rotated_detector")
endif()
endif ()
foreach (TASK ${TASK_LIST})

View File

@ -57,7 +57,7 @@ MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const
* @param[in] mat_count number of images in the batch
* @param[out] results a linear buffer to save detection results of each image. It must be released
* by \ref mmdeploy_detector_release_result
* @param result_count a linear buffer with length being \p mat_count to save the number of
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of
* detection results of each image. And it must be released by \ref
* mmdeploy_detector_release_result
* @return status of inference

View File

@ -0,0 +1,142 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "rotated_detector.h"
#include <numeric>
#include "codebase/mmrotate/mmrotate.h"
#include "core/device.h"
#include "core/graph.h"
#include "core/mat.h"
#include "core/utils/formatter.h"
#include "handle.h"
using namespace std;
using namespace mmdeploy;
namespace {
Value& config_template() {
// clang-format off
static Value v{
{
"pipeline", {
{"input", {"image"}},
{"output", {"det"}},
{
"tasks",{
{
{"name", "mmrotate"},
{"type", "Inference"},
{"params", {{"model", "TBD"}}},
{"input", {"image"}},
{"output", {"det"}}
}
}
}
}
}
};
// clang-format on
return v;
}
template <class ModelType>
int mmdeploy_rotated_detector_create_impl(ModelType&& m, const char* device_name, int device_id,
mm_handle_t* handle) {
try {
auto value = config_template();
value["pipeline"]["tasks"][0]["params"]["model"] = std::forward<ModelType>(m);
auto pose_estimator = std::make_unique<Handle>(device_name, device_id, std::move(value));
*handle = pose_estimator.release();
return MM_SUCCESS;
} catch (const std::exception& e) {
MMDEPLOY_ERROR("exception caught: {}", e.what());
} catch (...) {
MMDEPLOY_ERROR("unknown exception caught");
}
return MM_E_FAIL;
}
} // namespace
int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name, int device_id,
mm_handle_t* handle) {
return mmdeploy_rotated_detector_create_impl(*static_cast<Model*>(model), device_name, device_id,
handle);
}
int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name,
int device_id, mm_handle_t* handle) {
return mmdeploy_rotated_detector_create_impl(model_path, device_name, device_id, handle);
}
int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats, int mat_count,
mm_rotated_detect_t** results, int** result_count) {
if (handle == nullptr || mats == nullptr || mat_count == 0 || results == nullptr ||
result_count == nullptr) {
return MM_E_INVALID_ARG;
}
try {
auto detector = static_cast<Handle*>(handle);
Value input{Value::kArray};
for (int i = 0; i < mat_count; ++i) {
mmdeploy::Mat _mat{mats[i].height, mats[i].width, PixelFormat(mats[i].format),
DataType(mats[i].type), mats[i].data, Device{"cpu"}};
input.front().push_back({{"ori_img", _mat}});
}
auto output = detector->Run(std::move(input)).value().front();
auto detector_outputs = from_value<vector<mmrotate::RotatedDetectorOutput>>(output);
vector<int> _result_count;
_result_count.reserve(mat_count);
for (const auto& det_output : detector_outputs) {
_result_count.push_back((int)det_output.detections.size());
}
auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0);
std::unique_ptr<int[]> result_count_data(new int[_result_count.size()]{});
std::copy(_result_count.begin(), _result_count.end(), result_count_data.get());
std::unique_ptr<mm_rotated_detect_t[]> result_data(new mm_rotated_detect_t[total]{});
auto result_ptr = result_data.get();
for (const auto& det_output : detector_outputs) {
for (const auto& detection : det_output.detections) {
result_ptr->label_id = detection.label_id;
result_ptr->score = detection.score;
const auto& rbbox = detection.rbbox;
for (int i = 0; i < 5; i++) {
result_ptr->rbbox[i] = rbbox[i];
}
++result_ptr;
}
}
*result_count = result_count_data.release();
*results = result_data.release();
return MM_SUCCESS;
} catch (const std::exception& e) {
MMDEPLOY_ERROR("exception caught: {}", e.what());
} catch (...) {
MMDEPLOY_ERROR("unknown exception caught");
}
return MM_E_FAIL;
}
void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results,
const int* result_count) {
delete[] results;
delete[] result_count;
}
void mmdeploy_rotated_detector_destroy(mm_handle_t handle) { delete static_cast<Handle*>(handle); }

View File

@ -0,0 +1,82 @@
// Copyright (c) OpenMMLab. All rights reserved.
/**
* @file rotated_detector.h
* @brief Interface to MMRotate task
*/
#ifndef MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_
#define MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_
#include "common.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct mm_rotated_detect_t {
int label_id;
float score;
float rbbox[5]; // cx, cy, w, h, angle
} mm_rotated_detect_t;
/**
* @brief Create rotated detector's handle
* @param[in] model an instance of mmrotate sdk model created by
* \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h
* @param[in] device_name name of device, such as "cpu", "cuda", etc.
* @param[in] device_id id of device.
* @param[out] handle instance of a rotated detector
* @return status of creating rotated detector's handle
*/
MMDEPLOY_API int mmdeploy_rotated_detector_create(mm_model_t model, const char* device_name,
int device_id, mm_handle_t* handle);
/**
* @brief Create rotated detector's handle
* @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter
* @param[in] device_name name of device, such as "cpu", "cuda", etc.
* @param[in] device_id id of device.
* @param[out] handle instance of a rotated detector
* @return status of creating rotated detector's handle
*/
MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path,
const char* device_name, int device_id,
mm_handle_t* handle);
/**
* @brief Apply rotated detector to batch images and get their inference results
* @param[in] handle rotated detector's handle created by \ref
* mmdeploy_rotated_detector_create_by_path
* @param[in] mats a batch of images
* @param[in] mat_count number of images in the batch
* @param[out] results a linear buffer to save detection results of each image. It must be released
* by \ref mmdeploy_rotated_detector_release_result
* @param[out] result_count a linear buffer with length being \p mat_count to save the number of
* detection results of each image. And it must be released by \ref
* mmdeploy_rotated_detector_release_result
* @return status of inference
*/
MMDEPLOY_API int mmdeploy_rotated_detector_apply(mm_handle_t handle, const mm_mat_t* mats,
int mat_count, mm_rotated_detect_t** results,
int** result_count);
/** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply
* @param[in] results rotated detection results buffer
* @param[in] result_count \p results size buffer
*/
MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mm_rotated_detect_t* results,
const int* result_count);
/**
* @brief Destroy rotated detector's handle
* @param[in] handle rotated detector's handle created by \ref
* mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create
*/
MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mm_handle_t handle);
#ifdef __cplusplus
}
#endif
#endif // MMDEPLOY_SRC_APIS_C_ROTATED_DETECTOR_H_

View File

@ -25,6 +25,7 @@ mmdeploy_python_add_module(text_detector)
mmdeploy_python_add_module(text_recognizer)
mmdeploy_python_add_module(restorer)
mmdeploy_python_add_module(pose_detector)
mmdeploy_python_add_module(rotated_detector)
pybind11_add_module(${PROJECT_NAME} ${MMDEPLOY_PYTHON_SRCS})

View File

@ -0,0 +1,83 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "rotated_detector.h"
#include "common.h"
#include "core/logger.h"
namespace mmdeploy {
class PyRotatedDetector {
public:
PyRotatedDetector(const char *model_path, const char *device_name, int device_id) {
MMDEPLOY_INFO("{}, {}, {}", model_path, device_name, device_id);
auto status =
mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &handle_);
if (status != MM_SUCCESS) {
throw std::runtime_error("failed to create rotated detector");
}
}
py::list Apply(const std::vector<PyImage> &imgs) {
std::vector<mm_mat_t> mats;
mats.reserve(imgs.size());
for (const auto &img : imgs) {
auto mat = GetMat(img);
mats.push_back(mat);
}
mm_rotated_detect_t *rbboxes{};
int *res_count{};
auto status = mmdeploy_rotated_detector_apply(handle_, mats.data(), (int)mats.size(), &rbboxes,
&res_count);
if (status != MM_SUCCESS) {
throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status));
}
auto output = py::list{};
auto result = rbboxes;
auto counts = res_count;
for (int i = 0; i < mats.size(); i++) {
auto _dets = py::array_t<float>({*counts, 6});
auto _labels = py::array_t<int>({*counts});
auto dets = _dets.mutable_data();
auto labels = _labels.mutable_data();
for (int j = 0; j < *counts; j++) {
for (int k = 0; k < 5; k++) {
*dets++ = result->rbbox[k];
}
*dets++ = result->score;
*labels++ = result->label_id;
result++;
}
counts++;
output.append(py::make_tuple(std::move(_dets), std::move(_labels)));
}
mmdeploy_rotated_detector_release_result(rbboxes, res_count);
return output;
}
~PyRotatedDetector() {
mmdeploy_rotated_detector_destroy(handle_);
handle_ = {};
}
private:
mm_handle_t handle_{};
};
static void register_python_rotated_detector(py::module &m) {
py::class_<PyRotatedDetector>(m, "RotatedDetector")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyRotatedDetector>(model_path, device_name, device_id);
}))
.def("__call__", &PyRotatedDetector::Apply);
}
class PythonRotatedDetectorRegisterer {
public:
PythonRotatedDetectorRegisterer() {
gPythonBindings().emplace("rotated_detector", register_python_rotated_detector);
}
};
static PythonRotatedDetectorRegisterer python_rotated_detector_registerer;
} // namespace mmdeploy

View File

@ -10,6 +10,7 @@ if ("all" IN_LIST MMDEPLOY_CODEBASES)
list(APPEND CODEBASES "mmocr")
list(APPEND CODEBASES "mmedit")
list(APPEND CODEBASES "mmpose")
list(APPEND CODEBASES "mmrotate")
else ()
set(CODEBASES ${MMDEPLOY_CODEBASES})
endif ()

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.14)
project(mmdeploy_mmrotate)
include(${CMAKE_SOURCE_DIR}/cmake/opencv.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils)
add_library(mmdeploy::mmrotate ALIAS ${PROJECT_NAME})

View File

@ -0,0 +1,15 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "codebase/mmrotate/mmrotate.h"
using namespace std;
namespace mmdeploy {
namespace mmrotate {
REGISTER_CODEBASE(MMRotate);
} // namespace mmrotate
MMDEPLOY_DEFINE_REGISTRY(mmrotate::MMRotate);
} // namespace mmdeploy

View File

@ -0,0 +1,32 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_MMROTATE_H
#define MMDEPLOY_MMROTATE_H
#include <array>
#include "codebase/common.h"
#include "core/device.h"
#include "core/module.h"
namespace mmdeploy {
namespace mmrotate {
struct RotatedDetectorOutput {
struct Detection {
int label_id;
float score;
std::array<float, 5> rbbox; // cx,cy,w,h,ag
MMDEPLOY_ARCHIVE_MEMBERS(label_id, score, rbbox);
};
std::vector<Detection> detections;
MMDEPLOY_ARCHIVE_MEMBERS(detections);
};
DECLARE_CODEBASE(MMRotate, mmrotate);
} // namespace mmrotate
MMDEPLOY_DECLARE_REGISTRY(mmrotate::MMRotate);
} // namespace mmdeploy
#endif // MMDEPLOY_MMROTATE_H

View File

@ -0,0 +1,114 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#include "core/device.h"
#include "core/registry.h"
#include "core/serialization.h"
#include "core/tensor.h"
#include "core/utils/device_utils.h"
#include "core/utils/formatter.h"
#include "core/value.h"
#include "mmrotate.h"
#include "opencv_utils.h"
namespace mmdeploy::mmrotate {
using std::vector;
class ResizeRBBox : public MMRotate {
public:
explicit ResizeRBBox(const Value& cfg) : MMRotate(cfg) {
if (cfg.contains("params")) {
score_thr_ = cfg["params"].value("score_thr", 0.05f);
}
}
Result<Value> operator()(const Value& prep_res, const Value& infer_res) {
MMDEPLOY_DEBUG("prep_res: {}", prep_res);
MMDEPLOY_DEBUG("infer_res: {}", infer_res);
Device cpu_device{"cpu"};
OUTCOME_TRY(auto dets,
MakeAvailableOnDevice(infer_res["dets"].get<Tensor>(), cpu_device, stream_));
OUTCOME_TRY(auto labels,
MakeAvailableOnDevice(infer_res["labels"].get<Tensor>(), cpu_device, stream_));
OUTCOME_TRY(stream_.Wait());
if (!(dets.shape().size() == 3 && dets.shape(2) == 6 && dets.data_type() == DataType::kFLOAT)) {
MMDEPLOY_ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(),
(int)dets.data_type());
return Status(eNotSupported);
}
if (labels.shape().size() != 2) {
MMDEPLOY_ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
(int)labels.data_type());
return Status(eNotSupported);
}
OUTCOME_TRY(auto result, DispatchGetBBoxes(prep_res["img_metas"], dets, labels));
return to_value(result);
}
Result<RotatedDetectorOutput> DispatchGetBBoxes(const Value& prep_res, const Tensor& dets,
const Tensor& labels) {
auto data_type = labels.data_type();
switch (data_type) {
case DataType::kFLOAT:
return GetRBBoxes<float>(prep_res, dets, labels);
case DataType::kINT32:
return GetRBBoxes<int32_t>(prep_res, dets, labels);
case DataType::kINT64:
return GetRBBoxes<int64_t>(prep_res, dets, labels);
default:
return Status(eNotSupported);
}
}
template <typename T>
Result<RotatedDetectorOutput> GetRBBoxes(const Value& prep_res, const Tensor& dets,
const Tensor& labels) {
RotatedDetectorOutput objs;
auto* dets_ptr = dets.data<float>();
auto* labels_ptr = labels.data<T>();
vector<float> scale_factor;
if (prep_res.contains("scale_factor")) {
from_value(prep_res["scale_factor"], scale_factor);
} else {
scale_factor = {1.f, 1.f, 1.f, 1.f};
}
int ori_width = prep_res["ori_shape"][2].get<int>();
int ori_height = prep_res["ori_shape"][1].get<int>();
auto bboxes_number = dets.shape()[1];
auto channels = dets.shape()[2];
for (auto i = 0; i < bboxes_number; ++i, dets_ptr += channels, ++labels_ptr) {
float score = dets_ptr[channels - 1];
if (score <= score_thr_) {
continue;
}
auto cx = dets_ptr[0] / scale_factor[0];
auto cy = dets_ptr[1] / scale_factor[1];
auto width = dets_ptr[2] / scale_factor[0];
auto height = dets_ptr[3] / scale_factor[1];
auto angle = dets_ptr[4];
RotatedDetectorOutput::Detection det{};
det.label_id = static_cast<int>(*labels_ptr);
det.score = score;
det.rbbox = {cx, cy, width, height, angle};
objs.detections.push_back(std::move(det));
}
return objs;
}
private:
float score_thr_;
};
REGISTER_CODEBASE_COMPONENT(MMRotate, ResizeRBBox);
} // namespace mmdeploy::mmrotate

View File

@ -21,4 +21,5 @@ add_example(object_detection)
add_example(image_restorer)
add_example(image_segmentation)
add_example(pose_detection)
add_example(rotated_object_detection)
add_example(ocr)

View File

@ -0,0 +1,69 @@
#include <fstream>
#include <iostream>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#include <string>
#include "rotated_detector.h"
int main(int argc, char *argv[]) {
if (argc != 4) {
fprintf(stderr, "usage:\n oriented_object_detection device_name model_path image_path\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto image_path = argv[3];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
}
mm_handle_t detector{};
int status{};
status = mmdeploy_rotated_detector_create_by_path(model_path, device_name, 0, &detector);
if (status != MM_SUCCESS) {
fprintf(stderr, "failed to create rotated detector, code: %d\n", (int)status);
return 1;
}
mm_mat_t mat{img.data, img.rows, img.cols, 3, MM_BGR, MM_INT8};
mm_rotated_detect_t *rbboxes{};
int *res_count{};
status = mmdeploy_rotated_detector_apply(detector, &mat, 1, &rbboxes, &res_count);
if (status != MM_SUCCESS) {
fprintf(stderr, "failed to apply rotated detector, code: %d\n", (int)status);
return 1;
}
for (int i = 0; i < *res_count; ++i) {
// skip low score
if (rbboxes[i].score < 0.1) {
continue;
}
const auto &rbbox = rbboxes[i].rbbox;
float xc = rbbox[0];
float yc = rbbox[1];
float w = rbbox[2];
float h = rbbox[3];
float ag = rbbox[4];
float wx = w / 2 * std::cos(ag);
float wy = w / 2 * std::sin(ag);
float hx = -h / 2 * std::sin(ag);
float hy = h / 2 * std::cos(ag);
cv::Point p1 = {int(xc - wx - hx), int(yc - wy - hy)};
cv::Point p2 = {int(xc + wx - hx), int(yc + wy - hy)};
cv::Point p3 = {int(xc + wx + hx), int(yc + wy + hy)};
cv::Point p4 = {int(xc - wx + hx), int(yc - wy + hy)};
cv::drawContours(img, std::vector<std::vector<cv::Point>>{{p1, p2, p3, p4}}, -1, {0, 255, 0},
2);
}
cv::imwrite("output_rotated_detection.png", img);
mmdeploy_rotated_detector_release_result(rbboxes, res_count);
mmdeploy_rotated_detector_destroy(detector);
return 0;
}

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import mmcv
@ -13,6 +14,28 @@ from mmdeploy.utils import Task, get_input_shape
from .mmrotate import MMROTATE_TASK
def replace_RResize(pipelines):
"""Rename RResize to Resize.
args:
pipelines (list[dict]): Data pipeline configs.
Returns:
list: The new pipeline list with all RResize renamed to
Resize.
"""
pipelines = copy.deepcopy(pipelines)
for i, pipeline in enumerate(pipelines):
if pipeline['type'] == 'MultiScaleFlipAug':
assert 'transforms' in pipeline
pipeline['transforms'] = replace_RResize(pipeline['transforms'])
elif pipeline.type == 'RResize':
pipelines[i].type = 'Resize'
if 'keep_ratio' not in pipelines[i]:
pipelines[i]['keep_ratio'] = True # default value
return pipelines
def process_model_config(model_cfg: mmcv.Config,
imgs: Union[Sequence[str], Sequence[np.ndarray]],
input_shape: Optional[Sequence[int]] = None):
@ -30,10 +53,9 @@ def process_model_config(model_cfg: mmcv.Config,
"""
from mmdet.datasets import replace_ImageToTensor
cfg = model_cfg.copy()
cfg = copy.deepcopy(model_cfg)
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
# for static exporting
@ -320,6 +342,9 @@ class RotatedDetection(BaseTask):
"""
input_shape = get_input_shape(self.deploy_cfg)
model_cfg = process_model_config(self.model_cfg, [''], input_shape)
# rename sdk RResize -> Resize
model_cfg.data.test.pipeline = replace_RResize(
model_cfg.data.test.pipeline)
preprocess = model_cfg.data.test.pipeline
return preprocess
@ -329,7 +354,7 @@ class RotatedDetection(BaseTask):
Return:
dict: Composed of the postprocess information.
"""
postprocess = self.model_cfg.model.bbox_head
postprocess = self.model_cfg.model.test_cfg
return postprocess
def get_model_name(self) -> str:

View File

@ -161,6 +161,36 @@ class End2EndModel(BaseBackendModel):
out_file=out_file)
@__BACKEND_MODEL.register_module('sdk')
class SDKEnd2EndModel(End2EndModel):
"""SDK inference class, converts SDK output to mmrotate format."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list:
"""Run forward inference.
Args:
img (List[torch.Tensor]): A list contains input image(s)
in [N x C x H x W] format.
img_metas (Sequence[Sequence[dict]]): A list of meta info for
image(s).
*args: Other arguments.
**kwargs: Other key-pair arguments.
Returns:
list: A list contains predictions.
"""
results = []
dets, labels = self.wrapper.invoke(
[img[0].contiguous().detach().cpu().numpy()])[0]
dets_results = [dets[labels == i, :] for i in range(len(self.CLASSES))]
results.append(dets_results)
return results
def build_rotated_detection_model(model_files: Sequence[str],
model_cfg: Union[str, mmcv.Config],
deploy_cfg: Union[str, mmcv.Config],

View File

@ -77,5 +77,7 @@ SDK_TASK_MAP = {
Task.TEXT_RECOGNITION:
dict(component='CTCConvertor', cls_name='TextRecognizer'),
Task.POSE_DETECTION:
dict(component='Detector', cls_name='PoseDetector')
dict(component='Detector', cls_name='PoseDetector'),
Task.ROTATED_DETECTION:
dict(component='ResizeRBBox', cls_name='RotatedDetector')
}