[Fix] Use ROCm backened within the PyTorch framework (#1918)

* modified code for using ROCm backened within the PyTorch framework

* added hip runtime header

* flake8 linting fix
pull/2260/head
Zachary Streeter 2022-09-13 02:43:48 -05:00 committed by GitHub
parent 2fb2b91aed
commit 2046a394a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 69 additions and 33 deletions

View File

@ -8,7 +8,7 @@
#include "pytorch_cuda_helper.hpp"
#endif
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
@ -29,7 +29,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
int index = w + (h + (c + n * channel_num) * height) * width;
return index;
}
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
/* TODO: move this to a common place */
template <typename scalar_t>
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
@ -44,7 +44,7 @@ __device__ inline scalar_t max(scalar_t a, scalar_t b) {
template <typename scalar_t>
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset);
@ -55,8 +55,8 @@ __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
template <>
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
#ifdef HIP_DIFF
__PHALF(val) += __shfl_down(FULL_MASK, val, offset);
#ifdef MMCV_WITH_HIP
__PHALF(val) += __shfl_down(val, offset);
#else
__PHALF(val) +=
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
@ -316,7 +316,7 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
output_val += top_diff[top_id] * bottom_data[bottom_id];
}
}
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
__syncthreads();
#else
__syncwarp();

View File

@ -78,7 +78,11 @@ __global__ void correlation_forward_cuda_kernel(
}
// accumulate
for (int offset = 16; offset > 0; offset /= 2)
#ifdef MMCV_WITH_HIP
prod_sum += __shfl_down(float(prod_sum), offset);
#else
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
#endif
if (thread == 0) {
output[n][ph][pw][h][w] = prod_sum;
}

View File

@ -34,7 +34,7 @@ __device__ __forceinline__ static void reduceMax(double *address, double val) {
}
// get rid of meaningless warnings when compiling host code
#ifdef HIP_DIFF
#ifdef MMCV_WITH_HIP
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
atomicAdd(address, val);
}
@ -86,7 +86,7 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
#endif
}
#endif // __CUDA_ARCH__
#endif // HIP_DIFF
#endif // MMCV_WITH_HIP
template <typename T>
__global__ void feats_reduce_kernel(

View File

@ -27,7 +27,7 @@
namespace tv {
#ifdef __NVCC__
#if defined(__NVCC__) || defined(__HIP__)
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__
#define TV_DEVICE_INLINE __forceinline__ __device__
#define TV_HOST_DEVICE __device__ __host__

View File

@ -4,7 +4,14 @@
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifdef MMCV_WITH_HIP
#include <hip/hip_runtime_api.h>
int get_hiprt_version() {
int runtimeVersion;
hipRuntimeGetVersion(&runtimeVersion);
return runtimeVersion;
}
#else
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
@ -12,7 +19,7 @@ int get_cudart_version() { return CUDART_VERSION; }
std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() {
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not available");
std::ostringstream oss;
oss << get_hiprt_version();
return oss.str();
#endif
#else
return std::string("not available");

View File

@ -3,7 +3,7 @@
#include "pytorch_cuda_helper.hpp"
// Disable fp16 on ROCm device
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
#if __CUDA_ARCH__ >= 530
template <>
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
@ -15,8 +15,9 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
reinterpret_cast<__half*>(ious), num_bbox1,
num_bbox2, mode, aligned, offset);
}
#endif // __CUDA_ARCH__ >= 530
#endif // HIP_DIFF
#endif // MMCV_WITH_HIP
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
Tensor ious, const int mode,

View File

@ -1,9 +1,9 @@
#include <cuda_runtime_api.h>
#include <torch/script.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/indice.h>
#include <utils/spconv/spconv/reordering.h>
#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"
torch::Tensor FusedIndiceConvBatchnormCUDAKernelLauncher(

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <ATen/ATen.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/indice.h>
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/tensorview/helper_launch.h>
@ -23,7 +24,6 @@
#include <spconv/indice.cuh>
#include <type_traits>
#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"
namespace functor {

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <ATen/ATen.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/maxpool.h>
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/tensorview/helper_launch.h>
@ -23,7 +24,6 @@
#include <type_traits>
#include <utils/spconv/tensorview/helper_kernel.cuh>
#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"
template <typename scalar_t, typename Index, int NumTLP, int NumILP>

View File

@ -1,8 +1,8 @@
#include <cuda_runtime_api.h>
#include <torch/script.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/maxpool.h>
#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"
torch::Tensor IndiceMaxpoolForwardCUDAKernelLauncher(torch::Tensor features,

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include <ATen/ATen.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/mp_helper.h>
#include <utils/spconv/spconv/reordering.h>
#include <utils/spconv/tensorview/helper_launch.h>
@ -24,7 +25,6 @@
#include <type_traits>
#include <utils/spconv/tensorview/helper_kernel.cuh>
#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"
namespace functor {

View File

@ -1,9 +1,9 @@
#include <cuda_runtime_api.h>
#include <torch/script.h>
#include "../spconv_utils.h"
#include <utils/spconv/spconv/indice.h>
#include <utils/spconv/spconv/reordering.h>
#include "../spconv_utils.h"
#include "pytorch_cuda_helper.hpp"
template <unsigned NDim>

View File

@ -4,7 +4,14 @@
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifdef MMCV_WITH_HIP
#include <hip/hip_runtime_api.h>
int get_hiprt_version() {
int runtimeVersion;
hipRuntimeGetVersion(&runtimeVersion);
return runtimeVersion;
}
#else
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
@ -12,7 +19,7 @@ int get_cudart_version() { return CUDART_VERSION; }
std::string get_compiling_cuda_version() {
#ifdef MMCV_WITH_CUDA
#ifndef HIP_DIFF
#ifndef MMCV_WITH_HIP
std::ostringstream oss;
// copied from
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() {
printCudaStyleVersion(get_cudart_version());
return oss.str();
#else
return std::string("rocm not available");
std::ostringstream oss;
oss << get_hiprt_version();
return oss.str();
#endif
#else
return std::string("not available");

View File

@ -55,15 +55,27 @@ def collect_env():
env_info['CUDA_HOME'] = CUDA_HOME
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
try:
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('Cuda compilation tools')
build = nvcc.rfind('Build ')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
if CUDA_HOME == '/opt/rocm':
try:
nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc')
nvcc = subprocess.check_output(
f'"{nvcc}" --version', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('HIP version:')
build = nvcc.rfind('')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
else:
try:
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode('utf-8').strip()
release = nvcc.rfind('Cuda compilation tools')
build = nvcc.rfind('Build ')
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = 'Not Available'
env_info['NVCC'] = nvcc
try:

View File

@ -280,7 +280,7 @@ def get_extensions():
if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
'FORCE_CUDA', '0') == '1':
if is_rocm_pytorch:
define_macros += [('HIP_DIFF', None)]
define_macros += [('MMCV_WITH_HIP', None)]
define_macros += [('MMCV_WITH_CUDA', None)]
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
@ -289,6 +289,7 @@ def get_extensions():
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp')
extension = CUDAExtension
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
elif (hasattr(torch, 'is_mlu_available') and