From 4c51afce2a3e26de31db10174906a2f6db5cc9e0 Mon Sep 17 00:00:00 2001 From: jayggh <35617559+jayggh@users.noreply.github.com> Date: Tue, 13 Dec 2022 11:13:46 +0800 Subject: [PATCH] [Feature] Add FusedBiasLeakyRelu npu adapter (#2474) * init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * clean code * fix autocast bugs on npu (#2273) fix autocast bugs on npu (#2273) * code format * code format * code format * bug fix * pytorch_npu_helper.hpp clean code * Npu dev (#2306) * fix autocast bugs on npu * using scatter_kwargs in mmcv.device.scatter_gather * raise ImportError when compile with npu * add npu test case (#2307) * add npu test case * Update focal_loss.py * add comment * clean lint * update dtype assert * update DDP forward and comment * fix bug Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * sigmoidfocalloss npu adapter bug fix * BugFix: modify softmaxFocalLoss adapter * BugFix: remove equal sign in the code * add npu install information in README * add modulatedDeformConv npu adapter * init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * add modulatedDeformConv npu adapter * merge master branch 20221103 * Add masked_ Conv2d operator in NPU * add nms_npu * fix bug * fix code check * fix code check * fix code check * Masked_conv2d NPU * Masked_conv2d NPU * Masked_conv2d NPU * remove npu-install-info in README.md * annotate the clang-format in pre-commit-config-zh-ch.yaml * Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv * Create fused_bias_leakyrelu_npu.cpp Add NPU adapter for fused_bias_leaky_relu operator * Update fused_bias_leakyrelu_npu.cpp * Update fused_bias_leakyrelu_npu.cpp * Update ops.md * Update ops.md * Update fused_bias_leakyrelu_npu.cpp * Update fused_bias_leakyrelu_npu.cpp * Update test_fused_bias_leakyrelu.py * Update fused_bias_leakyrelu.py * Update test_fused_bias_leakyrelu.py * Update fused_bias_leakyrelu.py * Update test_fused_bias_leakyrelu.py * Update ops.md * amend for CI * bugfix * amend ops.md * Update test_fused_bias_leakyrelu.py * clean code * bugfix * clean code * Update fused_bias_leakyrelu_npu.cpp * Update fused_bias_leakyrelu_npu.cpp Co-authored-by: wangjiangben Co-authored-by: ckirchhoff <515629648@qq.com> Co-authored-by: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zcc-zjut Co-authored-by: wangxiaoxin_sherie Co-authored-by: momo609 <963372609@qq.com> --- docs/en/understand_mmcv/ops.md | 2 +- docs/zh_cn/understand_mmcv/ops.md | 2 +- .../pytorch/npu/fused_bias_leakyrelu_npu.cpp | 54 ++++++++++++++++++ mmcv/ops/fused_bias_leakyrelu.py | 2 +- tests/test_ops/test_fused_bias_leakyrelu.py | 57 +++++++++++++------ 5 files changed, 98 insertions(+), 19 deletions(-) create mode 100644 mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 058309e25..cfc70e773 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -24,7 +24,7 @@ We implement common ops used in detection, segmentation, etc. | DynamicScatter | | √ | | | | | FurthestPointSample | | √ | | | | | FurthestPointSampleWithDist | | √ | | | | -| FusedBiasLeakyrelu | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | | GatherPoints | | √ | | | | | GroupPoints | | √ | | | | | Iou3d | | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 51cc3eec4..7fbba8768 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -24,7 +24,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | DynamicScatter | | √ | | | | | FurthestPointSample | | √ | | | | | FurthestPointSampleWithDist | | √ | | | | -| FusedBiasLeakyrelu | | √ | | | | +| FusedBiasLeakyrelu | | √ | | | √ | | GatherPoints | | √ | | | | | GroupPoints | | √ | | | | | Iou3d | | √ | √ | | | diff --git a/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp new file mode 100644 index 000000000..cd052b586 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/fused_bias_leakyrelu_npu.cpp @@ -0,0 +1,54 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, + float alpha, float scale); + +Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias, + const Tensor &refer, int act, int grad, + float alpha, float scale) { + at::Tensor py = at::empty_like(input); + // forward + if (grad == 0) { + auto input_size = input.sizes(); + int input_length = input_size.size(); + c10::SmallVector input_size_tmp; + input_size_tmp = array_to_small_vector(input_size); + if (input_length > 1) { + for (int i = 0; i < input_length; i++) { + if (i != 1) { + input_size_tmp[i] = 1; + } + } + } + at::Tensor bias_tmp = at::reshape(bias, input_size_tmp); + at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast( + bias_tmp, input.sizes()); + OpCommand cmd; + cmd.Name("FusedBiasLeakyRelu") + .Input(input) + .Input(bias_) + .Output(py) + .Attr("scale", scale) + .Attr("negative_slope", alpha) + .Run(); + } + + // backward + if (grad == 1) { + OpCommand cmd; + cmd.Name("FusedBiasLeakyReluGrad") + .Input(input) + .Input(refer) + .Output(py) + .Attr("scale", scale) + .Attr("negative_slope", alpha) + .Run(); + } + return py; +} + +REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, fused_bias_leakyrelu_npu); diff --git a/mmcv/ops/fused_bias_leakyrelu.py b/mmcv/ops/fused_bias_leakyrelu.py index e23617fb3..fe17d2db7 100644 --- a/mmcv/ops/fused_bias_leakyrelu.py +++ b/mmcv/ops/fused_bias_leakyrelu.py @@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor, torch.Tensor: Feature map after non-linear activation. """ - if not input.is_cuda: + if not input.is_cuda and input.device.type != 'npu': return bias_leakyrelu_ref(input, bias, negative_slope, scale) return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype), diff --git a/tests/test_ops/test_fused_bias_leakyrelu.py b/tests/test_ops/test_fused_bias_leakyrelu.py index 47357860d..e6f6fb9f7 100644 --- a/tests/test_ops/test_fused_bias_leakyrelu.py +++ b/tests/test_ops/test_fused_bias_leakyrelu.py @@ -2,6 +2,8 @@ import pytest import torch +from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE + _USING_PARROTS = True try: from parrots.autograd import gradcheck @@ -14,36 +16,59 @@ class TestFusedBiasLeakyReLU: @classmethod def setup_class(cls): - if not torch.cuda.is_available(): + if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE: return - cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda() - cls.bias = torch.zeros(2, requires_grad=True).cuda() + if IS_CUDA_AVAILABLE: + cls.input_tensor = torch.randn((2, 2, 2, 2), + requires_grad=True).cuda() + cls.bias = torch.zeros(2, requires_grad=True).cuda() + elif IS_NPU_AVAILABLE: + cls.input_tensor = torch.randn((2, 2, 2, 2), + requires_grad=True).npu() + cls.bias = torch.zeros(2, requires_grad=True).npu() - @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') - def test_gradient(self): + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) + ]) + def test_gradient(self, device): from mmcv.ops import FusedBiasLeakyReLU if _USING_PARROTS: - gradcheck( - FusedBiasLeakyReLU(2).cuda(), - self.input_tensor, - delta=1e-4, - pt_atol=1e-3) + if IS_CUDA_AVAILABLE: + gradcheck( + FusedBiasLeakyReLU(2).cuda(), + self.input_tensor, + delta=1e-4, + pt_atol=1e-3) else: gradcheck( - FusedBiasLeakyReLU(2).cuda(), + FusedBiasLeakyReLU(2).to(device), self.input_tensor, eps=1e-4, atol=1e-3) - @pytest.mark.skipif( - not torch.cuda.is_available() or _USING_PARROTS, - reason='requires cuda') - def test_gradgradient(self): + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) + ]) + def test_gradgradient(self, device): from mmcv.ops import FusedBiasLeakyReLU gradgradcheck( - FusedBiasLeakyReLU(2).cuda(), + FusedBiasLeakyReLU(2).to(device), self.input_tensor, eps=1e-4, atol=1e-3)