diff --git a/setup.py b/setup.py index a045b76..bdc9eb5 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,7 @@ def get_extensions(): extra_compile_args = {"cxx": []} define_macros = [] - if torch.cuda.is_available() and CUDA_HOME is not None: + if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ): print("Compiling with CUDA") extension = CUDAExtension sources += source_cuda