mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* [Feature] Support Voxelization with cambricon MLU device (#2500) * [Feature] Support hard_voxelize with cambricon MLU backend * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510) * ms_opt * ms_opt * ms_opt * ms_opt * ms_opt * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. (#2520) * [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. * [Fix] update operator data layout setting. * [Fix] add cxx compile option to avoid symbol conflict. * [Fix] fix lint errors. * [Fix] update ops.md with info of ball_query support by MLU backend. * [Feature] Fix typo. * [Fix] Remove print. * [Fix] get mlu-ops from MMCV_MLU_OPS_PATH env. * [Fix] update MMCV_MLU_OPS_PATH check logic. * [Fix] update error info when failed to download mlu-ops. * [Fix] check mlu-ops version matching info in mmcv. * [Fix] revise wrong filename. * [Fix] remove f.close and re. * [Docs] Steps to compile mmcv-full on MLU machine (#2571) * [Docs] Steps to compile mmcv-full on MLU machine * [Docs] Adjust paragraph order * Update docs/zh_cn/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Docs] Modify the format --------- Co-authored-by: budefei <budefei@cambricon.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Fix] Fix tensor descriptor setting in MLU ball_query. (#2579) * [Feature] Add MLU support for Sparse Convolution op (#2589) * [Feature] Add sparse convolution MLU API * [Feature] update cpp code style * end-of-file * delete libext.a * code style * update ops.md --------- Co-authored-by: budefei <budefei@cambricon.com> * [Enhancement] Replace the implementation of deform_roi_pool with mlu-ops (#2598) * [Feature] Replace the implementation of deform_roi_pool with mlu-ops * [Feature] Modify code --------- Co-authored-by: budefei <budefei@cambricon.com> * [Enhancement] ms_deform_attn performance optimization (#2616) * ms_opt_v2 * ms_opt_v2_1 * optimize MultiScaleDeformableAttention ops for MLU * ms_opt_v2_1 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 --------- Co-authored-by: dongchengwei <dongchengwei@cambricon.com> * [Feature] Support NmsRotated with cambricon MLU backend (#2643) * [Feature] Support NmsRotated with cambricon MLU backend * [Feature] remove foolproofs in nms_rotated_mlu.cpp * [Feature] fix lint in test_nms_rotated.py * [Feature] fix kMLU not found in nms_rotated.cpp * [Feature] modify mlu support in nms.py * [Feature] modify nms_rotated support in ops.md * [Feature] modify ops/nms.py * [Enhance] Add a default value for MMCV_MLU_ARGS (#2688) * add mlu_args * add mlu_args * Modify the code --------- Co-authored-by: budefei <budefei@cambricon.com> * [Enhance] Ignore mlu-ops files (#2691) Co-authored-by: budefei <budefei@cambricon.com> --------- Co-authored-by: ZShaopeng <108382403+ZShaopeng@users.noreply.github.com> Co-authored-by: BinZheng <38182684+Wickyzheng@users.noreply.github.com> Co-authored-by: liuduanhui <103939338+DanieeelLiu@users.noreply.github.com> Co-authored-by: budefei <budefei@cambricon.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: duzekun <108381389+duzekunKTH@users.noreply.github.com> Co-authored-by: dongchengwei <dongchengwei@cambricon.com> Co-authored-by: liuyuan1-v <125547457+liuyuan1-v@users.noreply.github.com>
146 lines
5.3 KiB
Python
146 lines
5.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
from mmcv.cnn import build_conv_layer, build_norm_layer
|
|
from mmcv.ops import (SparseConvTensor, SparseInverseConv3d, SparseSequential,
|
|
SubMConv3d)
|
|
|
|
if torch.__version__ == 'parrots':
|
|
pytest.skip('not supported in parrots now', allow_module_level=True)
|
|
|
|
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
|
|
|
|
|
def make_sparse_convmodule(in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
indice_key,
|
|
stride=1,
|
|
padding=0,
|
|
conv_type='SubMConv3d',
|
|
norm_cfg=None,
|
|
order=('conv', 'norm', 'act')):
|
|
"""Make sparse convolution module.
|
|
|
|
Args:
|
|
in_channels (int): the number of input channels
|
|
out_channels (int): the number of out channels
|
|
kernel_size (int|tuple(int)): kernel size of convolution
|
|
indice_key (str): the indice key used for sparse tensor
|
|
stride (int|tuple(int)): the stride of convolution
|
|
padding (int or list[int]): the padding number of input
|
|
conv_type (str): sparse conv type in spconv
|
|
norm_cfg (dict[str]): config of normalization layer
|
|
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
|
sequence of "conv", "norm" and "act". Common examples are
|
|
("conv", "norm", "act") and ("act", "conv", "norm").
|
|
|
|
Returns:
|
|
spconv.SparseSequential: sparse convolution module.
|
|
"""
|
|
assert isinstance(order, tuple) and len(order) <= 3
|
|
assert set(order) | {'conv', 'norm', 'act'} == {'conv', 'norm', 'act'}
|
|
|
|
conv_cfg = dict(type=conv_type, indice_key=indice_key)
|
|
|
|
layers = list()
|
|
for layer in order:
|
|
if layer == 'conv':
|
|
if conv_type not in [
|
|
'SparseInverseConv3d', 'SparseInverseConv2d',
|
|
'SparseInverseConv1d'
|
|
]:
|
|
layers.append(
|
|
build_conv_layer(
|
|
conv_cfg,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=False))
|
|
else:
|
|
layers.append(
|
|
build_conv_layer(
|
|
conv_cfg,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
bias=False))
|
|
elif layer == 'norm':
|
|
layers.append(build_norm_layer(norm_cfg, out_channels)[1])
|
|
elif layer == 'act':
|
|
layers.append(nn.ReLU(inplace=True))
|
|
|
|
layers = SparseSequential(*layers)
|
|
return layers
|
|
|
|
|
|
@pytest.mark.parametrize('device', [
|
|
pytest.param(
|
|
'cuda',
|
|
marks=pytest.mark.skipif(
|
|
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
|
pytest.param(
|
|
'mlu',
|
|
marks=pytest.mark.skipif(
|
|
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
|
])
|
|
def test_make_sparse_convmodule(device):
|
|
torch.cuda.empty_cache()
|
|
voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315],
|
|
[6.8162713, -2.480431, -1.3616394, 0.36],
|
|
[11.643568, -4.744306, -1.3580885, 0.16],
|
|
[23.482342, 6.5036807, 0.5806964, 0.35]],
|
|
dtype=torch.float32,
|
|
device=device) # n, point_features
|
|
coordinates = torch.tensor(
|
|
[[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232],
|
|
[1, 35, 930, 469]],
|
|
dtype=torch.int32,
|
|
device=device) # n, 4(batch, ind_x, ind_y, ind_z)
|
|
|
|
# test
|
|
input_sp_tensor = SparseConvTensor(voxel_features, coordinates,
|
|
[41, 1600, 1408], 2)
|
|
|
|
sparse_block0 = make_sparse_convmodule(
|
|
4,
|
|
16,
|
|
3,
|
|
'test0',
|
|
stride=1,
|
|
padding=0,
|
|
conv_type='SubMConv3d',
|
|
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
|
|
order=('conv', 'norm', 'act')).to(device)
|
|
assert isinstance(sparse_block0[0], SubMConv3d)
|
|
assert sparse_block0[0].in_channels == 4
|
|
assert sparse_block0[0].out_channels == 16
|
|
assert isinstance(sparse_block0[1], torch.nn.BatchNorm1d)
|
|
assert sparse_block0[1].eps == 0.001
|
|
assert sparse_block0[1].momentum == 0.01
|
|
assert isinstance(sparse_block0[2], torch.nn.ReLU)
|
|
|
|
# test forward
|
|
out_features = sparse_block0(input_sp_tensor)
|
|
assert out_features.features.shape == torch.Size([4, 16])
|
|
|
|
# device == mlu: not support inverse==1 yet
|
|
if device != 'mlu':
|
|
sparse_block1 = make_sparse_convmodule(
|
|
4,
|
|
16,
|
|
3,
|
|
'test1',
|
|
stride=1,
|
|
padding=0,
|
|
conv_type='SparseInverseConv3d',
|
|
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
|
|
order=('norm', 'act', 'conv')).to(device)
|
|
assert isinstance(sparse_block1[2], SparseInverseConv3d)
|
|
assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d)
|
|
assert isinstance(sparse_block1[1], torch.nn.ReLU)
|