mirror of https://github.com/open-mmlab/mmcv.git
remove cuda args (#372)
parent
d9549fba04
commit
e43fe0e243
19
setup.py
19
setup.py
|
@ -153,22 +153,17 @@ def get_extensions():
|
|||
ext_name = 'mmcv._ext'
|
||||
if torch.__version__ == 'parrots':
|
||||
from parrots.utils.build_extension import BuildExtension, Extension
|
||||
cuda_args = [
|
||||
'-gencode=arch=compute_60,code=sm_60',
|
||||
'-gencode=arch=compute_61,code=sm_61',
|
||||
'-gencode=arch=compute_70,code=sm_70',
|
||||
'-gencode=arch=compute_70,code=compute_70'
|
||||
]
|
||||
define_macros = [('MMCV_USE_PARROTS', None)]
|
||||
op_files = glob.glob('./mmcv/ops/csrc/parrots/*')
|
||||
include_path = os.path.abspath('./mmcv/ops/csrc')
|
||||
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||
ext_ops = Extension(
|
||||
name=ext_name,
|
||||
sources=op_files,
|
||||
include_dirs=[include_path],
|
||||
define_macros=define_macros,
|
||||
extra_compile_args={
|
||||
'nvcc': cuda_args,
|
||||
'nvcc': [cuda_args] if cuda_args else [],
|
||||
'cxx': [],
|
||||
},
|
||||
cuda=True)
|
||||
|
@ -178,20 +173,14 @@ def get_extensions():
|
|||
CUDAExtension, CppExtension)
|
||||
# prevent ninja from using too many resources
|
||||
os.environ.setdefault('MAX_JOBS', '4')
|
||||
cuda_args = [
|
||||
'-gencode=arch=compute_52,code=sm_52',
|
||||
'-gencode=arch=compute_60,code=sm_60',
|
||||
'-gencode=arch=compute_61,code=sm_61',
|
||||
'-gencode=arch=compute_70,code=sm_70',
|
||||
'-gencode=arch=compute_70,code=compute_70'
|
||||
]
|
||||
define_macros = []
|
||||
extra_compile_args = {'cxx': []}
|
||||
|
||||
if (torch.cuda.is_available()
|
||||
or os.getenv('FORCE_CUDA', '0') == '1'):
|
||||
define_macros += [('MMCV_WITH_CUDA', None)]
|
||||
extra_compile_args['nvcc'] = cuda_args
|
||||
cuda_args = os.getenv('MMCV_CUDA_ARGS')
|
||||
extra_compile_args['nvcc'] = [cuda_args] if cuda_args else []
|
||||
op_files = glob.glob('./mmcv/ops/csrc/pytorch/*')
|
||||
extension = CUDAExtension
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue