diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index e07df3e43..a29ce85c5 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -11,6 +11,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - ContextBlock - ConvexIoU - CornerPool +- ChamferDistance - Deformable Convolution v1/v2 - Deformable RoIPool - DiffIoURotated diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 199de199e..1b8930915 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -11,6 +11,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - ContextBlock - ConvexIoU - CornerPool +- ChamferDistance - Deformable Convolution v1/v2 - Deformable RoIPool - DiffIoURotated diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index 487c6df9b..d0e7a5bd5 100755 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -7,6 +7,7 @@ from .border_align import BorderAlign, border_align from .box_iou_rotated import box_iou_rotated from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive from .cc_attention import CrissCrossAttention +from .chamfer_distance import chamfer_distance from .contour_expand import contour_expand from .convex_iou import convex_giou, convex_iou from .corner_pool import CornerPool @@ -99,5 +100,5 @@ __all__ = [ 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', 'convex_iou', 'convex_giou', - 'diff_iou_rotated_2d', 'diff_iou_rotated_3d' + 'diff_iou_rotated_2d', 'diff_iou_rotated_3d', 'chamfer_distance' ] diff --git a/mmcv/ops/chamfer_distance.py b/mmcv/ops/chamfer_distance.py new file mode 100644 index 000000000..d68eafb47 --- /dev/null +++ b/mmcv/ops/chamfer_distance.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence, Tuple + +import torch +from torch import Tensor +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from ..utils import ext_loader + +ext_module = ext_loader.load_ext( + '_ext', ['chamfer_distance_forward', 'chamfer_distance_backward']) + + +class ChamferDistanceFunction(Function): + """This is an implementation of the 2D Chamfer Distance. + + It has been used in the paper `Oriented RepPoints for Aerial Object + Detection (CVPR 2022) _`. + """ + + @staticmethod + def forward(ctx, xyz1: Tensor, xyz2: Tensor) -> Sequence[Tensor]: + """ + Args: + xyz1 (Tensor): Point set with shape (B, N, 2). + xyz2 (Tensor): Point set with shape (B, N, 2). + + Returns: + Sequence[Tensor]: + + - dist1 (Tensor): Chamfer distance (xyz1 to xyz2) with + shape (B, N). + - dist2 (Tensor): Chamfer distance (xyz2 to xyz1) with + shape (B, N). + - idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) + with shape (B, N), which be used in compute gradient. + - idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) + with shape (B, N), which be used in compute gradient. + """ + batch_size, n, _ = xyz1.size() + _, m, _ = xyz2.size() + device = xyz1.device + xyz1 = xyz1.contiguous() + xyz2 = xyz2.contiguous() + + dist1 = torch.zeros(batch_size, n).to(device) + dist2 = torch.zeros(batch_size, m).to(device) + idx1 = torch.zeros(batch_size, n).type(torch.IntTensor).to(device) + idx2 = torch.zeros(batch_size, m).type(torch.IntTensor).to(device) + + ext_module.chamfer_distance_forward(xyz1, xyz2, dist1, dist2, idx1, + idx2) + ctx.save_for_backward(xyz1, xyz2, idx1, idx2) + return dist1, dist2, idx1, idx2 + + @staticmethod + @once_differentiable + def backward(ctx, grad_dist1: Tensor, grad_dist2: Tensor, + grad_idx1: Tensor, + grad_idx2: Tensor) -> Tuple[Tensor, Tensor]: + """ + + Args: + grad_dist1 (Tensor): Gradient of chamfer distance + (xyz1 to xyz2) with shape (B, N). + grad_dist2 (Tensor): Gradient of chamfer distance + (xyz2 to xyz1) with shape (B, N). + grad_idx1 (Tensor): Index of chamfer distance (xyz1 to xyz2) + with shape (B, N), which be used in compute gradient. + grad_idx2 (Tensor): Index of chamfer distance (xyz2 to xyz2) + with shape (B, N), which be used in compute gradient. + + Returns: + Tuple[Tensor, Tensor]: + + - grad_xyz1 (Tensor): Gradient of the point set with shape \ + (B, N, 2). + - grad_xyz2 (Tensor):Gradient of the point set with shape \ + (B, N, 2). + """ + xyz1, xyz2, idx1, idx2 = ctx.saved_tensors + device = grad_dist1.device + grad_dist1 = grad_dist1.contiguous() + grad_dist2 = grad_dist2.contiguous() + grad_xyz1 = torch.zeros(xyz1.size()).to(device) + grad_xyz2 = torch.zeros(xyz2.size()).to(device) + + ext_module.chamfer_distance_backward(xyz1, xyz2, grad_xyz1, grad_xyz2, + grad_dist1, grad_dist2, idx1, + idx2) + return grad_xyz1, grad_xyz2 + + +chamfer_distance = ChamferDistanceFunction.apply diff --git a/mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh new file mode 100644 index 000000000..89feea4a5 --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/chamfer_distance_cuda_kernel.cuh @@ -0,0 +1,101 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cu +#ifndef CHAMFER_DISTANCE_CUDA_KERNEL_CUH +#define CHAMFER_DISTANCE_CUDA_KERNEL_CUH + +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif + +#define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 + +template +__global__ void chamfer_distance_forward_cuda_kernel(int b, int n, + const scalar_t* xyz, int m, + const scalar_t* xyz2, + scalar_t* result, + int* result_i) { + __shared__ scalar_t buf[MAX_SHARED_SCALAR_T]; + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int k2 = 0; k2 < m; k2 += THREADS_PER_BLOCK) { + int end_k = min(m, k2 + THREADS_PER_BLOCK) - k2; + for (int j = threadIdx.x; j < end_k * 2; j += blockDim.x) { + buf[j] = xyz2[(i * m + k2) * 2 + j]; + } + __syncthreads(); + for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) { + scalar_t x1 = xyz[(i * n + j) * 2 + 0]; + scalar_t y1 = xyz[(i * n + j) * 2 + 1]; + int best_i = 0; + scalar_t best = 1e10; + int end_ka = end_k & (~2); + if (end_ka == THREADS_PER_BLOCK) { + for (int k = 0; k < THREADS_PER_BLOCK; k += 4) { +#pragma unroll + for (int j = 0; j < 4; ++j) { + scalar_t x2 = buf[(k + j) * 2] - x1; + scalar_t y2 = buf[(k + j) * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (d < best) { + best = d; + best_i = k + k2 + j; + } + } + } + } else { + for (int k = 0; k < end_ka; k += 4) { +#pragma unroll + for (int j = 0; j < 4; ++j) { + scalar_t x2 = buf[(k + j) * 2] - x1; + scalar_t y2 = buf[(k + j) * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (d < best) { + best = d; + best_i = k + k2 + j; + } + } + } + } + for (int k = end_ka; k < end_k; k++) { + scalar_t x2 = buf[k * 2 + 0] - x1; + scalar_t y2 = buf[k * 2 + 1] - y1; + scalar_t d = x2 * x2 + y2 * y2; + if (k == 0 || d < best) { + best = d; + best_i = k + k2; + } + } + if (k2 == 0 || result[(i * n + j)] > best) { + result[(i * n + j)] = best; + result_i[(i * n + j)] = best_i; + } + } + __syncthreads(); + } + } +} + +template +__global__ void chamfer_distance_backward_cuda_kernel( + int b, int n, const scalar_t* xyz1, int m, const scalar_t* xyz2, + const scalar_t* grad_dist1, const int* idx1, scalar_t* grad_xyz1, + scalar_t* grad_xyz2) { + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int j = threadIdx.x; j < n; j += blockDim.x * gridDim.y) { + scalar_t x1 = xyz1[(i * n + j) * 2 + 0]; + scalar_t y1 = xyz1[(i * n + j) * 2 + 1]; + int j2 = idx1[i * n + j]; + scalar_t x2 = xyz2[(i * m + j2) * 2 + 0]; + scalar_t y2 = xyz2[(i * m + j2) * 2 + 1]; + scalar_t g = grad_dist1[i * n + j] * 2; + atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 0]), g * (x1 - x2)); + atomicAdd(&(grad_xyz1[(i * n + j) * 2 + 1]), g * (y1 - y2)); + atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 0]), -(g * (x1 - x2))); + atomicAdd(&(grad_xyz2[(i * m + j2) * 2 + 1]), -(g * (y1 - y2))); + } + } +} +#endif // CHAMFER_DISTANCE_CUDA_KERNEL_CUH diff --git a/mmcv/ops/csrc/pytorch/chamfer_distance.cpp b/mmcv/ops/csrc/pytorch/chamfer_distance.cpp new file mode 100644 index 000000000..6ea1ba675 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/chamfer_distance.cpp @@ -0,0 +1,35 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp + +#include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" + +void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + DISPATCH_DEVICE_IMPL(chamfer_distance_forward_impl, xyz1, xyz2, dist1, dist2, + idx1, idx2); +} + +void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, + Tensor gradxyz1, Tensor gradxyz2, + Tensor graddist1, Tensor graddist2, + Tensor idx1, Tensor idx2) { + DISPATCH_DEVICE_IMPL(chamfer_distance_backward_impl, xyz1, xyz2, gradxyz1, + gradxyz2, graddist1, graddist2, idx1, idx2); +} + +void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + chamfer_distance_forward_impl(xyz1, xyz2, dist1, dist2, idx1, idx2); +} + +void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, + Tensor gradxyz1, Tensor gradxyz2, + Tensor graddist1, Tensor graddist2, Tensor idx1, + Tensor idx2) { + chamfer_distance_backward_impl(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, + graddist2, idx1, idx2); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu new file mode 100644 index 000000000..980482eb5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/chamfer_distance_cuda.cu @@ -0,0 +1,63 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/chrdiller/pyTorchChamferDistance/blob/master/chamfer_distance/chamfer_distance.cpp +#include "chamfer_distance_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" + +void ChamferDistanceForwardCUDAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, const Tensor dist1, + const Tensor dist2, const Tensor idx1, const Tensor idx2) { + int batch_size = xyz1.size(0); + int n = xyz1.size(1); + int m = xyz2.size(1); + + at::cuda::CUDAGuard device_guard(xyz1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz1.scalar_type(), "chamfer_distance_forward_cuda_kernel", [&] { + chamfer_distance_forward_cuda_kernel + <<>>( + batch_size, n, xyz1.data_ptr(), m, + xyz2.data_ptr(), dist1.data_ptr(), + idx1.data_ptr()); + }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz1.scalar_type(), "chamfer_distance_forward_cuda_kernel", [&] { + chamfer_distance_forward_cuda_kernel + <<>>( + batch_size, m, xyz2.data_ptr(), n, + xyz1.data_ptr(), dist2.data_ptr(), + idx2.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} + +void ChamferDistanceBackwardCUDAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2, + Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2) { + int batch_size = xyz1.size(0); + int n = xyz1.size(1); + int m = xyz2.size(1); + + at::cuda::CUDAGuard device_guard(xyz1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz1.scalar_type(), "chamfer_distance_backward_cuda_kernel", [&] { + chamfer_distance_backward_cuda_kernel + <<>>( + batch_size, m, xyz1.data_ptr(), n, + xyz2.data_ptr(), grad_dist1.data_ptr(), + idx1.data_ptr(), grad_xyz1.data_ptr(), + grad_xyz2.data_ptr()); + }); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz1.scalar_type(), "chamfer_distance_backward_cuda_kernel", [&] { + chamfer_distance_backward_cuda_kernel + <<>>( + batch_size, n, xyz2.data_ptr(), m, + xyz1.data_ptr(), grad_dist2.data_ptr(), + idx2.data_ptr(), grad_xyz2.data_ptr(), + grad_xyz1.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index e290ed891..b64df5ad3 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -1700,3 +1700,40 @@ Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, CUDA, diff_iou_rotated_sort_vertices_forward_cuda); + +void ChamferDistanceForwardCUDAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, const Tensor dist1, + const Tensor dist2, const Tensor idx1, const Tensor idx2); + +void ChamferDistanceBackwardCUDAKernelLauncher( + const Tensor xyz1, const Tensor xyz2, Tensor grad_xyz1, Tensor grad_xyz2, + Tensor grad_dist1, Tensor grad_dist2, Tensor idx1, Tensor idx2); + +void chamfer_distance_forward_cuda(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2) { + ChamferDistanceForwardCUDAKernelLauncher(xyz1, xyz2, dist1, dist2, idx1, + idx2); +}; + +void chamfer_distance_backward_cuda(const Tensor xyz1, const Tensor xyz2, + Tensor gradxyz1, Tensor gradxyz2, + Tensor graddist1, Tensor graddist2, + Tensor idx1, Tensor idx2) { + ChamferDistanceBackwardCUDAKernelLauncher(xyz1, xyz2, gradxyz1, gradxyz2, + graddist1, graddist2, idx1, idx2); +}; + +void chamfer_distance_forward_impl(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx2); + +void chamfer_distance_backward_impl(const Tensor xyz1, const Tensor xyz2, + Tensor gradxyz1, Tensor gradxyz2, + Tensor graddist1, Tensor graddist2, + Tensor idx1, Tensor idx2); + +REGISTER_DEVICE_IMPL(chamfer_distance_forward_impl, CUDA, + chamfer_distance_forward_cuda); +REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, CUDA, + chamfer_distance_backward_cuda); diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 7909aac19..69adef3b6 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -401,6 +401,15 @@ at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid); +void chamfer_distance_forward(const Tensor xyz1, const Tensor xyz2, + const Tensor dist1, const Tensor dist2, + const Tensor idx1, const Tensor idx); + +void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, + Tensor gradxyz1, Tensor gradxyz2, + Tensor graddist1, Tensor graddist2, Tensor idx1, + Tensor idx2); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -811,4 +820,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &diff_iou_rotated_sort_vertices_forward, "diff_iou_rotated_sort_vertices_forward", py::arg("vertices"), py::arg("mask"), py::arg("num_valid")); + m.def("chamfer_distance_forward", &chamfer_distance_forward, + "chamfer_distance_forward", py::arg("xyz1"), py::arg("xyz2"), + py::arg("dist1"), py::arg("dist2"), py::arg("idx1"), py::arg("idx2")); + m.def("chamfer_distance_backward", &chamfer_distance_backward, + "chamfer_distance_backward", py::arg("xyz1"), py::arg("xyz2"), + py::arg("gradxyz1"), py::arg("gradxyz2"), py::arg("graddist1"), + py::arg("graddist2"), py::arg("idx1"), py::arg("idx2")); } diff --git a/tests/test_ops/test_chamfer_distance.py b/tests/test_ops/test_chamfer_distance.py new file mode 100644 index 000000000..522dcdddc --- /dev/null +++ b/tests/test_ops/test_chamfer_distance.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcv.ops import chamfer_distance + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_chamfer_distance(): + pointset1 = torch.tensor( + [[[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]], + [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]], + [[1.6, 9.99], [2.3, 9.99], [2.3, 10.39], [1.6, 10.39]]], + device='cuda', + requires_grad=True) + + pointset2 = torch.tensor( + [[[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]], + [[1.3, 9.39], [2.3, 9.39], [2.3, 10.39], [1.3, 10.39]], + [[1.0, 9.39], [3.0, 9.39], [3.0, 10.39], [1.0, 10.39]]], + device='cuda', + requires_grad=True) + + expected_dist1 = torch.tensor( + [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900], + [0.5200, 0.6500, 0.4900, 0.3600]], + device='cuda') + expected_dist2 = torch.tensor( + [[0.0900, 0.4900, 0.4900, 0.0900], [0.0900, 0.4900, 0.4900, 0.0900], + [0.7200, 0.8500, 0.4900, 0.3600]], + device='cuda') + + expected_pointset1_grad = torch.tensor( + [[[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000], + [0.6000, 0.0000]], + [[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000], + [-0.6000, 0.0000]], + [[1.2000, -0.8000], [-1.4000, -0.8000], [-1.4000, 0.0000], + [1.2000, 0.0000]]], + device='cuda') + + expected_pointset2_grad = torch.tensor( + [[[-0.6000, 0.0000], [1.4000, 0.0000], [1.4000, 0.0000], + [-0.6000, 0.0000]], + [[0.6000, 0.0000], [-1.4000, 0.0000], [-1.4000, 0.0000], + [0.6000, 0.0000]], + [[0.0000, 0.0000], [0.0000, 0.0000], [2.8000, 0.8000], + [-2.4000, 0.8000]]], + device='cuda') + + dist1, dist2, idx1, idx2 = chamfer_distance(pointset1, pointset2) + dist1.backward(torch.ones_like(dist1)) + assert torch.allclose(dist1, expected_dist1, 1e-2) + assert torch.allclose(dist2, expected_dist2, 1e-2) + assert torch.allclose(pointset1.grad.data, expected_pointset1_grad, 1e-2) + assert torch.allclose(pointset2.grad.data, expected_pointset2_grad, 1e-2)