[Feature] Support DCNv1 on Ascend device (#2480)

* update lately npu modification--DCNv1

update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* update lately npu modification--DCNv1

* check code

* Add ops to EN/ZH documents
pull/2528/head
Ryan Wang 2023-01-06 14:17:29 +08:00 committed by GitHub
parent a9535377bf
commit f76de9077b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 3 deletions

View File

@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
| ConvexIoU | | √ | | | |
| CornerPool | | √ | | | |
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | |

View File

@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| ConvexIoU | | √ | | | |
| CornerPool | | √ | | | |
| Correlation | | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | |
| Deformable Convolution v1/v2 | √ | √ | | | |
| Deformable RoIPool | | √ | √ | | √ |
| DiffIoURotated | | √ | | | |
| DynamicScatter | | √ | | | |

View File

@ -12,6 +12,7 @@ from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
from .modulated_deform_conv import ModulatedDeformConv2dFunction
ext_module = ext_loader.load_ext('_ext', [
'deform_conv_forward', 'deform_conv_backward_input',
@ -46,6 +47,23 @@ class DeformConv2dFunction(Function):
bias_i=bias,
im2col_step_i=im2col_step)
@staticmethod
def _npu_backward(ctx, grad_output):
input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \
ctx.saved_tensors
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True)
grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp)
return grad_input, grad_offset, grad_weight, \
None, None, None, None, None, None, None
@staticmethod
def forward(ctx,
input: Tensor,
@ -69,6 +87,7 @@ class DeformConv2dFunction(Function):
ctx.groups = groups
ctx.deform_groups = deform_groups
ctx.im2col_step = im2col_step
ctx.device = input.device.type
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
@ -79,6 +98,13 @@ class DeformConv2dFunction(Function):
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
if ctx.device == 'npu':
mask_shape, _ = torch.chunk(offset, 2, dim=1)
mask = torch.ones_like(mask_shape).to(input.device)
bias = input.new_empty(0)
output = ModulatedDeformConv2dFunction._npu_forward(
ctx, input, offset, mask, weight, bias)
return output
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(
@ -115,6 +141,8 @@ class DeformConv2dFunction(Function):
ctx, grad_output: Tensor
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
None, None, None, None, None, None]:
if ctx.device == 'npu':
return DeformConv2dFunction._npu_backward(ctx, grad_output)
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None

View File

@ -39,7 +39,7 @@ class ModulatedDeformConv2dFunction(Function):
split_num = deformable_group * 2 * kernel_h * kernel_w
sort_index = list(range(split_num))
sort_index_fp = (sort_index[1::2] + sort_index[::2])
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)}
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)}
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
sort_index_fp = torch.IntTensor(sort_index_fp)
sort_index_bp = torch.IntTensor(sort_index_bp)