diff --git a/setup.py b/setup.py
index a045b76..bdc9eb5 100644
--- a/setup.py
+++ b/setup.py
@@ -70,7 +70,7 @@ def get_extensions():
     extra_compile_args = {"cxx": []}
     define_macros = []
 
-    if torch.cuda.is_available() and CUDA_HOME is not None:
+    if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
         print("Compiling with CUDA")
         extension = CUDAExtension
         sources += source_cuda