[Feature] Support masked_conv in Ascend device (#2387)

* init npu

* add npu extension and focal loss adapter

* clean code

* clean code

* clean code

* clean code

* fix autocast bugs on npu (#2273)

fix autocast bugs on npu (#2273)

* code format

* code format

* code format

* bug fix

* pytorch_npu_helper.hpp clean code

* Npu dev (#2306)

* fix autocast bugs on npu
* using scatter_kwargs in mmcv.device.scatter_gather

* raise ImportError when compile with npu

* add npu test case (#2307)

* add npu test case

* Update focal_loss.py

* add comment

* clean lint

* update dtype assert

* update DDP forward and comment

* fix bug

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* sigmoidfocalloss npu adapter bug fix

* BugFix: modify softmaxFocalLoss adapter

* BugFix: remove equal sign in the code

* add npu install information in README

* add modulatedDeformConv npu adapter

* init npu

* add npu extension and focal loss adapter

* clean code

* clean code

* clean code

* add modulatedDeformConv npu adapter

* merge master branch 20221103

* Add masked_ Conv2d operator in NPU

* add nms_npu

* fix bug

* fix code check

* fix code check

* fix code check

* Masked_conv2d NPU

* Masked_conv2d NPU

* Masked_conv2d NPU

* remove npu-install-info in README.md

* annotate the clang-format in pre-commit-config-zh-ch.yaml

* Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv

Co-authored-by: wangjiangben <wangjiangben_hw@126.com>
Co-authored-by: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: zcc-zjut <zcczxy2019@163.com>
Co-authored-by: wangxiaoxin_sherie <wangxiaoxin7@huawei.com>
Co-authored-by: momo609 <963372609@qq.com>
pull/2415/head
ckirchhoff 2022-11-15 10:23:17 +08:00 committed by GitHub
parent 89253699da
commit de6b0021af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 172 additions and 11 deletions

View File

@ -1,10 +1,19 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor target_y = at::reshape(target, input.sizes());
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
} else {
target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
}
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
int64_t weight_size = weight.size(0);
@ -14,6 +23,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
input.sizes());
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(target_y)
@ -21,7 +31,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
.Output(output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Attr("reduction", reduction)
.Run();
}
@ -31,7 +41,15 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha) {
at::Tensor target_y = at::reshape(target, input.sizes());
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
} else {
target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
}
target_y =
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
at::Tensor grad_up = at::ones_like(input);
@ -42,6 +60,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
input.sizes());
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(target_y)
@ -50,7 +69,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Attr("reduction", reduction)
.Run();
}
@ -71,16 +90,25 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
at::Tensor op_output = at::ones_like(input);
OpCommand cmd;
string reduction = "none";
cmd.Name("SoftmaxFocalLoss")
.Input(input)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Output(op_output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Attr("reduction", reduction)
.Run();
int64_t n_batch = input.size(0);
c10::SmallVector<int64_t, 2> offsets = {0, 0};
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
at::IntArrayRef offset = at::IntArrayRef(offsets);
at::IntArrayRef size = at::IntArrayRef(sizes);
at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size,
output);
}
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
@ -102,8 +130,8 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
input.sizes());
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
.Input(target_y)
@ -112,7 +140,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", "none")
.Attr("reduction", reduction)
.Run();
}

View File

@ -0,0 +1,45 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
at::Tensor boxed_offest = at_npu::native::OpPreparation::ApplyTensor(boxes);
at::Tensor ones_tensor =
at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1);
at::add_out(boxed_offest, boxes, ones_tensor, offset);
at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kFloat), boxes)
.fill_(iou_threshold);
at::Tensor scores_threshold_y =
at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kFloat), boxes)
.fill_(0);
at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kInt), boxes)
.fill_(boxes.size(0));
c10::SmallVector<int64_t, SIZE> outputsize = {boxes.size(0)};
at::Tensor output = at_npu::native::OpPreparation::ApplyTensor(
outputsize, boxes.options().dtype(at::kInt), boxes)
.fill_(-1);
OpCommand cmd;
cmd.Name("NonMaxSuppressionV3")
.Input(boxes)
.Input(scores)
.Input(max_outputsize_y)
.Input(iou_threshold_y)
.Input(scores_threshold_y)
.Output(output)
.Run();
auto outputsizeBool = at::gt(output, -1);
auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int);
auto countLen = at::sum(outputsizeInt, at::ScalarType::Int);
at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong());
actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast(
actual_output, at::kLong);
return actual_output;
}
Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset);
REGISTER_NPU_IMPL(nms_impl, nms_npu);

