change the npu code for roi align rotated (#3238)

pull/2993/merge
huangyuan64 2025-02-18 11:48:26 +08:00 committed by GitHub
parent a4a884dd88
commit 6ba6b7e679
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 23 additions and 7 deletions

View File

@ -11,11 +11,16 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output,
int64_t aligned_height_64 = aligned_height;
int64_t aligned_width_64 = aligned_width;
int64_t sampling_ratio_64 = sampling_ratio;
at::Tensor input_trans = input.permute({0, 2, 3, 1}).contiguous();
at::Tensor rois_trans = rois.permute({1, 0}).contiguous();
at::Tensor output_trans = output.permute({0, 2, 3, 1}).contiguous();
OpCommand cmd;
cmd.Name("RoiAlignRotated")
.Input(input)
.Input(rois)
.Output(output)
.Input(input_trans)
.Input(rois_trans)
.Output(output_trans)
.Attr("pooled_h", aligned_height_64)
.Attr("pooled_w", aligned_width_64)
.Attr("spatial_scale", spatial_scale)
@ -23,6 +28,9 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output,
.Attr("aligned", aligned)
.Attr("clockwise", clockwise)
.Run();
output_trans = output_trans.permute({0, 3, 1, 2}).contiguous();
output.copy_(output_trans);
}
void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
@ -33,16 +41,21 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
int64_t aligned_height_64 = aligned_height;
int64_t aligned_width_64 = aligned_width;
int64_t sampling_ratio_64 = sampling_ratio;
at::Tensor top_grad_trans = top_grad.permute({0, 2, 3, 1}).contiguous();
at::Tensor rois_trans = rois.permute({1, 0}).contiguous();
at::Tensor bottom_grad_trans = bottom_grad.permute({0, 2, 3, 1}).contiguous();
c10::SmallVector<int64_t, 8> y_grad_shape;
auto shape = bottom_grad.sizes();
auto shape = bottom_grad_trans.sizes();
for (uint64_t i = 0; i < shape.size(); i++) {
y_grad_shape.emplace_back(shape[i]);
}
OpCommand cmd;
cmd.Name("RoiAlignRotatedGrad")
.Input(top_grad)
.Input(rois)
.Output(bottom_grad)
.Input(top_grad_trans)
.Input(rois_trans)
.Output(bottom_grad_trans)
.Attr("y_grad_shape", y_grad_shape)
.Attr("pooled_h", aligned_width_64)
.Attr("pooled_w", aligned_height_64)
@ -51,6 +64,9 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
.Attr("aligned", aligned)
.Attr("clockwise", clockwise)
.Run();
bottom_grad_trans = bottom_grad_trans.permute({0, 3, 1, 2}).contiguous();
bottom_grad.copy_(bottom_grad_trans);
}
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,