mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
fix bug in test
This commit is contained in:
parent
86da5b3f01
commit
3825e04378
@ -22,6 +22,8 @@ class TestSparseGptOps(unittest.TestCase):
|
||||
model(x)
|
||||
|
||||
for device in ['cpu', 'cuda']:
|
||||
if device == 'cuda' and (not torch.cuda.is_available()):
|
||||
self.skipTest('cuda is not available')
|
||||
device = torch.device(device)
|
||||
|
||||
# prepare
|
||||
@ -31,7 +33,7 @@ class TestSparseGptOps(unittest.TestCase):
|
||||
12, 20, bias=False).to(device)
|
||||
sparse_linear.load_state_dict(linear.state_dict(), strict=False)
|
||||
|
||||
random_data = torch.rand([100, 5, 12]).to(
|
||||
random_data = torch.rand([10, 5, 12]).to(
|
||||
device) # [loader_batch,batch,feature]
|
||||
data_0 = random_data[0]
|
||||
|
||||
@ -39,11 +41,12 @@ class TestSparseGptOps(unittest.TestCase):
|
||||
|
||||
# prune
|
||||
|
||||
sparse_linear.init_hessian()
|
||||
sparse_linear.register_hessian_hook()
|
||||
infer(sparse_linear, random_data)
|
||||
sparse_linear.remove_hessian_hook()
|
||||
|
||||
sparse_linear.prune()
|
||||
sparse_linear.prune(0.5)
|
||||
|
||||
# compare
|
||||
|
||||
@ -56,9 +59,10 @@ class TestSparseGptOps(unittest.TestCase):
|
||||
model = torchvision.models.resnet18()
|
||||
|
||||
mutator = sparse_gpt.SparseGptCompressor()
|
||||
mutator.prepare_from_supernet(model)
|
||||
mutator.prepare(model)
|
||||
|
||||
x = torch.rand(10, 3, 224, 224)
|
||||
mutator.init_hessian()
|
||||
mutator.register_hessian_hooks()
|
||||
model(x)
|
||||
mutator.remove_hessian_hooks()
|
||||
|
Loading…
x
Reference in New Issue
Block a user