mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support masked_conv in Ascend device (#2387)
* init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * clean code * fix autocast bugs on npu (#2273) fix autocast bugs on npu (#2273) * code format * code format * code format * bug fix * pytorch_npu_helper.hpp clean code * Npu dev (#2306) * fix autocast bugs on npu * using scatter_kwargs in mmcv.device.scatter_gather * raise ImportError when compile with npu * add npu test case (#2307) * add npu test case * Update focal_loss.py * add comment * clean lint * update dtype assert * update DDP forward and comment * fix bug Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * sigmoidfocalloss npu adapter bug fix * BugFix: modify softmaxFocalLoss adapter * BugFix: remove equal sign in the code * add npu install information in README * add modulatedDeformConv npu adapter * init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * add modulatedDeformConv npu adapter * merge master branch 20221103 * Add masked_ Conv2d operator in NPU * add nms_npu * fix bug * fix code check * fix code check * fix code check * Masked_conv2d NPU * Masked_conv2d NPU * Masked_conv2d NPU * remove npu-install-info in README.md * annotate the clang-format in pre-commit-config-zh-ch.yaml * Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv Co-authored-by: wangjiangben <wangjiangben_hw@126.com> Co-authored-by: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zcc-zjut <zcczxy2019@163.com> Co-authored-by: wangxiaoxin_sherie <wangxiaoxin7@huawei.com> Co-authored-by: momo609 <963372609@qq.com>pull/2415/head
parent
89253699da
commit
de6b0021af
|
@ -1,10 +1,19 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor output, float gamma, float alpha) {
|
||||
at::Tensor target_y = at::reshape(target, input.sizes());
|
||||
int64_t n_class = input.size(1);
|
||||
at::Tensor target_y = at::ones_like(input);
|
||||
if (n_class == 1) {
|
||||
target_y = at::reshape(target, input.sizes());
|
||||
target_y = at::mul(target_y, -1.0);
|
||||
target_y = at::add(target_y, 1.0);
|
||||
} else {
|
||||
target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
|
||||
}
|
||||
target_y =
|
||||
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
|
||||
int64_t weight_size = weight.size(0);
|
||||
|
@ -14,6 +23,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
input.sizes());
|
||||
}
|
||||
OpCommand cmd;
|
||||
string reduction = "none";
|
||||
cmd.Name("SigmoidFocalLoss")
|
||||
.Input(input)
|
||||
.Input(target_y)
|
||||
|
@ -21,7 +31,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
.Output(output)
|
||||
.Attr("gamma", gamma)
|
||||
.Attr("alpha", alpha)
|
||||
.Attr("reduction", "none")
|
||||
.Attr("reduction", reduction)
|
||||
.Run();
|
||||
}
|
||||
|
||||
|
@ -31,7 +41,15 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
|
|||
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
|
||||
Tensor grad_input, float gamma,
|
||||
float alpha) {
|
||||
at::Tensor target_y = at::reshape(target, input.sizes());
|
||||
int64_t n_class = input.size(1);
|
||||
at::Tensor target_y = at::ones_like(input);
|
||||
if (n_class == 1) {
|
||||
target_y = at::reshape(target, input.sizes());
|
||||
} else {
|
||||
target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class);
|
||||
target_y = at::mul(target_y, -1.0);
|
||||
target_y = at::add(target_y, 1.0);
|
||||
}
|
||||
target_y =
|
||||
at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt);
|
||||
at::Tensor grad_up = at::ones_like(input);
|
||||
|
@ -42,6 +60,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
input.sizes());
|
||||
}
|
||||
OpCommand cmd;
|
||||
string reduction = "none";
|
||||
cmd.Name("SigmoidFocalLossGrad")
|
||||
.Input(input)
|
||||
.Input(target_y)
|
||||
|
@ -50,7 +69,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
.Output(grad_input)
|
||||
.Attr("gamma", gamma)
|
||||
.Attr("alpha", alpha)
|
||||
.Attr("reduction", "none")
|
||||
.Attr("reduction", reduction)
|
||||
.Run();
|
||||
}
|
||||
|
||||
|
@ -71,16 +90,25 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
|
||||
input.sizes());
|
||||
}
|
||||
at::Tensor op_output = at::ones_like(input);
|
||||
OpCommand cmd;
|
||||
string reduction = "none";
|
||||
cmd.Name("SoftmaxFocalLoss")
|
||||
.Input(input)
|
||||
.Input(target_y)
|
||||
.Input(weight_y)
|
||||
.Output(output)
|
||||
.Output(op_output)
|
||||
.Attr("gamma", gamma)
|
||||
.Attr("alpha", alpha)
|
||||
.Attr("reduction", "none")
|
||||
.Attr("reduction", reduction)
|
||||
.Run();
|
||||
int64_t n_batch = input.size(0);
|
||||
c10::SmallVector<int64_t, 2> offsets = {0, 0};
|
||||
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
|
||||
at::IntArrayRef offset = at::IntArrayRef(offsets);
|
||||
at::IntArrayRef size = at::IntArrayRef(sizes);
|
||||
at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size,
|
||||
output);
|
||||
}
|
||||
|
||||
void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
|
||||
|
@ -102,8 +130,8 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight,
|
||||
input.sizes());
|
||||
}
|
||||
|
||||
OpCommand cmd;
|
||||
string reduction = "none";
|
||||
cmd.Name("SoftmaxFocalLossGrad")
|
||||
.Input(input)
|
||||
.Input(target_y)
|
||||
|
@ -112,7 +140,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
|
|||
.Output(grad_input)
|
||||
.Attr("gamma", gamma)
|
||||
.Attr("alpha", alpha)
|
||||
.Attr("reduction", "none")
|
||||
.Attr("reduction", reduction)
|
||||
.Run();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
|
||||
at::Tensor boxed_offest = at_npu::native::OpPreparation::ApplyTensor(boxes);
|
||||
at::Tensor ones_tensor =
|
||||
at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1);
|
||||
at::add_out(boxed_offest, boxes, ones_tensor, offset);
|
||||
at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor(
|
||||
{}, boxes.options().dtype(at::kFloat), boxes)
|
||||
.fill_(iou_threshold);
|
||||
at::Tensor scores_threshold_y =
|
||||
at_npu::native::OpPreparation::ApplyTensor(
|
||||
{}, boxes.options().dtype(at::kFloat), boxes)
|
||||
.fill_(0);
|
||||
at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor(
|
||||
{}, boxes.options().dtype(at::kInt), boxes)
|
||||
.fill_(boxes.size(0));
|
||||
c10::SmallVector<int64_t, SIZE> outputsize = {boxes.size(0)};
|
||||
at::Tensor output = at_npu::native::OpPreparation::ApplyTensor(
|
||||
outputsize, boxes.options().dtype(at::kInt), boxes)
|
||||
.fill_(-1);
|
||||
OpCommand cmd;
|
||||
cmd.Name("NonMaxSuppressionV3")
|
||||
.Input(boxes)
|
||||
.Input(scores)
|
||||
.Input(max_outputsize_y)
|
||||
.Input(iou_threshold_y)
|
||||
.Input(scores_threshold_y)
|
||||
.Output(output)
|
||||
.Run();
|
||||
auto outputsizeBool = at::gt(output, -1);
|
||||
auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int);
|
||||
auto countLen = at::sum(outputsizeInt, at::ScalarType::Int);
|
||||
at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong());
|
||||
actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast(
|
||||
actual_output, at::kLong);
|
||||
return actual_output;
|
||||
}
|
||||
|
||||
Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset);
|
||||
|
||||
REGISTER_NPU_IMPL(nms_impl, nms_npu);
|
|
@ -45,6 +45,22 @@ class MaskedConv2dFunction(Function):
|
|||
'Stride could not only be 1 in masked_conv2d currently.')
|
||||
out_channel, in_channel, kernel_h, kernel_w = weight.size()
|
||||
|
||||
if features.device.type == 'npu':
|
||||
import torch_npu
|
||||
output = torch_npu.npu_conv2d(
|
||||
features,
|
||||
weight,
|
||||
bias,
|
||||
stride=(stride_h, stride_w),
|
||||
padding=(pad_h, pad_w),
|
||||
dilation=(1, 1),
|
||||
groups=1)
|
||||
if mask.size()[1:] != output.size()[2:]:
|
||||
raise ValueError(
|
||||
'The mask is inconsistent with the shape of output_conv.')
|
||||
output = output * mask
|
||||
return output
|
||||
|
||||
batch_size = features.size(0)
|
||||
out_h = int(
|
||||
math.floor(
|
||||
|
|
|
@ -34,6 +34,66 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
groups_i=groups,
|
||||
deform_groups_i=deform_groups)
|
||||
|
||||
@staticmethod
|
||||
def _calculate_sort_index(kernel_h, kernel_w, deformable_group):
|
||||
split_num = deformable_group * 2 * kernel_h * kernel_w
|
||||
sort_index = list(range(split_num))
|
||||
sort_index_fp = (sort_index[1::2] + sort_index[::2])
|
||||
sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)}
|
||||
sort_index_bp = [sort_index_bp_dict[i] for i in sort_index]
|
||||
sort_index_fp = torch.IntTensor(sort_index_fp)
|
||||
sort_index_bp = torch.IntTensor(sort_index_bp)
|
||||
sort_index_fp = sort_index_fp.npu()
|
||||
sort_index_bp = sort_index_bp.npu()
|
||||
return sort_index_fp, sort_index_bp
|
||||
|
||||
@staticmethod
|
||||
def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
|
||||
_, _, kernel_h, kernel_w = weight.shape
|
||||
conv2d_bias = bias if len(bias) > 0 else None
|
||||
sort_index_fp, sort_index_bp = \
|
||||
ModulatedDeformConv2dFunction._calculate_sort_index(
|
||||
kernel_w, kernel_h, ctx.deform_groups)
|
||||
select_offset = offset.index_select(1, sort_index_fp)
|
||||
offset_all = torch.cat([select_offset, mask], dim=1)
|
||||
output, offset_out = torch.npu_deformable_conv2d(
|
||||
input_tensor,
|
||||
weight,
|
||||
offset_all,
|
||||
conv2d_bias,
|
||||
kernel_size=[kernel_w, kernel_h],
|
||||
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||
groups=ctx.groups,
|
||||
deformable_groups=ctx.deform_groups,
|
||||
modulated=True)
|
||||
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
||||
or input_tensor.requires_grad:
|
||||
ctx.save_for_backward(input_tensor, weight, offset_out, offset_all,
|
||||
sort_index_bp)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _npu_backward(ctx, grad_output):
|
||||
input_tensor, weight, offset_out, offset_all, sort_index_bp = \
|
||||
ctx.saved_tensors
|
||||
grad_input, grad_weight, grad_offset_all, grad_bias = \
|
||||
torch.npu_deformable_conv2dbk(
|
||||
input_tensor, grad_output, offset_out, weight, offset_all,
|
||||
kernel_size=[weight.shape[3], weight.shape[2]],
|
||||
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
|
||||
padding=[1, 1, ctx.padding[0], ctx.padding[1]],
|
||||
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
|
||||
groups=ctx.groups, deformable_groups=ctx.deform_groups,
|
||||
modulated=True)
|
||||
grad_offset = grad_offset_all.index_select(1, sort_index_bp)
|
||||
grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :]
|
||||
if not ctx.with_bias:
|
||||
grad_bias = None
|
||||
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
input: torch.Tensor,
|
||||
|
@ -56,6 +116,7 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
ctx.groups = groups
|
||||
ctx.deform_groups = deform_groups
|
||||
ctx.with_bias = bias is not None
|
||||
ctx.device = input.device.type
|
||||
if not ctx.with_bias:
|
||||
bias = input.new_empty(0) # fake tensor
|
||||
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
||||
|
@ -69,6 +130,10 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
weight = weight.type_as(input)
|
||||
bias = bias.type_as(input) # type: ignore
|
||||
mask = mask.type_as(input)
|
||||
if ctx.device == 'npu':
|
||||
output = ModulatedDeformConv2dFunction._npu_forward(
|
||||
ctx, input, offset, mask, weight, bias)
|
||||
return output
|
||||
ctx.save_for_backward(input, offset, mask, weight, bias)
|
||||
output = input.new_empty(
|
||||
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
|
||||
|
@ -98,6 +163,9 @@ class ModulatedDeformConv2dFunction(Function):
|
|||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output: torch.Tensor) -> tuple:
|
||||
if ctx.device == 'npu':
|
||||
return ModulatedDeformConv2dFunction._npu_backward(
|
||||
ctx, grad_output)
|
||||
input, offset, mask, weight, bias = ctx.saved_tensors
|
||||
grad_input = torch.zeros_like(input)
|
||||
grad_offset = torch.zeros_like(offset)
|
||||
|
|
|
@ -23,4 +23,4 @@ default_section = THIRDPARTY
|
|||
# than "BA"
|
||||
[codespell]
|
||||
quiet-level = 3
|
||||
ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping
|
||||
ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping,cann
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class TestMaskedConv2d:
|
||||
|
@ -16,7 +16,11 @@ class TestMaskedConv2d:
|
|||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_masked_conv2d_all_close(self, device):
|
||||
from mmcv.ops import MaskedConv2d
|
||||
|
|
Loading…
Reference in New Issue