mmcv/tests/test_ops/test_spconv.py
bdf 733e6ff84e
Pick MLU modifications from master (1.x) to main (2.x) (#2704)
* [Feature] Support Voxelization with cambricon MLU device (#2500)

* [Feature] Support hard_voxelize with cambricon MLU backend

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Feature](bangc-ops): add voxelization op

* [Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510)

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* ms_opt

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization

* [Feature] ms_deform_attn performance optimization

* [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. (#2520)

* [Feature] Support ball_query with cambricon MLU backend and mlu-ops library.

* [Fix] update operator data layout setting.

* [Fix] add cxx compile option to avoid symbol conflict.

* [Fix] fix lint errors.

* [Fix] update ops.md with info of ball_query support by MLU backend.

* [Feature] Fix typo.

* [Fix] Remove print.

* [Fix] get mlu-ops from MMCV_MLU_OPS_PATH env.

* [Fix] update MMCV_MLU_OPS_PATH check logic.

* [Fix] update error info when failed to download mlu-ops.

* [Fix] check mlu-ops version matching info in mmcv.

* [Fix] revise wrong filename.

* [Fix] remove f.close and re.

* [Docs] Steps to compile mmcv-full on MLU machine (#2571)

* [Docs] Steps to compile mmcv-full on MLU machine

* [Docs] Adjust paragraph order

* Update docs/zh_cn/get_started/build.md

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/zh_cn/get_started/build.md

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/en/get_started/build.md

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/en/get_started/build.md

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [Docs] Modify the format

---------

Co-authored-by: budefei <budefei@cambricon.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [Fix] Fix tensor descriptor setting in MLU ball_query. (#2579)

* [Feature] Add MLU support for Sparse Convolution op (#2589)

* [Feature] Add sparse convolution MLU API

* [Feature] update cpp code style

* end-of-file

* delete libext.a

* code style

* update ops.md

---------

Co-authored-by: budefei <budefei@cambricon.com>

* [Enhancement] Replace the implementation of deform_roi_pool with mlu-ops (#2598)

* [Feature] Replace the implementation of deform_roi_pool with mlu-ops

* [Feature] Modify code

---------

Co-authored-by: budefei <budefei@cambricon.com>

* [Enhancement] ms_deform_attn performance optimization (#2616)

* ms_opt_v2

* ms_opt_v2_1

* optimize MultiScaleDeformableAttention ops for MLU

* ms_opt_v2_1

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

* [Feature] ms_deform_attn performance optimization V2

---------

Co-authored-by: dongchengwei <dongchengwei@cambricon.com>

* [Feature] Support NmsRotated with cambricon MLU backend (#2643)

* [Feature] Support NmsRotated with cambricon MLU backend

* [Feature] remove foolproofs in nms_rotated_mlu.cpp

* [Feature] fix lint in test_nms_rotated.py

* [Feature] fix kMLU not found in nms_rotated.cpp

* [Feature] modify mlu support in nms.py

* [Feature] modify nms_rotated support in ops.md

* [Feature] modify ops/nms.py

* [Enhance] Add a default value for MMCV_MLU_ARGS (#2688)

* add mlu_args

* add mlu_args

* Modify the code

---------

Co-authored-by: budefei <budefei@cambricon.com>

* [Enhance] Ignore mlu-ops files (#2691)

Co-authored-by: budefei <budefei@cambricon.com>

---------

Co-authored-by: ZShaopeng <108382403+ZShaopeng@users.noreply.github.com>
Co-authored-by: BinZheng <38182684+Wickyzheng@users.noreply.github.com>
Co-authored-by: liuduanhui <103939338+DanieeelLiu@users.noreply.github.com>
Co-authored-by: budefei <budefei@cambricon.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: duzekun <108381389+duzekunKTH@users.noreply.github.com>
Co-authored-by: dongchengwei <dongchengwei@cambricon.com>
Co-authored-by: liuyuan1-v <125547457+liuyuan1-v@users.noreply.github.com>
2023-04-19 10:42:07 +08:00

146 lines
5.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from torch import nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.ops import (SparseConvTensor, SparseInverseConv3d, SparseSequential,
SubMConv3d)
if torch.__version__ == 'parrots':
pytest.skip('not supported in parrots now', allow_module_level=True)
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
def make_sparse_convmodule(in_channels,
out_channels,
kernel_size,
indice_key,
stride=1,
padding=0,
conv_type='SubMConv3d',
norm_cfg=None,
order=('conv', 'norm', 'act')):
"""Make sparse convolution module.
Args:
in_channels (int): the number of input channels
out_channels (int): the number of out channels
kernel_size (int|tuple(int)): kernel size of convolution
indice_key (str): the indice key used for sparse tensor
stride (int|tuple(int)): the stride of convolution
padding (int or list[int]): the padding number of input
conv_type (str): sparse conv type in spconv
norm_cfg (dict[str]): config of normalization layer
order (tuple[str]): The order of conv/norm/activation layers. It is a
sequence of "conv", "norm" and "act". Common examples are
("conv", "norm", "act") and ("act", "conv", "norm").
Returns:
spconv.SparseSequential: sparse convolution module.
"""
assert isinstance(order, tuple) and len(order) <= 3
assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'}
conv_cfg = dict(type=conv_type, indice_key=indice_key)
layers = list()
for layer in order:
if layer == 'conv':
if conv_type not in [
'SparseInverseConv3d', 'SparseInverseConv2d',
'SparseInverseConv1d'
]:
layers.append(
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=False))
else:
layers.append(
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size,
bias=False))
elif layer == 'norm':
layers.append(build_norm_layer(norm_cfg, out_channels)[1])
elif layer == 'act':
layers.append(nn.ReLU(inplace=True))
layers = SparseSequential(*layers)
return layers
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_make_sparse_convmodule(device):
torch.cuda.empty_cache()
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
[6.8162713, -2.480431, -1.3616394, 0.36],
[11.643568, -4.744306, -1.3580885, 0.16],
[23.482342, 6.5036807, 0.5806964, 0.35]],
dtype=torch.float32,
device=device) # n, point_features
coordinates = torch.tensor(
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
[1, 35, 930, 469]],
dtype=torch.int32,
device=device) # n, 4(batch, ind_x, ind_y, ind_z)
# test
input_sp_tensor = SparseConvTensor(voxel_features, coordinates,
[41, 1600, 1408], 2)
sparse_block0 = make_sparse_convmodule(
4,
16,
3,
'test0',
stride=1,
padding=0,
conv_type='SubMConv3d',
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
order=('conv', 'norm', 'act')).to(device)
assert isinstance(sparse_block0[0], SubMConv3d)
assert sparse_block0[0].in_channels == 4
assert sparse_block0[0].out_channels == 16
assert isinstance(sparse_block0[1], torch.nn.BatchNorm1d)
assert sparse_block0[1].eps == 0.001
assert sparse_block0[1].momentum == 0.01
assert isinstance(sparse_block0[2], torch.nn.ReLU)
# test forward
out_features = sparse_block0(input_sp_tensor)
assert out_features.features.shape == torch.Size([4, 16])
# device == mlu: not support inverse==1 yet
if device != 'mlu':
sparse_block1 = make_sparse_convmodule(
4,
16,
3,
'test1',
stride=1,
padding=0,
conv_type='SparseInverseConv3d',
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
order=('norm', 'act', 'conv')).to(device)
assert isinstance(sparse_block1[2], SparseInverseConv3d)
assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d)
assert isinstance(sparse_block1[1], torch.nn.ReLU)