pruning and sparsity initial commit
parent
997ba7b346
commit
38f5c1ad1d
|
@ -48,7 +48,7 @@ class Model(nn.Module):
|
|||
if type(model_cfg) is dict:
|
||||
self.md = model_cfg # model dict
|
||||
else: # is *.yaml
|
||||
import yaml
|
||||
import yaml # for torch hub
|
||||
with open(model_cfg) as f:
|
||||
self.md = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
||||
|
||||
|
|
|
@ -76,6 +76,26 @@ def find_modules(model, mclass=nn.Conv2d):
|
|||
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
||||
|
||||
|
||||
def sparsity(model):
|
||||
# Return global model sparsity
|
||||
a, b = 0., 0.
|
||||
for p in model.parameters():
|
||||
a += p.numel()
|
||||
b += (p == 0).sum()
|
||||
return b / a
|
||||
|
||||
|
||||
def prune(model, amount=0.3):
|
||||
# Prune model to requested global sparsity
|
||||
import torch.nn.utils.prune as prune
|
||||
print('Pruning model... ', end='')
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, torch.nn.Conv2d):
|
||||
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
||||
prune.remove(m, 'weight') # make permanent
|
||||
print(' %.3g global sparsity' % sparsity(model))
|
||||
|
||||
|
||||
def fuse_conv_and_bn(conv, bn):
|
||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in New Issue