mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix tensor descriptor setting in MLU ball_query. (#2579)
parent
497c2df495
commit
46ea829b67
|
@ -14,11 +14,6 @@
|
|||
void ball_query_forward_mlu(int b, int n, int m, float min_radius,
|
||||
float max_radius, int nsample, const Tensor new_xyz,
|
||||
const Tensor xyz, Tensor idx) {
|
||||
MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc;
|
||||
new_xyz_desc.set(new_xyz);
|
||||
xyz_desc.set(xyz);
|
||||
idx_desc.set(idx);
|
||||
|
||||
auto new_xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
new_xyz, new_xyz.suggest_memory_format());
|
||||
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
|
@ -26,6 +21,11 @@ void ball_query_forward_mlu(int b, int n, int m, float min_radius,
|
|||
auto idx_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous(
|
||||
idx, new_xyz.suggest_memory_format());
|
||||
|
||||
MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc;
|
||||
new_xyz_desc.set(new_xyz_contiguous);
|
||||
xyz_desc.set(xyz_contiguous);
|
||||
idx_desc.set(idx_contiguous);
|
||||
|
||||
auto new_xyz_impl = torch_mlu::getMluTensorImpl(new_xyz_contiguous);
|
||||
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous);
|
||||
auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous);
|
||||
|
|
Loading…
Reference in New Issue