diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 9784e459b..6967467c9 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -40,7 +40,7 @@ We implement common ops used in detection, segmentation, etc. | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | | PointsInPolygons | | √ | | | | -| PSAMask | √ | √ | √ | | | +| PSAMask | √ | √ | √ | | √ | | RotatedFeatureAlign | √ | √ | | | | | RoIPointPool3d | | √ | √ | | | | RoIPool | | √ | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 715b38a7f..cbfb39d3d 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -40,7 +40,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | | PointsInPolygons | | √ | | | | -| PSAMask | √ | √ | √ | | | +| PSAMask | √ | √ | √ | | √ | | RotatedFeatureAlign | √ | √ | | | | | RoIPointPool3d | | √ | √ | | | | RoIPool | | √ | √ | | | diff --git a/mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp b/mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp new file mode 100644 index 000000000..7fc0ad69b --- /dev/null +++ b/mmcv/ops/csrc/pytorch/npu/psa_mask_npu.cpp @@ -0,0 +1,95 @@ +#include "pytorch_npu_helper.hpp" + +using namespace NPU_NAME_SPACE; +using namespace std; + +void psamask_forward_npu(const int psa_type, + const Tensor x, + Tensor y, + const int num, + const int h_feature, + const int w_feature, + const int h_mask, + const int w_mask, + const int half_h_mask, + const int half_w_mask) { + int64_t psa_type_i64 = psa_type; + int64_t num_i64 = num; + int64_t h_feature_i64 = h_feature; + int64_t w_feature_i64 = w_feature; + int64_t h_mask_i64 = h_mask; + int64_t w_mask_i64 = w_mask; + int64_t half_h_mask_i64 = half_h_mask; + int64_t half_w_mask_i64 = half_w_mask; + OpCommand cmd; + cmd.Name("PSAMask") + .Input(x) + .Output(y) + .Attr("psa_type", psa_type_i64) + .Attr("num", num_i64) + .Attr("h_feature", h_feature_i64) + .Attr("w_feature", w_feature_i64) + .Attr("h_mask", h_mask_i64) + .Attr("w_mask", w_mask_i64) + .Attr("half_h_mask", half_h_mask_i64) + .Attr("half_w_mask", half_w_mask_i64) + .Run(); +} + +void psamask_forward_impl(const int psa_type, + const Tensor x, + Tensor y, + const int num, + const int h_feature, + const int w_feature, + const int h_mask, + const int w_mask, + const int half_h_mask, + const int half_w_mask); + +void psamask_backward_npu(const int psa_type, + const Tensor y_grad, + Tensor x_grad, + const int num, + const int h_feature, + const int w_feature, + const int h_mask, + const int w_mask, + const int half_h_mask, + const int half_w_mask) { + int64_t psa_type_i64 = psa_type; + int64_t num_i64 = num; + int64_t h_feature_i64 = h_feature; + int64_t w_feature_i64 = w_feature; + int64_t h_mask_i64 = h_mask; + int64_t w_mask_i64 = w_mask; + int64_t half_h_mask_i64 = half_h_mask; + int64_t half_w_mask_i64 = half_w_mask; + OpCommand cmd; + cmd.Name("PSAMaskGrad") + .Input(y_grad) + .Output(x_grad) + .Attr("psa_type", psa_type_i64) + .Attr("num", num_i64) + .Attr("h_feature", h_feature_i64) + .Attr("w_feature", w_feature_i64) + .Attr("h_mask", h_mask_i64) + .Attr("w_mask", w_mask_i64) + .Attr("half_h_mask", half_h_mask_i64) + .Attr("half_w_mask", half_w_mask_i64) + .Run(); +} + +void psamask_backward_impl(const int psa_type, + const Tensor y_grad, + Tensor x_grad, + const int num, + const int h_feature, + const int w_feature, + const int h_mask, + const int w_mask, + const int half_h_mask, + const int half_w_mask); + +REGISTER_NPU_IMPL(psamask_forward_impl, psamask_forward_npu); +REGISTER_NPU_IMPL(psamask_backward_impl, psamask_backward_npu); diff --git a/tests/test_ops/test_psa_mask.py b/tests/test_ops/test_psa_mask.py index 8c1f3101a..b0fd86e8f 100644 --- a/tests/test_ops/test_psa_mask.py +++ b/tests/test_ops/test_psa_mask.py @@ -4,7 +4,7 @@ import pytest import torch import torch.nn as nn -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class Loss(nn.Module): @@ -28,7 +28,11 @@ class TestPSAMask: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_psa_mask_collect(self, device): from mmcv.ops import PSAMask @@ -76,7 +80,11 @@ class TestPSAMask: pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'npu', + marks=pytest.mark.skipif( + not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) def test_psa_mask_distribute(self, device): from mmcv.ops import PSAMask