* * partition rtmdet * * add rtmdet deploy config * * add rtmdet deploy config * * modify rtmdet pipline anchor_generator's info dump * support rtmdet infer in sdk * fix a bug * * fix a bug in csrc/mmdeploy/preprocess/transform/normalize.cpp * * fix a bug * * update docs * * fix lint * * update several urls in docspull/1623/head
parent
db8de7ec1f
commit
71fc8e3241
|
@ -0,0 +1,19 @@
|
|||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']
|
||||
|
||||
onnx_config = dict(input_shape=[640, 640])
|
||||
|
||||
codebase_config = dict(model_type='rknn')
|
||||
|
||||
backend_config = dict(input_size_list=[[3, 640, 640]])
|
||||
|
||||
# rtmdet for rknn-toolkit and rknn-toolkit2
|
||||
# partition_config = dict(
|
||||
# type='rknn', # the partition policy name
|
||||
# apply_marks=True, # should always be set to True
|
||||
# partition_cfg=[
|
||||
# dict(
|
||||
# save_file='model.onnx', # name to save the partitioned onnx
|
||||
# start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
# end=['rtmdet_head:output'], # [mark_name:output, ...]
|
||||
# output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
|
||||
# ])
|
|
@ -0,0 +1,194 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "rtmdet_head.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "mmdeploy/core/model.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
RTMDetSepBNHead::RTMDetSepBNHead(const Value& cfg) : MMDetection(cfg) {
|
||||
auto init = [&]() -> Result<void> {
|
||||
auto model = cfg["context"]["model"].get<Model>();
|
||||
if (cfg.contains("params")) {
|
||||
nms_pre_ = cfg["params"].value("nms_pre", -1);
|
||||
score_thr_ = cfg["params"].value("score_thr", 0.02f);
|
||||
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
|
||||
max_per_img_ = cfg["params"].value("max_per_img", 100);
|
||||
iou_threshold_ = cfg["params"].contains("nms")
|
||||
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
|
||||
: 0.45f;
|
||||
if (cfg["params"].contains("anchor_generator")) {
|
||||
offset_ = cfg["params"]["anchor_generator"].value("offset", 0);
|
||||
from_value(cfg["params"]["anchor_generator"]["strides"], strides_);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
};
|
||||
init().value();
|
||||
}
|
||||
|
||||
Result<Value> RTMDetSepBNHead::operator()(const Value& prep_res, const Value& infer_res) {
|
||||
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
|
||||
try {
|
||||
std::vector<Tensor> cls_scores;
|
||||
std::vector<Tensor> bbox_preds;
|
||||
const Device kHost{0, 0};
|
||||
int i = 0;
|
||||
int divisor = infer_res.size() / 2;
|
||||
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) {
|
||||
auto pred_map = iter->get<Tensor>();
|
||||
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream()));
|
||||
if (i < divisor)
|
||||
cls_scores.push_back(_pred_map);
|
||||
else
|
||||
bbox_preds.push_back(_pred_map);
|
||||
i++;
|
||||
}
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], bbox_preds, cls_scores));
|
||||
return to_value(result);
|
||||
} catch (...) {
|
||||
return Status(eFail);
|
||||
}
|
||||
}
|
||||
|
||||
static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }
|
||||
|
||||
Result<Detections> RTMDetSepBNHead::GetBBoxes(const Value& prep_res,
|
||||
const std::vector<Tensor>& bbox_preds,
|
||||
const std::vector<Tensor>& cls_scores) const {
|
||||
MMDEPLOY_DEBUG("bbox_pred: {}, {}", bbox_preds[0].shape(), dets[0].data_type());
|
||||
MMDEPLOY_DEBUG("cls_score: {}, {}", scores[0].shape(), scores[0].data_type());
|
||||
|
||||
std::vector<float> filter_boxes;
|
||||
std::vector<float> obj_probs;
|
||||
std::vector<int> class_ids;
|
||||
|
||||
for (int i = 0; i < bbox_preds.size(); i++) {
|
||||
RTMDetFeatDeocde(bbox_preds[i], cls_scores[i], strides_[i], offset_, filter_boxes, obj_probs,
|
||||
class_ids);
|
||||
}
|
||||
|
||||
std::vector<int> indexArray;
|
||||
for (int i = 0; i < obj_probs.size(); ++i) {
|
||||
indexArray.push_back(i);
|
||||
}
|
||||
Sort(obj_probs, class_ids, indexArray);
|
||||
|
||||
Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT,
|
||||
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"});
|
||||
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>());
|
||||
NMS(dets, iou_threshold_, indexArray);
|
||||
|
||||
Detections objs;
|
||||
std::vector<float> scale_factor;
|
||||
if (prep_res.contains("scale_factor")) {
|
||||
from_value(prep_res["scale_factor"], scale_factor);
|
||||
} else {
|
||||
scale_factor = {1.f, 1.f, 1.f, 1.f};
|
||||
}
|
||||
int ori_width = prep_res["ori_shape"][2].get<int>();
|
||||
int ori_height = prep_res["ori_shape"][1].get<int>();
|
||||
auto det_ptr = dets.data<float>();
|
||||
for (int i = 0; i < indexArray.size(); ++i) {
|
||||
if (indexArray[i] == -1) {
|
||||
continue;
|
||||
}
|
||||
int j = indexArray[i];
|
||||
auto x1 = det_ptr[j * 4 + 0];
|
||||
auto y1 = det_ptr[j * 4 + 1];
|
||||
auto x2 = det_ptr[j * 4 + 2];
|
||||
auto y2 = det_ptr[j * 4 + 3];
|
||||
int label_id = class_ids[i];
|
||||
float score = obj_probs[i];
|
||||
|
||||
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);
|
||||
|
||||
auto rect =
|
||||
MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height, 0, 0);
|
||||
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
|
||||
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
|
||||
rect[3] - rect[1]);
|
||||
continue;
|
||||
}
|
||||
Detection det{};
|
||||
det.index = i;
|
||||
det.label_id = label_id;
|
||||
det.score = score;
|
||||
det.bbox = rect;
|
||||
objs.push_back(std::move(det));
|
||||
}
|
||||
|
||||
return objs;
|
||||
}
|
||||
|
||||
int RTMDetSepBNHead::RTMDetFeatDeocde(const Tensor& bbox_pred, const Tensor& cls_score,
|
||||
const float stride, const float offset,
|
||||
std::vector<float>& filter_boxes,
|
||||
std::vector<float>& obj_probs,
|
||||
std::vector<int>& class_ids) const {
|
||||
int cls_param_num = cls_score.shape(1);
|
||||
int feat_h = bbox_pred.shape(2);
|
||||
int feat_w = bbox_pred.shape(3);
|
||||
int feat_size = feat_h * feat_w;
|
||||
auto bbox_ptr = bbox_pred.data<float>();
|
||||
auto score_ptr = cls_score.data<float>(); // (b, c, h, w)
|
||||
int valid_count = 0;
|
||||
for (int i = 0; i < feat_h; i++) {
|
||||
for (int j = 0; j < feat_w; j++) {
|
||||
float max_score = score_ptr[i * feat_w + j];
|
||||
int class_id = 0;
|
||||
for (int k = 0; k < cls_param_num; k++) {
|
||||
float score = score_ptr[k * feat_size + i * feat_w + j];
|
||||
if (score > max_score) {
|
||||
max_score = score;
|
||||
class_id = k;
|
||||
}
|
||||
}
|
||||
max_score = sigmoid(max_score);
|
||||
if (max_score < score_thr_) continue;
|
||||
|
||||
obj_probs.push_back(max_score);
|
||||
class_ids.push_back(class_id);
|
||||
|
||||
float tl_x = bbox_ptr[0 * feat_size + i * feat_w + j];
|
||||
float tl_y = bbox_ptr[1 * feat_size + i * feat_w + j];
|
||||
float br_x = bbox_ptr[2 * feat_size + i * feat_w + j];
|
||||
float br_y = bbox_ptr[3 * feat_size + i * feat_w + j];
|
||||
|
||||
auto box = RTMDetdecode(tl_x, tl_y, br_x, br_y, stride, offset, j, i);
|
||||
|
||||
tl_x = box[0];
|
||||
tl_y = box[1];
|
||||
br_x = box[2];
|
||||
br_y = box[3];
|
||||
|
||||
filter_boxes.push_back(tl_x);
|
||||
filter_boxes.push_back(tl_y);
|
||||
filter_boxes.push_back(br_x);
|
||||
filter_boxes.push_back(br_y);
|
||||
valid_count++;
|
||||
}
|
||||
}
|
||||
return valid_count;
|
||||
}
|
||||
|
||||
std::array<float, 4> RTMDetSepBNHead::RTMDetdecode(float tl_x, float tl_y, float br_x, float br_y,
|
||||
float stride, float offset, int j, int i) const {
|
||||
tl_x = (offset + j) * stride - tl_x;
|
||||
tl_y = (offset + i) * stride - tl_y;
|
||||
br_x = (offset + j) * stride + br_x;
|
||||
br_y = (offset + i) * stride + br_y;
|
||||
return std::array<float, 4>{tl_x, tl_y, br_x, br_y};
|
||||
}
|
||||
|
||||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, RTMDetSepBNHead);
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
|
||||
#define MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
|
||||
|
||||
#include "mmdeploy/codebase/mmdet/mmdet.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
|
||||
namespace mmdeploy::mmdet {
|
||||
|
||||
class RTMDetSepBNHead : public MMDetection {
|
||||
public:
|
||||
explicit RTMDetSepBNHead(const Value& cfg);
|
||||
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
|
||||
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& bbox_preds,
|
||||
const std::vector<Tensor>& cls_scores) const;
|
||||
int RTMDetFeatDeocde(const Tensor& bbox_pred, const Tensor& cls_score, const float stride,
|
||||
const float offset, std::vector<float>& filter_boxes,
|
||||
std::vector<float>& obj_probs, std::vector<int>& class_ids) const;
|
||||
std::array<float, 4> RTMDetdecode(float tl_x, float tl_y, float br_x, float br_y, float stride,
|
||||
float offset, int j, int i) const;
|
||||
|
||||
private:
|
||||
float score_thr_{0.4f};
|
||||
int nms_pre_{1000};
|
||||
float iou_threshold_{0.45f};
|
||||
int min_bbox_size_{0};
|
||||
int max_per_img_{100};
|
||||
float offset_{0.0f};
|
||||
std::vector<float> strides_;
|
||||
};
|
||||
|
||||
} // namespace mmdeploy::mmdet
|
||||
|
||||
#endif // MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
|
|
@ -99,13 +99,14 @@ class Normalize : public Transform {
|
|||
Tensor dst;
|
||||
if (to_float_) {
|
||||
OUTCOME_TRY(normalize_.Apply(tensor, dst));
|
||||
data[key] = std::move(dst);
|
||||
} else if (to_rgb_) {
|
||||
auto src_mat = to_mat(tensor, PixelFormat::kBGR);
|
||||
Mat dst_mat;
|
||||
OUTCOME_TRY(cvt_color_.Apply(src_mat, dst_mat, PixelFormat::kBGR));
|
||||
dst = to_tensor(src_mat);
|
||||
data[key] = std::move(dst);
|
||||
}
|
||||
data[key] = std::move(dst);
|
||||
|
||||
for (auto& v : mean_) {
|
||||
data["img_norm_cfg"]["mean"].push_back(v);
|
||||
|
|
|
@ -156,6 +156,22 @@ label: 65, score: 0.95
|
|||
])
|
||||
```
|
||||
|
||||
RTMDet: you may paste the following partition configuration into [detection_rknn-int8_static-640x640.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/configs/mmdet/detection/detection_rknn-int8_static-640x640.py):
|
||||
|
||||
```python
|
||||
# rtmdet for rknn-toolkit and rknn-toolkit2
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True, # should always be set to True
|
||||
partition_cfg=[
|
||||
dict(
|
||||
save_file='model.onnx', # name to save the partitioned onnx
|
||||
start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
end=['rtmdet_head:output'], # [mark_name:output, ...]
|
||||
output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
|
||||
])
|
||||
```
|
||||
|
||||
RetinaNet & SSD & FSAF with rknn-toolkit2, you may paste the following partition configuration into [detection_rknn_static-320x320.py](https://github.com/open-mmlab/mmdeploy/tree/1.x/configs/mmdet/detection/detection_rknn_static-320x320.py). Users with rknn-toolkit can directly use default config.
|
||||
|
||||
```python
|
||||
|
|
|
@ -11,7 +11,7 @@ Notes:
|
|||
### Prerequisite
|
||||
|
||||
1. Install and build your target backend. You could refer to [ONNXRuntime-install](../05-supported-backends/onnxruntime.md), [TensorRT-install](../05-supported-backends/tensorrt.md), [ncnn-install](../05-supported-backends/ncnn.md), [PPLNN-install](../05-supported-backends/pplnn.md), [OpenVINO-install](../05-supported-backends/openvino.md) for more information.
|
||||
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/get_started.md#installation), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/2_get_started.md#installation).
|
||||
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/get_started.md#installation), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/get_started/install.md).
|
||||
|
||||
### Usage
|
||||
|
||||
|
|
|
@ -177,4 +177,4 @@ detection_tensorrt-int8_dynamic-320x320-1344x1344.py
|
|||
|
||||
## 6. How to write model config
|
||||
|
||||
According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/user_guides/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md), [MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/1_config.md).
|
||||
According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/user_guides/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md), [MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/config.md).
|
||||
|
|
|
@ -134,6 +134,24 @@ python tools/deploy.py \
|
|||
|
||||
```
|
||||
|
||||
- RTMDet
|
||||
|
||||
将下面的模型拆分配置写入到 [detection_rknn-int8_static-640x640.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/configs/mmdet/detection/detection_rknn-int8_static-640x640.py)
|
||||
|
||||
```python
|
||||
# rtmdet for rknn-toolkit and rknn-toolkit2
|
||||
partition_config = dict(
|
||||
type='rknn', # the partition policy name
|
||||
apply_marks=True, # should always be set to True
|
||||
partition_cfg=[
|
||||
dict(
|
||||
save_file='model.onnx', # name to save the partitioned onnx
|
||||
start=['detector_forward:input'], # [mark_name:input, ...]
|
||||
end=['rtmdet_head:output'], # [mark_name:output, ...]
|
||||
output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
|
||||
])
|
||||
```
|
||||
|
||||
- RetinaNet & SSD & FSAF with rknn-toolkit2
|
||||
|
||||
将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/1.x/configs/mmdet/detection/detection_rknn_static-320x320.py)。使用 rknn-toolkit 的用户则不用。
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
### 准备工作
|
||||
|
||||
1. 安装您的目标后端。 您可以参考 [ONNXRuntime-install](../05-supported-backends/onnxruntime.md) ,[TensorRT-install](../05-supported-backends/tensorrt.md) ,[ncnn-install](../05-supported-backends/ncnn.md) ,[PPLNN-install](../05-supported-backends/pplnn.md), [OpenVINO-install](../05-supported-backends/openvino.md)。
|
||||
2. 安装您的目标代码库。 您可以参考 [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/get_started.md#%E5%AE%89%E8%A3%85), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/zh_cn/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/2_get_started.md#installation)。
|
||||
2. 安装您的目标代码库。 您可以参考 [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/get_started.md#%E5%AE%89%E8%A3%85), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/zh_cn/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/get_started/install.md)。
|
||||
|
||||
### 使用方法
|
||||
|
||||
|
|
|
@ -187,4 +187,4 @@ detection_tensorrt-int8_dynamic-320x320-1344x1344.py
|
|||
|
||||
## 6. 如何编写模型配置文件
|
||||
|
||||
请根据模型具体任务的代码库,编写模型配置文件。 模型配置文件用于初始化模型,详情请参考[MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/user_guides/config.md),[MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md),[MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/1_config.md)。
|
||||
请根据模型具体任务的代码库,编写模型配置文件。 模型配置文件用于初始化模型,详情请参考[MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/user_guides/config.md),[MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md),[MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/config.md)。
|
||||
|
|
|
@ -313,7 +313,8 @@ class ObjectDetection(BaseTask):
|
|||
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
|
||||
type = 'ResizeInstanceMask' # for instance-seg
|
||||
if get_backend(self.deploy_cfg) == Backend.RKNN:
|
||||
if 'YOLO' in self.model_cfg.model.type:
|
||||
if 'YOLO' in self.model_cfg.model.type or \
|
||||
'RTMDet' in self.model_cfg.model.type:
|
||||
bbox_head = self.model_cfg.model.bbox_head
|
||||
type = bbox_head.type
|
||||
params['anchor_generator'] = bbox_head.get(
|
||||
|
|
|
@ -755,6 +755,14 @@ class RKNNModel(End2EndModel):
|
|||
batch_img_metas=metainfos,
|
||||
cfg=self.model_cfg._cfg_dict.model.test_cfg,
|
||||
rescale=True)
|
||||
elif head_cfg.type == 'RTMDetSepBNHead':
|
||||
divisor = round(len(outputs) / 2)
|
||||
ret = head.predict_by_feat(
|
||||
outputs[:divisor],
|
||||
outputs[divisor:],
|
||||
batch_img_metas=metainfos,
|
||||
cfg=self.model_cfg._cfg_dict.model.test_cfg,
|
||||
rescale=True)
|
||||
elif head_cfg.type in ('RetinaHead', 'SSDHead', 'FSAFHead'):
|
||||
partition_cfgs = get_partition_config(self.deploy_cfg)
|
||||
if partition_cfgs is None: # bbox decoding done in rknn model
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmengine.structures import InstanceData
|
|||
from torch import Tensor
|
||||
|
||||
from mmdeploy.codebase.mmdet import get_post_processing_params
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.mmcv.ops import multiclass_nms
|
||||
|
||||
|
||||
|
@ -51,6 +51,12 @@ def rtmdet_head__predict_by_feat(self,
|
|||
tensor in the tuple is (N, num_box), and each element
|
||||
represents the class label of the corresponding box.
|
||||
"""
|
||||
|
||||
@mark('rtmdet_head', inputs=['cls_scores', 'bbox_preds'])
|
||||
def __mark_pred_maps(cls_scores, bbox_preds):
|
||||
return cls_scores, bbox_preds
|
||||
|
||||
cls_scores, bbox_preds = __mark_pred_maps(cls_scores, bbox_preds)
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
assert len(cls_scores) == len(bbox_preds)
|
||||
device = cls_scores[0].device
|
||||
|
|
Loading…
Reference in New Issue