From 2046a394a2657b3cb22efe90fd3640e051406fa0 Mon Sep 17 00:00:00 2001 From: Zachary Streeter <90640993+zstreet87@users.noreply.github.com> Date: Tue, 13 Sep 2022 02:43:48 -0500 Subject: [PATCH] [Fix] Use ROCm backened within the PyTorch framework (#1918) * modified code for using ROCm backened within the PyTorch framework * added hip runtime header * flake8 linting fix --- .../csrc/common/cuda/carafe_cuda_kernel.cuh | 12 ++++---- .../ops/csrc/common/cuda/correlation_cuda.cuh | 4 +++ .../cuda/scatter_points_cuda_kernel.cuh | 4 +-- .../utils/spconv/tensorview/tensorview.h | 2 +- mmcv/ops/csrc/parrots/info.cpp | 15 ++++++++-- .../csrc/pytorch/cuda/bbox_overlaps_cuda.cu | 5 ++-- .../pytorch/cuda/fused_spconv_ops_cuda.cu | 2 +- mmcv/ops/csrc/pytorch/cuda/sparse_indice.cu | 2 +- mmcv/ops/csrc/pytorch/cuda/sparse_maxpool.cu | 2 +- .../csrc/pytorch/cuda/sparse_pool_ops_cuda.cu | 2 +- .../csrc/pytorch/cuda/sparse_reordering.cu | 2 +- mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu | 2 +- mmcv/ops/csrc/pytorch/info.cpp | 15 ++++++++-- mmcv/utils/env.py | 30 +++++++++++++------ setup.py | 3 +- 15 files changed, 69 insertions(+), 33 deletions(-) diff --git a/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh index e7fa990fe..7c5f4f3b8 100644 --- a/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/carafe_cuda_kernel.cuh @@ -8,7 +8,7 @@ #include "pytorch_cuda_helper.hpp" #endif -#ifdef HIP_DIFF +#ifdef MMCV_WITH_HIP #define WARP_SIZE 64 #else #define WARP_SIZE 32 @@ -29,7 +29,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h, int index = w + (h + (c + n * channel_num) * height) * width; return index; } -#ifndef HIP_DIFF +#ifndef MMCV_WITH_HIP /* TODO: move this to a common place */ template __device__ inline scalar_t min(scalar_t a, scalar_t b) { @@ -44,7 +44,7 @@ __device__ inline scalar_t max(scalar_t a, scalar_t b) { template __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) -#ifdef HIP_DIFF +#ifdef MMCV_WITH_HIP val += __shfl_down(val, offset); #else val += __shfl_down_sync(FULL_MASK, val, offset); @@ -55,8 +55,8 @@ __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { template <> __device__ __forceinline__ phalf warpReduceSum(phalf val) { for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) -#ifdef HIP_DIFF - __PHALF(val) += __shfl_down(FULL_MASK, val, offset); +#ifdef MMCV_WITH_HIP + __PHALF(val) += __shfl_down(val, offset); #else __PHALF(val) += __shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset); @@ -316,7 +316,7 @@ __global__ void CARAFEBackward_Mask(const int num_kernels, output_val += top_diff[top_id] * bottom_data[bottom_id]; } } -#ifdef HIP_DIFF +#ifdef MMCV_WITH_HIP __syncthreads(); #else __syncwarp(); diff --git a/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh b/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh index 2f7f11298..703de8232 100644 --- a/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh +++ b/mmcv/ops/csrc/common/cuda/correlation_cuda.cuh @@ -78,7 +78,11 @@ __global__ void correlation_forward_cuda_kernel( } // accumulate for (int offset = 16; offset > 0; offset /= 2) +#ifdef MMCV_WITH_HIP + prod_sum += __shfl_down(float(prod_sum), offset); +#else prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset); +#endif if (thread == 0) { output[n][ph][pw][h][w] = prod_sum; } diff --git a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh index 7f9c40202..af5b9f67b 100644 --- a/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/scatter_points_cuda_kernel.cuh @@ -34,7 +34,7 @@ __device__ __forceinline__ static void reduceMax(double *address, double val) { } // get rid of meaningless warnings when compiling host code -#ifdef HIP_DIFF +#ifdef MMCV_WITH_HIP __device__ __forceinline__ static void reduceAdd(float *address, float val) { atomicAdd(address, val); } @@ -86,7 +86,7 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) { #endif } #endif // __CUDA_ARCH__ -#endif // HIP_DIFF +#endif // MMCV_WITH_HIP template __global__ void feats_reduce_kernel( diff --git a/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h b/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h index cb2f018a9..27745beaa 100644 --- a/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h +++ b/mmcv/ops/csrc/common/utils/spconv/tensorview/tensorview.h @@ -27,7 +27,7 @@ namespace tv { -#ifdef __NVCC__ +#if defined(__NVCC__) || defined(__HIP__) #define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__ #define TV_DEVICE_INLINE __forceinline__ __device__ #define TV_HOST_DEVICE __device__ __host__ diff --git a/mmcv/ops/csrc/parrots/info.cpp b/mmcv/ops/csrc/parrots/info.cpp index a08d227d4..a4cc41861 100644 --- a/mmcv/ops/csrc/parrots/info.cpp +++ b/mmcv/ops/csrc/parrots/info.cpp @@ -4,7 +4,14 @@ #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA -#ifndef HIP_DIFF +#ifdef MMCV_WITH_HIP +#include +int get_hiprt_version() { + int runtimeVersion; + hipRuntimeGetVersion(&runtimeVersion); + return runtimeVersion; +} +#else #include int get_cudart_version() { return CUDART_VERSION; } #endif @@ -12,7 +19,7 @@ int get_cudart_version() { return CUDART_VERSION; } std::string get_compiling_cuda_version() { #ifdef MMCV_WITH_CUDA -#ifndef HIP_DIFF +#ifndef MMCV_WITH_HIP std::ostringstream oss; // copied from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 @@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() { printCudaStyleVersion(get_cudart_version()); return oss.str(); #else - return std::string("rocm not available"); + std::ostringstream oss; + oss << get_hiprt_version(); + return oss.str(); #endif #else return std::string("not available"); diff --git a/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu index b3272539b..7dae535cf 100644 --- a/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/bbox_overlaps_cuda.cu @@ -3,7 +3,7 @@ #include "pytorch_cuda_helper.hpp" // Disable fp16 on ROCm device -#ifndef HIP_DIFF +#ifndef MMCV_WITH_HIP #if __CUDA_ARCH__ >= 530 template <> __global__ void bbox_overlaps_cuda_kernel( @@ -15,8 +15,9 @@ __global__ void bbox_overlaps_cuda_kernel( reinterpret_cast<__half*>(ious), num_bbox1, num_bbox2, mode, aligned, offset); } + #endif // __CUDA_ARCH__ >= 530 -#endif // HIP_DIFF +#endif // MMCV_WITH_HIP void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, const int mode, diff --git a/mmcv/ops/csrc/pytorch/cuda/fused_spconv_ops_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/fused_spconv_ops_cuda.cu index 1b161d392..959bcd2fd 100644 --- a/mmcv/ops/csrc/pytorch/cuda/fused_spconv_ops_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/fused_spconv_ops_cuda.cu @@ -1,9 +1,9 @@ #include #include +#include "../spconv_utils.h" #include #include -#include "../spconv_utils.h" #include "pytorch_cuda_helper.hpp" torch::Tensor FusedIndiceConvBatchnormCUDAKernelLauncher( diff --git a/mmcv/ops/csrc/pytorch/cuda/sparse_indice.cu b/mmcv/ops/csrc/pytorch/cuda/sparse_indice.cu index 89a2d3af8..9226673e6 100644 --- a/mmcv/ops/csrc/pytorch/cuda/sparse_indice.cu +++ b/mmcv/ops/csrc/pytorch/cuda/sparse_indice.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "../spconv_utils.h" #include #include #include @@ -23,7 +24,6 @@ #include #include -#include "../spconv_utils.h" #include "pytorch_cuda_helper.hpp" namespace functor { diff --git a/mmcv/ops/csrc/pytorch/cuda/sparse_maxpool.cu b/mmcv/ops/csrc/pytorch/cuda/sparse_maxpool.cu index 1addf2e98..746fd0f2f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/sparse_maxpool.cu +++ b/mmcv/ops/csrc/pytorch/cuda/sparse_maxpool.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "../spconv_utils.h" #include #include #include @@ -23,7 +24,6 @@ #include #include -#include "../spconv_utils.h" #include "pytorch_cuda_helper.hpp" template diff --git a/mmcv/ops/csrc/pytorch/cuda/sparse_pool_ops_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/sparse_pool_ops_cuda.cu index 44ca42e3f..6d4135f09 100644 --- a/mmcv/ops/csrc/pytorch/cuda/sparse_pool_ops_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/sparse_pool_ops_cuda.cu @@ -1,8 +1,8 @@ #include #include +#include "../spconv_utils.h" #include -#include "../spconv_utils.h" #include "pytorch_cuda_helper.hpp" torch::Tensor IndiceMaxpoolForwardCUDAKernelLauncher(torch::Tensor features, diff --git a/mmcv/ops/csrc/pytorch/cuda/sparse_reordering.cu b/mmcv/ops/csrc/pytorch/cuda/sparse_reordering.cu index 2929a7577..d004a1cd9 100644 --- a/mmcv/ops/csrc/pytorch/cuda/sparse_reordering.cu +++ b/mmcv/ops/csrc/pytorch/cuda/sparse_reordering.cu @@ -13,6 +13,7 @@ // limitations under the License. #include +#include "../spconv_utils.h" #include #include #include @@ -24,7 +25,6 @@ #include #include -#include "../spconv_utils.h" #include "pytorch_cuda_helper.hpp" namespace functor { diff --git a/mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu index e1a0e1a73..985e98694 100644 --- a/mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu @@ -1,9 +1,9 @@ #include #include +#include "../spconv_utils.h" #include #include -#include "../spconv_utils.h" #include "pytorch_cuda_helper.hpp" template diff --git a/mmcv/ops/csrc/pytorch/info.cpp b/mmcv/ops/csrc/pytorch/info.cpp index a08d227d4..a4cc41861 100644 --- a/mmcv/ops/csrc/pytorch/info.cpp +++ b/mmcv/ops/csrc/pytorch/info.cpp @@ -4,7 +4,14 @@ #include "pytorch_cpp_helper.hpp" #ifdef MMCV_WITH_CUDA -#ifndef HIP_DIFF +#ifdef MMCV_WITH_HIP +#include +int get_hiprt_version() { + int runtimeVersion; + hipRuntimeGetVersion(&runtimeVersion); + return runtimeVersion; +} +#else #include int get_cudart_version() { return CUDART_VERSION; } #endif @@ -12,7 +19,7 @@ int get_cudart_version() { return CUDART_VERSION; } std::string get_compiling_cuda_version() { #ifdef MMCV_WITH_CUDA -#ifndef HIP_DIFF +#ifndef MMCV_WITH_HIP std::ostringstream oss; // copied from // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 @@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() { printCudaStyleVersion(get_cudart_version()); return oss.str(); #else - return std::string("rocm not available"); + std::ostringstream oss; + oss << get_hiprt_version(); + return oss.str(); #endif #else return std::string("not available"); diff --git a/mmcv/utils/env.py b/mmcv/utils/env.py index 511332506..83f42b600 100644 --- a/mmcv/utils/env.py +++ b/mmcv/utils/env.py @@ -55,15 +55,27 @@ def collect_env(): env_info['CUDA_HOME'] = CUDA_HOME if CUDA_HOME is not None and osp.isdir(CUDA_HOME): - try: - nvcc = osp.join(CUDA_HOME, 'bin/nvcc') - nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True) - nvcc = nvcc.decode('utf-8').strip() - release = nvcc.rfind('Cuda compilation tools') - build = nvcc.rfind('Build ') - nvcc = nvcc[release:build].strip() - except subprocess.SubprocessError: - nvcc = 'Not Available' + if CUDA_HOME == '/opt/rocm': + try: + nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc') + nvcc = subprocess.check_output( + f'"{nvcc}" --version', shell=True) + nvcc = nvcc.decode('utf-8').strip() + release = nvcc.rfind('HIP version:') + build = nvcc.rfind('') + nvcc = nvcc[release:build].strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' + else: + try: + nvcc = osp.join(CUDA_HOME, 'bin/nvcc') + nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True) + nvcc = nvcc.decode('utf-8').strip() + release = nvcc.rfind('Cuda compilation tools') + build = nvcc.rfind('Build ') + nvcc = nvcc[release:build].strip() + except subprocess.SubprocessError: + nvcc = 'Not Available' env_info['NVCC'] = nvcc try: diff --git a/setup.py b/setup.py index 274c13de3..e9ccc3717 100644 --- a/setup.py +++ b/setup.py @@ -280,7 +280,7 @@ def get_extensions(): if is_rocm_pytorch or torch.cuda.is_available() or os.getenv( 'FORCE_CUDA', '0') == '1': if is_rocm_pytorch: - define_macros += [('HIP_DIFF', None)] + define_macros += [('MMCV_WITH_HIP', None)] define_macros += [('MMCV_WITH_CUDA', None)] cuda_args = os.getenv('MMCV_CUDA_ARGS') extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] @@ -289,6 +289,7 @@ def get_extensions(): glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \ glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp') extension = CUDAExtension + include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda')) elif (hasattr(torch, 'is_mlu_available') and