View File

@ -45,6 +45,22 @@ class MaskedConv2dFunction(Function):
'Stride could not only be 1 in masked_conv2d currently.')
out_channel, in_channel, kernel_h, kernel_w = weight.size()
if features.device.type == 'npu':
import torch_npu
output = torch_npu.npu_conv2d(
features,
weight,
bias,
stride=(stride_h, stride_w),
padding=(pad_h, pad_w),
dilation=(1, 1),
groups=1)
if mask.size()[1:] != output.size()[2:]:
raise ValueError(
'The mask is inconsistent with the shape of output_conv.')
output = output * mask
return output
batch_size = features.size(0)
out_h = int(
math.floor(

View File

@ -34,6 +34,66 @@ class ModulatedDeformConv2dFunction(Function):
groups_i=groups,
deform_groups_i=deform_groups)
@staticmethod
def _calculate_sort_index(kernel_h, kernel_w, deformable_group):
split_num = deformable_group * 2 * kernel_h * kernel_w
sort_index = list(range(split_num))
sort_index_fp = (sort_index[1::2] + sort_index[::2])
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)}
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
sort_index_fp = torch.IntTensor(sort_index_fp)
sort_index_bp = torch.IntTensor(sort_index_bp)
sort_index_fp = sort_index_fp.npu()
sort_index_bp = sort_index_bp.npu()
return sort_index_fp, sort_index_bp
@staticmethod
def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
_, _, kernel_h, kernel_w = weight.shape
conv2d_bias = bias if len(bias) > 0 else None
sort_index_fp, sort_index_bp = \
ModulatedDeformConv2dFunction._calculate_sort_index(
kernel_w, kernel_h, ctx.deform_groups)
select_offset = offset.index_select(1, sort_index_fp)
offset_all = torch.cat([select_offset, mask], dim=1)
output, offset_out = torch.npu_deformable_conv2d(
input_tensor,
weight,
offset_all,
conv2d_bias,
kernel_size=[kernel_w, kernel_h],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups,
deformable_groups=ctx.deform_groups,
modulated=True)
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input_tensor.requires_grad:
ctx.save_for_backward(input_tensor, weight, offset_out, offset_all,
sort_index_bp)
return output
@staticmethod
def _npu_backward(ctx, grad_output):
input_tensor, weight, offset_out, offset_all, sort_index_bp = \
ctx.saved_tensors
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True)
grad_offset = grad_offset_all.index_select(1, sort_index_bp)
grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :]
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None, None, None, None)
@staticmethod
def forward(ctx,
input: torch.Tensor,
@ -56,6 +116,7 @@ class ModulatedDeformConv2dFunction(Function):
ctx.groups = groups
ctx.deform_groups = deform_groups
ctx.with_bias = bias is not None
ctx.device = input.device.type
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
@ -69,6 +130,10 @@ class ModulatedDeformConv2dFunction(Function):
weight = weight.type_as(input)
bias = bias.type_as(input) # type: ignore
mask = mask.type_as(input)
if ctx.device == 'npu':
output = ModulatedDeformConv2dFunction._npu_forward(
ctx, input, offset, mask, weight, bias)
return output
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
@ -98,6 +163,9 @@ class ModulatedDeformConv2dFunction(Function):
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
if ctx.device == 'npu':
return ModulatedDeformConv2dFunction._npu_backward(
ctx, grad_output)
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)

View File

@ -23,4 +23,4 @@ default_section = THIRDPARTY
# than "BA"
[codespell]
quiet-level = 3
ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping
ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping,cann

View File

@ -3,7 +3,7 @@ import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
class TestMaskedConv2d:
@ -16,7 +16,11 @@ class TestMaskedConv2d:
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_masked_conv2d_all_close(self, device):
from mmcv.ops import MaskedConv2d