[Feature] Add FusedBiasLeakyRelu npu adapter (#2474)

* init npu

* add npu extension and focal loss adapter

* clean code

* clean code

* clean code

* clean code

* fix autocast bugs on npu (#2273)

fix autocast bugs on npu (#2273)

* code format

* code format

* code format

* bug fix

* pytorch_npu_helper.hpp clean code

* Npu dev (#2306)

* fix autocast bugs on npu
* using scatter_kwargs in mmcv.device.scatter_gather

* raise ImportError when compile with npu

* add npu test case (#2307)

* add npu test case

* Update focal_loss.py

* add comment

* clean lint

* update dtype assert

* update DDP forward and comment

* fix bug

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* sigmoidfocalloss npu adapter bug fix

* BugFix: modify softmaxFocalLoss adapter

* BugFix: remove equal sign in the code

* add npu install information in README

* add modulatedDeformConv npu adapter

* init npu

* add npu extension and focal loss adapter

* clean code

* clean code

* clean code

* add modulatedDeformConv npu adapter

* merge master branch 20221103

* Add masked_ Conv2d operator in NPU

* add nms_npu

* fix bug

* fix code check

* fix code check

* fix code check

* Masked_conv2d NPU

* Masked_conv2d NPU

* Masked_conv2d NPU

* remove npu-install-info in README.md

* annotate the clang-format in pre-commit-config-zh-ch.yaml

* Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv

* Create fused_bias_leakyrelu_npu.cpp

Add NPU adapter for fused_bias_leaky_relu operator

* Update fused_bias_leakyrelu_npu.cpp

* Update fused_bias_leakyrelu_npu.cpp

* Update ops.md

* Update ops.md

* Update fused_bias_leakyrelu_npu.cpp

* Update fused_bias_leakyrelu_npu.cpp

* Update test_fused_bias_leakyrelu.py

* Update fused_bias_leakyrelu.py

* Update test_fused_bias_leakyrelu.py

* Update fused_bias_leakyrelu.py

* Update test_fused_bias_leakyrelu.py

* Update ops.md

* amend for CI

* bugfix

* amend ops.md

* Update test_fused_bias_leakyrelu.py

* clean code

* bugfix

* clean code

* Update fused_bias_leakyrelu_npu.cpp

* Update fused_bias_leakyrelu_npu.cpp

Co-authored-by: wangjiangben <wangjiangben_hw@126.com>
Co-authored-by: ckirchhoff <515629648@qq.com>
Co-authored-by: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: zcc-zjut <zcczxy2019@163.com>
Co-authored-by: wangxiaoxin_sherie <wangxiaoxin7@huawei.com>
Co-authored-by: momo609 <963372609@qq.com>
pull/2477/head
jayggh 2022-12-13 11:13:46 +08:00 committed by GitHub
parent 4c4ac6b43f
commit 4c51afce2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 98 additions and 19 deletions

View File

@ -24,7 +24,7 @@ We implement common ops used in detection, segmentation, etc.
| DynamicScatter | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | |
| GatherPoints | | √ | | | |
| GroupPoints | | √ | | | |
| Iou3d | | √ | √ | | |

View File

@ -24,7 +24,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| DynamicScatter | | √ | | | |
| FurthestPointSample | | √ | | | |
| FurthestPointSampleWithDist | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | |
| FusedBiasLeakyrelu | | √ | | | |
| GatherPoints | | √ | | | |
| GroupPoints | | √ | | | |
| Iou3d | | √ | √ | | |

View File

@ -0,0 +1,54 @@
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad,
float alpha, float scale);
Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias,
const Tensor &refer, int act, int grad,
float alpha, float scale) {
at::Tensor py = at::empty_like(input);
// forward
if (grad == 0) {
auto input_size = input.sizes();
int input_length = input_size.size();
c10::SmallVector<int64_t, SIZE> input_size_tmp;
input_size_tmp = array_to_small_vector(input_size);
if (input_length > 1) {
for (int i = 0; i < input_length; i++) {
if (i != 1) {
input_size_tmp[i] = 1;
}
}
}
at::Tensor bias_tmp = at::reshape(bias, input_size_tmp);
at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(
bias_tmp, input.sizes());
OpCommand cmd;
cmd.Name("FusedBiasLeakyRelu")
.Input(input)
.Input(bias_)
.Output(py)
.Attr("scale", scale)
.Attr("negative_slope", alpha)
.Run();
}
// backward
if (grad == 1) {
OpCommand cmd;
cmd.Name("FusedBiasLeakyReluGrad")
.Input(input)
.Input(refer)
.Output(py)
.Attr("scale", scale)
.Attr("negative_slope", alpha)
.Run();
}
return py;
}
REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, fused_bias_leakyrelu_npu);

View File

@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor,
torch.Tensor: Feature map after non-linear activation.
"""
if not input.is_cuda:
if not input.is_cuda and input.device.type != 'npu':
return bias_leakyrelu_ref(input, bias, negative_slope, scale)
return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),

View File

@ -2,6 +2,8 @@
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
_USING_PARROTS = True
try:
from parrots.autograd import gradcheck
@ -14,36 +16,59 @@ class TestFusedBiasLeakyReLU:
@classmethod
def setup_class(cls):
if not torch.cuda.is_available():
if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE:
return
cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda()
cls.bias = torch.zeros(2, requires_grad=True).cuda()
if IS_CUDA_AVAILABLE:
cls.input_tensor = torch.randn((2, 2, 2, 2),
requires_grad=True).cuda()
cls.bias = torch.zeros(2, requires_grad=True).cuda()
elif IS_NPU_AVAILABLE:
cls.input_tensor = torch.randn((2, 2, 2, 2),
requires_grad=True).npu()
cls.bias = torch.zeros(2, requires_grad=True).npu()
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
def test_gradient(self):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_gradient(self, device):
from mmcv.ops import FusedBiasLeakyReLU
if _USING_PARROTS:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
delta=1e-4,
pt_atol=1e-3)
if IS_CUDA_AVAILABLE:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
self.input_tensor,
delta=1e-4,
pt_atol=1e-3)
else:
gradcheck(
FusedBiasLeakyReLU(2).cuda(),
FusedBiasLeakyReLU(2).to(device),
self.input_tensor,
eps=1e-4,
atol=1e-3)
@pytest.mark.skipif(
not torch.cuda.is_available() or _USING_PARROTS,
reason='requires cuda')
def test_gradgradient(self):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
def test_gradgradient(self, device):
from mmcv.ops import FusedBiasLeakyReLU
gradgradcheck(
FusedBiasLeakyReLU(2).cuda(),
FusedBiasLeakyReLU(2).to(device),
self.input_tensor,
eps=1e-4,
atol=1e-3)