[Enhancement] Support RTMDet-Ins (#1867)
* support RTMDet-Ins * optimization * avoid out of boundary --------- Co-authored-by: lvhan028 <lvhan_028@163.com>pull/1922/head
parent
43383e83ff
commit
9f9b3a81b2
|
@ -3,6 +3,8 @@
|
|||
#include "mmdeploy/core/registry.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/experimental/module_adapter.h"
|
||||
#include "mmdeploy/operation/managed.h"
|
||||
#include "mmdeploy/operation/vision.h"
|
||||
#include "object_detection.h"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include "opencv_utils.h"
|
||||
|
@ -14,7 +16,10 @@ class ResizeInstanceMask : public ResizeBBox {
|
|||
explicit ResizeInstanceMask(const Value& cfg) : ResizeBBox(cfg) {
|
||||
if (cfg.contains("params")) {
|
||||
mask_thr_binary_ = cfg["params"].value("mask_thr_binary", mask_thr_binary_);
|
||||
is_rcnn_ = cfg["params"].contains("rcnn");
|
||||
}
|
||||
operation::Context ctx(device_, stream_);
|
||||
warp_affine_ = operation::Managed<operation::WarpAffine>::Create("bilinear");
|
||||
}
|
||||
|
||||
// TODO: remove duplication
|
||||
|
@ -53,15 +58,17 @@ class ResizeInstanceMask : public ResizeBBox {
|
|||
|
||||
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
|
||||
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));
|
||||
OUTCOME_TRY(auto _masks, MakeAvailableOnDevice(masks, kHost, stream()));
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
// Note: `masks` are kept on device to avoid data copy overhead from device to host.
|
||||
// refer to https://github.com/open-mmlab/mmdeploy/issues/1849
|
||||
// OUTCOME_TRY(auto _masks, MakeAvailableOnDevice(masks, kHost, stream()));
|
||||
// OUTCOME_TRY(stream().Wait());
|
||||
|
||||
OUTCOME_TRY(auto result, DispatchGetBBoxes(prep_res["img_metas"], _dets, _labels));
|
||||
|
||||
auto ori_w = prep_res["img_metas"]["ori_shape"][2].get<int>();
|
||||
auto ori_h = prep_res["img_metas"]["ori_shape"][1].get<int>();
|
||||
|
||||
ProcessMasks(result, _masks, ori_w, ori_h);
|
||||
ProcessMasks(result, masks, _dets, ori_w, ori_h);
|
||||
|
||||
return to_value(result);
|
||||
} catch (const std::exception& e) {
|
||||
|
@ -71,14 +78,23 @@ class ResizeInstanceMask : public ResizeBBox {
|
|||
}
|
||||
|
||||
protected:
|
||||
void ProcessMasks(Detections& result, Tensor cpu_masks, int img_w, int img_h) const {
|
||||
auto shape = TensorShape{cpu_masks.shape(1), cpu_masks.shape(2), cpu_masks.shape(3)};
|
||||
cpu_masks.Reshape(shape);
|
||||
MMDEPLOY_DEBUG("{}, {}", cpu_masks.shape(), cpu_masks.data_type());
|
||||
Result<void> ProcessMasks(Detections& result, Tensor d_mask, Tensor cpu_dets, int img_w,
|
||||
int img_h) {
|
||||
d_mask.Squeeze(0);
|
||||
cpu_dets.Squeeze(0);
|
||||
|
||||
::mmdeploy::operation::Context ctx(device_, stream_);
|
||||
|
||||
std::vector<Tensor> warped_masks;
|
||||
warped_masks.reserve(result.size());
|
||||
|
||||
std::vector<Tensor> h_warped_masks;
|
||||
h_warped_masks.reserve(result.size());
|
||||
|
||||
for (auto& det : result) {
|
||||
auto mask = cpu_masks.Slice(det.index);
|
||||
cv::Mat mask_mat((int)mask.shape(1), (int)mask.shape(2), CV_32F, mask.data<float>());
|
||||
cv::Mat warped_mask;
|
||||
auto mask = d_mask.Slice(det.index);
|
||||
auto mask_height = (int)mask.shape(1);
|
||||
auto mask_width = (int)mask.shape(2);
|
||||
auto& bbox = det.bbox;
|
||||
// same as mmdet with skip_empty = True
|
||||
auto x0 = std::max(std::floor(bbox[0]) - 1, 0.f);
|
||||
|
@ -88,22 +104,67 @@ class ResizeInstanceMask : public ResizeBBox {
|
|||
auto width = static_cast<int>(x1 - x0);
|
||||
auto height = static_cast<int>(y1 - y0);
|
||||
// params align_corners = False
|
||||
auto fx = (float)mask_mat.cols / (bbox[2] - bbox[0]);
|
||||
auto fy = (float)mask_mat.rows / (bbox[3] - bbox[1]);
|
||||
auto tx = (x0 + .5f - bbox[0]) * fx - .5f;
|
||||
auto ty = (y0 + .5f - bbox[1]) * fy - .5f;
|
||||
|
||||
cv::Mat m = (cv::Mat_<float>(2, 3) << fx, 0, tx, 0, fy, ty);
|
||||
cv::warpAffine(mask_mat, warped_mask, m, cv::Size{width, height},
|
||||
cv::INTER_LINEAR | cv::WARP_INVERSE_MAP);
|
||||
warped_mask = warped_mask > mask_thr_binary_;
|
||||
|
||||
det.mask = Mat(height, width, PixelFormat::kGRAYSCALE, DataType::kINT8,
|
||||
std::shared_ptr<void>(warped_mask.data, [mat = warped_mask](void*) {}));
|
||||
}
|
||||
float fx;
|
||||
float fy;
|
||||
float tx;
|
||||
float ty;
|
||||
if (is_rcnn_) { // mask r-cnn
|
||||
fx = (float)mask_width / (bbox[2] - bbox[0]);
|
||||
fy = (float)mask_height / (bbox[3] - bbox[1]);
|
||||
tx = (x0 + .5f - bbox[0]) * fx - .5f;
|
||||
ty = (y0 + .5f - bbox[1]) * fy - .5f;
|
||||
} else { // rtmdet-ins
|
||||
auto raw_bbox = cpu_dets.Slice(det.index);
|
||||
auto raw_bbox_data = raw_bbox.data<float>();
|
||||
fx = (raw_bbox_data[2] - raw_bbox_data[0]) / (bbox[2] - bbox[0]);
|
||||
fy = (raw_bbox_data[3] - raw_bbox_data[1]) / (bbox[3] - bbox[1]);
|
||||
tx = (x0 + .5f - bbox[0]) * fx - .5f + raw_bbox_data[0];
|
||||
ty = (y0 + .5f - bbox[1]) * fy - .5f + raw_bbox_data[1];
|
||||
}
|
||||
|
||||
float affine_matrix[] = {fx, 0, tx, 0, fy, ty};
|
||||
|
||||
cv::Mat_<float> m(2, 3, affine_matrix);
|
||||
cv::invertAffineTransform(m, m);
|
||||
|
||||
mask.Reshape({1, mask_height, mask_width, 1});
|
||||
|
||||
Tensor& warped_mask = warped_masks.emplace_back();
|
||||
OUTCOME_TRY(warp_affine_.Apply(mask, warped_mask, affine_matrix, height, width));
|
||||
|
||||
OUTCOME_TRY(CopyToHost(warped_mask, h_warped_masks.emplace_back()));
|
||||
}
|
||||
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
|
||||
for (size_t i = 0; i < h_warped_masks.size(); ++i) {
|
||||
result[i].mask = ThresholdMask(h_warped_masks[i]);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<void> CopyToHost(const Tensor& src, Tensor& dst) {
|
||||
if (src.device() == kHost) {
|
||||
dst = src;
|
||||
return success();
|
||||
}
|
||||
dst = TensorDesc{kHost, src.data_type(), src.shape()};
|
||||
OUTCOME_TRY(stream_.Copy(src.buffer(), dst.buffer(), dst.byte_size()));
|
||||
return success();
|
||||
}
|
||||
|
||||
Mat ThresholdMask(const Tensor& h_mask) const {
|
||||
cv::Mat warped_mat = cpu::Tensor2CVMat(h_mask);
|
||||
warped_mat = warped_mat > mask_thr_binary_;
|
||||
return {warped_mat.rows, warped_mat.cols, PixelFormat::kGRAYSCALE, DataType::kINT8,
|
||||
std::shared_ptr<void>(warped_mat.data, [mat = warped_mat](void*) {})};
|
||||
}
|
||||
|
||||
private:
|
||||
operation::Managed<operation::WarpAffine> warp_affine_;
|
||||
float mask_thr_binary_{.5f};
|
||||
bool is_rcnn_{true};
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, ResizeInstanceMask);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import cv2
|
||||
from mmdeploy_python import Detector
|
||||
|
@ -36,7 +37,10 @@ def main():
|
|||
if masks[index].size:
|
||||
mask = masks[index]
|
||||
blue, green, red = cv2.split(img)
|
||||
mask_img = blue[top:top + mask.shape[0], left:left + mask.shape[1]]
|
||||
|
||||
x0 = int(max(math.floor(bbox[0]) - 1, 0))
|
||||
y0 = int(max(math.floor(bbox[1]) - 1, 0))
|
||||
mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
|
||||
cv2.bitwise_or(mask, mask_img, mask_img)
|
||||
img = cv2.merge([blue, green, red])
|
||||
|
||||
|
|
|
@ -311,6 +311,7 @@ class ObjectDetection(BaseTask):
|
|||
params['score_thr'] = params['rcnn']['score_thr']
|
||||
if 'mask_thr_binary' in params['rcnn']:
|
||||
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
|
||||
if 'mask_thr_binary' in params:
|
||||
type = 'ResizeInstanceMask' # for instance-seg
|
||||
if get_backend(self.deploy_cfg) == Backend.RKNN:
|
||||
if 'YOLO' in self.model_cfg.model.type or \
|
||||
|
|
Loading…
Reference in New Issue