fix bug in test

This commit is contained in:
FIRST_NAME LAST_NAME 2023-05-19 10:39:33 +08:00
parent 86da5b3f01
commit 3825e04378

View File

@ -22,6 +22,8 @@ class TestSparseGptOps(unittest.TestCase):
model(x)
for device in ['cpu', 'cuda']:
if device == 'cuda' and (not torch.cuda.is_available()):
self.skipTest('cuda is not available')
device = torch.device(device)
# prepare
@ -31,7 +33,7 @@ class TestSparseGptOps(unittest.TestCase):
12, 20, bias=False).to(device)
sparse_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([100, 5, 12]).to(
random_data = torch.rand([10, 5, 12]).to(
device) # [loader_batch,batch,feature]
data_0 = random_data[0]
@ -39,11 +41,12 @@ class TestSparseGptOps(unittest.TestCase):
# prune
sparse_linear.init_hessian()
sparse_linear.register_hessian_hook()
infer(sparse_linear, random_data)
sparse_linear.remove_hessian_hook()
sparse_linear.prune()
sparse_linear.prune(0.5)
# compare
@ -56,9 +59,10 @@ class TestSparseGptOps(unittest.TestCase):
model = torchvision.models.resnet18()
mutator = sparse_gpt.SparseGptCompressor()
mutator.prepare_from_supernet(model)
mutator.prepare(model)
x = torch.rand(10, 3, 224, 224)
mutator.init_hessian()
mutator.register_hessian_hooks()
model(x)
mutator.remove_hessian_hooks()