From de6b0021afb7a6735667d514008945922eca47b6 Mon Sep 17 00:00:00 2001 From: ckirchhoff <515629648@qq.com> Date: Tue, 15 Nov 2022 10:23:17 +0800 Subject: [PATCH] [Feature] Support masked_conv in Ascend device (#2387) * init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * clean code * fix autocast bugs on npu (#2273) fix autocast bugs on npu (#2273) * code format * code format * code format * bug fix * pytorch_npu_helper.hpp clean code * Npu dev (#2306) * fix autocast bugs on npu * using scatter_kwargs in mmcv.device.scatter_gather * raise ImportError when compile with npu * add npu test case (#2307) * add npu test case * Update focal_loss.py * add comment * clean lint * update dtype assert * update DDP forward and comment * fix bug Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * sigmoidfocalloss npu adapter bug fix * BugFix: modify softmaxFocalLoss adapter * BugFix: remove equal sign in the code * add npu install information in README * add modulatedDeformConv npu adapter * init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * add modulatedDeformConv npu adapter * merge master branch 20221103 * Add masked_ Conv2d operator in NPU * add nms_npu * fix bug * fix code check * fix code check * fix code check * Masked_conv2d NPU * Masked_conv2d NPU * Masked_conv2d NPU * remove npu-install-info in README.md * annotate the clang-format in pre-commit-config-zh-ch.yaml * Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv Co-authored-by: wangjiangben Co-authored-by: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zcc-zjut Co-authored-by: wangxiaoxin_sherie Co-authored-by: momo609 <963372609@qq.com> --- mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp | 44 ++++++++++--- mmcv/ops/csrc/pytorch/npu/nms_npu.cpp | 45 +++++++++++++ mmcv/ops/masked_conv.py | 16 +++++ mmcv/ops/modulated_deform_conv.py | 68 ++++++++++++++++++++ setup.cfg | 2 +- tests/test_ops/test_masked_conv2d.py | 8 ++- 6 files changed, 172 insertions(+), 11 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/nms_npu.cpp diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index bd8282468..c949bf953 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -1,10 +1,19 @@ #include "pytorch_npu_helper.hpp" using namespace NPU_NAME_SPACE; +using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if (n_class == 1) { + target_y = at::reshape(target, input.sizes()); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); int64_t weight_size = weight.size(0); @@ -14,6 +23,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLoss") .Input(input) .Input(target_y) @@ -21,7 +31,7 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Output(output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -31,7 +41,15 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { - at::Tensor target_y = at::reshape(target, input.sizes()); + int64_t n_class = input.size(1); + at::Tensor target_y = at::ones_like(input); + if (n_class == 1) { + target_y = at::reshape(target, input.sizes()); + } else { + target_y = at_npu::native::NPUNativeFunctions::one_hot(target, n_class); + target_y = at::mul(target_y, -1.0); + target_y = at::add(target_y, 1.0); + } target_y = at_npu::native::NPUNativeFunctions::npu_dtype_cast(target_y, at::kInt); at::Tensor grad_up = at::ones_like(input); @@ -42,6 +60,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, input.sizes()); } OpCommand cmd; + string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") .Input(input) .Input(target_y) @@ -50,7 +69,7 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } @@ -71,16 +90,25 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } + at::Tensor op_output = at::ones_like(input); OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLoss") .Input(input) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(op_output) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); + int64_t n_batch = input.size(0); + c10::SmallVector offsets = {0, 0}; + c10::SmallVector sizes = {n_batch, 1}; + at::IntArrayRef offset = at::IntArrayRef(offsets); + at::IntArrayRef size = at::IntArrayRef(sizes); + at_npu::native::NPUNativeFunctions::npu_slice_out(op_output, offset, size, + output); } void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -102,8 +130,8 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, weight_y = at_npu::native::NPUNativeFunctions::npu_broadcast(weight, input.sizes()); } - OpCommand cmd; + string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") .Input(input) .Input(target_y) @@ -112,7 +140,7 @@ void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, .Output(grad_input) .Attr("gamma", gamma) .Attr("alpha", alpha) - .Attr("reduction", "none") + .Attr("reduction", reduction) .Run(); } diff --git a/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp new file mode 100644 index 000000000..2f86893ea --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/nms_npu.cpp @@ -0,0 +1,45 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { + at::Tensor boxed_offest = at_npu::native::OpPreparation::ApplyTensor(boxes); + at::Tensor ones_tensor = + at_npu::native::OpPreparation::ApplyTensor(boxes).fill_(1); + at::add_out(boxed_offest, boxes, ones_tensor, offset); + at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kFloat), boxes) + .fill_(iou_threshold); + at::Tensor scores_threshold_y = + at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kFloat), boxes) + .fill_(0); + at::Tensor max_outputsize_y = at_npu::native::OpPreparation::ApplyTensor( + {}, boxes.options().dtype(at::kInt), boxes) + .fill_(boxes.size(0)); + c10::SmallVector outputsize = {boxes.size(0)}; + at::Tensor output = at_npu::native::OpPreparation::ApplyTensor( + outputsize, boxes.options().dtype(at::kInt), boxes) + .fill_(-1); + OpCommand cmd; + cmd.Name("NonMaxSuppressionV3") + .Input(boxes) + .Input(scores) + .Input(max_outputsize_y) + .Input(iou_threshold_y) + .Input(scores_threshold_y) + .Output(output) + .Run(); + auto outputsizeBool = at::gt(output, -1); + auto outputsizeInt = outputsizeBool.to(at::ScalarType::Int); + auto countLen = at::sum(outputsizeInt, at::ScalarType::Int); + at::Tensor actual_output = output.slice(0, 0, countLen.item().toLong()); + actual_output = at_npu::native::NPUNativeFunctions::npu_dtype_cast( + actual_output, at::kLong); + return actual_output; +} + +Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset); + +REGISTER_NPU_IMPL(nms_impl, nms_npu); diff --git a/mmcv/ops/masked_conv.py b/mmcv/ops/masked_conv.py index 6706ae9b2..e00a98b99 100644 --- a/mmcv/ops/masked_conv.py +++ b/mmcv/ops/masked_conv.py @@ -45,6 +45,22 @@ class MaskedConv2dFunction(Function): 'Stride could not only be 1 in masked_conv2d currently.') out_channel, in_channel, kernel_h, kernel_w = weight.size() + if features.device.type == 'npu': + import torch_npu + output = torch_npu.npu_conv2d( + features, + weight, + bias, + stride=(stride_h, stride_w), + padding=(pad_h, pad_w), + dilation=(1, 1), + groups=1) + if mask.size()[1:] != output.size()[2:]: + raise ValueError( + 'The mask is inconsistent with the shape of output_conv.') + output = output * mask + return output + batch_size = features.size(0) out_h = int( math.floor( diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index df5095f2e..7970d5323 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -34,6 +34,66 @@ class ModulatedDeformConv2dFunction(Function): groups_i=groups, deform_groups_i=deform_groups) + @staticmethod + def _calculate_sort_index(kernel_h, kernel_w, deformable_group): + split_num = deformable_group * 2 * kernel_h * kernel_w + sort_index = list(range(split_num)) + sort_index_fp = (sort_index[1::2] + sort_index[::2]) + sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index)} + sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] + sort_index_fp = torch.IntTensor(sort_index_fp) + sort_index_bp = torch.IntTensor(sort_index_bp) + sort_index_fp = sort_index_fp.npu() + sort_index_bp = sort_index_bp.npu() + return sort_index_fp, sort_index_bp + + @staticmethod + def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): + _, _, kernel_h, kernel_w = weight.shape + conv2d_bias = bias if len(bias) > 0 else None + sort_index_fp, sort_index_bp = \ + ModulatedDeformConv2dFunction._calculate_sort_index( + kernel_w, kernel_h, ctx.deform_groups) + select_offset = offset.index_select(1, sort_index_fp) + offset_all = torch.cat([select_offset, mask], dim=1) + output, offset_out = torch.npu_deformable_conv2d( + input_tensor, + weight, + offset_all, + conv2d_bias, + kernel_size=[kernel_w, kernel_h], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, + deformable_groups=ctx.deform_groups, + modulated=True) + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input_tensor.requires_grad: + ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, + sort_index_bp) + return output + + @staticmethod + def _npu_backward(ctx, grad_output): + input_tensor, weight, offset_out, offset_all, sort_index_bp = \ + ctx.saved_tensors + grad_input, grad_weight, grad_offset_all, grad_bias = \ + torch.npu_deformable_conv2dbk( + input_tensor, grad_output, offset_out, weight, offset_all, + kernel_size=[weight.shape[3], weight.shape[2]], + stride=[1, 1, ctx.stride[0], ctx.stride[1]], + padding=[1, 1, ctx.padding[0], ctx.padding[1]], + dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], + groups=ctx.groups, deformable_groups=ctx.deform_groups, + modulated=True) + grad_offset = grad_offset_all.index_select(1, sort_index_bp) + grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] + if not ctx.with_bias: + grad_bias = None + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, + None, None, None, None, None, None, None, None) + @staticmethod def forward(ctx, input: torch.Tensor, @@ -56,6 +116,7 @@ class ModulatedDeformConv2dFunction(Function): ctx.groups = groups ctx.deform_groups = deform_groups ctx.with_bias = bias is not None + ctx.device = input.device.type if not ctx.with_bias: bias = input.new_empty(0) # fake tensor # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; @@ -69,6 +130,10 @@ class ModulatedDeformConv2dFunction(Function): weight = weight.type_as(input) bias = bias.type_as(input) # type: ignore mask = mask.type_as(input) + if ctx.device == 'npu': + output = ModulatedDeformConv2dFunction._npu_forward( + ctx, input, offset, mask, weight, bias) + return output ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty( ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) @@ -98,6 +163,9 @@ class ModulatedDeformConv2dFunction(Function): @staticmethod @once_differentiable def backward(ctx, grad_output: torch.Tensor) -> tuple: + if ctx.device == 'npu': + return ModulatedDeformConv2dFunction._npu_backward( + ctx, grad_output) input, offset, mask, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) diff --git a/setup.cfg b/setup.cfg index 9609b41f4..6feae1f9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,4 +23,4 @@ default_section = THIRDPARTY # than "BA" [codespell] quiet-level = 3 -ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping +ignore-words-list = inout,hist,ba,inh,ro,tne,warmup,warpped,warpping,cann diff --git a/tests/test_ops/test_masked_conv2d.py b/tests/test_ops/test_masked_conv2d.py index a292f6a4f..072b2f7f6 100644 --- a/tests/test_ops/test_masked_conv2d.py +++ b/tests/test_ops/test_masked_conv2d.py @@ -3,7 +3,7 @@ import numpy as np import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class TestMaskedConv2d: @@ -16,7 +16,11 @@ class TestMaskedConv2d: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_masked_conv2d_all_close(self, device): from mmcv.ops import MaskedConv2d