pull/538/head
humu789 2023-05-24 15:54:37 +08:00
parent 521696e0aa
commit cb1e4667d9
1 changed files with 4 additions and 2 deletions

View File

@ -12,7 +12,8 @@ class TestGPTQOps(unittest.TestCase):
@torch.no_grad()
def test_op(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
if digit_version(torch.__version__) < digit_version(
'1.12.0') or not torch.cuda.is_available():
self.skipTest('torch<1.12.0')
def get_loss(linear, linear1, data):
@ -59,7 +60,8 @@ class TestGPTQOps(unittest.TestCase):
@torch.no_grad()
def test_model(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
if digit_version(torch.__version__) < digit_version(
'1.12.0') or not torch.cuda.is_available():
self.skipTest('torch<1.12.0')
import torchvision
model = torchvision.models.resnet18()