diff --git a/setup.py b/setup.py index bdc9eb5..275b6fc 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,18 @@ import glob import os import subprocess +import subprocess +import sys + +def install_torch(): + try: + import torch + except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"]) + +# Call the function to ensure torch is installed +install_torch() + import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension