mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Use ROCm backened within the PyTorch framework (#1918)
* modified code for using ROCm backened within the PyTorch framework * added hip runtime header * flake8 linting fixpull/2260/head
parent
2fb2b91aed
commit
2046a394a2
|
@ -8,7 +8,7 @@
|
|||
#include "pytorch_cuda_helper.hpp"
|
||||
#endif
|
||||
|
||||
#ifdef HIP_DIFF
|
||||
#ifdef MMCV_WITH_HIP
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
|
@ -29,7 +29,7 @@ __device__ inline int Loc2Index(const int n, const int c, const int h,
|
|||
int index = w + (h + (c + n * channel_num) * height) * width;
|
||||
return index;
|
||||
}
|
||||
#ifndef HIP_DIFF
|
||||
#ifndef MMCV_WITH_HIP
|
||||
/* TODO: move this to a common place */
|
||||
template <typename scalar_t>
|
||||
__device__ inline scalar_t min(scalar_t a, scalar_t b) {
|
||||
|
@ -44,7 +44,7 @@ __device__ inline scalar_t max(scalar_t a, scalar_t b) {
|
|||
template <typename scalar_t>
|
||||
__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
|
||||
#ifdef HIP_DIFF
|
||||
#ifdef MMCV_WITH_HIP
|
||||
val += __shfl_down(val, offset);
|
||||
#else
|
||||
val += __shfl_down_sync(FULL_MASK, val, offset);
|
||||
|
@ -55,8 +55,8 @@ __device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) {
|
|||
template <>
|
||||
__device__ __forceinline__ phalf warpReduceSum(phalf val) {
|
||||
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2)
|
||||
#ifdef HIP_DIFF
|
||||
__PHALF(val) += __shfl_down(FULL_MASK, val, offset);
|
||||
#ifdef MMCV_WITH_HIP
|
||||
__PHALF(val) += __shfl_down(val, offset);
|
||||
#else
|
||||
__PHALF(val) +=
|
||||
__shfl_down_sync(FULL_MASK, static_cast<__half>(__PHALF(val)), offset);
|
||||
|
@ -316,7 +316,7 @@ __global__ void CARAFEBackward_Mask(const int num_kernels,
|
|||
output_val += top_diff[top_id] * bottom_data[bottom_id];
|
||||
}
|
||||
}
|
||||
#ifdef HIP_DIFF
|
||||
#ifdef MMCV_WITH_HIP
|
||||
__syncthreads();
|
||||
#else
|
||||
__syncwarp();
|
||||
|
|
|
@ -78,7 +78,11 @@ __global__ void correlation_forward_cuda_kernel(
|
|||
}
|
||||
// accumulate
|
||||
for (int offset = 16; offset > 0; offset /= 2)
|
||||
#ifdef MMCV_WITH_HIP
|
||||
prod_sum += __shfl_down(float(prod_sum), offset);
|
||||
#else
|
||||
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
|
||||
#endif
|
||||
if (thread == 0) {
|
||||
output[n][ph][pw][h][w] = prod_sum;
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ __device__ __forceinline__ static void reduceMax(double *address, double val) {
|
|||
}
|
||||
|
||||
// get rid of meaningless warnings when compiling host code
|
||||
#ifdef HIP_DIFF
|
||||
#ifdef MMCV_WITH_HIP
|
||||
__device__ __forceinline__ static void reduceAdd(float *address, float val) {
|
||||
atomicAdd(address, val);
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
|
|||
#endif
|
||||
}
|
||||
#endif // __CUDA_ARCH__
|
||||
#endif // HIP_DIFF
|
||||
#endif // MMCV_WITH_HIP
|
||||
|
||||
template <typename T>
|
||||
__global__ void feats_reduce_kernel(
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
|
||||
namespace tv {
|
||||
|
||||
#ifdef __NVCC__
|
||||
#if defined(__NVCC__) || defined(__HIP__)
|
||||
#define TV_HOST_DEVICE_INLINE __forceinline__ __device__ __host__
|
||||
#define TV_DEVICE_INLINE __forceinline__ __device__
|
||||
#define TV_HOST_DEVICE __device__ __host__
|
||||
|
|
|
@ -4,7 +4,14 @@
|
|||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
#ifndef HIP_DIFF
|
||||
#ifdef MMCV_WITH_HIP
|
||||
#include <hip/hip_runtime_api.h>
|
||||
int get_hiprt_version() {
|
||||
int runtimeVersion;
|
||||
hipRuntimeGetVersion(&runtimeVersion);
|
||||
return runtimeVersion;
|
||||
}
|
||||
#else
|
||||
#include <cuda_runtime_api.h>
|
||||
int get_cudart_version() { return CUDART_VERSION; }
|
||||
#endif
|
||||
|
@ -12,7 +19,7 @@ int get_cudart_version() { return CUDART_VERSION; }
|
|||
|
||||
std::string get_compiling_cuda_version() {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
#ifndef HIP_DIFF
|
||||
#ifndef MMCV_WITH_HIP
|
||||
std::ostringstream oss;
|
||||
// copied from
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
|
||||
|
@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() {
|
|||
printCudaStyleVersion(get_cudart_version());
|
||||
return oss.str();
|
||||
#else
|
||||
return std::string("rocm not available");
|
||||
std::ostringstream oss;
|
||||
oss << get_hiprt_version();
|
||||
return oss.str();
|
||||
#endif
|
||||
#else
|
||||
return std::string("not available");
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
// Disable fp16 on ROCm device
|
||||
#ifndef HIP_DIFF
|
||||
#ifndef MMCV_WITH_HIP
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
template <>
|
||||
__global__ void bbox_overlaps_cuda_kernel<at::Half>(
|
||||
|
@ -15,8 +15,9 @@ __global__ void bbox_overlaps_cuda_kernel<at::Half>(
|
|||
reinterpret_cast<__half*>(ious), num_bbox1,
|
||||
num_bbox2, mode, aligned, offset);
|
||||
}
|
||||
|
||||
#endif // __CUDA_ARCH__ >= 530
|
||||
#endif // HIP_DIFF
|
||||
#endif // MMCV_WITH_HIP
|
||||
|
||||
void BBoxOverlapsCUDAKernelLauncher(const Tensor bboxes1, const Tensor bboxes2,
|
||||
Tensor ious, const int mode,
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <torch/script.h>
|
||||
#include "../spconv_utils.h"
|
||||
#include <utils/spconv/spconv/indice.h>
|
||||
#include <utils/spconv/spconv/reordering.h>
|
||||
|
||||
#include "../spconv_utils.h"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
torch::Tensor FusedIndiceConvBatchnormCUDAKernelLauncher(
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "../spconv_utils.h"
|
||||
#include <utils/spconv/spconv/indice.h>
|
||||
#include <utils/spconv/spconv/mp_helper.h>
|
||||
#include <utils/spconv/tensorview/helper_launch.h>
|
||||
|
@ -23,7 +24,6 @@
|
|||
#include <spconv/indice.cuh>
|
||||
#include <type_traits>
|
||||
|
||||
#include "../spconv_utils.h"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
namespace functor {
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "../spconv_utils.h"
|
||||
#include <utils/spconv/spconv/maxpool.h>
|
||||
#include <utils/spconv/spconv/mp_helper.h>
|
||||
#include <utils/spconv/tensorview/helper_launch.h>
|
||||
|
@ -23,7 +24,6 @@
|
|||
#include <type_traits>
|
||||
#include <utils/spconv/tensorview/helper_kernel.cuh>
|
||||
|
||||
#include "../spconv_utils.h"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
template <typename scalar_t, typename Index, int NumTLP, int NumILP>
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <torch/script.h>
|
||||
#include "../spconv_utils.h"
|
||||
#include <utils/spconv/spconv/maxpool.h>
|
||||
|
||||
#include "../spconv_utils.h"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
torch::Tensor IndiceMaxpoolForwardCUDAKernelLauncher(torch::Tensor features,
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "../spconv_utils.h"
|
||||
#include <utils/spconv/spconv/mp_helper.h>
|
||||
#include <utils/spconv/spconv/reordering.h>
|
||||
#include <utils/spconv/tensorview/helper_launch.h>
|
||||
|
@ -24,7 +25,6 @@
|
|||
#include <type_traits>
|
||||
#include <utils/spconv/tensorview/helper_kernel.cuh>
|
||||
|
||||
#include "../spconv_utils.h"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
namespace functor {
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
#include <cuda_runtime_api.h>
|
||||
#include <torch/script.h>
|
||||
#include "../spconv_utils.h"
|
||||
#include <utils/spconv/spconv/indice.h>
|
||||
#include <utils/spconv/spconv/reordering.h>
|
||||
|
||||
#include "../spconv_utils.h"
|
||||
#include "pytorch_cuda_helper.hpp"
|
||||
|
||||
template <unsigned NDim>
|
||||
|
|
|
@ -4,7 +4,14 @@
|
|||
#include "pytorch_cpp_helper.hpp"
|
||||
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
#ifndef HIP_DIFF
|
||||
#ifdef MMCV_WITH_HIP
|
||||
#include <hip/hip_runtime_api.h>
|
||||
int get_hiprt_version() {
|
||||
int runtimeVersion;
|
||||
hipRuntimeGetVersion(&runtimeVersion);
|
||||
return runtimeVersion;
|
||||
}
|
||||
#else
|
||||
#include <cuda_runtime_api.h>
|
||||
int get_cudart_version() { return CUDART_VERSION; }
|
||||
#endif
|
||||
|
@ -12,7 +19,7 @@ int get_cudart_version() { return CUDART_VERSION; }
|
|||
|
||||
std::string get_compiling_cuda_version() {
|
||||
#ifdef MMCV_WITH_CUDA
|
||||
#ifndef HIP_DIFF
|
||||
#ifndef MMCV_WITH_HIP
|
||||
std::ostringstream oss;
|
||||
// copied from
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231
|
||||
|
@ -25,7 +32,9 @@ std::string get_compiling_cuda_version() {
|
|||
printCudaStyleVersion(get_cudart_version());
|
||||
return oss.str();
|
||||
#else
|
||||
return std::string("rocm not available");
|
||||
std::ostringstream oss;
|
||||
oss << get_hiprt_version();
|
||||
return oss.str();
|
||||
#endif
|
||||
#else
|
||||
return std::string("not available");
|
||||
|
|
|
@ -55,15 +55,27 @@ def collect_env():
|
|||
env_info['CUDA_HOME'] = CUDA_HOME
|
||||
|
||||
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
|
||||
try:
|
||||
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
|
||||
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
|
||||
nvcc = nvcc.decode('utf-8').strip()
|
||||
release = nvcc.rfind('Cuda compilation tools')
|
||||
build = nvcc.rfind('Build ')
|
||||
nvcc = nvcc[release:build].strip()
|
||||
except subprocess.SubprocessError:
|
||||
nvcc = 'Not Available'
|
||||
if CUDA_HOME == '/opt/rocm':
|
||||
try:
|
||||
nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc')
|
||||
nvcc = subprocess.check_output(
|
||||
f'"{nvcc}" --version', shell=True)
|
||||
nvcc = nvcc.decode('utf-8').strip()
|
||||
release = nvcc.rfind('HIP version:')
|
||||
build = nvcc.rfind('')
|
||||
nvcc = nvcc[release:build].strip()
|
||||
except subprocess.SubprocessError:
|
||||
nvcc = 'Not Available'
|
||||
else:
|
||||
try:
|
||||
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
|
||||
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
|
||||
nvcc = nvcc.decode('utf-8').strip()
|
||||
release = nvcc.rfind('Cuda compilation tools')
|
||||
build = nvcc.rfind('Build ')
|
||||
nvcc = nvcc[release:build].strip()
|
||||
except subprocess.SubprocessError:
|
||||
nvcc = 'Not Available'
|
||||
env_info['NVCC'] = nvcc
|
||||
|
||||
try:
|
||||
|
|
3
setup.py
3
setup.py
|
@ -280,7 +280,7 @@ def get_extensions():
|
|||
if is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
|
||||
'FORCE_CUDA', '0') == '1':
|
||||
if is_rocm_pytorch:
|
||||
define_macros += [('HIP_DIFF', None)]
|
||||
define_macros += [('MMCV_WITH_HIP', None)]
|
||||
define_macros += [('MMCV_WITH_CUDA', None)]
|
||||
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
|
||||
|
@ -289,6 +289,7 @@ def get_extensions():
|
|||
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \
|
||||
glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cpp')
|
||||
extension = CUDAExtension
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/pytorch'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common'))
|
||||
include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/cuda'))
|
||||
elif (hasattr(torch, 'is_mlu_available') and
|
||||
|
|
Loading…
Reference in New Issue