mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
[NFC] Update sparse linear add the __repr__
This commit is contained in:
parent
a2852e8941
commit
be01a30ac2
@ -13,7 +13,6 @@ def compute_mask(t, N, M):
|
|||||||
nparams_topprune = int(M * (1-percentile))
|
nparams_topprune = int(M * (1-percentile))
|
||||||
if nparams_topprune != 0:
|
if nparams_topprune != 0:
|
||||||
topk = torch.topk(torch.abs(t_reshaped), k=nparams_topprune, largest=False, dim = -1)
|
topk = torch.topk(torch.abs(t_reshaped), k=nparams_topprune, largest=False, dim = -1)
|
||||||
#print(topk.indices)
|
|
||||||
mask_reshaped = mask_reshaped.scatter(dim = -1, index = topk.indices, value = 0)
|
mask_reshaped = mask_reshaped.scatter(dim = -1, index = topk.indices, value = 0)
|
||||||
|
|
||||||
return mask_reshaped.reshape(out_channel, in_channel)
|
return mask_reshaped.reshape(out_channel, in_channel)
|
||||||
@ -21,6 +20,8 @@ def compute_mask(t, N, M):
|
|||||||
class SparseLinearSuper(nn.Module):
|
class SparseLinearSuper(nn.Module):
|
||||||
def __init__(self, in_features, out_features, bias=True):
|
def __init__(self, in_features, out_features, bias=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
self.weight = nn.Parameter(torch.ones(out_features, in_features))
|
self.weight = nn.Parameter(torch.ones(out_features, in_features))
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = nn.Parameter(torch.ones(out_features))
|
self.bias = nn.Parameter(torch.ones(out_features))
|
||||||
@ -39,7 +40,9 @@ class SparseLinearSuper(nn.Module):
|
|||||||
n, m = self.sparsity_config
|
n, m = self.sparsity_config
|
||||||
self.mask = compute_mask(self.weight, n, m)
|
self.mask = compute_mask(self.weight, n, m)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"SparseLinearSuper(in_features={self.in_features}, out_features={self.out_features}, sparse_config:{self.sparsity_config})"
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
weight = self.weight * self.mask
|
weight = self.weight * self.mask
|
||||||
#weight = self.weight
|
#weight = self.weight
|
||||||
@ -58,7 +61,7 @@ if __name__ == '__main__':
|
|||||||
m = SparseLinearSuper(12, 12)
|
m = SparseLinearSuper(12, 12)
|
||||||
input = torch.randn(12)
|
input = torch.randn(12)
|
||||||
print(m(input))
|
print(m(input))
|
||||||
m.set_sample_config((2,4))
|
m.set_sample_config((1,4))
|
||||||
print(m(input))
|
print(m(input))
|
||||||
print(m.num_pruned_params())
|
print(m.num_pruned_params())
|
||||||
#print(sum(p.numel() for p in m.parameters() if p.requires_grad))
|
#print(sum(p.numel() for p in m.parameters() if p.requires_grad))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user