mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Feature] Add ChamferDistance op in gpu (#1933)
* update * fix lint * fix lint * Update mmcv/ops/chamfer_distance.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update chamfer_distance.py * Update test_chamfer_distance.py * fix * Update chamfer_distance_cuda_kernel.cuh * Update chamfer_distance_cuda_kernel.cuh * fix * Update mmcv/ops/chamfer_distance.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update chamfer_distance.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
5601427bbc
commit
834d5978df
@ -11,6 +11,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
|
||||
- ContextBlock
|
||||
- ConvexIoU
|
||||
- CornerPool
|
||||
- ChamferDistance
|
||||
- Deformable Convolution v1/v2
|
||||
- Deformable RoIPool
|
||||
- DiffIoURotated
|
||||
|
@ -11,6 +11,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
|
||||
- ContextBlock
|
||||
- ConvexIoU
|
||||
- CornerPool
|
||||
- ChamferDistance
|
||||
- Deformable Convolution v1/v2
|
||||
- Deformable RoIPool
|
||||
- DiffIoURotated
|
||||
|
@ -7,6 +7,7 @@ from .border_align import BorderAlign, border_align
|
||||
from .box_iou_rotated import box_iou_rotated
|
||||
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
|
||||
from .cc_attention import CrissCrossAttention
|
||||
from .chamfer_distance import chamfer_distance
|
||||
from .contour_expand import contour_expand
|
||||
from .convex_iou import convex_giou, convex_iou
|
||||
from .corner_pool import CornerPool
|
||||
@ -99,5 +100,5 @@ __all__ = [
|
||||
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
|
||||
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
|
||||
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
|
||||
'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
|
||||
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance'
|
||||
]
|
||||
|
95
mmcv/ops/chamfer_distance.py
Normal file
95
mmcv/ops/chamfer_distance.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from ..utils import ext_loader
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext', ['chamfer_distance_forward', 'chamfer_distance_backward'])
|
||||
|
||||
|
||||
class ChamferDistanceFunction(Function):
|
||||
"""This is an implementation of the 2D Chamfer Distance.
|
||||
|
||||
It has been used in the paper `Oriented RepPoints for Aerial Object
|
||||
Detection (CVPR 2022) <https://arxiv.org/abs/2105.11111>_`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:
|
||||
"""
|
||||
Args:
|
||||
xyz1 (Tensor): Point set with shape (B, N, 2).
|
||||
xyz2 (Tensor): Point set with shape (B, N, 2).
|
||||
|
||||
Returns:
|
||||
Sequence[Tensor]:
|
||||
|
||||
- dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with
|
||||
shape (B, N).
|
||||
- dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with
|
||||
shape (B, N).
|
||||
- idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
|
||||
with shape (B, N), which be used in compute gradient.
|
||||
- idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
|
||||
with shape (B, N), which be used in compute gradient.
|
||||
"""
|
||||
batch_size, n, _ = xyz1.size()
|
||||
_, m, _ = xyz2.size()
|
||||
device = xyz1.device
|
||||
xyz1 = xyz1.contiguous()
|
||||
xyz2 = xyz2.contiguous()
|
||||
|
||||
dist1 = torch.zeros(batch_size, n).to(device)
|
||||
dist2 = torch.zeros(batch_size, m).to(device)
|
||||
idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device)
|
||||
idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device)
|
||||
|
||||
ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1,
|
||||
idx2)
|
||||
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
||||
return dist1, dist2, idx1, idx2
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
|
||||
grad_idx1: Tensor,
|
||||
grad_idx2: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
grad_dist1 (Tensor): Gradient of chamfer distance
|
||||
(xyz1 to xyz2) with shape (B, N).
|
||||
grad_dist2 (Tensor): Gradient of chamfer distance
|
||||
(xyz2 to xyz1) with shape (B, N).
|
||||
grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
|
||||
with shape (B, N), which be used in compute gradient.
|
||||
grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
|
||||
with shape (B, N), which be used in compute gradient.
|
||||
|
||||
Returns:
|
||||
Tuple[Tensor, Tensor]:
|
||||
|
||||
- grad_xyz1 (Tensor): Gradient of the point set with shape \
|
||||
(B, N, 2).
|
||||
- grad_xyz2 (Tensor):Gradient of the point set with shape \
|
||||
(B, N, 2).
|
||||
"""
|
||||
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
||||
device = grad_dist1.device
|
||||
grad_dist1 = grad_dist1.contiguous()
|
||||
grad_dist2 = grad_dist2.contiguous()
|
||||
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
|
||||
grad_xyz2 = torch.zeros(xyz2.size()).to(device)
|
||||
|
||||
ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2,
|
||||
grad_dist1, grad_dist2, idx1,
|
||||
idx2)
|
||||
return grad_xyz1, grad_xyz2
|
||||
|
||||
|
||||
chamfer_distance = ChamferDistanceFunction.apply
|
101
mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh
Normal file
101
mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh
Normal file
@ -0,0 +1,101 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu
|
||||
#ifndef CHAMFER_DISTANCE_CUDA_KERNEL_CUH
|
||||
#define CHAMFER_DISTANCE_CUDA_KERNEL_CUH
|
||||
|
||||
#ifdef MMCV_USE_PARROTS
|
||||
#include "parrots_cuda_helper.hpp"
|
||||
#else
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void chamfer_distance_forward_cuda_kernel(int b, int n,
|
||||
const scalar_t* xyz, int m,
|
||||
const scalar_t* xyz2,
|
||||
scalar_t* result,
|
||||
int* result_i) {
|
||||
__shared__ scalar_t buf[MAX_SHARED_SCALAR_T];
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
||||
for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) {
|
||||
int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2;
|
||||
for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) {
|
||||
buf[j] = xyz2[(i * m + k2) * 2 + j];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
|
||||
scalar_t x1 = xyz[(i * n + j) * 2 + 0];
|
||||
scalar_t y1 = xyz[(i * n + j) * 2 + 1];
|
||||
int best_i = 0;
|
||||
scalar_t best = 1e10;
|
||||
int end_ka = end_k & (~2);
|
||||
if (end_ka == THREADS_PER_BLOCK) {
|
||||
for (int k = 0; k < THREADS_PER_BLOCK; k += 4) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
scalar_t x2 = buf[(k + j) * 2] - x1;
|
||||
scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
|
||||
scalar_t d = x2 * x2 + y2 * y2;
|
||||
if (d < best) {
|
||||
best = d;
|
||||
best_i = k + k2 + j;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int k = 0; k < end_ka; k += 4) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
scalar_t x2 = buf[(k + j) * 2] - x1;
|
||||
scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
|
||||
scalar_t d = x2 * x2 + y2 * y2;
|
||||
if (d < best) {
|
||||
best = d;
|
||||
best_i = k + k2 + j;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int k = end_ka; k < end_k; k++) {
|
||||
scalar_t x2 = buf[k * 2 + 0] - x1;
|
||||
scalar_t y2 = buf[k * 2 + 1] - y1;
|
||||
scalar_t d = x2 * x2 + y2 * y2;
|
||||
if (k == 0 || d < best) {
|
||||
best = d;
|
||||
best_i = k + k2;
|
||||
}
|
||||
}
|
||||
if (k2 == 0 || result[(i * n + j)] > best) {
|
||||
result[(i * n + j)] = best;
|
||||
result_i[(i * n + j)] = best_i;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void chamfer_distance_backward_cuda_kernel(
|
||||
int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2,
|
||||
const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1,
|
||||
scalar_t* grad_xyz2) {
|
||||
for (int i = blockIdx.x; i < b; i += gridDim.x) {
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
|
||||
scalar_t x1 = xyz1[(i * n + j) * 2 + 0];
|
||||
scalar_t y1 = xyz1[(i * n + j) * 2 + 1];
|
||||
int j2 = idx1[i * n + j];
|
||||
scalar_t x2 = xyz2[(i * m + j2) * 2 + 0];
|
||||
scalar_t y2 = xyz2[(i * m + j2) * 2 + 1];
|
||||
scalar_t g = grad_dist1[i * n + j] * 2;
|
||||
atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2));
|
||||
atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2));
|
||||
atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2)));
|
||||
atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2)));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // CHAMFER_DISTANCE_CUDA_KERNEL_CUH
|
35
mmcv/ops/csrc/pytorch/chamfer_distance.cpp
Normal file
35
mmcv/ops/csrc/pytorch/chamfer_distance.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp
|
||||
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
|
||||
void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
|
||||
const Tensor dist1, const Tensor dist2,
|
||||
const Tensor idx1, const Tensor idx2) {
|
||||
DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2,
|
||||
idx1, idx2);
|
||||
}
|
||||
|
||||
void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
|
||||
Tensor gradxyz1, Tensor gradxyz2,
|
||||
Tensor graddist1, Tensor graddist2,
|
||||
Tensor idx1, Tensor idx2) {
|
||||
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, gradxyz1,
|
||||
gradxyz2, graddist1, graddist2, idx1, idx2);
|
||||
}
|
||||
|
||||
void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
|
||||
const Tensor dist1, const Tensor dist2,
|
||||
const Tensor idx1, const Tensor idx2) {
|
||||
chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2);
|
||||
}
|
||||
|
||||
void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
|
||||
Tensor gradxyz1, Tensor gradxyz2,
|
||||
Tensor graddist1, Tensor graddist2, Tensor idx1,
|
||||
Tensor idx2) {
|
||||
chamfer_distance_backward_impl(xyz1, xyz2, gradxyz1, gradxyz2, graddist1,
|
||||
graddist2, idx1, idx2);
|
||||
}
|
63
mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu
Normal file
63
mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
// Modified from
|
||||
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp
|
||||
#include "chamfer_distance_cuda_kernel.cuh"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
void ChamferDistanceForwardCUDAKernelLauncher(
|
||||
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
|
||||
const Tensor dist2, const Tensor idx1, const Tensor idx2) {
|
||||
int batch_size = xyz1.size(0);
|
||||
int n = xyz1.size(1);
|
||||
int m = xyz2.size(1);
|
||||
|
||||
at::cuda::CUDAGuard device_guard(xyz1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
xyz1.scalar_type(), "chamfer_distance_forward_cuda_kernel", [&] {
|
||||
chamfer_distance_forward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(batch_size * n), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
batch_size, n, xyz1.data_ptr<scalar_t>(), m,
|
||||
xyz2.data_ptr<scalar_t>(), dist1.data_ptr<scalar_t>(),
|
||||
idx1.data_ptr<int>());
|
||||
});
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
xyz1.scalar_type(), "chamfer_distance_forward_cuda_kernel", [&] {
|
||||
chamfer_distance_forward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(batch_size * m), THREADS_PER_BLOCK, 0, stream>>>(
|
||||
batch_size, m, xyz2.data_ptr<scalar_t>(), n,
|
||||
xyz1.data_ptr<scalar_t>(), dist2.data_ptr<scalar_t>(),
|
||||
idx2.data_ptr<int>());
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
void ChamferDistanceBackwardCUDAKernelLauncher(
|
||||
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
|
||||
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2) {
|
||||
int batch_size = xyz1.size(0);
|
||||
int n = xyz1.size(1);
|
||||
int m = xyz2.size(1);
|
||||
|
||||
at::cuda::CUDAGuard device_guard(xyz1.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
xyz1.scalar_type(), "chamfer_distance_backward_cuda_kernel", [&] {
|
||||
chamfer_distance_backward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(batch_size * n), THREADS_PER_BLOCK / 2, 0, stream>>>(
|
||||
batch_size, m, xyz1.data_ptr<scalar_t>(), n,
|
||||
xyz2.data_ptr<scalar_t>(), grad_dist1.data_ptr<scalar_t>(),
|
||||
idx1.data_ptr<int>(), grad_xyz1.data_ptr<scalar_t>(),
|
||||
grad_xyz2.data_ptr<scalar_t>());
|
||||
});
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
xyz1.scalar_type(), "chamfer_distance_backward_cuda_kernel", [&] {
|
||||
chamfer_distance_backward_cuda_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(batch_size * m), THREADS_PER_BLOCK / 2, 0, stream>>>(
|
||||
batch_size, n, xyz2.data_ptr<scalar_t>(), m,
|
||||
xyz1.data_ptr<scalar_t>(), grad_dist2.data_ptr<scalar_t>(),
|
||||
idx2.data_ptr<int>(), grad_xyz2.data_ptr<scalar_t>(),
|
||||
grad_xyz1.data_ptr<scalar_t>());
|
||||
});
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
@ -1700,3 +1700,40 @@ Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
|
||||
|
||||
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
|
||||
diff_iou_rotated_sort_vertices_forward_cuda);
|
||||
|
||||
void ChamferDistanceForwardCUDAKernelLauncher(
|
||||
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
|
||||
const Tensor dist2, const Tensor idx1, const Tensor idx2);
|
||||
|
||||
void ChamferDistanceBackwardCUDAKernelLauncher(
|
||||
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
|
||||
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2);
|
||||
|
||||
void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
|
||||
const Tensor dist1, const Tensor dist2,
|
||||
const Tensor idx1, const Tensor idx2) {
|
||||
ChamferDistanceForwardCUDAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1,
|
||||
idx2);
|
||||
};
|
||||
|
||||
void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2,
|
||||
Tensor gradxyz1, Tensor gradxyz2,
|
||||
Tensor graddist1, Tensor graddist2,
|
||||
Tensor idx1, Tensor idx2) {
|
||||
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, gradxyz1, gradxyz2,
|
||||
graddist1, graddist2, idx1, idx2);
|
||||
};
|
||||
|
||||
void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
|
||||
const Tensor dist1, const Tensor dist2,
|
||||
const Tensor idx1, const Tensor idx2);
|
||||
|
||||
void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
|
||||
Tensor gradxyz1, Tensor gradxyz2,
|
||||
Tensor graddist1, Tensor graddist2,
|
||||
Tensor idx1, Tensor idx2);
|
||||
|
||||
REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
|
||||
chamfer_distance_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
|
||||
chamfer_distance_backward_cuda);
|
||||
|
@ -401,6 +401,15 @@ at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices,
|
||||
at::Tensor mask,
|
||||
at::Tensor num_valid);
|
||||
|
||||
void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
|
||||
const Tensor dist1, const Tensor dist2,
|
||||
const Tensor idx1, const Tensor idx);
|
||||
|
||||
void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
|
||||
Tensor gradxyz1, Tensor gradxyz2,
|
||||
Tensor graddist1, Tensor graddist2, Tensor idx1,
|
||||
Tensor idx2);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
|
||||
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
|
||||
@ -811,4 +820,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&diff_iou_rotated_sort_vertices_forward,
|
||||
"diff_iou_rotated_sort_vertices_forward", py::arg("vertices"),
|
||||
py::arg("mask"), py::arg("num_valid"));
|
||||
m.def("chamfer_distance_forward", &chamfer_distance_forward,
|
||||
"chamfer_distance_forward", py::arg("xyz1"), py::arg("xyz2"),
|
||||
py::arg("dist1"), py::arg("dist2"), py::arg("idx1"), py::arg("idx2"));
|
||||
m.def("chamfer_distance_backward", &chamfer_distance_backward,
|
||||
"chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"),
|
||||
py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"),
|
||||
py::arg("graddist2"), py::arg("idx1"), py::arg("idx2"));
|
||||
}
|
||||
|
57
tests/test_ops/test_chamfer_distance.py
Normal file
57
tests/test_ops/test_chamfer_distance.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcv.ops import chamfer_distance
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_chamfer_distance():
|
||||
pointset1 = torch.tensor(
|
||||
[[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
|
||||
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
|
||||
[[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]],
|
||||
device='cuda',
|
||||
requires_grad=True)
|
||||
|
||||
pointset2 = torch.tensor(
|
||||
[[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
|
||||
[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
|
||||
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]],
|
||||
device='cuda',
|
||||
requires_grad=True)
|
||||
|
||||
expected_dist1 = torch.tensor(
|
||||
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
|
||||
[0.5200, 0.6500, 0.4900, 0.3600]],
|
||||
device='cuda')
|
||||
expected_dist2 = torch.tensor(
|
||||
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
|
||||
[0.7200, 0.8500, 0.4900, 0.3600]],
|
||||
device='cuda')
|
||||
|
||||
expected_pointset1_grad = torch.tensor(
|
||||
[[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
|
||||
[0.6000, 0.0000]],
|
||||
[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
|
||||
[-0.6000, 0.0000]],
|
||||
[[1.2000, -0.8000], [-1.4000, -0.8000], [-1.4000, 0.0000],
|
||||
[1.2000, 0.0000]]],
|
||||
device='cuda')
|
||||
|
||||
expected_pointset2_grad = torch.tensor(
|
||||
[[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
|
||||
[-0.6000, 0.0000]],
|
||||
[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
|
||||
[0.6000, 0.0000]],
|
||||
[[0.0000, 0.0000], [0.0000, 0.0000], [2.8000, 0.8000],
|
||||
[-2.4000, 0.8000]]],
|
||||
device='cuda')
|
||||
|
||||
dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2)
|
||||
dist1.backward(torch.ones_like(dist1))
|
||||
assert torch.allclose(dist1, expected_dist1, 1e-2)
|
||||
assert torch.allclose(dist2, expected_dist2, 1e-2)
|
||||
assert torch.allclose(pointset1.grad.data, expected_pointset1_grad, 1e-2)
|
||||
assert torch.allclose(pointset2.grad.data, expected_pointset2_grad, 1e-2)
|
Loading…
x
Reference in New Issue
Block a user