mirror of https://github.com/open-mmlab/mmcv.git
change the npu code for roi align rotated (#3238)
parent
a4a884dd88
commit
6ba6b7e679
|
@ -11,11 +11,16 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output,
|
|||
int64_t aligned_height_64 = aligned_height;
|
||||
int64_t aligned_width_64 = aligned_width;
|
||||
int64_t sampling_ratio_64 = sampling_ratio;
|
||||
|
||||
at::Tensor input_trans = input.permute({0, 2, 3, 1}).contiguous();
|
||||
at::Tensor rois_trans = rois.permute({1, 0}).contiguous();
|
||||
at::Tensor output_trans = output.permute({0, 2, 3, 1}).contiguous();
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("RoiAlignRotated")
|
||||
.Input(input)
|
||||
.Input(rois)
|
||||
.Output(output)
|
||||
.Input(input_trans)
|
||||
.Input(rois_trans)
|
||||
.Output(output_trans)
|
||||
.Attr("pooled_h", aligned_height_64)
|
||||
.Attr("pooled_w", aligned_width_64)
|
||||
.Attr("spatial_scale", spatial_scale)
|
||||
|
@ -23,6 +28,9 @@ void roi_align_rotated_forward_npu(Tensor input, Tensor rois, Tensor output,
|
|||
.Attr("aligned", aligned)
|
||||
.Attr("clockwise", clockwise)
|
||||
.Run();
|
||||
|
||||
output_trans = output_trans.permute({0, 3, 1, 2}).contiguous();
|
||||
output.copy_(output_trans);
|
||||
}
|
||||
|
||||
void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
|
||||
|
@ -33,16 +41,21 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
|
|||
int64_t aligned_height_64 = aligned_height;
|
||||
int64_t aligned_width_64 = aligned_width;
|
||||
int64_t sampling_ratio_64 = sampling_ratio;
|
||||
|
||||
at::Tensor top_grad_trans = top_grad.permute({0, 2, 3, 1}).contiguous();
|
||||
at::Tensor rois_trans = rois.permute({1, 0}).contiguous();
|
||||
at::Tensor bottom_grad_trans = bottom_grad.permute({0, 2, 3, 1}).contiguous();
|
||||
|
||||
c10::SmallVector<int64_t, 8> y_grad_shape;
|
||||
auto shape = bottom_grad.sizes();
|
||||
auto shape = bottom_grad_trans.sizes();
|
||||
for (uint64_t i = 0; i < shape.size(); i++) {
|
||||
y_grad_shape.emplace_back(shape[i]);
|
||||
}
|
||||
OpCommand cmd;
|
||||
cmd.Name("RoiAlignRotatedGrad")
|
||||
.Input(top_grad)
|
||||
.Input(rois)
|
||||
.Output(bottom_grad)
|
||||
.Input(top_grad_trans)
|
||||
.Input(rois_trans)
|
||||
.Output(bottom_grad_trans)
|
||||
.Attr("y_grad_shape", y_grad_shape)
|
||||
.Attr("pooled_h", aligned_width_64)
|
||||
.Attr("pooled_w", aligned_height_64)
|
||||
|
@ -51,6 +64,9 @@ void roi_align_rotated_backward_npu(Tensor top_grad, Tensor rois,
|
|||
.Attr("aligned", aligned)
|
||||
.Attr("clockwise", clockwise)
|
||||
.Run();
|
||||
|
||||
bottom_grad_trans = bottom_grad_trans.permute({0, 3, 1, 2}).contiguous();
|
||||
bottom_grad.copy_(bottom_grad_trans);
|
||||
}
|
||||
|
||||
void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||||
|
|
Loading…
Reference in New Issue