mirror of https://github.com/open-mmlab/mmcv.git
177 lines
6.0 KiB
Markdown
177 lines
6.0 KiB
Markdown
### v1.3.18
|
||
|
||
部分自定义算子对于不同的设备有不同实现,为此添加的大量宏命令与类型检查使得代码变得难以维护。例如:
|
||
|
||
```c++
|
||
if (input.device().is_cuda()) {
|
||
#ifdef MMCV_WITH_CUDA
|
||
CHECK_CUDA_INPUT(input);
|
||
CHECK_CUDA_INPUT(rois);
|
||
CHECK_CUDA_INPUT(output);
|
||
CHECK_CUDA_INPUT(argmax_y);
|
||
CHECK_CUDA_INPUT(argmax_x);
|
||
|
||
roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x,
|
||
aligned_height, aligned_width, spatial_scale,
|
||
sampling_ratio, pool_mode, aligned);
|
||
#else
|
||
AT_ERROR("RoIAlign is not compiled with GPU support");
|
||
#endif
|
||
} else {
|
||
CHECK_CPU_INPUT(input);
|
||
CHECK_CPU_INPUT(rois);
|
||
CHECK_CPU_INPUT(output);
|
||
CHECK_CPU_INPUT(argmax_y);
|
||
CHECK_CPU_INPUT(argmax_x);
|
||
roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x,
|
||
aligned_height, aligned_width, spatial_scale,
|
||
sampling_ratio, pool_mode, aligned);
|
||
}
|
||
```
|
||
|
||
为此我们设计了注册与分发的机制以更好的管理这些算子实现。
|
||
|
||
```c++
|
||
|
||
void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output,
|
||
Tensor argmax_y, Tensor argmax_x,
|
||
int aligned_height, int aligned_width,
|
||
float spatial_scale, int sampling_ratio,
|
||
int pool_mode, bool aligned);
|
||
|
||
void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output,
|
||
Tensor argmax_y, Tensor argmax_x,
|
||
int aligned_height, int aligned_width,
|
||
float spatial_scale, int sampling_ratio,
|
||
int pool_mode, bool aligned) {
|
||
ROIAlignForwardCUDAKernelLauncher(
|
||
input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width,
|
||
spatial_scale, sampling_ratio, pool_mode, aligned);
|
||
}
|
||
|
||
// 注册算子的cuda实现
|
||
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||
Tensor argmax_y, Tensor argmax_x,
|
||
int aligned_height, int aligned_width,
|
||
float spatial_scale, int sampling_ratio,
|
||
int pool_mode, bool aligned);
|
||
REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda);
|
||
|
||
// roi_align.cpp
|
||
// 使用dispatcher根据参数中的Tensor device类型对实现进行分发
|
||
void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
|
||
Tensor argmax_y, Tensor argmax_x,
|
||
int aligned_height, int aligned_width,
|
||
float spatial_scale, int sampling_ratio,
|
||
int pool_mode, bool aligned) {
|
||
DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y,
|
||
argmax_x, aligned_height, aligned_width, spatial_scale,
|
||
sampling_ratio, pool_mode, aligned);
|
||
}
|
||
|
||
```
|
||
|
||
### v1.3.11
|
||
|
||
为了灵活地支持更多的后端和硬件,例如 `NVIDIA GPUs` 、`AMD GPUs`,我们重构了 `mmcv/ops/csrc` 目录。注意,这次重构不会影响 API 的使用。更多相关信息,请参考 [PR1206](https://github.com/open-mmlab/mmcv/pull/1206)。
|
||
|
||
原始的目录结构如下所示
|
||
|
||
```
|
||
.
|
||
├── common_cuda_helper.hpp
|
||
├── ops_cuda_kernel.cuh
|
||
├── pytorch_cpp_helper.hpp
|
||
├── pytorch_cuda_helper.hpp
|
||
├── parrots_cpp_helper.hpp
|
||
├── parrots_cuda_helper.hpp
|
||
├── parrots_cudawarpfunction.cuh
|
||
├── onnxruntime
|
||
│ ├── onnxruntime_register.h
|
||
│ ├── onnxruntime_session_options_config_keys.h
|
||
│ ├── ort_mmcv_utils.h
|
||
│ ├── ...
|
||
│ ├── onnx_ops.h
|
||
│ └── cpu
|
||
│ ├── onnxruntime_register.cpp
|
||
│ ├── ...
|
||
│ └── onnx_ops_impl.cpp
|
||
├── parrots
|
||
│ ├── ...
|
||
│ ├── ops.cpp
|
||
│ ├── ops_cuda.cu
|
||
│ ├── ops_parrots.cpp
|
||
│ └── ops_pytorch.h
|
||
├── pytorch
|
||
│ ├── ...
|
||
│ ├── ops.cpp
|
||
│ ├── ops_cuda.cu
|
||
│ ├── pybind.cpp
|
||
└── tensorrt
|
||
├── trt_cuda_helper.cuh
|
||
├── trt_plugin_helper.hpp
|
||
├── trt_plugin.hpp
|
||
├── trt_serialize.hpp
|
||
├── ...
|
||
├── trt_ops.hpp
|
||
└── plugins
|
||
├── trt_cuda_helper.cu
|
||
├── trt_plugin.cpp
|
||
├── ...
|
||
├── trt_ops.cpp
|
||
└── trt_ops_kernel.cu
|
||
```
|
||
|
||
重构之后,它的结构如下所示
|
||
|
||
```
|
||
.
|
||
├── common
|
||
│ ├── box_iou_rotated_utils.hpp
|
||
│ ├── parrots_cpp_helper.hpp
|
||
│ ├── parrots_cuda_helper.hpp
|
||
│ ├── pytorch_cpp_helper.hpp
|
||
│ ├── pytorch_cuda_helper.hpp
|
||
│ └── cuda
|
||
│ ├── common_cuda_helper.hpp
|
||
│ ├── parrots_cudawarpfunction.cuh
|
||
│ ├── ...
|
||
│ └── ops_cuda_kernel.cuh
|
||
├── onnxruntime
|
||
│ ├── onnxruntime_register.h
|
||
│ ├── onnxruntime_session_options_config_keys.h
|
||
│ ├── ort_mmcv_utils.h
|
||
│ ├── ...
|
||
│ ├── onnx_ops.h
|
||
│ └── cpu
|
||
│ ├── onnxruntime_register.cpp
|
||
│ ├── ...
|
||
│ └── onnx_ops_impl.cpp
|
||
├── parrots
|
||
│ ├── ...
|
||
│ ├── ops.cpp
|
||
│ ├── ops_parrots.cpp
|
||
│ └── ops_pytorch.h
|
||
├── pytorch
|
||
│ ├── info.cpp
|
||
│ ├── pybind.cpp
|
||
│ ├── ...
|
||
│ ├── ops.cpp
|
||
│ └── cuda
|
||
│ ├── ...
|
||
│ └── ops_cuda.cu
|
||
└── tensorrt
|
||
├── trt_cuda_helper.cuh
|
||
├── trt_plugin_helper.hpp
|
||
├── trt_plugin.hpp
|
||
├── trt_serialize.hpp
|
||
├── ...
|
||
├── trt_ops.hpp
|
||
└── plugins
|
||
├── trt_cuda_helper.cu
|
||
├── trt_plugin.cpp
|
||
├── ...
|
||
├── trt_ops.cpp
|
||
└── trt_ops_kernel.cu
|
||
```
|