mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support DCNv1 on Ascend device (#2480)
* update lately npu modification--DCNv1 update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * update lately npu modification--DCNv1 * check code * Add ops to EN/ZH documentspull/2528/head
parent
a9535377bf
commit
f76de9077b
|
@ -18,7 +18,7 @@ We implement common ops used in detection, segmentation, etc.
|
||||||
| ConvexIoU | | √ | | | |
|
| ConvexIoU | | √ | | | |
|
||||||
| CornerPool | | √ | | | |
|
| CornerPool | | √ | | | |
|
||||||
| Correlation | | √ | | | |
|
| Correlation | | √ | | | |
|
||||||
| Deformable Convolution v1/v2 | √ | √ | | | |
|
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||||
| Deformable RoIPool | | √ | √ | | √ |
|
| Deformable RoIPool | | √ | √ | | √ |
|
||||||
| DiffIoURotated | | √ | | | |
|
| DiffIoURotated | | √ | | | |
|
||||||
| DynamicScatter | | √ | | | |
|
| DynamicScatter | | √ | | | |
|
||||||
|
|
|
@ -18,7 +18,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
||||||
| ConvexIoU | | √ | | | |
|
| ConvexIoU | | √ | | | |
|
||||||
| CornerPool | | √ | | | |
|
| CornerPool | | √ | | | |
|
||||||
| Correlation | | √ | | | |
|
| Correlation | | √ | | | |
|
||||||
| Deformable Convolution v1/v2 | √ | √ | | | |
|
| Deformable Convolution v1/v2 | √ | √ | | | √ |
|
||||||
| Deformable RoIPool | | √ | √ | | √ |
|
| Deformable RoIPool | | √ | √ | | √ |
|
||||||
| DiffIoURotated | | √ | | | |
|
| DiffIoURotated | | √ | | | |
|
||||||
| DynamicScatter | | √ | | | |
|
| DynamicScatter | | √ | | | |
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch.nn.modules.utils import _pair, _single
|
||||||
from mmcv.utils import deprecated_api_warning
|
from mmcv.utils import deprecated_api_warning
|
||||||
from ..cnn import CONV_LAYERS
|
from ..cnn import CONV_LAYERS
|
||||||
from ..utils import ext_loader, print_log
|
from ..utils import ext_loader, print_log
|
||||||
|
from .modulated_deform_conv import ModulatedDeformConv2dFunction
|
||||||
|
|
||||||
ext_module = ext_loader.load_ext('_ext', [
|
ext_module = ext_loader.load_ext('_ext', [
|
||||||
'deform_conv_forward', 'deform_conv_backward_input',
|
'deform_conv_forward', 'deform_conv_backward_input',
|
||||||
|
@ -46,6 +47,23 @@ class DeformConv2dFunction(Function):
|
||||||
bias_i=bias,
|
bias_i=bias,
|
||||||
im2col_step_i=im2col_step)
|
im2col_step_i=im2col_step)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _npu_backward(ctx, grad_output):
|
||||||
|
input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \
|
||||||
|
ctx.saved_tensors
|
||||||
|
grad_input, grad_weight, grad_offset_all, grad_bias = \
|
||||||
|
torch.npu_deformable_conv2dbk(
|
||||||
|
input_tensor, grad_output, offset_out, weight, offset_all,
|
||||||
|
kernel_size=[weight.shape[3], weight.shape[2]],
|
||||||
|
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||||
|
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||||
|
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||||
|
groups=ctx.groups, deformable_groups=ctx.deform_groups,
|
||||||
|
modulated=True)
|
||||||
|
grad_offset = grad_offset_all.index_select(1, sort_index_for_npu_bp)
|
||||||
|
return grad_input, grad_offset, grad_weight, \
|
||||||
|
None, None, None, None, None, None, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx,
|
def forward(ctx,
|
||||||
input: Tensor,
|
input: Tensor,
|
||||||
|
@ -69,6 +87,7 @@ class DeformConv2dFunction(Function):
|
||||||
ctx.groups = groups
|
ctx.groups = groups
|
||||||
ctx.deform_groups = deform_groups
|
ctx.deform_groups = deform_groups
|
||||||
ctx.im2col_step = im2col_step
|
ctx.im2col_step = im2col_step
|
||||||
|
ctx.device = input.device.type
|
||||||
|
|
||||||
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
||||||
# amp won't cast the type of model (float32), but "offset" is cast
|
# amp won't cast the type of model (float32), but "offset" is cast
|
||||||
|
@ -79,6 +98,13 @@ class DeformConv2dFunction(Function):
|
||||||
# whatever the pytorch version is.
|
# whatever the pytorch version is.
|
||||||
input = input.type_as(offset)
|
input = input.type_as(offset)
|
||||||
weight = weight.type_as(input)
|
weight = weight.type_as(input)
|
||||||
|
if ctx.device == 'npu':
|
||||||
|
mask_shape, _ = torch.chunk(offset, 2, dim=1)
|
||||||
|
mask = torch.ones_like(mask_shape).to(input.device)
|
||||||
|
bias = input.new_empty(0)
|
||||||
|
output = ModulatedDeformConv2dFunction._npu_forward(
|
||||||
|
ctx, input, offset, mask, weight, bias)
|
||||||
|
return output
|
||||||
ctx.save_for_backward(input, offset, weight)
|
ctx.save_for_backward(input, offset, weight)
|
||||||
|
|
||||||
output = input.new_empty(
|
output = input.new_empty(
|
||||||
|
@ -115,6 +141,8 @@ class DeformConv2dFunction(Function):
|
||||||
ctx, grad_output: Tensor
|
ctx, grad_output: Tensor
|
||||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
|
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], None,
|
||||||
None, None, None, None, None, None]:
|
None, None, None, None, None, None]:
|
||||||
|
if ctx.device == 'npu':
|
||||||
|
return DeformConv2dFunction._npu_backward(ctx, grad_output)
|
||||||
input, offset, weight = ctx.saved_tensors
|
input, offset, weight = ctx.saved_tensors
|
||||||
|
|
||||||
grad_input = grad_offset = grad_weight = None
|
grad_input = grad_offset = grad_weight = None
|
||||||
|
|
|
@ -39,7 +39,7 @@ class ModulatedDeformConv2dFunction(Function):
|
||||||
split_num = deformable_group * 2 * kernel_h * kernel_w
|
split_num = deformable_group * 2 * kernel_h * kernel_w
|
||||||
sort_index = list(range(split_num))
|
sort_index = list(range(split_num))
|
||||||
sort_index_fp = (sort_index[1::2] + sort_index[::2])
|
sort_index_fp = (sort_index[1::2] + sort_index[::2])
|
||||||
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)}
|
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)}
|
||||||
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
|
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
|
||||||
sort_index_fp = torch.IntTensor(sort_index_fp)
|
sort_index_fp = torch.IntTensor(sort_index_fp)
|
||||||
sort_index_bp = torch.IntTensor(sort_index_bp)
|
sort_index_bp = torch.IntTensor(sort_index_bp)
|
||||||
|
|
Loading…
Reference in New Issue