diff --git a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp index d28d50fd4..aed43c2af 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_align_rotated_npu.cpp @@ -11,11 +11,16 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output, int64_t aligned_height_64 = aligned_height; int64_t aligned_width_64 = aligned_width; int64_t sampling_ratio_64 = sampling_ratio; + + at::Tensor input_trans = input.permute({0, 2, 3, 1}).contiguous(); + at::Tensor rois_trans = rois.permute({1, 0}).contiguous(); + at::Tensor output_trans = output.permute({0, 2, 3, 1}).contiguous(); + OpCommand cmd; cmd.Name("RoiAlignRotated") - .Input(input) - .Input(rois) - .Output(output) + .Input(input_trans) + .Input(rois_trans) + .Output(output_trans) .Attr("pooled_h", aligned_height_64) .Attr("pooled_w", aligned_width_64) .Attr("spatial_scale", spatial_scale) @@ -23,6 +28,9 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output, .Attr("aligned", aligned) .Attr("clockwise", clockwise) .Run(); + + output_trans = output_trans.permute({0, 3, 1, 2}).contiguous(); + output.copy_(output_trans); } void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois, @@ -33,16 +41,21 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois, int64_t aligned_height_64 = aligned_height; int64_t aligned_width_64 = aligned_width; int64_t sampling_ratio_64 = sampling_ratio; + + at::Tensor top_grad_trans = top_grad.permute({0, 2, 3, 1}).contiguous(); + at::Tensor rois_trans = rois.permute({1, 0}).contiguous(); + at::Tensor bottom_grad_trans = bottom_grad.permute({0, 2, 3, 1}).contiguous(); + c10::SmallVector y_grad_shape; - auto shape = bottom_grad.sizes(); + auto shape = bottom_grad_trans.sizes(); for (uint64_t i = 0; i < shape.size(); i++) { y_grad_shape.emplace_back(shape[i]); } OpCommand cmd; cmd.Name("RoiAlignRotatedGrad") - .Input(top_grad) - .Input(rois) - .Output(bottom_grad) + .Input(top_grad_trans) + .Input(rois_trans) + .Output(bottom_grad_trans) .Attr("y_grad_shape", y_grad_shape) .Attr("pooled_h", aligned_width_64) .Attr("pooled_w", aligned_height_64) @@ -51,6 +64,9 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois, .Attr("aligned", aligned) .Attr("clockwise", clockwise) .Run(); + + bottom_grad_trans = bottom_grad_trans.permute({0, 3, 1, 2}).contiguous(); + bottom_grad.copy_(bottom_grad_trans); } void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,