mirror of https://github.com/open-mmlab/mmcv.git
remove cuda args (#372)
parent
d9549fba04
commit
e43fe0e243
19
setup.py
19
setup.py
|
@ -153,22 +153,17 @@ def get_extensions():
|
||||||
ext_name = 'mmcv._ext'
|
ext_name = 'mmcv._ext'
|
||||||
if torch.__version__ == 'parrots':
|
if torch.__version__ == 'parrots':
|
||||||
from parrots.utils.build_extension import BuildExtension, Extension
|
from parrots.utils.build_extension import BuildExtension, Extension
|
||||||
cuda_args = [
|
|
||||||
'-gencode=arch=compute_60,code=sm_60',
|
|
||||||
'-gencode=arch=compute_61,code=sm_61',
|
|
||||||
'-gencode=arch=compute_70,code=sm_70',
|
|
||||||
'-gencode=arch=compute_70,code=compute_70'
|
|
||||||
]
|
|
||||||
define_macros = [('MMCV_USE_PARROTS', None)]
|
define_macros = [('MMCV_USE_PARROTS', None)]
|
||||||
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
|
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
|
||||||
include_path = os.path.abspath('./mmcv/ops/csrc')
|
include_path = os.path.abspath('./mmcv/ops/csrc')
|
||||||
|
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||||
ext_ops = Extension(
|
ext_ops = Extension(
|
||||||
name=ext_name,
|
name=ext_name,
|
||||||
sources=op_files,
|
sources=op_files,
|
||||||
include_dirs=[include_path],
|
include_dirs=[include_path],
|
||||||
define_macros=define_macros,
|
define_macros=define_macros,
|
||||||
extra_compile_args={
|
extra_compile_args={
|
||||||
'nvcc': cuda_args,
|
'nvcc': [cuda_args] if cuda_args else [],
|
||||||
'cxx': [],
|
'cxx': [],
|
||||||
},
|
},
|
||||||
cuda=True)
|
cuda=True)
|
||||||
|
@ -178,20 +173,14 @@ def get_extensions():
|
||||||
CUDAExtension, CppExtension)
|
CUDAExtension, CppExtension)
|
||||||
# prevent ninja from using too many resources
|
# prevent ninja from using too many resources
|
||||||
os.environ.setdefault('MAX_JOBS', '4')
|
os.environ.setdefault('MAX_JOBS', '4')
|
||||||
cuda_args = [
|
|
||||||
'-gencode=arch=compute_52,code=sm_52',
|
|
||||||
'-gencode=arch=compute_60,code=sm_60',
|
|
||||||
'-gencode=arch=compute_61,code=sm_61',
|
|
||||||
'-gencode=arch=compute_70,code=sm_70',
|
|
||||||
'-gencode=arch=compute_70,code=compute_70'
|
|
||||||
]
|
|
||||||
define_macros = []
|
define_macros = []
|
||||||
extra_compile_args = {'cxx': []}
|
extra_compile_args = {'cxx': []}
|
||||||
|
|
||||||
if (torch.cuda.is_available()
|
if (torch.cuda.is_available()
|
||||||
or os.getenv('FORCE_CUDA', '0') == '1'):
|
or os.getenv('FORCE_CUDA', '0') == '1'):
|
||||||
define_macros += [('MMCV_WITH_CUDA', None)]
|
define_macros += [('MMCV_WITH_CUDA', None)]
|
||||||
extra_compile_args['nvcc'] = cuda_args
|
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||||
|
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
|
||||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
|
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
|
||||||
extension = CUDAExtension
|
extension = CUDAExtension
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue