remove cuda args ()

pull/373/head
zhuyuanhao 2020-06-29 21:14:09 +08:00 committed by GitHub
parent d9549fba04
commit e43fe0e243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 15 deletions

View File

@ -153,22 +153,17 @@ def get_extensions():
ext_name = 'mmcv._ext'
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension, Extension
cuda_args = [
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
define_macros = [('MMCV_USE_PARROTS', None)]
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
include_path = os.path.abspath('./mmcv/ops/csrc')
cuda_args = os.getenv('MMCV_CUDA_ARGS')
ext_ops = Extension(
name=ext_name,
sources=op_files,
include_dirs=[include_path],
define_macros=define_macros,
extra_compile_args={
'nvcc': cuda_args,
'nvcc': [cuda_args] if cuda_args else [],
'cxx': [],
},
cuda=True)
@ -178,20 +173,14 @@ def get_extensions():
CUDAExtension, CppExtension)
# prevent ninja from using too many resources
os.environ.setdefault('MAX_JOBS', '4')
cuda_args = [
'-gencode=arch=compute_52,code=sm_52',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70'
]
define_macros = []
extra_compile_args = {'cxx': []}
if (torch.cuda.is_available()
or os.getenv('FORCE_CUDA', '0') == '1'):
define_macros += [('MMCV_WITH_CUDA', None)]
extra_compile_args['nvcc'] = cuda_args
cuda_args = os.getenv('MMCV_CUDA_ARGS')
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
extension = CUDAExtension
else: