mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Support ball_query with cambricon MLU backend and mlu-ops library. (#2520)
* [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. * [Fix] update operator data layout setting. * [Fix] add cxx compile option to avoid symbol conflict. * [Fix] fix lint errors. * [Fix] update ops.md with info of ball_query support by MLU backend. * [Feature] Fix typo. * [Fix] Remove print. * [Fix] get mlu-ops from MMCV_MLU_OPS_PATH env. * [Fix] update MMCV_MLU_OPS_PATH check logic. * [Fix] update error info when failed to download mlu-ops. * [Fix] check mlu-ops version matching info in mmcv. * [Fix] revise wrong filename. * [Fix] remove f.close and re.pull/2572/head
parent
84f60c178c
commit
dfb03806a1
|
@ -6,7 +6,7 @@ We implement common ops used in detection, segmentation, etc.
|
|||
| ---------------------------- | --- | ---- | --- | --- | ------ |
|
||||
| ActiveRotatedFilter | √ | √ | | | |
|
||||
| AssignScoreWithK | | √ | | | |
|
||||
| BallQuery | | √ | | | |
|
||||
| BallQuery | | √ | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ | |
|
||||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | | | |
|
||||
|
|
|
@ -6,7 +6,7 @@ MMCV 提供了检测、分割等任务中常用的算子
|
|||
| ---------------------------- | --- | ---- | --- | --- | ------ |
|
||||
| ActiveRotatedFilter | √ | √ | | | |
|
||||
| AssignScoreWithK | | √ | | | |
|
||||
| BallQuery | | √ | | | |
|
||||
| BallQuery | | √ | √ | | |
|
||||
| BBoxOverlaps | | √ | √ | √ | |
|
||||
| BorderAlign | | √ | | | |
|
||||
| BoxIouRotated | √ | √ | | | |
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
void ball_query_forward_mlu(int b, int n, int m, float min_radius,
|
||||
float max_radius, int nsample, const Tensor new_xyz,
|
||||
const Tensor xyz, Tensor idx) {
|
||||
MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc;
|
||||
new_xyz_desc.set(new_xyz);
|
||||
xyz_desc.set(xyz);
|
||||
idx_desc.set(idx);
|
||||
|
||||
auto new_xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
new_xyz, new_xyz.suggest_memory_format());
|
||||
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
xyz, new_xyz.suggest_memory_format());
|
||||
auto idx_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
idx, new_xyz.suggest_memory_format());
|
||||
|
||||
auto new_xyz_impl = torch_mlu::getMluTensorImpl(new_xyz_contiguous);
|
||||
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous);
|
||||
auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous);
|
||||
auto new_xyz_ptr = new_xyz_impl->cnnlMalloc();
|
||||
auto xyz_ptr = xyz_impl->cnnlMalloc();
|
||||
auto idx_ptr = idx_impl->cnnlMalloc();
|
||||
|
||||
auto handle = mluOpGetCurrentHandle();
|
||||
mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(),
|
||||
xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(),
|
||||
idx_ptr);
|
||||
}
|
||||
|
||||
void ball_query_forward_impl(int b, int n, int m, float min_radius,
|
||||
float max_radius, int nsample,
|
||||
const Tensor new_xyz, const Tensor xyz,
|
||||
Tensor idx);
|
||||
|
||||
REGISTER_DEVICE_IMPL(ball_query_forward_impl, MLU, ball_query_forward_mlu);
|
|
@ -0,0 +1,103 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#include "mlu_common_helper.h"
|
||||
|
||||
// Descriptors
|
||||
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type) {
|
||||
const std::map<std::string, mluOpDataType_t> mapping_type = {
|
||||
{std::string("c10::Half"), MLUOP_DTYPE_HALF},
|
||||
{std::string("float"), MLUOP_DTYPE_FLOAT},
|
||||
{std::string("double"), MLUOP_DTYPE_DOUBLE},
|
||||
{std::string("int8"), MLUOP_DTYPE_INT8},
|
||||
{std::string("signed char"), MLUOP_DTYPE_INT8},
|
||||
{std::string("short int"), MLUOP_DTYPE_INT16},
|
||||
{std::string("short"), MLUOP_DTYPE_INT16},
|
||||
{std::string("int"), MLUOP_DTYPE_INT32},
|
||||
{std::string("long int"), MLUOP_DTYPE_INT64},
|
||||
{std::string("long"), MLUOP_DTYPE_INT64},
|
||||
{std::string("unsigned char"), MLUOP_DTYPE_UINT8},
|
||||
{std::string("bool"), MLUOP_DTYPE_BOOL},
|
||||
{std::string("c10::complex<c10::Half>"), MLUOP_DTYPE_COMPLEX_HALF},
|
||||
{std::string("c10::complex<float>"), MLUOP_DTYPE_COMPLEX_FLOAT}};
|
||||
|
||||
if (mapping_type.find(std::string(data_type.name())) != mapping_type.end()) {
|
||||
return mapping_type.find(std::string(data_type.name()))->second;
|
||||
}
|
||||
return MLUOP_DTYPE_INVALID;
|
||||
}
|
||||
|
||||
// laytout
|
||||
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input) {
|
||||
auto suggest_memory_format = input.suggest_memory_format();
|
||||
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY;
|
||||
switch (input.dim()) {
|
||||
case 4:
|
||||
layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast)
|
||||
? MLUOP_LAYOUT_NHWC
|
||||
: MLUOP_LAYOUT_NCHW;
|
||||
break;
|
||||
case 5:
|
||||
layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast3d)
|
||||
? MLUOP_LAYOUT_NDHWC
|
||||
: MLUOP_LAYOUT_NCDHW;
|
||||
break;
|
||||
default:
|
||||
layout = MLUOP_LAYOUT_ARRAY;
|
||||
}
|
||||
return layout;
|
||||
}
|
||||
|
||||
void MluOpTensorDescriptor::set(Tensor t) {
|
||||
mluOpDataType_t data_type = getMluOpDataType(t.dtype());
|
||||
mluOpTensorLayout_t layout = getMluOpSuggestLayout(t);
|
||||
int t_dim = t.dim();
|
||||
std::vector<int> dim_array;
|
||||
if (t_dim == 0) {
|
||||
dim_array.push_back(
|
||||
1); // ScalarTensor(0-dim 1-item Tensor) view like size = 1 as default;
|
||||
} else {
|
||||
for (int i = 0; i < t_dim; i++) {
|
||||
dim_array.push_back(static_cast<int>(t.sizes().vec()[i]));
|
||||
}
|
||||
}
|
||||
set_desc(t, layout, data_type, dim_array);
|
||||
}
|
||||
|
||||
void MluOpTensorDescriptor::set_desc(const at::Tensor& t,
|
||||
mluOpTensorLayout_t layout,
|
||||
mluOpDataType_t dtype,
|
||||
std::vector<int>& dims) {
|
||||
int dimNb = dims.size();
|
||||
mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data());
|
||||
}
|
||||
|
||||
// Handles
|
||||
std::once_flag mmcv_mluop_init_flag;
|
||||
std::mutex mmcv_mluop_mutex;
|
||||
static std::vector<MluOpHandle> mmcv_mluop_handles;
|
||||
|
||||
mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index) {
|
||||
std::call_once(mmcv_mluop_init_flag,
|
||||
[]() // Init mmcv_mluop_handles 1-device <-> 1-handle
|
||||
{
|
||||
c10::DeviceIndex num_devices = torch_mlu::device_count();
|
||||
mmcv_mluop_handles.resize(num_devices);
|
||||
});
|
||||
|
||||
if (device_index == -1) {
|
||||
device_index = torch_mlu::current_device();
|
||||
}
|
||||
std::lock_guard<std::mutex> mmcv_mluop_guard(mmcv_mluop_mutex);
|
||||
auto queue = torch_mlu::getCurrentQueue(device_index).queue();
|
||||
mmcv_mluop_handles[device_index].setQueue(queue);
|
||||
return mmcv_mluop_handles[device_index].handle;
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*************************************************************************
|
||||
* Copyright (C) 2022 Cambricon.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#pragma once
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include "aten.h"
|
||||
#include "mlu_op.h"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
#define MLUOP_MAJOR 0
|
||||
#define MLUOP_MINOR 4
|
||||
#define MLUOP_PATCHLEVEL 1
|
||||
|
||||
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type);
|
||||
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input);
|
||||
|
||||
class MluOpTensorDescriptor {
|
||||
public:
|
||||
MluOpTensorDescriptor() { mluOpCreateTensorDescriptor(&desc_); };
|
||||
~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); }
|
||||
|
||||
void set(at::Tensor);
|
||||
mluOpTensorDescriptor_t desc() { return desc_; }
|
||||
|
||||
private:
|
||||
mluOpTensorDescriptor_t desc_;
|
||||
void set_desc(const at::Tensor&, mluOpTensorLayout_t, mluOpDataType_t,
|
||||
std::vector<int>& dims);
|
||||
};
|
||||
|
||||
mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index = -1);
|
||||
|
||||
class MluOpHandle {
|
||||
public:
|
||||
MluOpHandle() : handle(nullptr) { mluOpCreate(&handle); }
|
||||
~MluOpHandle() {
|
||||
if (handle) {
|
||||
mluOpDestroy(handle);
|
||||
handle = nullptr;
|
||||
}
|
||||
}
|
||||
void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); }
|
||||
mluOpHandle_t handle;
|
||||
};
|
88
setup.py
88
setup.py
|
@ -270,6 +270,7 @@ def get_extensions():
|
|||
|
||||
include_dirs = []
|
||||
|
||||
extra_objects = []
|
||||
is_rocm_pytorch = False
|
||||
try:
|
||||
from torch.utils.cpp_extension import ROCM_HOME
|
||||
|
@ -300,16 +301,98 @@ def get_extensions():
|
|||
torch.is_mlu_available()) or \
|
||||
os.getenv('FORCE_MLU', '0') == '1':
|
||||
from torch_mlu.utils.cpp_extension import MLUExtension
|
||||
|
||||
def get_mluops_version(file_path):
|
||||
with open(file_path) as f:
|
||||
for line in f:
|
||||
if re.search('MLUOP_MAJOR', line):
|
||||
major = line.strip().split(' ')[2]
|
||||
if re.search('MLUOP_MINOR', line):
|
||||
minor = line.strip().split(' ')[2]
|
||||
if re.search('MLUOP_PATCHLEVEL', line):
|
||||
patchlevel = line.strip().split(' ')[2]
|
||||
mluops_version = f'v{major}.{minor}.{patchlevel}'
|
||||
return mluops_version
|
||||
|
||||
mmcv_mluops_version = get_mluops_version(
|
||||
'./mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h')
|
||||
mlu_ops_path = os.getenv('MMCV_MLU_OPS_PATH')
|
||||
if mlu_ops_path:
|
||||
exists_mluops_version = get_mluops_version(
|
||||
mlu_ops_path + '/bangc-ops/mlu_op.h')
|
||||
if exists_mluops_version != mmcv_mluops_version:
|
||||
print('the version of mlu-ops provided is %s,'
|
||||
' while %s is needed.' %
|
||||
(exists_mluops_version, mmcv_mluops_version))
|
||||
exit()
|
||||
try:
|
||||
if os.path.exists('mlu-ops'):
|
||||
if os.path.islink('mlu-ops'):
|
||||
os.remove('mlu-ops')
|
||||
os.symlink(mlu_ops_path, 'mlu-ops')
|
||||
elif os.path.abspath('mlu-ops') != mlu_ops_path:
|
||||
os.symlink(mlu_ops_path, 'mlu-ops')
|
||||
else:
|
||||
os.symlink(mlu_ops_path, 'mlu-ops')
|
||||
except Exception:
|
||||
raise FileExistsError(
|
||||
'mlu-ops already exists, please move it out,'
|
||||
'or rename or remove it.')
|
||||
else:
|
||||
if not os.path.exists('mlu-ops'):
|
||||
import requests
|
||||
mluops_url = 'https://github.com/Cambricon/mlu-ops/' + \
|
||||
'archive/refs/tags/' + mmcv_mluops_version + '.zip'
|
||||
req = requests.get(mluops_url)
|
||||
with open('./mlu-ops.zip', 'wb') as f:
|
||||
try:
|
||||
f.write(req.content)
|
||||
except Exception:
|
||||
raise ImportError('failed to download mlu-ops')
|
||||
|
||||
from zipfile import BadZipFile, ZipFile
|
||||
with ZipFile('./mlu-ops.zip', 'r') as archive:
|
||||
try:
|
||||
archive.extractall()
|
||||
dir_name = archive.namelist()[0].split('/')[0]
|
||||
os.rename(dir_name, 'mlu-ops')
|
||||
except BadZipFile:
|
||||
print('invalid mlu-ops.zip file')
|
||||
else:
|
||||
exists_mluops_version = get_mluops_version(
|
||||
'./mlu-ops/bangc-ops/mlu_op.h')
|
||||
if exists_mluops_version != mmcv_mluops_version:
|
||||
print('the version of provided mlu-ops is %s,'
|
||||
' while %s is needed.' %
|
||||
(exists_mluops_version, mmcv_mluops_version))
|
||||
exit()
|
||||
|
||||
define_macros += [('MMCV_WITH_MLU', None)]
|
||||
mlu_args = os.getenv('MMCV_MLU_ARGS')
|
||||
extra_compile_args['cncc'] = [mlu_args] if mlu_args else []
|
||||
mluops_includes = []
|
||||
mluops_includes.append('-I' +
|
||||
os.path.abspath('./mlu-ops/bangc-ops'))
|
||||
mluops_includes.append(
|
||||
'-I' + os.path.abspath('./mlu-ops/bangc-ops/kernels'))
|
||||
extra_compile_args['cncc'] = [mlu_args] + \
|
||||
mluops_includes if mlu_args else mluops_includes
|
||||
extra_compile_args['cxx'] += ['-fno-gnu-unique']
|
||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \
|
||||
glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu')
|
||||
glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') + \
|
||||
glob.glob(
|
||||
'./mlu-ops/bangc-ops/core/**/*.cpp', recursive=True) + \
|
||||
glob.glob(
|
||||
'./mlu-ops/bangc-ops/kernels/**/*.cpp', recursive=True) + \
|
||||
glob.glob(
|
||||
'./mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True)
|
||||
extra_objects = glob.glob(
|
||||
'./mlu-ops/bangc-ops/kernels/*/x86_64/*.o')
|
||||
extension = MLUExtension
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu'))
|
||||
include_dirs.append(os.path.abspath('./mlu-ops/bangc-ops'))
|
||||
elif (hasattr(torch.backends, 'mps')
|
||||
and torch.backends.mps.is_available()) or os.getenv(
|
||||
'FORCE_MPS', '0') == '1':
|
||||
|
@ -371,6 +454,7 @@ def get_extensions():
|
|||
sources=op_files,
|
||||
include_dirs=include_dirs,
|
||||
define_macros=define_macros,
|
||||
extra_objects=extra_objects,
|
||||
extra_compile_args=extra_compile_args)
|
||||
extensions.append(ext_ops)
|
||||
|
||||
|
|
|
@ -3,55 +3,59 @@ import pytest
|
|||
import torch
|
||||
|
||||
from mmcv.ops import ball_query
|
||||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_ball_query():
|
||||
new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625],
|
||||
[-2.2769, 2.7817, -0.2334],
|
||||
[-0.4003, 2.4666, -0.5116],
|
||||
[-0.0740, 1.3147, -1.3625],
|
||||
[-0.0740, 1.3147, -1.3625]],
|
||||
[[-2.0289, 2.4952, -0.1708],
|
||||
[-2.0668, 6.0278, -0.4875],
|
||||
[0.4066, 1.4211, -0.2947],
|
||||
[-2.0289, 2.4952, -0.1708],
|
||||
[-2.0289, 2.4952, -0.1708]]]).cuda()
|
||||
@pytest.mark.parametrize('device', [
|
||||
pytest.param(
|
||||
'cuda',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
|
||||
pytest.param(
|
||||
'mlu',
|
||||
marks=pytest.mark.skipif(
|
||||
not IS_MLU_AVAILABLE, reason='requires MLU support'))
|
||||
])
|
||||
def test_ball_query(device):
|
||||
new_xyz = torch.tensor(
|
||||
[[[-0.0740, 1.3147, -1.3625], [-2.2769, 2.7817, -0.2334],
|
||||
[-0.4003, 2.4666, -0.5116], [-0.0740, 1.3147, -1.3625],
|
||||
[-0.0740, 1.3147, -1.3625]],
|
||||
[[-2.0289, 2.4952, -0.1708], [-2.0668, 6.0278, -0.4875],
|
||||
[0.4066, 1.4211, -0.2947], [-2.0289, 2.4952, -0.1708],
|
||||
[-2.0289, 2.4952, -0.1708]]],
|
||||
device=device)
|
||||
|
||||
xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
|
||||
[-0.4003, 2.4666,
|
||||
-0.5116], [-0.5251, 2.4379, -0.8466],
|
||||
[-0.9691, 1.1418,
|
||||
-1.3733], [-0.2232, 0.9561, -1.3626],
|
||||
[-2.2769, 2.7817, -0.2334],
|
||||
[-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432],
|
||||
[0.4917, 1.1529, -1.3496]],
|
||||
[[-2.0289, 2.4952,
|
||||
-0.1708], [-0.7188, 0.9956, -0.5096],
|
||||
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
|
||||
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
|
||||
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
|
||||
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
|
||||
-1.2000]]]).cuda()
|
||||
xyz = torch.tensor(
|
||||
[[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
|
||||
[-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466],
|
||||
[-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626],
|
||||
[-2.2769, 2.7817, -0.2334], [-0.2822, 1.3192, -1.3645],
|
||||
[0.1533, 1.5024, -1.0432], [0.4917, 1.1529, -1.3496]],
|
||||
[[-2.0289, 2.4952, -0.1708], [-0.7188, 0.9956, -0.5096],
|
||||
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
|
||||
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
|
||||
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
|
||||
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, -1.2000]]],
|
||||
device=device)
|
||||
|
||||
idx = ball_query(0, 0.2, 5, xyz, new_xyz)
|
||||
expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
|
||||
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
|
||||
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]).cuda()
|
||||
expected_idx = torch.tensor(
|
||||
[[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], [2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]],
|
||||
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]],
|
||||
device=device)
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
||||
# test dilated ball query
|
||||
idx = ball_query(0.2, 0.4, 5, xyz, new_xyz)
|
||||
expected_idx = torch.tensor([[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6],
|
||||
[2, 3, 2, 2, 2], [0, 5, 7, 0, 0],
|
||||
[0, 5, 7, 0, 0]],
|
||||
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
|
||||
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]]).cuda()
|
||||
expected_idx = torch.tensor(
|
||||
[[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6], [2, 3, 2, 2, 2], [0, 5, 7, 0, 0],
|
||||
[0, 5, 7, 0, 0]],
|
||||
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]]],
|
||||
device=device)
|
||||
assert torch.all(idx == expected_idx)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue