[Feature] Add Ascend support for RoIPoolGrad op (#2569)

* add roipoolgrad adapter

* amend
pull/2541/head^2
xinliang123 2023-02-06 18:29:59 +08:00 committed by GitHub
parent 04e8727adf
commit 615d28176a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 17 deletions

View File

@ -11,7 +11,6 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
OpCommand cmd;
cmd.Name("RoiPoolingWithArgMax")
.Input(input)
@ -27,8 +26,38 @@ void roi_pool_forward_npu(Tensor input, Tensor rois, Tensor output,
.Run();
}
void roi_pool_backward_npu(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale) {
int64_t pooled_height_64 = pooled_height;
int64_t pooled_width_64 = pooled_width;
int64_t pooled_channel = 1;
at::Tensor roi_actual_num = at_npu::native::OpPreparation::ApplyTensor(
{}, rois.options().dtype(at::kInt), rois);
at::Tensor x = at::ones_like(grad_input);
OpCommand cmd;
cmd.Name("RoiPoolingGradWithArgMax")
.Input(grad_output)
.Input(x)
.Input(rois)
.Input(roi_actual_num)
.Input(argmax)
.Output(grad_input)
.Attr("pooled_h", pooled_height_64)
.Attr("pooled_w", pooled_width_64)
.Attr("spatial_scale_h", spatial_scale)
.Attr("spatial_scale_w", spatial_scale)
.Attr("pool_channel", pooled_channel)
.Run();
}
void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height, int pooled_width,
float spatial_scale);
void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height,
int pooled_width, float spatial_scale);
REGISTER_NPU_IMPL(roi_pool_forward_impl, roi_pool_forward_npu);
REGISTER_NPU_IMPL(roi_pool_backward_impl, roi_pool_backward_npu);

View File

@ -69,20 +69,13 @@ class TestRoiPool:
np_output = np.array(output[0])
np_grad = np.array(output[1])
if device == 'npu':
import torch_npu # noqa: F401
x = torch.tensor(np_input, dtype=dtype).npu()
rois = torch.tensor(np_rois, dtype=dtype).npu()
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
else:
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
x = torch.tensor(
np_input, dtype=dtype, device=device, requires_grad=True)
rois = torch.tensor(np_rois, dtype=dtype, device=device)
output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale)
output.backward(torch.ones_like(output))
assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3)
assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
@pytest.mark.parametrize('device', [
pytest.param(
@ -103,8 +96,8 @@ class TestRoiPool:
pytest.param(
torch.double,
marks=pytest.mark.skipif(
IS_MLU_AVAILABLE,
reason='MLU does not support for 64-bit floating point')),
IS_MLU_AVAILABLE or IS_NPU_AVAILABLE,
reason='MLU, NPU does not support for 64-bit floating point')),
torch.half
])
def test_roipool_allclose(self, device, dtype):