mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix roi_align npu bug (#2862)
parent
43c5c76f9b
commit
0a2f60ba01
|
@ -7,13 +7,14 @@ void roi_align_forward_npu(Tensor input, Tensor rois, Tensor output,
|
||||||
Tensor argmax_y, Tensor argmax_x, int aligned_height,
|
Tensor argmax_y, Tensor argmax_x, int aligned_height,
|
||||||
int aligned_width, float spatial_scale,
|
int aligned_width, float spatial_scale,
|
||||||
int sampling_ratio, int pool_mode, bool aligned) {
|
int sampling_ratio, int pool_mode, bool aligned) {
|
||||||
|
int64_t roi_end_mode = 2;
|
||||||
if (!aligned) {
|
if (!aligned) {
|
||||||
LOG(WARNING) << "The [aligned] attr in roi_align op is false";
|
LOG(WARNING) << "The [aligned] attr in roi_align op is false";
|
||||||
|
roi_end_mode = 0;
|
||||||
}
|
}
|
||||||
int64_t aligned_height_64 = aligned_height;
|
int64_t aligned_height_64 = aligned_height;
|
||||||
int64_t aligned_width_64 = aligned_width;
|
int64_t aligned_width_64 = aligned_width;
|
||||||
int64_t sampling_ratio_64 = sampling_ratio;
|
int64_t sampling_ratio_64 = sampling_ratio;
|
||||||
int64_t roi_end_mode = 0;
|
|
||||||
OpCommand cmd;
|
OpCommand cmd;
|
||||||
cmd.Name("ROIAlign")
|
cmd.Name("ROIAlign")
|
||||||
.Input(input)
|
.Input(input)
|
||||||
|
@ -35,7 +36,11 @@ void roi_align_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax_y,
|
||||||
int64_t aligned_height_64 = aligned_height;
|
int64_t aligned_height_64 = aligned_height;
|
||||||
int64_t aligned_width_64 = aligned_width;
|
int64_t aligned_width_64 = aligned_width;
|
||||||
int64_t sampling_ratio_64 = sampling_ratio;
|
int64_t sampling_ratio_64 = sampling_ratio;
|
||||||
int64_t roi_end_mode = 0;
|
int64_t roi_end_mode = 2;
|
||||||
|
if (!aligned) {
|
||||||
|
LOG(WARNING) << "The [aligned] attr in roi_align_grad op is false";
|
||||||
|
roi_end_mode = 0;
|
||||||
|
}
|
||||||
c10::SmallVector<int64_t, SIZE> xdiff_shape =
|
c10::SmallVector<int64_t, SIZE> xdiff_shape =
|
||||||
at_npu::native::array_to_small_vector(grad_input.sizes());
|
at_npu::native::array_to_small_vector(grad_input.sizes());
|
||||||
OpCommand cmd;
|
OpCommand cmd;
|
||||||
|
|
Loading…
Reference in New Issue