fix cpp header error (#371)

* 1. use macro USE_PARROTS control header include
2. add clang-format google style in pre-commit

* use MMCV_ macros
pull/372/head
zhuyuanhao 2020-06-29 18:48:50 +08:00 committed by GitHub
parent 2c6fc5fd9b
commit d9549fba04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 216 additions and 266 deletions

View File

@ -33,6 +33,6 @@ repos:
- id: clang-format
name: clang-format
description: Format files with ClangFormat
entry: clang-format -i
entry: clang-format -style=google -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$

View File

@ -1,6 +1,12 @@
#ifndef BBOX_OVERLAPS_CUDA_KERNEL_CUH
#define BBOX_OVERLAPS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
T* ious, const int num_bbox1,
@ -73,4 +79,5 @@ __global__ void bbox_overlaps_cuda_kernel(const T* bbox1, const T* bbox2,
}
}
}
#endif
#endif // BBOX_OVERLAPS_CUDA_KERNEL_CUH

View File

@ -1,3 +1,12 @@
#ifndef CARAFE_CUDA_KERNEL_CUH
#define CARAFE_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define WARP_SIZE 32
#define THREADS_PER_PIXEL 32
#define MAX_SHARED_MEMORY 49152
@ -301,3 +310,5 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
mask_diff[mask_id] = output_val;
}
}
#endif // CARAFE_CUDA_KERNEL_CUH

View File

@ -1,6 +1,12 @@
#ifndef CARAFE_NAIVE_CUDA_KERNEL_CUH
#define CARAFE_NAIVE_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
__device__ inline int Loc2Index(const int n, const int c, const int h,
const int w, const int channel_num,
const int height, const int width) {
@ -101,4 +107,4 @@ __global__ void carafe_naive_backward_cuda_kernel(
}
}
#endif
#endif // CARAFE_NAIVE_CUDA_KERNEL_CUH

View File

@ -1,5 +1,11 @@
#ifndef CA_CUDA_KERNEL_CUH
#define CA_CUDA_KERNEL_CUH
#ifndef CC_ATTENTION_CUDA_KERNEL_CUH
#define CC_ATTENTION_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
@ -176,4 +182,4 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight,
}
}
#endif
#endif // CC_ATTENTION_CUDA_KERNEL_CUH

View File

@ -63,8 +63,14 @@
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#ifndef DEFORM_CONV_KERNEL_CUH
#define DEFORM_CONV_KERNEL_CUH
#ifndef DEFORM_CONV_CUDA_KERNEL_CUH
#define DEFORM_CONV_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__device__ T deformable_im2col_bilinear(const T *input, const int data_width,
@ -353,4 +359,4 @@ __global__ void deformable_col2im_coord_gpu_kernel(
}
}
#endif
#endif // DEFORM_CONV_CUDA_KERNEL_CUH

View File

@ -1,5 +1,11 @@
#ifndef DEFORM_POOL_KERNEL_CUH
#define DEFORM_POOL_KERNEL_CUH
#ifndef DEFORM_ROI_POOL_CUDA_KERNEL_CUH
#define DEFORM_ROI_POOL_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void deform_roi_pool_forward_cuda_kernel(
@ -174,4 +180,4 @@ __global__ void deform_roi_pool_backward_cuda_kernel(
}
}
#endif
#endif // DEFORM_ROI_POOL_CUDA_KERNEL_CUH

View File

@ -1,3 +1,12 @@
#ifndef MASKED_CONV2D_CUDA_KERNEL_CUH
#define MASKED_CONV2D_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename scalar_t>
__global__ void MaskedIm2colForward(const int n, const scalar_t *data_im,
const int height, const int width,
@ -48,3 +57,5 @@ __global__ void MaskedCol2imForward(const int n, const scalar_t *data_col,
data_im[(c_im * height + h_im) * width + w_im] = data_col[index];
}
}
#endif // MASKED_CONV2D_CUDA_KERNEL_CUH

View File

