[Refactor] Refactor the directory of csrc (#1206)

* [Refactor] Refactor the csrc directory

* update MANIFEST.in

* fix hip

* add csrc readme

* trailing whitespace

* fix syntax error in setup.py

* add compatibility docs

* move parrots_cudawarpfunction.cuh to common/cuda

* fix grammar, update directory tree

* fix MANIFEST.in

* Add new structre of csrc in compatibility.md

* Add original structre of csrc in compatibility.md

* fix typo

* remove TODO

* modify according to comment

* format

Co-authored-by: grimoire <yaoqian@sensetime.com>
pull/1254/head
Zaida Zhou 2021-08-10 15:09:19 +08:00 committed by GitHub
parent 9fa5de8b9b
commit dfb48c87ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 404 additions and 951 deletions

View File

@ -1,5 +1,5 @@
include requirements/runtime.txt
include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json mmcv/model_zoo/mmcls.json
include mmcv/ops/csrc/*.cuh mmcv/ops/csrc/*.hpp
include mmcv/ops/csrc/pytorch/*.cu mmcv/ops/csrc/pytorch/*.cpp
include mmcv/ops/csrc/parrots/*.cu mmcv/ops/csrc/parrots/*.cpp
include mmcv/ops/csrc/common/cuda/*.cuh mmcv/ops/csrc/common/cuda/*.hpp mmcv/ops/csrc/common/*.hpp
include mmcv/ops/csrc/pytorch/*.cpp mmcv/ops/csrc/pytorch/cuda/*.cu
include mmcv/ops/csrc/parrots/*.h mmcv/ops/csrc/parrots/*.cpp

View File

@ -0,0 +1,105 @@
## Compatibility of MMCV
### MMCV v1.3.11
In order to flexibly support more backends and hardwares like `NVIDIA GPUs` and `AMD GPUs`, the directory of `mmcv/ops/csrc` is refactored. Note that this refactoring will not affect the usage in API. For related information, please refer to [PR1206](https://github.com/open-mmlab/mmcv/pull/1206).
The original directory was organized as follows.
```
.
├── common_cuda_helper.hpp
├── ops_cuda_kernel.cuh
├── pytorch_cpp_helper.hpp
├── pytorch_cuda_helper.hpp
├── parrots_cpp_helper.hpp
├── parrots_cuda_helper.hpp
├── parrots_cudawarpfunction.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│ ├── onnxruntime_register.cpp
│      ├── ...
│      └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_cuda.cu
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── ...
│   ├── ops.cpp
│   ├── ops_cuda.cu
│   ├── pybind.cpp
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
├── trt_plugin.hpp
├── trt_serialize.hpp
├── ...
├── trt_ops.hpp
└── plugins
   ├── trt_cuda_helper.cu
   ├── trt_plugin.cpp
   ├── ...
   ├── trt_ops.cpp
   └── trt_ops_kernel.cu
```
After refactored, it is organized as follows.
```
.
├── common
│ ├── box_iou_rotated_utils.hpp
│ ├── parrots_cpp_helper.hpp
│ ├── parrots_cuda_helper.hpp
│ ├── pytorch_cpp_helper.hpp
│ ├── pytorch_cuda_helper.hpp
│   └── cuda
│   ├── common_cuda_helper.hpp
│   ├── parrots_cudawarpfunction.cuh
│   ├── ...
│   └── ops_cuda_kernel.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│ ├── onnxruntime_register.cpp
│      ├── ...
│      └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── info.cpp
│   ├── pybind.cpp
│   ├── ...
│   ├── ops.cpp
│   └── cuda
│      ├── ...
│      └── ops_cuda.cu
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
├── trt_plugin.hpp
├── trt_serialize.hpp
├── ...
├── trt_ops.hpp
└── plugins
   ├── trt_cuda_helper.cu
   ├── trt_plugin.cpp
   ├── ...
   ├── trt_ops.cpp
   └── trt_ops_kernel.cu
```

View File

@ -10,6 +10,7 @@ You can switch between Chinese and English documents in the lower-left corner of
deployment.rst
understand_mmcv.rst
api.rst
compatibility.md
faq.md
community.rst

View File

@ -0,0 +1,105 @@
## MMCV 兼容性说明
### MMCV v1.3.11
为了灵活地支持更多的后端和硬件,例如 `NVIDIA GPUs` 、`AMD GPUs`,我们重构了 `mmcv/ops/csrc` 目录。注意,这次重构不会影响 API 的使用。更多相关信息,请参考 [PR1206](https://github.com/open-mmlab/mmcv/pull/1206)。
原始的目录结构如下所示
```
.
├── common_cuda_helper.hpp
├── ops_cuda_kernel.cuh
├── pytorch_cpp_helper.hpp
├── pytorch_cuda_helper.hpp
├── parrots_cpp_helper.hpp
├── parrots_cuda_helper.hpp
├── parrots_cudawarpfunction.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│ ├── onnxruntime_register.cpp
│      ├── ...
│      └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_cuda.cu
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── ...
│   ├── ops.cpp
│   ├── ops_cuda.cu
│   ├── pybind.cpp
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
├── trt_plugin.hpp
├── trt_serialize.hpp
├── ...
├── trt_ops.hpp
└── plugins
   ├── trt_cuda_helper.cu
   ├── trt_plugin.cpp
   ├── ...
   ├── trt_ops.cpp
   └── trt_ops_kernel.cu
```
重构之后,它的结构如下所示
```
.
├── common
│ ├── box_iou_rotated_utils.hpp
│ ├── parrots_cpp_helper.hpp
│ ├── parrots_cuda_helper.hpp
│ ├── pytorch_cpp_helper.hpp
│ ├── pytorch_cuda_helper.hpp
│   └── cuda
│   ├── common_cuda_helper.hpp
│   ├── parrots_cudawarpfunction.cuh
│   ├── ...
│   └── ops_cuda_kernel.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│ ├── onnxruntime_register.cpp
│      ├── ...
│      └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── info.cpp
│   ├── pybind.cpp
│   ├── ...
│   ├── ops.cpp
│   └── cuda
│      ├── ...
│      └── ops_cuda.cu
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
├── trt_plugin.hpp
├── trt_serialize.hpp
├── ...
├── trt_ops.hpp
└── plugins
   ├── trt_cuda_helper.cu
   ├── trt_plugin.cpp
   ├── ...
   ├── trt_ops.cpp
   └── trt_ops_kernel.cu
```

View File

@ -10,6 +10,7 @@
deployment.rst
understand_mmcv.rst
api.rst
compatibility.md
faq.md
community.rst

View File

@ -0,0 +1,169 @@
# Code Structure of CUDA operators
This folder contains all non-python code for MMCV custom ops. Please follow the same architecture if you want to add new ops.
## Directories Tree
```folder
.
├── common
│ ├── box_iou_rotated_utils.hpp
│ ├── parrots_cpp_helper.hpp
│ ├── parrots_cuda_helper.hpp
│ ├── pytorch_cpp_helper.hpp
│ ├── pytorch_cuda_helper.hpp
│   └── cuda
│   ├── common_cuda_helper.hpp
│   ├── parrots_cudawarpfunction.cuh
│   ├── ...
│   └── ops_cuda_kernel.cuh
├── onnxruntime
│   ├── onnxruntime_register.h
│   ├── onnxruntime_session_options_config_keys.h
│   ├── ort_mmcv_utils.h
│   ├── ...
│   ├── onnx_ops.h
│   └── cpu
│ ├── onnxruntime_register.cpp
│      ├── ...
│      └── onnx_ops_impl.cpp
├── parrots
│   ├── ...
│   ├── ops.cpp
│   ├── ops_parrots.cpp
│   └── ops_pytorch.h
├── pytorch
│   ├── info.cpp
│   ├── pybind.cpp
│   ├── ...
│   ├── ops.cpp
│   └── cuda
│      ├── ...
│      └── ops_cuda.cu
└── tensorrt
├── trt_cuda_helper.cuh
├── trt_plugin_helper.hpp
├── trt_plugin.hpp
├── trt_serialize.hpp
├── ...
├── trt_ops.hpp
└── plugins
   ├── trt_cuda_helper.cu
   ├── trt_plugin.cpp
   ├── ...
   ├── trt_ops.cpp
   └── trt_ops_kernel.cu
```
## Components
- `common`: This directory contains all tools and shared codes.
- `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax.
- `onnxruntime`: **ONNX Runtime** support for custom ops.
- `cpu`: CPU implementation of supported ops.
- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
- `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
- `tensorrt`: **TensorRT** support for custom ops.
- `plugins`: This directory contains the implementation of the supported custom ops. Some ops might also use shared cuda kernel in `common/cuda`.
## How to add new PyTorch ops?
1. (Optional) Add shared kernel in `common` to support special hardware platform.
```c++
// src/common/cuda/new_ops_cuda_kernel.cuh
template <typename T>
__global__ void new_ops_forward_cuda_kernel(const T* input, T* output, ...) {
// forward here
}
```
Add cuda kernel launcher in `pytorch/cuda`.
```c++
// src/pytorch/cuda
#include <new_ops_cuda_kernel.cuh>
void NewOpsForwardCUDAKernelLauncher(Tensor input, Tensor output, ...){
// initialize
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "new_ops_forward_cuda_kernel", ([&] {
new_ops_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),...);
}));
AT_CUDA_CHECK(cudaGetLastError());
}
```
2. Add ops implementation in `pytorch` directory. Select different implementations according to device type.
```c++
// src/pytorch/new_ops.cpp
#ifdef MMCV_WITH_CUDA
Tensor new_ops_forward_cuda(Tensor input, Tensor output, ...){
// implement cuda forward here
// use `NewOpsForwardCUDAKernelLauncher` here
}
#else
Tensor new_ops_forward_cpu(Tensor input, Tensor output, ...){
// implement cpu forward here
}
...
Tensor new_ops_forward(Tensor input, Tensor output, ...){
// select implementation by input device type
if (boxes.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(output);
return new_ops_forward_cuda(input, output, ...);
#else
AT_ERROR("new ops is not compiled with GPU support");
#endif
} else {
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(output);
return new_ops_forward_cpu(input, output, ...);
}
}
```
3. Binding the implementation in `pytorch/pybind.cpp`
```c++
// src/pytorch/pybind.cpp
...
Tensor new_ops_forward(Tensor input, Tensor output, ...);
...
// bind with pybind11
m.def("new_ops_forward", &new_ops_forward, "new_ops_forward",
py::arg("input"), py::arg("output"), ...);
...
```
4. Build MMCV again. Enjoy new ops in python
```python
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['new_ops_forward'])
...
ext_module.new_ops_forward(input, output, ...)
```

View File

@ -1,109 +0,0 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
template <typename scalar_t>
static __global__ void fused_bias_act_kernel(
scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
const scalar_t *p_ref, int act, int grad, scalar_t alpha, scalar_t scale,
int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
// act = 1: linear layer
// act = 3: leaky relu layer
// grad = 0: direct forward path
// grad = 1: first order deviation
// grad = 2: second order deviation
switch (act * 10 + grad) {
default:
case 10:
y = x;
break;
case 11:
y = x;
break;
case 12:
y = 0.0;
break;
case 30:
y = (x > 0.0) ? x : x * alpha;
break;
case 31:
y = (ref > 0.0) ? x : x * alpha;
break;
case 32:
y = 0.0;
break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_leakyrelu_op(const torch::Tensor &input,
const torch::Tensor &bias,
const torch::Tensor &refer, int act,
int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
});
return y;
}

View File

@ -1,360 +0,0 @@
/*!
**************************************************************************************************
* Deformable DETR
* Copyright (c) 2020 SenseTime. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
**************************************************************************************************
* Modified from
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <ms_deform_attn_cuda_kernel.cuh>
#include <vector>
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int64_t *data_spatial_shapes,
const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc,
const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size,
const int num_heads, const int channels,
const int num_levels, const int num_query,
const int num_point, scalar_t *data_col) {
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0, stream>>>(
num_kernels, data_value, data_spatial_shapes, data_level_start_index,
data_sampling_loc, data_attn_weight, batch_size, spatial_size,
num_heads, channels, num_levels, num_query, num_point, data_col);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
template <typename scalar_t>
void ms_deformable_col2im_cuda(
cudaStream_t stream, const scalar_t *grad_col, const scalar_t *data_value,
const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight,
const int batch_size, const int spatial_size, const int num_heads,
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
const int num_threads =
(channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
const int num_kernels = batch_size * num_query * num_heads * channels;
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
if (channels > 1024) {
if ((channels & 1023) == 0) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
}
} else {
switch (channels) {
case 1:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
1>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 2:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
2>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 4:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
4>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 8:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
8>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 16:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
16>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 32:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t,
32>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 64:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
64>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 128:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
128>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 256:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
256>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 512:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
512>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
case 1024:
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t,
1024>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
stream>>>(num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc,
data_attn_weight, batch_size, spatial_size, num_heads,
channels, num_levels, num_query, num_point, grad_value,
grad_sampling_loc, grad_attn_weight);
break;
default:
if (channels < 64) {
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
} else {
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
num_threads * 3 * sizeof(scalar_t), stream>>>(
num_kernels, grad_col, data_value, data_spatial_shapes,
data_level_start_index, data_sampling_loc, data_attn_weight,
batch_size, spatial_size, num_heads, channels, num_levels,
num_query, num_point, grad_value, grad_sampling_loc,
grad_attn_weight);
}
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
auto output =
at::zeros({batch, num_query, num_heads, channels}, value.options());
const int batch_n = im2col_step_;
auto output_n = output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
for (int n = 0; n < batch / im2col_step_; ++n) {
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(
at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point, columns.data<scalar_t>());
}));
}
output = output.view({batch, num_query, num_heads * channels});
return output;
}
void ms_deform_attn_cuda_backward(
const at::Tensor &value, const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
const at::Tensor &attn_weight, const at::Tensor &grad_output,
at::Tensor &grad_value, at::Tensor &grad_sampling_loc,
at::Tensor &grad_attn_weight, const int im2col_step) {
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(),
"spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(),
"level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(),
"sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(),
"attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(),
"grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(),
"spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(),
"level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(),
"sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
const int batch = value.size(0);
const int spatial_size = value.size(1);
const int num_heads = value.size(2);
const int channels = value.size(3);
const int num_levels = spatial_shapes.size(0);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(4);
const int im2col_step_ = std::min(batch, im2col_step);
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)",
batch, im2col_step_);
const int batch_n = im2col_step_;
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
auto grad_output_n = grad_output.view(
{batch / im2col_step_, batch_n, num_query, num_heads, channels});
for (int n = 0; n < batch / im2col_step_; ++n) {
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(
value.type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(
at::cuda::getCurrentCUDAStream(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query,
num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() +
n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() +
n * im2col_step_ * per_attn_weight_size);
}));
}
}

View File

@ -1,370 +0,0 @@
// from
// https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}

View File

@ -1,25 +0,0 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
#include "box_iou_rotated_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
using scalar_t = float;
AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor");
int output_size = ious.numel();
int num_boxes1 = boxes1.size(0);
int num_boxes2 = boxes2.size(0);
at::cuda::CUDAGuard device_guard(boxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
box_iou_rotated_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -15,9 +15,10 @@
#include <cuda_runtime.h>
#include <THC/THCAtomics.cuh>
#include <ms_deform_attn_cuda_kernel.cuh>
#include <vector>
#include "ms_deform_attn_cuda_kernel.cuh"
template <typename scalar_t>
void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t *data_value,
const int64_t *data_spatial_shapes,

View File

@ -1,61 +0,0 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
#include "nms_rotated_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor nms_rotated_cuda(const Tensor dets, const Tensor scores,
const Tensor order_t, const Tensor dets_sorted,
float iou_threshold, const int multi_label) {
// using scalar_t = float;
AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
int dets_num = dets.size(0);
const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);
Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.type(), "nms_rotated_kernel_cuda", [&] {
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num, iou_threshold, dets_sorted.data<scalar_t>(),
(unsigned long long*)mask.data<int64_t>(), multi_label);
});
Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}

View File

@ -4,7 +4,7 @@
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include "../pytorch_cpp_helper.hpp"
#include "pytorch_cpp_helper.hpp"
// implementation taken from Caffe2
template <typename T>

View File

@ -4,7 +4,7 @@
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include "../pytorch_cpp_helper.hpp"
#include "pytorch_cpp_helper.hpp"
// implementation taken from Caffe2
template <typename T>

BIN
model.pth 100644

Binary file not shown.

BIN
modelA.pth 100644

Binary file not shown.

View File

@ -145,11 +145,10 @@ def get_extensions():
library_dirs += [tensorrt_lib_path]
libraries += ['nvinfer', 'nvparsers', 'nvinfer_plugin']
libraries += ['cudart']
kwargs = {}
define_macros = []
extra_compile_args = {'cxx': []}
include_path = os.path.abspath('./mmcv/ops/csrc')
include_path = os.path.abspath('./mmcv/ops/csrc/common/cuda')
include_trt_path = os.path.abspath('./mmcv/ops/csrc/tensorrt')
include_dirs.append(include_path)
include_dirs.append(include_trt_path)
@ -163,9 +162,6 @@ def get_extensions():
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
library_dirs += library_paths(cuda=True)
kwargs['library_dirs'] = library_dirs
kwargs['libraries'] = libraries
from setuptools import Extension
ext_ops = Extension(
name=ext_name,
@ -187,9 +183,11 @@ def get_extensions():
# new parrots op impl do not use MMCV_USE_PARROTS
# define_macros = [('MMCV_USE_PARROTS', None)]
define_macros = []
op_files = glob.glob('./mmcv/ops/csrc/parrots/*.cu') +\
include_dirs = []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') +\
glob.glob('./mmcv/ops/csrc/parrots/*.cpp')
include_dirs = [os.path.abspath('./mmcv/ops/csrc')]
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args = {
'nvcc': [cuda_args] if cuda_args else [],
@ -219,6 +217,7 @@ def get_extensions():
os.environ.setdefault('MAX_JOBS', '4')
define_macros = []
extra_compile_args = {'cxx': []}
include_dirs = []
is_rocm_pytorch = False
if parse_version(torch.__version__) >= parse_version('1.5'):
@ -226,13 +225,13 @@ def get_extensions():
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
this_dir = 'mmcv/ops/csrc/'
project_dir = 'mmcv/ops/csrc/'
if is_rocm_pytorch:
from torch.utils.hipify import hipify_python
hipify_python.hipify(
project_directory=this_dir,
output_directory=this_dir,
project_directory=project_dir,
output_directory=project_dir,
includes='mmcv/ops/csrc/*',
show_detailed=True,
is_pytorch_extension=True,
@ -243,25 +242,26 @@ def get_extensions():
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/hip/*')
extension = CUDAExtension
include_path = os.path.abspath('./mmcv/ops/csrc/hip')
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/hip'))
elif torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu')
extension = CUDAExtension
include_path = os.path.abspath('./mmcv/ops/csrc')
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
else:
print(f'Compiling {ext_name} without CUDA')
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp')
extension = CppExtension
include_path = os.path.abspath('./mmcv/ops/csrc')
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=[include_path],
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)
@ -276,7 +276,6 @@ def get_extensions():
ort_path = os.getenv('ONNXRUNTIME_DIR', '0')
library_dirs += [os.path.join(ort_path, 'lib')]
libraries.append('onnxruntime')
kwargs = {}
define_macros = []
extra_compile_args = {'cxx': []}
@ -297,9 +296,6 @@ def get_extensions():
include_dirs += include_paths(cuda=False)
library_dirs += library_paths(cuda=False)
kwargs['library_dirs'] = library_dirs
kwargs['libraries'] = libraries
from setuptools import Extension
ext_ops = Extension(
name=ext_name,