[Feature] Add ChamferDistance op in gpu (#1933)

* update

* fix lint

* fix lint

* Update mmcv/ops/chamfer_distance.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update chamfer_distance.py

* Update test_chamfer_distance.py

* fix

* Update chamfer_distance_cuda_kernel.cuh

* Update chamfer_distance_cuda_kernel.cuh

* fix

* Update mmcv/ops/chamfer_distance.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update chamfer_distance.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
Yue Zhou 2022-06-08 18:12:57 +08:00 committed by GitHub
parent 5601427bbc
commit 834d5978df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 408 additions and 1 deletions

View File

@ -11,6 +11,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- ContextBlock
- ConvexIoU
- CornerPool
- ChamferDistance
- Deformable Convolution v1/v2
- Deformable RoIPool
- DiffIoURotated

View File

@ -11,6 +11,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- ContextBlock
- ConvexIoU
- CornerPool
- ChamferDistance
- Deformable Convolution v1/v2
- Deformable RoIPool
- DiffIoURotated

View File

@ -7,6 +7,7 @@ from .border_align import BorderAlign, border_align
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .chamfer_distance import chamfer_distance
from .contour_expand import contour_expand
from .convex_iou import convex_giou, convex_iou
from .corner_pool import CornerPool
@ -99,5 +100,5 @@ __all__ = [
'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part',
'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons',
'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou',
'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance'
]

View File

@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple
import torch
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['chamfer_distance_forward', 'chamfer_distance_backward'])
class ChamferDistanceFunction(Function):
"""This is an implementation of the 2D Chamfer Distance.
It has been used in the paper `Oriented RepPoints for Aerial Object
Detection (CVPR 2022) <https://arxiv.org/abs/2105.11111>_`.
"""
@staticmethod
def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]:
"""
Args:
xyz1 (Tensor): Point set with shape (B, N, 2).
xyz2 (Tensor): Point set with shape (B, N, 2).
Returns:
Sequence[Tensor]:
- dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with
shape (B, N).
- dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with
shape (B, N).
- idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
with shape (B, N), which be used in compute gradient.
- idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
with shape (B, N), which be used in compute gradient.
"""
batch_size, n, _ = xyz1.size()
_, m, _ = xyz2.size()
device = xyz1.device
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
dist1 = torch.zeros(batch_size, n).to(device)
dist2 = torch.zeros(batch_size, m).to(device)
idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device)
idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device)
ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1,
idx2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2, idx1, idx2
@staticmethod
@once_differentiable
def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor,
grad_idx1: Tensor,
grad_idx2: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
grad_dist1 (Tensor): Gradient of chamfer distance
(xyz1 to xyz2) with shape (B, N).
grad_dist2 (Tensor): Gradient of chamfer distance
(xyz2 to xyz1) with shape (B, N).
grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2)
with shape (B, N), which be used in compute gradient.
grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2)
with shape (B, N), which be used in compute gradient.
Returns:
Tuple[Tensor, Tensor]:
- grad_xyz1 (Tensor): Gradient of the point set with shape \
(B, N, 2).
- grad_xyz2 (Tensor):Gradient of the point set with shape \
(B, N, 2).
"""
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
device = grad_dist1.device
grad_dist1 = grad_dist1.contiguous()
grad_dist2 = grad_dist2.contiguous()
grad_xyz1 = torch.zeros(xyz1.size()).to(device)
grad_xyz2 = torch.zeros(xyz2.size()).to(device)
ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2,
grad_dist1, grad_dist2, idx1,
idx2)
return grad_xyz1, grad_xyz2
chamfer_distance = ChamferDistanceFunction.apply

View File

@ -0,0 +1,101 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu
#ifndef CHAMFER_DISTANCE_CUDA_KERNEL_CUH
#define CHAMFER_DISTANCE_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144
template <typename scalar_t>
__global__ void chamfer_distance_forward_cuda_kernel(int b, int n,
const scalar_t* xyz, int m,
const scalar_t* xyz2,
scalar_t* result,
int* result_i) {
__shared__ scalar_t buf[MAX_SHARED_SCALAR_T];
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) {
int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2;
for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) {
buf[j] = xyz2[(i * m + k2) * 2 + j];
}
__syncthreads();
for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
scalar_t x1 = xyz[(i * n + j) * 2 + 0];
scalar_t y1 = xyz[(i * n + j) * 2 + 1];
int best_i = 0;
scalar_t best = 1e10;
int end_ka = end_k & (~2);
if (end_ka == THREADS_PER_BLOCK) {
for (int k = 0; k < THREADS_PER_BLOCK; k += 4) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
scalar_t x2 = buf[(k + j) * 2] - x1;
scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
scalar_t d = x2 * x2 + y2 * y2;
if (d < best) {
best = d;
best_i = k + k2 + j;
}
}
}
} else {
for (int k = 0; k < end_ka; k += 4) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
scalar_t x2 = buf[(k + j) * 2] - x1;
scalar_t y2 = buf[(k + j) * 2 + 1] - y1;
scalar_t d = x2 * x2 + y2 * y2;
if (d < best) {
best = d;
best_i = k + k2 + j;
}
}
}
}
for (int k = end_ka; k < end_k; k++) {
scalar_t x2 = buf[k * 2 + 0] - x1;
scalar_t y2 = buf[k * 2 + 1] - y1;
scalar_t d = x2 * x2 + y2 * y2;
if (k == 0 || d < best) {
best = d;
best_i = k + k2;
}
}
if (k2 == 0 || result[(i * n + j)] > best) {
result[(i * n + j)] = best;
result_i[(i * n + j)] = best_i;
}
}
__syncthreads();
}
}
}
template <typename scalar_t>
__global__ void chamfer_distance_backward_cuda_kernel(
int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2,
const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1,
scalar_t* grad_xyz2) {
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) {
scalar_t x1 = xyz1[(i * n + j) * 2 + 0];
scalar_t y1 = xyz1[(i * n + j) * 2 + 1];
int j2 = idx1[i * n + j];
scalar_t x2 = xyz2[(i * m + j2) * 2 + 0];
scalar_t y2 = xyz2[(i * m + j2) * 2 + 1];
scalar_t g = grad_dist1[i * n + j] * 2;
atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2));
atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2));
atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2)));
}
}
}
#endif // CHAMFER_DISTANCE_CUDA_KERNEL_CUH

View File

@ -0,0 +1,35 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2,
idx1, idx2);
}
void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2) {
DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, gradxyz1,
gradxyz2, graddist1, graddist2, idx1, idx2);
}
void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2);
}
void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2, Tensor idx1,
Tensor idx2) {
chamfer_distance_backward_impl(xyz1, xyz2, gradxyz1, gradxyz2, graddist1,
graddist2, idx1, idx2);
}

View File

@ -0,0 +1,63 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp
#include "chamfer_distance_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void ChamferDistanceForwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
const Tensor dist2, const Tensor idx1, const Tensor idx2) {
int batch_size = xyz1.size(0);
int n = xyz1.size(1);
int m = xyz2.size(1);
at::cuda::CUDAGuard device_guard(xyz1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_distance_forward_cuda_kernel", [&] {
chamfer_distance_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * n), THREADS_PER_BLOCK, 0, stream>>>(
batch_size, n, xyz1.data_ptr<scalar_t>(), m,
xyz2.data_ptr<scalar_t>(), dist1.data_ptr<scalar_t>(),
idx1.data_ptr<int>());
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_distance_forward_cuda_kernel", [&] {
chamfer_distance_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * m), THREADS_PER_BLOCK, 0, stream>>>(
batch_size, m, xyz2.data_ptr<scalar_t>(), n,
xyz1.data_ptr<scalar_t>(), dist2.data_ptr<scalar_t>(),
idx2.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2) {
int batch_size = xyz1.size(0);
int n = xyz1.size(1);
int m = xyz2.size(1);
at::cuda::CUDAGuard device_guard(xyz1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_distance_backward_cuda_kernel", [&] {
chamfer_distance_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * n), THREADS_PER_BLOCK / 2, 0, stream>>>(
batch_size, m, xyz1.data_ptr<scalar_t>(), n,
xyz2.data_ptr<scalar_t>(), grad_dist1.data_ptr<scalar_t>(),
idx1.data_ptr<int>(), grad_xyz1.data_ptr<scalar_t>(),
grad_xyz2.data_ptr<scalar_t>());
});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
xyz1.scalar_type(), "chamfer_distance_backward_cuda_kernel", [&] {
chamfer_distance_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(batch_size * m), THREADS_PER_BLOCK / 2, 0, stream>>>(
batch_size, n, xyz2.data_ptr<scalar_t>(), m,
xyz1.data_ptr<scalar_t>(), grad_dist2.data_ptr<scalar_t>(),
idx2.data_ptr<int>(), grad_xyz2.data_ptr<scalar_t>(),
grad_xyz1.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -1700,3 +1700,40 @@ Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask,
REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA,
diff_iou_rotated_sort_vertices_forward_cuda);
void ChamferDistanceForwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, const Tensor dist1,
const Tensor dist2, const Tensor idx1, const Tensor idx2);
void ChamferDistanceBackwardCUDAKernelLauncher(
const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2,
Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2);
void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2) {
ChamferDistanceForwardCUDAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1,
idx2);
};
void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2) {
ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, gradxyz1, gradxyz2,
graddist1, graddist2, idx1, idx2);
};
void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx2);
void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2,
Tensor idx1, Tensor idx2);
REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA,
chamfer_distance_forward_cuda);
REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA,
chamfer_distance_backward_cuda);

View File

@ -401,6 +401,15 @@ at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices,
at::Tensor mask,
at::Tensor num_valid);
void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2,
const Tensor dist1, const Tensor dist2,
const Tensor idx1, const Tensor idx);
void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor gradxyz1, Tensor gradxyz2,
Tensor graddist1, Tensor graddist2, Tensor idx1,
Tensor idx2);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
@ -811,4 +820,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&diff_iou_rotated_sort_vertices_forward,
"diff_iou_rotated_sort_vertices_forward", py::arg("vertices"),
py::arg("mask"), py::arg("num_valid"));
m.def("chamfer_distance_forward", &chamfer_distance_forward,
"chamfer_distance_forward", py::arg("xyz1"), py::arg("xyz2"),
py::arg("dist1"), py::arg("dist2"), py::arg("idx1"), py::arg("idx2"));
m.def("chamfer_distance_backward", &chamfer_distance_backward,
"chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"),
py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"),
py::arg("graddist2"), py::arg("idx1"), py::arg("idx2"));
}

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmcv.ops import chamfer_distance
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_chamfer_distance():
pointset1 = torch.tensor(
[[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
[[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]],
device='cuda',
requires_grad=True)
pointset2 = torch.tensor(
[[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]],
[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]],
[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]],
device='cuda',
requires_grad=True)
expected_dist1 = torch.tensor(
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
[0.5200, 0.6500, 0.4900, 0.3600]],
device='cuda')
expected_dist2 = torch.tensor(
[[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900],
[0.7200, 0.8500, 0.4900, 0.3600]],
device='cuda')
expected_pointset1_grad = torch.tensor(
[[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
[0.6000, 0.0000]],
[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
[-0.6000, 0.0000]],
[[1.2000, -0.8000], [-1.4000, -0.8000], [-1.4000, 0.0000],
[1.2000, 0.0000]]],
device='cuda')
expected_pointset2_grad = torch.tensor(
[[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000],
[-0.6000, 0.0000]],
[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000],
[0.6000, 0.0000]],
[[0.0000, 0.0000], [0.0000, 0.0000], [2.8000, 0.8000],
[-2.4000, 0.8000]]],
device='cuda')
dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2)
dist1.backward(torch.ones_like(dist1))
assert torch.allclose(dist1, expected_dist1, 1e-2)
assert torch.allclose(dist2, expected_dist2, 1e-2)
assert torch.allclose(pointset1.grad.data, expected_pointset1_grad, 1e-2)
assert torch.allclose(pointset2.grad.data, expected_pointset2_grad, 1e-2)