mirror of https://github.com/open-mmlab/mmcv.git
[NPU] 更新3个 box_iou 相关 NPU 算子适配层 (#3276)
* Update boxes_overlap_bev_npu.cpp 更换 boxes_overlap_bev_npu 的底层 aclnn 调用算子 * Update box_iou_rotated.py 更新 box_iou_rotated python 的 npu 相关适配层 * Update box_iou_rotated_npu.cpp Update box_iou_rotated_npu.cpp * Create diff_iou_rotated_npu.cpp diff_iou_rotated 算子添加 npu 适配 * Update box_iou_rotated_npu.cpp * Update diff_iou_rotated_npu.cpp * Update boxes_overlap_bev_npu.cpp * Update boxes_overlap_bev_npu.cpp * Update diff_iou_rotated_npu.cpp * Update boxes_overlap_bev_npu.cppmain
parent
f19d3e771c
commit
90d83c94cf
|
@ -142,11 +142,6 @@ def box_iou_rotated(bboxes1: torch.Tensor,
|
|||
flip_mat[-1] = -1
|
||||
bboxes1 = bboxes1 * flip_mat
|
||||
bboxes2 = bboxes2 * flip_mat
|
||||
if bboxes1.device.type == 'npu':
|
||||
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
|
||||
scale_mat[-1] = 1.0 / 0.01745329252
|
||||
bboxes1 = bboxes1 * scale_mat
|
||||
bboxes2 = bboxes2 * scale_mat
|
||||
bboxes1 = bboxes1.contiguous()
|
||||
bboxes2 = bboxes2.contiguous()
|
||||
ext_module.box_iou_rotated(
|
||||
|
|
|
@ -8,40 +8,10 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
|||
|
||||
void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
|
||||
const int mode_flag, const bool aligned) {
|
||||
at::Tensor boxes = at::ones_like(boxes1);
|
||||
at::Tensor query_boxes = at::ones_like(boxes2);
|
||||
boxes = boxes1.transpose(0, 1).unsqueeze(0);
|
||||
query_boxes = boxes2.transpose(0, 1).unsqueeze(0);
|
||||
|
||||
bool is_trans = false;
|
||||
string modeStr = "iou";
|
||||
if (mode_flag == 1) {
|
||||
modeStr = "iof";
|
||||
}
|
||||
bool is_cross = true;
|
||||
if (aligned) {
|
||||
is_cross = false;
|
||||
}
|
||||
float v_threshold = 0;
|
||||
float e_threshold = 0;
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("RotatedIou")
|
||||
.Input(boxes)
|
||||
.Input(query_boxes)
|
||||
.Output(ious)
|
||||
.Attr("trans", is_trans)
|
||||
.Attr("mode", modeStr)
|
||||
.Attr("is_cross", is_cross)
|
||||
.Attr("v_threshold", v_threshold)
|
||||
.Attr("e_threshold", e_threshold)
|
||||
.Run();
|
||||
|
||||
if (is_cross) {
|
||||
ious = ious.view({boxes1.size(0), boxes2.size(0)});
|
||||
} else {
|
||||
ious = ious.view({boxes1.size(0), 1});
|
||||
}
|
||||
TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)");
|
||||
TORCH_CHECK(boxes1.size(1) == 5, "boxes1 must be 2D tensor (N, 5)");
|
||||
EXEC_NPU_CMD(aclnnBoxIou, boxes1, boxes2, mode_flag, aligned, ious);
|
||||
return;
|
||||
}
|
||||
|
||||
REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu);
|
||||
|
|
|
@ -3,6 +3,11 @@
|
|||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
constexpr int32_t MODE_FLAG_OVERLAP = 0;
|
||||
constexpr int32_t FORMAT_FLAG_XYZWHDR = 3;
|
||||
}; // namespace
|
||||
|
||||
void iou3d_boxes_overlap_bev_forward_impl(const int num_a, const Tensor boxes_a,
|
||||
const int num_b, const Tensor boxes_b,
|
||||
Tensor ans_overlap);
|
||||
|
@ -13,13 +18,13 @@ void iou3d_boxes_overlap_bev_forward_npu(const int num_a, const Tensor boxes_a,
|
|||
TORCH_CHECK(boxes_a.size(1) == 7, "boxes_a must be 2D tensor (N, 7)");
|
||||
TORCH_CHECK(boxes_b.size(1) == 7, "boxes_b must be 2D tensor (N, 7)");
|
||||
|
||||
auto format_flag = 3;
|
||||
auto clockwise = true;
|
||||
auto mode_flag = 0;
|
||||
auto aligned = false;
|
||||
auto margin = 1e-2;
|
||||
bool aligned = false;
|
||||
double margin = 1e-5;
|
||||
int32_t mode_flag = MODE_FLAG_OVERLAP;
|
||||
int32_t format_flag = FORMAT_FLAG_XYZWHDR;
|
||||
|
||||
EXEC_NPU_CMD(aclnnBoxesOverlapBev, boxes_a, boxes_b, format_flag, clockwise,
|
||||
EXEC_NPU_CMD(aclnnBoxesOverlapBevV1, boxes_a, boxes_b, format_flag, clockwise,
|
||||
mode_flag, aligned, margin, ans_overlap);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
Tensor diff_iou_rotated_sort_vertices_npu(Tensor vertices, Tensor mask,
|
||||
Tensor num_valid) {
|
||||
TORCH_CHECK(vertices.dim() == 4,
|
||||
"vertices must be a 4D Tensor, but got: ", vertices.dim());
|
||||
TORCH_CHECK(mask.dim() == 3,
|
||||
"mask must be a 3D Tensor, but got: ", mask.dim());
|
||||
TORCH_CHECK(num_valid.dim() == 2,
|
||||
"num_valid must be a 2D Tensor, but got: ", num_valid.dim());
|
||||
|
||||
uint32_t B = vertices.size(0);
|
||||
uint32_t N = vertices.size(1);
|
||||
|
||||
at::Tensor sortedIdx = at::empty({B, N, 9}, num_valid.options());
|
||||
at::Tensor mask_fp = mask.to(at::kFloat);
|
||||
|
||||
EXEC_NPU_CMD(aclnnDiffIouRotatedSortVertices, vertices, mask_fp, num_valid,
|
||||
sortedIdx);
|
||||
|
||||
return sortedIdx;
|
||||
}
|
||||
|
||||
Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
|
||||
Tensor num_valid);
|
||||
|
||||
REGISTER_NPU_IMPL(diff_iou_rotated_sort_vertices_forward_impl,
|
||||
diff_iou_rotated_sort_vertices_npu);
|
Loading…
Reference in New Issue