mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
fix ut
This commit is contained in:
parent
521696e0aa
commit
cb1e4667d9
@ -12,7 +12,8 @@ class TestGPTQOps(unittest.TestCase):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_op(self):
|
||||
if digit_version(torch.__version__) < digit_version('1.12.0'):
|
||||
if digit_version(torch.__version__) < digit_version(
|
||||
'1.12.0') or not torch.cuda.is_available():
|
||||
self.skipTest('torch<1.12.0')
|
||||
|
||||
def get_loss(linear, linear1, data):
|
||||
@ -59,7 +60,8 @@ class TestGPTQOps(unittest.TestCase):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_model(self):
|
||||
if digit_version(torch.__version__) < digit_version('1.12.0'):
|
||||
if digit_version(torch.__version__) < digit_version(
|
||||
'1.12.0') or not torch.cuda.is_available():
|
||||
self.skipTest('torch<1.12.0')
|
||||
import torchvision
|
||||
model = torchvision.models.resnet18()
|
||||
|
Loading…
x
Reference in New Issue
Block a user