@ -63,8 +63,14 @@
// modified from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#ifndef MODULATED_DEFORM_CONV_KERNEL_CUH
#define MODULATED_DEFORM_CONV_KERNEL_CUH
#ifndef MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
#define MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__device__ T dmcn_im2col_bilinear(const T *input, const int data_width,
@ -385,4 +391,4 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(
}
}
#endif
#endif // MODULATED_DEFORM_CONV_CUDA_KERNEL_CUH

View File

@ -1,5 +1,11 @@
#ifndef NMS_KERNEL_CUH
#define NMS_KERNEL_CUH
#ifndef NMS_CUDA_KERNEL_CUH
#define NMS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long int) * 8;
@ -60,4 +66,4 @@ __global__ void nms_cuda(const int n_boxes, const float iou_threshold,
dev_mask[cur_box_idx * gridDim.y + col_start] = t;
}
}
#endif
#endif // NMS_CUDA_KERNEL_CUH

View File

@ -406,7 +406,7 @@ void DeformConvBackwardInputCUDAKernelLauncher(
}
}
int DeformConvBackwardParametersCUDAKernelLauncher(
void DeformConvBackwardParametersCUDAKernelLauncher(
DArrayLite input, DArrayLite offset, DArrayLite gradOutput,
DArrayLite gradWeight, DArrayLite columns, DArrayLite ones, int kW, int kH,
int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group,

View File

@ -1,6 +1,6 @@
#include "parrots_cuda_helper.hpp"
#include "sigmoid_focal_loss_kernel.cuh"
#include "softmax_focal_loss_kernel.cuh"
#include "sigmoid_focal_loss_cuda_kernel.cuh"
#include "softmax_focal_loss_cuda_kernel.cuh"
void SigmoidFocalLossForwardCUDAKernelLauncher(
const DArrayLite input, const DArrayLite target, const DArrayLite weight,

View File

@ -1,4 +1,4 @@
#include "nms_kernel.cuh"
#include "nms_cuda_kernel.cuh"
#include "parrots_cuda_helper.hpp"
DArrayLite NMSCUDAKernelLauncher(const DArrayLite boxes_sorted,

View File

@ -1,5 +1,5 @@
#include "parrots_cuda_helper.hpp"
#include "roi_align_kernel.cuh"
#include "roi_align_cuda_kernel.cuh"
void ROIAlignForwardCUDAKernelLauncher(const DArrayLite input,
const DArrayLite rois, DArrayLite output,

View File

@ -1,5 +1,5 @@
#include "parrots_cuda_helper.hpp"
#include "roi_pool_kernel.cuh"
#include "roi_pool_cuda_kernel.cuh"
void ROIPoolForwardCUDAKernelLauncher(const DArrayLite input,
const DArrayLite rois, DArrayLite output,

View File

@ -1,5 +1,12 @@
#ifndef PSAMASK_CUDA_CUH
#define PSAMASK_CUDA_CUH
#ifndef PSAMASK_CUDA_KERNEL_CUH
#define PSAMASK_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
// CUDA: grid stride looping
#ifndef CUDA_KERNEL_LOOP
#define CUDA_KERNEL_LOOP(i, n) \
@ -130,4 +137,4 @@ __global__ void psamask_distribute_backward_cuda(
}
}
#endif
#endif // PSAMASK_CUDA_KERNEL_CUH

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,
const bool aligned, const int offset);
@ -14,7 +14,7 @@ void bbox_overlaps_cuda(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset) {
if (bboxes1.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(bboxes1);
CHECK_CUDA_INPUT(bboxes2);
CHECK_CUDA_INPUT(ious);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks,
Tensor rfeatures, Tensor routput,
Tensor rmasks, Tensor output,
@ -38,7 +38,7 @@ void carafe_forward(Tensor features, Tensor masks, Tensor rfeatures,
Tensor routput, Tensor rmasks, Tensor output,
int kernel_size, int group_size, int scale_factor) {
if (features.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(features);
CHECK_CUDA_INPUT(masks);
CHECK_CUDA_INPUT(rfeatures);
@ -61,7 +61,7 @@ void carafe_backward(Tensor top_grad, Tensor rfeatures, Tensor masks,
Tensor mask_grad, int kernel_size, int group_size,
int scale_factor) {
if (top_grad.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(top_grad);
CHECK_CUDA_INPUT(rfeatures);
CHECK_CUDA_INPUT(masks);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void CARAFENAIVEForwardCUDAKernelLauncher(const Tensor features,
const Tensor masks, Tensor output,
const int kernel_size,
@ -32,7 +32,7 @@ void carafe_naive_backward_cuda(Tensor top_grad, Tensor features, Tensor masks,
void carafe_naive_forward(Tensor features, Tensor masks, Tensor output,
int kernel_size, int group_size, int scale_factor) {
if (features.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(features);
CHECK_CUDA_INPUT(masks);
CHECK_CUDA_INPUT(output);
@ -50,7 +50,7 @@ void carafe_naive_backward(Tensor top_grad, Tensor features, Tensor masks,
Tensor bottom_grad, Tensor mask_grad,
int kernel_size, int group_size, int scale_factor) {
if (top_grad.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(top_grad);
CHECK_CUDA_INPUT(features);
CHECK_CUDA_INPUT(masks);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, Tensor weight);
void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
@ -33,7 +33,7 @@ void ca_map_backward_cuda(const Tensor dout, const Tensor weight,
void ca_forward(const Tensor t, const Tensor f, Tensor weight) {
if (t.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(t);
CHECK_CUDA_INPUT(f);
CHECK_CUDA_INPUT(weight);
@ -49,7 +49,7 @@ void ca_forward(const Tensor t, const Tensor f, Tensor weight) {
void ca_backward(const Tensor dw, const Tensor t, const Tensor f, Tensor dt,
Tensor df) {
if (dw.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(dw);
CHECK_CUDA_INPUT(t);
CHECK_CUDA_INPUT(f);
@ -66,7 +66,7 @@ void ca_backward(const Tensor dw, const Tensor t, const Tensor f, Tensor dt,
void ca_map_forward(const Tensor weight, const Tensor g, Tensor out) {
if (weight.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(g);
CHECK_CUDA_INPUT(out);
@ -82,7 +82,7 @@ void ca_map_forward(const Tensor weight, const Tensor g, Tensor out) {
void ca_map_backward(const Tensor dout, const Tensor weight, const Tensor g,
Tensor dw, Tensor dg) {
if (dout.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(dout);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(g);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void DeformConvForwardCUDAKernelLauncher(Tensor input, Tensor weight,
Tensor offset, Tensor output,
Tensor columns, Tensor ones, int kW,
@ -62,7 +62,7 @@ void deform_conv_forward(Tensor input, Tensor weight, Tensor offset,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(offset);
CHECK_CUDA_INPUT(weight);
@ -88,7 +88,7 @@ void deform_conv_backward_input(Tensor input, Tensor offset, Tensor gradOutput,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(offset);
CHECK_CUDA_INPUT(gradOutput);
@ -117,7 +117,7 @@ void deform_conv_backward_parameters(Tensor input, Tensor offset,
int deformable_group, float scale,
int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(offset);
CHECK_CUDA_INPUT(gradOutput);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void DeformRoIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois,
Tensor offset, Tensor output,
int pooled_height, int pooled_width,
@ -38,7 +38,7 @@ void deform_roi_pool_forward(Tensor input, Tensor rois, Tensor offset,
float spatial_scale, int sampling_ratio,
float gamma) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(offset);
@ -61,7 +61,7 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma) {
if (grad_output.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(rois);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target,
Tensor weight, Tensor output,
const float gamma,
@ -54,7 +54,7 @@ void softmax_focal_loss_backward_cuda(Tensor input, Tensor target,
void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(target);
CHECK_CUDA_INPUT(weight);
@ -73,7 +73,7 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma, float alpha) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(target);
CHECK_CUDA_INPUT(weight);
@ -92,7 +92,7 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(target);
CHECK_CUDA_INPUT(weight);
@ -112,7 +112,7 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input, float gamma,
float alpha) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(target);
CHECK_CUDA_INPUT(weight);

View File

@ -1,6 +1,6 @@
#include "pytorch_cuda_helper.hpp"
#include "sigmoid_focal_loss_kernel.cuh"
#include "softmax_focal_loss_kernel.cuh"
#include "sigmoid_focal_loss_cuda_kernel.cuh"
#include "softmax_focal_loss_cuda_kernel.cuh"
void SigmoidFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target,
Tensor weight, Tensor output,

View File

@ -2,13 +2,13 @@
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/vision.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
std::string get_compiling_cuda_version() {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void MaskedIm2colForwardCUDAKernelLauncher(const Tensor bottom_data,
const Tensor mask_h_idx,
const Tensor mask_w_idx,
@ -39,7 +39,7 @@ void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w) {
if (im.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(im);
CHECK_CUDA_INPUT(mask_h_idx);
CHECK_CUDA_INPUT(mask_w_idx);
@ -58,7 +58,7 @@ void masked_col2im_forward(const Tensor col, const Tensor mask_h_idx,
const Tensor mask_w_idx, Tensor im, int height,
int width, int channels) {
if (col.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(col);
CHECK_CUDA_INPUT(mask_h_idx);
CHECK_CUDA_INPUT(mask_w_idx);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void ModulatedDeformConvForwardCUDAKernelLauncher(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
@ -50,7 +50,7 @@ void modulated_deform_conv_forward(
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
@ -80,7 +80,7 @@ void modulated_deform_conv_backward(
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,
int offset);
@ -62,7 +62,7 @@ Tensor nms_cpu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
if (boxes.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes);
CHECK_CUDA_INPUT(scores);
return nms_cuda(boxes, scores, iou_threshold, offset);

View File

@ -1,4 +1,4 @@
#include "nms_kernel.cuh"
#include "nms_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor NMSCUDAKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold,

View File

@ -182,7 +182,7 @@ void psamask_backward_cpu(const int psa_type, const Tensor grad_output,
grad_input);
}
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void PSAMaskForwardCUDAKernelLauncher(const int psa_type, const Tensor input,
Tensor output, const int num_,
const int h_feature, const int w_feature,
@ -221,7 +221,7 @@ void psamask_forward(const Tensor input, Tensor output, const int psa_type,
const int h_mask, const int w_mask, const int half_h_mask,
const int half_w_mask) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(output);
psamask_forward_cuda(psa_type, input, output, num_, h_feature, w_feature,
@ -240,7 +240,7 @@ void psamask_backward(Tensor grad_output, const Tensor grad_input,
const int w_feature, const int h_mask, const int w_mask,
const int half_h_mask, const int half_w_mask) {
if (grad_input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_input);
CHECK_CUDA_INPUT(grad_output);
psamask_backward_cuda(psa_type, grad_output, grad_input, num_, h_feature,

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,
int aligned_height, int aligned_width,
@ -40,7 +40,7 @@ void roi_align_forward(Tensor input, Tensor rois, Tensor output,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(output);
@ -63,7 +63,7 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y,
int aligned_width, float spatial_scale,
int sampling_ratio, int pool_mode, bool aligned) {
if (grad_output.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(argmax_y);

View File

@ -1,5 +1,5 @@
#include "pytorch_cuda_helper.hpp"
#include "roi_align_kernel.cuh"
#include "roi_align_cuda_kernel.cuh"
void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x,

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void ROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,
int pooled_width, float spatial_scale);
@ -29,7 +29,7 @@ void roi_pool_forward(Tensor input, Tensor rois, Tensor output, Tensor argmax,
int pooled_height, int pooled_width,
float spatial_scale) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(output);
@ -49,7 +49,7 @@ void roi_pool_backward(Tensor grad_output, Tensor rois, Tensor argmax,
Tensor grad_input, int pooled_height, int pooled_width,
float spatial_scale) {
if (grad_output.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(rois);
CHECK_CUDA_INPUT(argmax);

View File

@ -1,5 +1,5 @@
#include "pytorch_cuda_helper.hpp"
#include "roi_pool_kernel.cuh"
#include "roi_pool_cuda_kernel.cuh"
void ROIPoolForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
Tensor argmax, int pooled_height,

View File

@ -1,6 +1,6 @@
#include "pytorch_cpp_helper.hpp"
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
void SyncBNForwardMeanCUDAKernelLauncher(const Tensor input, Tensor mean);
void SyncBNForwardVarCUDAKernelLauncher(const Tensor input, const Tensor mean,
@ -61,7 +61,7 @@ void sync_bn_backward_data_cuda(const Tensor grad_output, const Tensor weight,
void sync_bn_forward_mean(const Tensor input, Tensor mean) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(mean);
sync_bn_forward_mean_cuda(input, mean);
@ -75,7 +75,7 @@ void sync_bn_forward_mean(const Tensor input, Tensor mean) {
void sync_bn_forward_var(const Tensor input, const Tensor mean, Tensor var) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(mean);
CHECK_CUDA_INPUT(var);
@ -95,7 +95,7 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean,
Tensor output, float eps, float momentum,
int group_size) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(mean);
CHECK_CUDA_INPUT(var);
@ -120,7 +120,7 @@ void sync_bn_forward_output(const Tensor input, const Tensor mean,
void sync_bn_backward_param(const Tensor grad_output, const Tensor norm,
Tensor grad_weight, Tensor grad_bias) {
if (grad_output.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(norm);
CHECK_CUDA_INPUT(grad_weight);
@ -139,7 +139,7 @@ void sync_bn_backward_data(const Tensor grad_output, const Tensor weight,
const Tensor norm, const Tensor std,
Tensor grad_input) {
if (grad_output.device().is_cuda()) {
#ifdef WITH_CUDA
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(grad_output);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(grad_weight);

View File

@ -1,5 +1,11 @@
#ifndef ROI_ALIGN_KERNEL_CUH
#define ROI_ALIGN_KERNEL_CUH
#ifndef ROI_ALIGN_CUDA_KERNEL_CUH
#define ROI_ALIGN_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
/*** Forward ***/
template <typename T>
@ -196,4 +202,4 @@ __global__ void roi_align_backward_cuda_kernel(
}
}
#endif // ROI_ALIGN_KERNEL_CUH
#endif // ROI_ALIGN_CUDA_KERNEL_CUH

View File

@ -1,7 +1,11 @@
#ifndef ROI_POOL_KERNEL_CUH
#define ROI_POOL_KERNEL_CUH
#ifndef ROI_POOL_CUDA_KERNEL_CUH
#define ROI_POOL_CUDA_KERNEL_CUH
#include <cuda.h>
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void roi_pool_forward_cuda_kernel(
@ -85,4 +89,4 @@ __global__ void roi_pool_backward_cuda_kernel(
}
}
#endif
#endif // ROI_POOL_CUDA_KERNEL_CUH

View File

@ -1,5 +1,11 @@
#ifndef SIGMOID_FOCAL_LOSS_KERNEL_CUH
#define SIGMOID_FOCAL_LOSS_KERNEL_CUH
#ifndef SIGMOID_FOCAL_LOSS_CUDA_KERNEL_CUH
#define SIGMOID_FOCAL_LOSS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void sigmoid_focal_loss_forward_cuda_kernel(
@ -60,4 +66,5 @@ __global__ void sigmoid_focal_loss_backward_cuda_kernel(
}
}
}
#endif
#endif // SIGMOID_FOCAL_LOSS_CUDA_KERNEL_CUH

View File

@ -1,5 +1,11 @@
#ifndef SOFTMAX_FOCAL_LOSS_KERNEL_CUH
#define SOFTMAX_FOCAL_LOSS_KERNEL_CUH
#ifndef SOFTMAX_FOCAL_LOSS_CUDA_KERNEL_CUH
#define SOFTMAX_FOCAL_LOSS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void softmax_focal_loss_forward_cuda_kernel(
@ -61,4 +67,5 @@ __global__ void softmax_focal_loss_backward_cuda2_kernel(
}
}
}
#endif
#endif // SOFTMAX_FOCAL_LOSS_CUDA_KERNEL_CUH

View File

@ -1,160 +0,0 @@
#ifndef SOFTNMS_KERNEL_CUH
#define SOFTNMS_KERNEL_CUH
#include <cuda.h>
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
int const threadsPerBlock = sizeof(unsigned long long int) * 8;
template <typename scalar_t>
__device__ inline scalar_t devIoU(scalar_t const *const a,
scalar_t const *const b) {
scalar_t left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
scalar_t top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
scalar_t width = fmaxf(right - left + 1, 0.f),
height = fmaxf(bottom - top + 1, 0.f);
scalar_t interS = width * height;
scalar_t Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
scalar_t Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
return interS / (Sa + Sb - interS);
}
template <typename scalar_t>
__global__ void softnms_max_kernel(const int n_boxes,
const scalar_t overlap_thresh,
const scalar_t *dev_boxes, int *order,
float *max_value, int *max_index) {
__shared__ float maximum[threadsPerBlock];
__shared__ int max_id[threadsPerBlock];
unsigned int tid = threadIdx.x;
unsigned int idx = blockIdx.x * threadsPerBlock + threadIdx.x;
if (idx >= n_boxes) {
return;
}
const int block_size = fminf(n_boxes + tid - idx, threadsPerBlock);
int *l_order = order + (idx - tid);
if (l_order[tid] == 0 && dev_boxes[idx * 5 + 4] >= overlap_thresh) {
maximum[tid] = dev_boxes[idx * 5 + 4];
} else {
maximum[tid] = -1.0;
}
max_id[tid] = tid;
__syncthreads();
if (block_size >= 1024 && tid < 512) {
if (maximum[tid] < maximum[tid + 512]) {
maximum[tid] = maximum[tid + 512];
max_id[tid] = max_id[tid + 512];
}
}
if (block_size >= 512 && tid < 256) {
if (maximum[tid] < maximum[tid + 256]) {
maximum[tid] = maximum[tid + 256];
max_id[tid] = max_id[tid + 256];
}
}
if (block_size >= 256 && tid < 128) {
if (maximum[tid] < maximum[tid + 128]) {
maximum[tid] = maximum[tid + 128];
max_id[tid] = max_id[tid + 128];
}
}
if (block_size >= 128 && tid < 64) {
if (maximum[tid] < maximum[tid + 64]) {
maximum[tid] = maximum[tid + 64];
max_id[tid] = max_id[tid + 64];
}
}
if (tid < 32) {
volatile float *vmaximum = maximum;
volatile int *vmax_id = max_id;
if (block_size >= 64 && vmaximum[tid] < vmaximum[tid + 32]) {
vmaximum[tid] = vmaximum[tid + 32];
vmax_id[tid] = vmax_id[tid + 32];
}
if (block_size >= 32 && tid < 16 && vmaximum[tid] < vmaximum[tid + 16]) {
vmaximum[tid] = vmaximum[tid + 16];
vmax_id[tid] = vmax_id[tid + 16];
}
if (block_size >= 16 && tid < 8 && vmaximum[tid] < vmaximum[tid + 8]) {
vmaximum[tid] = vmaximum[tid + 8];
vmax_id[tid] = vmax_id[tid + 8];
}
if (block_size >= 8 && tid < 4 && vmaximum[tid] < vmaximum[tid + 4]) {
vmaximum[tid] = vmaximum[tid + 4];
vmax_id[tid] = vmax_id[tid + 4];
}
if (block_size >= 4 && tid < 2 && vmaximum[tid] < vmaximum[tid + 2]) {
vmaximum[tid] = vmaximum[tid + 2];
vmax_id[tid] = vmax_id[tid + 2];
}
if (block_size >= 2 && tid < 1 && vmaximum[tid] < vmaximum[tid + 1]) {
vmaximum[tid] = vmaximum[tid + 1];
vmax_id[tid] = vmax_id[tid + 1];
}
}
if (tid == 0) {
max_value[blockIdx.x] = maximum[0];
max_index[blockIdx.x] = max_id[0];
}
}
template <typename scalar_t>
__global__ void softnms_update_kernel(const int n_boxes, const scalar_t sigma,
const scalar_t n_thresh,
const unsigned int method,
const scalar_t overlap_thresh,
scalar_t *dev_boxes, int *order,
unsigned long long *keep, int max_id) {
const int col_start = blockIdx.x;
const int col_size =
fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
const int cur_idx = threadsPerBlock * col_start + threadIdx.x;
const int tid = threadIdx.x;
if (cur_idx >= n_boxes) {
return;
}
__shared__ scalar_t cur_max_boxes[5];
cur_max_boxes[0] = dev_boxes[max_id * 5 + 0];
cur_max_boxes[1] = dev_boxes[max_id * 5 + 1];
cur_max_boxes[2] = dev_boxes[max_id * 5 + 2];
cur_max_boxes[3] = dev_boxes[max_id * 5 + 3];
cur_max_boxes[4] = dev_boxes[max_id * 5 + 4];
__syncthreads();
if (cur_idx != max_id && tid < col_size && order[cur_idx] == 0 &&
(!(keep[col_start] & (1ULL << tid)))) {
scalar_t block_boxes[5];
block_boxes[0] = dev_boxes[cur_idx * 5 + 0];
block_boxes[1] = dev_boxes[cur_idx * 5 + 1];
block_boxes[2] = dev_boxes[cur_idx * 5 + 2];
block_boxes[3] = dev_boxes[cur_idx * 5 + 3];
block_boxes[4] = dev_boxes[cur_idx * 5 + 4];
scalar_t ovr = devIoU(cur_max_boxes, block_boxes);
scalar_t weight = 1.0;
if (method == 1) {
if (ovr > n_thresh) {
weight = 1.0 - ovr;
}
} else if (method == 2) {
weight = exp(-(ovr * ovr) / sigma);
} else if (ovr >= n_thresh) {
weight = 0.0;
}
block_boxes[4] *= weight;
dev_boxes[cur_idx * 5 + 4] = block_boxes[4];
if (block_boxes[4] < overlap_thresh) {
keep[col_start] |= 1ULL << tid;
}
}
}
#endif

View File

@ -1,5 +1,11 @@
#ifndef SYNC_BN_KERNEL_CUH
#define SYNC_BN_KERNEL_CUH
#ifndef SYNCBN_CUDA_KERNEL_CUH
#define SYNCBN_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
template <typename T>
__global__ void sync_bn_forward_mean_cuda_kernel(const T *input, float *mean,
@ -321,4 +327,4 @@ __global__ void sync_bn_backward_data_cuda_kernel(
}
}
#endif // SYNC_BN_KERNEL_CUH
#endif // SYNCBN_CUDA_KERNEL_CUH

View File

@ -150,22 +150,23 @@ def get_extensions():
try:
import torch
cuda_args = [
'-gencode=arch=compute_52,code=sm_52',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
ext_name = 'mmcv._ext'
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
cuda_args = [
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
define_macros = [('MMCV_USE_PARROTS', None)]
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
include_path = os.path.abspath('./mmcv/ops/csrc')
ext_ops = Extension(
name=ext_name,
sources=op_files,
include_dirs=[include_path],
define_macros=define_macros,
extra_compile_args={
'nvcc': cuda_args,
'cxx': [],
@ -177,12 +178,19 @@ def get_extensions():
CUDAExtension, CppExtension)
# prevent ninja from using too many resources
os.environ.setdefault('MAX_JOBS', '4')
cuda_args = [
'-gencode=arch=compute_52,code=sm_52',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
define_macros = []
extra_compile_args = {'cxx': []}
if (torch.cuda.is_available()
or os.getenv('FORCE_CUDA', '0') == '1'):
define_macros += [('WITH_CUDA', None)]
define_macros += [('MMCV_WITH_CUDA', None)]
extra_compile_args['nvcc'] = cuda_args
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
extension = CUDAExtension