mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add FusedBiasLeakyRelu npu adapter (#2474)
* init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * clean code * fix autocast bugs on npu (#2273) fix autocast bugs on npu (#2273) * code format * code format * code format * bug fix * pytorch_npu_helper.hpp clean code * Npu dev (#2306) * fix autocast bugs on npu * using scatter_kwargs in mmcv.device.scatter_gather * raise ImportError when compile with npu * add npu test case (#2307) * add npu test case * Update focal_loss.py * add comment * clean lint * update dtype assert * update DDP forward and comment * fix bug Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * sigmoidfocalloss npu adapter bug fix * BugFix: modify softmaxFocalLoss adapter * BugFix: remove equal sign in the code * add npu install information in README * add modulatedDeformConv npu adapter * init npu * add npu extension and focal loss adapter * clean code * clean code * clean code * add modulatedDeformConv npu adapter * merge master branch 20221103 * Add masked_ Conv2d operator in NPU * add nms_npu * fix bug * fix code check * fix code check * fix code check * Masked_conv2d NPU * Masked_conv2d NPU * Masked_conv2d NPU * remove npu-install-info in README.md * annotate the clang-format in pre-commit-config-zh-ch.yaml * Clean code: fix the clean code problem in masked_conv2d and modulated_deform_conv * Create fused_bias_leakyrelu_npu.cpp Add NPU adapter for fused_bias_leaky_relu operator * Update fused_bias_leakyrelu_npu.cpp * Update fused_bias_leakyrelu_npu.cpp * Update ops.md * Update ops.md * Update fused_bias_leakyrelu_npu.cpp * Update fused_bias_leakyrelu_npu.cpp * Update test_fused_bias_leakyrelu.py * Update fused_bias_leakyrelu.py * Update test_fused_bias_leakyrelu.py * Update fused_bias_leakyrelu.py * Update test_fused_bias_leakyrelu.py * Update ops.md * amend for CI * bugfix * amend ops.md * Update test_fused_bias_leakyrelu.py * clean code * bugfix * clean code * Update fused_bias_leakyrelu_npu.cpp * Update fused_bias_leakyrelu_npu.cpp Co-authored-by: wangjiangben <wangjiangben_hw@126.com> Co-authored-by: ckirchhoff <515629648@qq.com> Co-authored-by: wangjiangben-hw <111729245+wangjiangben-hw@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zcc-zjut <zcczxy2019@163.com> Co-authored-by: wangxiaoxin_sherie <wangxiaoxin7@huawei.com> Co-authored-by: momo609 <963372609@qq.com>pull/2477/head
parent
4c4ac6b43f
commit
4c51afce2a
|
@ -24,7 +24,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| DynamicScatter | | √ | | | |
|
||||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
| FusedBiasLeakyrelu | | √ | | | |
|
||||
| FusedBiasLeakyrelu | | √ | | | √ |
|
||||
| GatherPoints | | √ | | | |
|
||||
| GroupPoints | | √ | | | |
|
||||
| Iou3d | | √ | √ | | |
|
||||
|
|
|
@ -24,7 +24,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| DynamicScatter | | √ | | | |
|
||||
| FurthestPointSample | | √ | | | |
|
||||
| FurthestPointSampleWithDist | | √ | | | |
|
||||
| FusedBiasLeakyrelu | | √ | | | |
|
||||
| FusedBiasLeakyrelu | | √ | | | √ |
|
||||
| GatherPoints | | √ | | | |
|
||||
| GroupPoints | | √ | | | |
|
||||
| Iou3d | | √ | √ | | |
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
#include "pytorch_npu_helper.hpp"
|
||||
|
||||
using namespace NPU_NAME_SPACE;
|
||||
using namespace std;
|
||||
|
||||
Tensor fused_bias_leakyrelu_op_impl(const Tensor &input, const Tensor &bias,
|
||||
const Tensor &refer, int act, int grad,
|
||||
float alpha, float scale);
|
||||
|
||||
Tensor fused_bias_leakyrelu_npu(const Tensor &input, const Tensor &bias,
|
||||
const Tensor &refer, int act, int grad,
|
||||
float alpha, float scale) {
|
||||
at::Tensor py = at::empty_like(input);
|
||||
// forward
|
||||
if (grad == 0) {
|
||||
auto input_size = input.sizes();
|
||||
int input_length = input_size.size();
|
||||
c10::SmallVector<int64_t, SIZE> input_size_tmp;
|
||||
input_size_tmp = array_to_small_vector(input_size);
|
||||
if (input_length > 1) {
|
||||
for (int i = 0; i < input_length; i++) {
|
||||
if (i != 1) {
|
||||
input_size_tmp[i] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
at::Tensor bias_tmp = at::reshape(bias, input_size_tmp);
|
||||
at::Tensor bias_ = at_npu::native::NPUNativeFunctions::npu_broadcast(
|
||||
bias_tmp, input.sizes());
|
||||
OpCommand cmd;
|
||||
cmd.Name("FusedBiasLeakyRelu")
|
||||
.Input(input)
|
||||
.Input(bias_)
|
||||
.Output(py)
|
||||
.Attr("scale", scale)
|
||||
.Attr("negative_slope", alpha)
|
||||
.Run();
|
||||
}
|
||||
|
||||
// backward
|
||||
if (grad == 1) {
|
||||
OpCommand cmd;
|
||||
cmd.Name("FusedBiasLeakyReluGrad")
|
||||
.Input(input)
|
||||
.Input(refer)
|
||||
.Output(py)
|
||||
.Attr("scale", scale)
|
||||
.Attr("negative_slope", alpha)
|
||||
.Run();
|
||||
}
|
||||
return py;
|
||||
}
|
||||
|
||||
REGISTER_NPU_IMPL(fused_bias_leakyrelu_op_impl, fused_bias_leakyrelu_npu);
|
|
@ -258,7 +258,7 @@ def fused_bias_leakyrelu(input: torch.Tensor,
|
|||
torch.Tensor: Feature map after non-linear activation.
|
||||
"""
|
||||
|
||||
if not input.is_cuda:
|
||||
if not input.is_cuda and input.device.type != 'npu':
|
||||
return bias_leakyrelu_ref(input, bias, negative_slope, scale)
|
||||
|
||||
return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
|
||||
|
||||
_USING_PARROTS = True
|
||||
try:
|
||||
from parrots.autograd import gradcheck
|
||||
|
@ -14,36 +16,59 @@ class TestFusedBiasLeakyReLU:
|
|||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
if not torch.cuda.is_available():
|
||||
if not IS_CUDA_AVAILABLE and not IS_NPU_AVAILABLE:
|
||||
return
|
||||
cls.input_tensor = torch.randn((2, 2, 2, 2), requires_grad=True).cuda()
|
||||
cls.bias = torch.zeros(2, requires_grad=True).cuda()
|
||||
if IS_CUDA_AVAILABLE:
|
||||
cls.input_tensor = torch.randn((2, 2, 2, 2),
|
||||
requires_grad=True).cuda()
|
||||
cls.bias = torch.zeros(2, requires_grad=True).cuda()
|
||||
elif IS_NPU_AVAILABLE:
|
||||
cls.input_tensor = torch.randn((2, 2, 2, 2),
|
||||
requires_grad=True).npu()
|
||||
cls.bias = torch.zeros(2, requires_grad=True).npu()
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||
def test_gradient(self):
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_gradient(self, device):
|
||||
|
||||
from mmcv.ops import FusedBiasLeakyReLU
|
||||
if _USING_PARROTS:
|
||||
gradcheck(
|
||||
FusedBiasLeakyReLU(2).cuda(),
|
||||
self.input_tensor,
|
||||
delta=1e-4,
|
||||
pt_atol=1e-3)
|
||||
if IS_CUDA_AVAILABLE:
|
||||
gradcheck(
|
||||
FusedBiasLeakyReLU(2).cuda(),
|
||||
self.input_tensor,
|
||||
delta=1e-4,
|
||||
pt_atol=1e-3)
|
||||
else:
|
||||
gradcheck(
|
||||
FusedBiasLeakyReLU(2).cuda(),
|
||||
FusedBiasLeakyReLU(2).to(device),
|
||||
self.input_tensor,
|
||||
eps=1e-4,
|
||||
atol=1e-3)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or _USING_PARROTS,
|
||||
reason='requires cuda')
|
||||
def test_gradgradient(self):
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'npu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_NPU_AVAILABLE, reason='requires NPU support'))
|
||||
])
|
||||
def test_gradgradient(self, device):
|
||||
|
||||
from mmcv.ops import FusedBiasLeakyReLU
|
||||
gradgradcheck(
|
||||
FusedBiasLeakyReLU(2).cuda(),
|
||||
FusedBiasLeakyReLU(2).to(device),
|
||||
self.input_tensor,
|
||||
eps=1e-4,
|
||||
atol=1e-3)
|
||||
|
|
Loading…
Reference in New Issue