mirror of https://github.com/open-mmlab/mmcv.git
Add multi_scale_deform_attn_grad op adapter for NPU (#3042)
parent
cd05d71254
commit
780ffed9f3
|
@ -55,7 +55,7 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value,
|
|||
|
||||
c10::SmallVector<int64_t, 3> output_size = {
|
||||
value.size(0), sampling_locations.size(1), value.size(2) * value.size(3)};
|
||||
at::Tensor output = at::empty(output_size, value_fp32.options());
|
||||
at::Tensor output = at::zeros(output_size, value_fp32.options());
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("MultiScaleDeformableAttnFunction")
|
||||
|
@ -75,3 +75,60 @@ Tensor ms_deform_attn_forward_npu(const Tensor &value,
|
|||
}
|
||||
|
||||
REGISTER_NPU_IMPL(ms_deform_attn_impl_forward, ms_deform_attn_forward_npu);
|
||||
|
||||
void ms_deform_attn_impl_backward(
|
||||
const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index, const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight, const Tensor &grad_output, Tensor &grad_value,
|
||||
Tensor &grad_sampling_loc, Tensor &grad_attn_weight,
|
||||
const int im2col_step);
|
||||
|
||||
void ms_deform_attn_backward_npu(const Tensor &value, const Tensor &spatial_shapes,
|
||||
const Tensor &level_start_index,
|
||||
const Tensor &sampling_loc,
|
||||
const Tensor &attn_weight,
|
||||
const Tensor &grad_output, Tensor &grad_value,
|
||||
Tensor &grad_sampling_loc,
|
||||
Tensor &grad_attn_weight, const int im2col_step) {
|
||||
check_support(value, attn_weight);
|
||||
at::Tensor value_fp32 = value;
|
||||
at::Tensor spatial_shapes_int32 = spatial_shapes;
|
||||
at::Tensor level_start_index_int32 = level_start_index;
|
||||
at::Tensor sampling_loc_fp32 = sampling_loc.transpose(4, 5).contiguous();
|
||||
at::Tensor attn_weight_fp32 = attn_weight;
|
||||
at::Tensor grad_output_fp32 = grad_output;
|
||||
if (value.scalar_type() != at::kFloat) {
|
||||
value_fp32 = value.to(at::kFloat);
|
||||
}
|
||||
if (spatial_shapes.scalar_type() != at::kInt) {
|
||||
spatial_shapes_int32 = spatial_shapes.to(at::kInt);
|
||||
}
|
||||
if (level_start_index.scalar_type() != at::kInt) {
|
||||
level_start_index_int32 = level_start_index.to(at::kInt);
|
||||
}
|
||||
if (sampling_loc.scalar_type() != at::kFloat) {
|
||||
sampling_loc_fp32 = sampling_loc_fp32.to(at::kFloat);
|
||||
}
|
||||
if (attn_weight.scalar_type() != at::kFloat) {
|
||||
attn_weight_fp32 = attn_weight.to(at::kFloat);
|
||||
}
|
||||
if (grad_output.scalar_type() != at::kFloat) {
|
||||
grad_output_fp32 = grad_output.to(at::kFloat);
|
||||
}
|
||||
|
||||
OpCommand cmd;
|
||||
cmd.Name("MultiScaleDeformableAttentionGrad")
|
||||
.Input(value_fp32)
|
||||
.Input(spatial_shapes_int32)
|
||||
.Input(level_start_index_int32)
|
||||
.Input(sampling_loc_fp32)
|
||||
.Input(attn_weight_fp32)
|
||||
.Input(grad_output_fp32)
|
||||
.Output(grad_value)
|
||||
.Output(grad_sampling_loc)
|
||||
.Output(grad_attn_weight)
|
||||
.Run();
|
||||
grad_sampling_loc = grad_sampling_loc.transpose(4, 5).contiguous();
|
||||
}
|
||||
|
||||
REGISTER_NPU_IMPL(ms_deform_attn_impl_backward, ms_deform_attn_backward_npu);
|
||||
|
|
|
@ -337,3 +337,67 @@ def test_gradient_numerical(channels,
|
|||
im2col_step),
|
||||
eps=eps,
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_NPU_AVAILABLE, reason='requires NPU support')
|
||||
def test_backward_equal_with_pytorch_npu():
|
||||
N, M, D = 6, 4, 8
|
||||
Lq, L, P = 10000, 4, 8
|
||||
shapes = torch.as_tensor([(60, 40), (30, 20), (16, 24), (53, 32)],
|
||||
dtype=torch.int32)
|
||||
level_start_index = torch.cat((shapes.new_zeros(
|
||||
(1, )), shapes.prod(1).cumsum(0)[:-1]))
|
||||
S = sum((H * W).item() for H, W in shapes)
|
||||
|
||||
torch.manual_seed(3)
|
||||
value = torch.rand(N, S, M, D) * 0.01
|
||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2)
|
||||
attention_weights = torch.rand(N, Lq, M, L, P) + 1e-5
|
||||
attention_weights /= attention_weights.sum(
|
||||
-1, keepdim=True).sum(
|
||||
-2, keepdim=True)
|
||||
im2col_step = 2
|
||||
value.requires_grad = True
|
||||
sampling_locations.requires_grad = True
|
||||
attention_weights.requires_grad = True
|
||||
output_pytorch = multi_scale_deformable_attn_pytorch(
|
||||
value.float(), shapes, sampling_locations.float(),
|
||||
attention_weights.float())
|
||||
grad_output_pytorch = torch.ones_like(output_pytorch)
|
||||
output_pytorch.backward(grad_output_pytorch)
|
||||
grad_value = value.grad.detach().cpu()
|
||||
grad_location = sampling_locations.grad.detach().cpu()
|
||||
grad_attn_weight = attention_weights.grad.detach().cpu()
|
||||
|
||||
value_npu = value.npu()
|
||||
shapes_npu = shapes.npu()
|
||||
level_start_index_npu = level_start_index.npu()
|
||||
sampling_locations_npu = sampling_locations.npu()
|
||||
attention_weights_npu = attention_weights.npu()
|
||||
output_npu = MultiScaleDeformableAttnFunction.apply(
|
||||
value_npu.float(), shapes_npu, level_start_index_npu,
|
||||
sampling_locations_npu.float(), attention_weights_npu.float(),
|
||||
im2col_step)
|
||||
grad_output_npu = torch.ones_like(output_npu)
|
||||
output_npu.backward(grad_output_npu)
|
||||
grad_value_npu = value_npu.grad.detach().cpu()
|
||||
grad_location_npu = sampling_locations_npu.grad.detach().cpu()
|
||||
grad_attn_weight_npu = attention_weights_npu.grad.detach().cpu()
|
||||
assert torch.allclose(grad_value_npu, grad_value)
|
||||
max_abs_err_1 = (grad_value_npu - grad_value).abs().max()
|
||||
max_rel_err_1 = ((grad_value_npu - grad_value).abs() /
|
||||
grad_value.abs()).max()
|
||||
assert max_abs_err_1 < 1e-5
|
||||
assert max_rel_err_1 < 1e-4
|
||||
assert torch.allclose(grad_location_npu, grad_location)
|
||||
max_abs_err_2 = (grad_location_npu - grad_location).abs().max()
|
||||
max_rel_err_2 = ((grad_location_npu - grad_location).abs() /
|
||||
grad_location.abs()).max()
|
||||
assert max_abs_err_2 < 1e-5
|
||||
assert max_rel_err_2 < 1e-4
|
||||
assert torch.allclose(grad_attn_weight_npu, grad_attn_weight)
|
||||
max_abs_err_3 = (grad_attn_weight_npu - grad_attn_weight).abs().max()
|
||||
max_rel_err_3 = ((grad_attn_weight_npu - grad_attn_weight).abs() /
|
||||
grad_attn_weight.abs()).max()
|
||||
assert max_abs_err_3 < 1e-5
|
||||
assert max_rel_err_3 < 1e-4
|
||||
|
|
Loading…
Reference in New Issue