diff --git a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp index 36bd9c7a8..f428311fe 100644 --- a/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/roi_pool_npu.cpp @@ -11,7 +11,6 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, int64_t pooled_channel = 1; at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor( {}, rois.options().dtype(at::kInt), rois); - OpCommand cmd; cmd.Name("RoiPoolingWithArgMax") .Input(input) @@ -27,8 +26,38 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output, .Run(); } +void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + int64_t pooled_height_64 = pooled_height; + int64_t pooled_width_64 = pooled_width; + int64_t pooled_channel = 1; + at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor( + {}, rois.options().dtype(at::kInt), rois); + at::Tensor x = at::ones_like(grad_input); + OpCommand cmd; + cmd.Name("RoiPoolingGradWithArgMax") + .Input(grad_output) + .Input(x) + .Input(rois) + .Input(roi_actual_num) + .Input(argmax) + .Output(grad_input) + .Attr("pooled_h", pooled_height_64) + .Attr("pooled_w", pooled_width_64) + .Attr("spatial_scale_h", spatial_scale) + .Attr("spatial_scale_w", spatial_scale) + .Attr("pool_channel", pooled_channel) + .Run(); +} + void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, Tensor argmax, int pooled_height, int pooled_width, float spatial_scale); +void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); + REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu); +REGISTER_NPU_IMPL(roi_pool_backward_impl, roi_pool_backward_npu); diff --git a/tests/test_ops/test_roi_pool.py b/tests/test_ops/test_roi_pool.py index be5ab9296..5ab04bce2 100644 --- a/tests/test_ops/test_roi_pool.py +++ b/tests/test_ops/test_roi_pool.py @@ -69,20 +69,13 @@ class TestRoiPool: np_output = np.array(output[0]) np_grad = np.array(output[1]) - if device == 'npu': - import torch_npu # noqa: F401 - x = torch.tensor(np_input, dtype=dtype).npu() - rois = torch.tensor(np_rois, dtype=dtype).npu() - output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) - assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) - else: - x = torch.tensor( - np_input, dtype=dtype, device=device, requires_grad=True) - rois = torch.tensor(np_rois, dtype=dtype, device=device) - output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) - output.backward(torch.ones_like(output)) - assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) - assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3) + x = torch.tensor( + np_input, dtype=dtype, device=device, requires_grad=True) + rois = torch.tensor(np_rois, dtype=dtype, device=device) + output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) + output.backward(torch.ones_like(output)) + assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) + assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3) @pytest.mark.parametrize('device', [ pytest.param( @@ -103,8 +96,8 @@ class TestRoiPool: pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE, - reason='MLU does not support for 64-bit floating point')), + IS_MLU_AVAILABLE or IS_NPU_AVAILABLE, + reason='MLU, NPU does not support for 64-bit floating point')), torch.half ]) def test_roipool_allclose(self, device, dtype):