### v1.3.18 部分自定义算子对于不同的设备有不同实现,为此添加的大量宏命令与类型检查使得代码变得难以维护。例如: ```c++ if (input.device().is_cuda()) { #ifdef MMCV_WITH_CUDA CHECK_CUDA_INPUT(input); CHECK_CUDA_INPUT(rois); CHECK_CUDA_INPUT(output); CHECK_CUDA_INPUT(argmax_y); CHECK_CUDA_INPUT(argmax_x); roi_align_forward_cuda(input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); #else AT_ERROR("RoIAlign is not compiled with GPU support"); #endif } else { CHECK_CPU_INPUT(input); CHECK_CPU_INPUT(rois); CHECK_CPU_INPUT(output); CHECK_CPU_INPUT(argmax_y); CHECK_CPU_INPUT(argmax_x); roi_align_forward_cpu(input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); } ``` 为此我们设计了注册与分发的机制以更好的管理这些算子实现。 ```c++ void ROIAlignForwardCUDAKernelLauncher(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned); void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned) { ROIAlignForwardCUDAKernelLauncher( input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); } // 注册算子的cuda实现 void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned); REGISTER_DEVICE_IMPL(roi_align_forward_impl, CUDA, roi_align_forward_cuda); // roi_align.cpp // 使用dispatcher根据参数中的Tensor device类型对实现进行分发 void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, int aligned_height, int aligned_width, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned) { DISPATCH_DEVICE_IMPL(roi_align_forward_impl, input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); } ``` ### v1.3.11 为了灵活地支持更多的后端和硬件,例如 `NVIDIA GPUs` 、`AMD GPUs`,我们重构了 `mmcv/ops/csrc` 目录。注意,这次重构不会影响 API 的使用。更多相关信息,请参考 [PR1206](https://github.com/open-mmlab/mmcv/pull/1206)。 原始的目录结构如下所示 ``` . ├── common_cuda_helper.hpp ├── ops_cuda_kernel.cuh ├── pytorch_cpp_helper.hpp ├── pytorch_cuda_helper.hpp ├── parrots_cpp_helper.hpp ├── parrots_cuda_helper.hpp ├── parrots_cudawarpfunction.cuh ├── onnxruntime │   ├── onnxruntime_register.h │   ├── onnxruntime_session_options_config_keys.h │   ├── ort_mmcv_utils.h │   ├── ... │   ├── onnx_ops.h │   └── cpu │ ├── onnxruntime_register.cpp │      ├── ... │      └── onnx_ops_impl.cpp ├── parrots │   ├── ... │   ├── ops.cpp │   ├── ops_cuda.cu │   ├── ops_parrots.cpp │   └── ops_pytorch.h ├── pytorch │   ├── ... │   ├── ops.cpp │   ├── ops_cuda.cu │   ├── pybind.cpp └── tensorrt ├── trt_cuda_helper.cuh ├── trt_plugin_helper.hpp ├── trt_plugin.hpp ├── trt_serialize.hpp ├── ... ├── trt_ops.hpp └── plugins    ├── trt_cuda_helper.cu    ├── trt_plugin.cpp    ├── ...    ├── trt_ops.cpp    └── trt_ops_kernel.cu ``` 重构之后,它的结构如下所示 ``` . ├── common │ ├── box_iou_rotated_utils.hpp │ ├── parrots_cpp_helper.hpp │ ├── parrots_cuda_helper.hpp │ ├── pytorch_cpp_helper.hpp │ ├── pytorch_cuda_helper.hpp │   └── cuda │   ├── common_cuda_helper.hpp │   ├── parrots_cudawarpfunction.cuh │   ├── ... │   └── ops_cuda_kernel.cuh ├── onnxruntime │   ├── onnxruntime_register.h │   ├── onnxruntime_session_options_config_keys.h │   ├── ort_mmcv_utils.h │   ├── ... │   ├── onnx_ops.h │   └── cpu │ ├── onnxruntime_register.cpp │      ├── ... │      └── onnx_ops_impl.cpp ├── parrots │   ├── ... │   ├── ops.cpp │   ├── ops_parrots.cpp │   └── ops_pytorch.h ├── pytorch │   ├── info.cpp │   ├── pybind.cpp │   ├── ... │   ├── ops.cpp │   └── cuda │      ├── ... │      └── ops_cuda.cu └── tensorrt ├── trt_cuda_helper.cuh ├── trt_plugin_helper.hpp ├── trt_plugin.hpp ├── trt_serialize.hpp ├── ... ├── trt_ops.hpp └── plugins    ├── trt_cuda_helper.cu    ├── trt_plugin.cpp    ├── ...    ├── trt_ops.cpp    └── trt_ops_kernel.cu ```