mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support DCNv1 on Ascend device (#2480)
* update lately npu modification--DCNv1 update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * check code * Add ops to EN/ZH documentspull/2528/head
parent
a9535377bf
commit
f76de9077b
|
@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| ConvexIoU | | √ | | | |
|
||||
| CornerPool | | √ | | | |
|
||||
| Correlation | | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | | | |
|
||||
| DynamicScatter | | √ | | | |
|
||||
|
|
|
@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| ConvexIoU | | √ | | | |
|
||||
| CornerPool | | √ | | | |
|
||||
| Correlation | | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | |
|
||||
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||
| Deformable RoIPool | | √ | √ | | √ |
|
||||
| DiffIoURotated | | √ | | | |
|
||||
| DynamicScatter | | √ | | | |
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.nn.modules.utils import _pair, _single
|
|||
from mmcv.utils import deprecated_api_warning
|
||||
from ..cnn import CONV_LAYERS
|
||||
from ..utils import ext_loader, print_log
|
||||
from .modulated_deform_conv import ModulatedDeformConv2dFunction
|
||||
|
||||
ext_module = ext_loader.load_ext('_ext', [
|
||||
'deform_conv_forward', 'deform_conv_backward_input',
|
||||
|
@ -46,6 +47,23 @@ class DeformConv2dFunction(Function):
|
|||
bias_i=bias,
|
||||
im2col_step_i=im2col_step)
|
||||
|
||||
@staticmethod
|
||||
def _npu_backward(ctx, grad_output):
|
||||
input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \
|
||||
ctx.saved_tensors
|
||||
grad_input, grad_weight, grad_offset_all, grad_bias = \
|
||||
torch.npu_deformable_conv2dbk(
|
||||
input_tensor, grad_output, offset_out, weight, offset_all,
|
||||
kernel_size=[weight.shape[3], weight.shape[2]],
|
||||
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||
groups=ctx.groups, deformable_groups=ctx.deform_groups,
|
||||
modulated=True)
|
||||
grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp)
|
||||
return grad_input, grad_offset, grad_weight, \
|
||||
None, None, None, None, None, None, None
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
input: Tensor,
|
||||
|
@ -69,6 +87,7 @@ class DeformConv2dFunction(Function):
|
|||
ctx.groups = groups
|
||||
ctx.deform_groups = deform_groups
|
||||
ctx.im2col_step = im2col_step
|
||||
ctx.device = input.device.type
|
||||
|
||||
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
||||
# amp won't cast the type of model (float32), but "offset" is cast
|
||||
|
@ -79,6 +98,13 @@ class DeformConv2dFunction(Function):
|
|||
# whatever the pytorch version is.
|
||||
input = input.type_as(offset)
|
||||
weight = weight.type_as(input)
|
||||
if ctx.device == 'npu':
|
||||
mask_shape, _ = torch.chunk(offset, 2, dim=1)
|
||||
mask = torch.ones_like(mask_shape).to(input.device)
|
||||
bias = input.new_empty(0)
|
||||
output = ModulatedDeformConv2dFunction._npu_forward(
|
||||
ctx, input, offset, mask, weight, bias)
|
||||
return output
|
||||
ctx.save_for_backward(input, offset, weight)
|
||||
|
||||
output = input.new_empty(
|
||||
|
@ -115,6 +141,8 @@ class DeformConv2dFunction(Function):
|
|||
ctx, grad_output: Tensor
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
|
||||
None, None, None, None, None, None]:
|
||||
if ctx.device == 'npu':
|
||||
return DeformConv2dFunction._npu_backward(ctx, grad_output)
|
||||
input, offset, weight = ctx.saved_tensors
|
||||
|
||||
grad_input = grad_offset = grad_weight = None
|
||||
|
|
|
@ -39,7 +39,7 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
split_num = deformable_group * 2 * kernel_h * kernel_w
|
||||
sort_index = list(range(split_num))
|
||||
sort_index_fp = (sort_index[1::2] + sort_index[::2])
|
||||
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)}
|
||||
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)}
|
||||
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
|
||||
sort_index_fp = torch.IntTensor(sort_index_fp)
|
||||
sort_index_bp = torch.IntTensor(sort_index_bp)
|
||||
|
|
Loading…
Reference in New Issue