mirror of
https://github.com/facebookresearch/deit.git
synced 2025-06-03 14:52:20 +08:00
280 lines
10 KiB
Python
280 lines
10 KiB
Python
import torch
|
|
from torch.nn.utils import prune
|
|
from .utils import get_weights, get_modules, get_modules_with_name
|
|
import numpy as np
|
|
|
|
ALL_PRUNERS = ['lamp', 'glob', 'unif', 'unifplus', 'erk', 'custom']
|
|
def check_valid_pruner(name):
|
|
return name in ALL_PRUNERS
|
|
|
|
def weight_pruner_loader(pruner_string):
|
|
"""
|
|
Gives you the pruning methods: LAMP, Glob, Unif, Unif+, ERK, nxm
|
|
"""
|
|
if pruner_string == 'lamp':
|
|
return prune_weights_lamp
|
|
elif pruner_string == 'glob':
|
|
return prune_weights_global
|
|
elif pruner_string == 'unif':
|
|
return prune_weights_uniform
|
|
elif pruner_string == 'unifplus':
|
|
return prune_weights_unifplus
|
|
elif pruner_string == 'erk':
|
|
return prune_weights_erk
|
|
elif pruner_string == 'custom':
|
|
return prune_weights_from_config
|
|
else:
|
|
raise ValueError('Unknown pruner')
|
|
|
|
"""
|
|
prune_weights_reparam: Allocate identity mask to every weight tensors.
|
|
prune_weights_l1predefined: Perform layerwise pruning w.r.t. given amounts.
|
|
prune_weights_semistructured: Perform layerwise N:M pruning w.r.t. given configuration (keep N weights out of M consecutive ones)
|
|
"""
|
|
|
|
def prune_weights_semistructured(module, configs=None):
|
|
"""
|
|
Remove the weight by the defined N:M configs.
|
|
The configs will be a 2D lists with the following format:
|
|
configs = [[N1:M1], [N2:M2] ...]
|
|
Please make sure that the len(configs) == number_of_pruning_layers
|
|
"""
|
|
def compute_mask(t, N, M):
|
|
out_channel, in_channel = t.shape
|
|
percentile = N / M
|
|
t_reshaped = t.reshape(out_channel, -1, M)
|
|
#print(t_reshaped.shape)
|
|
mask = torch.ones_like(t)
|
|
mask_reshaped = mask.reshape(out_channel, -1, M)
|
|
|
|
nparams_topprune = int(M * percentile)
|
|
if nparams_topprune != 0:
|
|
topk = torch.topk(torch.abs(t_reshaped), k=nparams_topprune, largest=False, dim = -1)
|
|
#print(topk.indices)
|
|
mask_reshaped = mask_reshaped.scatter(dim = -1, index = topk.indices, value = 0)
|
|
|
|
return mask_reshaped.reshape(out_channel, in_channel)
|
|
|
|
if configs == None:
|
|
raise ValueError("Currently nxm pruning only support from manual config. \
|
|
Please provide config of the sparsity level for earch pruning target")
|
|
|
|
mlist = get_modules_with_name(module)
|
|
for idx, (name, m) in enumerate(mlist):
|
|
weight_tensor = m.weight
|
|
config = configs[idx]
|
|
N, M = config[0], config[1]
|
|
print(f"module: {name}, N:M = ({N}, {M})")
|
|
mask = compute_mask(weight_tensor, N, M)
|
|
prune.custom_from_mask(m, name = 'weight', mask = mask)
|
|
|
|
def prune_weights_reparam(model):
|
|
module_list = get_modules(model)
|
|
for m in module_list:
|
|
prune.identity(m,name="weight")
|
|
|
|
def prune_weights_l1predefined(model,amounts):
|
|
mlist = get_modules_with_name(model)
|
|
for idx,(name, m) in enumerate(mlist):
|
|
print(f"module: {name}, amounts of removed weight: {float(amounts[idx])}")
|
|
prune.l1_unstructured(m,name="weight",amount=float(amounts[idx]))
|
|
|
|
"""
|
|
Methods: All weights
|
|
"""
|
|
|
|
def prune_weights_global(model,amount):
|
|
parameters_to_prune = _extract_weight_tuples(model)
|
|
prune.global_unstructured(parameters_to_prune,pruning_method = prune.L1Unstructured,amount=amount)
|
|
|
|
def prune_weights_lamp(model,amount):
|
|
assert amount <= 1
|
|
amounts = _compute_lamp_amounts(model,amount)
|
|
prune_weights_l1predefined(model,amounts)
|
|
|
|
def prune_weights_uniform(model,amount):
|
|
module_list = get_modules(model)
|
|
assert amount <= 1 # Can be updated later to handle > 1.
|
|
for m in module_list:
|
|
print("module:", m, " remove amount:", amount)
|
|
prune.l1_unstructured(m,name="weight",amount=amount)
|
|
|
|
def prune_weights_unifplus(model,amount):
|
|
assert amount <= 1
|
|
amounts = _compute_unifplus_amounts(model,amount)
|
|
prune_weights_l1predefined(model,amounts)
|
|
|
|
def prune_weights_erk(model,amount):
|
|
assert amount <= 1
|
|
amounts = _compute_erk_amounts(model,amount)
|
|
prune_weights_l1predefined(model,amounts)
|
|
|
|
|
|
def prune_weights_from_config(model, mode, configs):
|
|
if mode == 'nxm':
|
|
prune_weights_semistructured(model, configs)
|
|
elif mode == 'unstructured':
|
|
prune_weights_l1predefined(model, amounts=configs)
|
|
|
|
"""
|
|
These are not intended to be exported.
|
|
"""
|
|
|
|
def _extract_weight_tuples(model):
|
|
"""
|
|
Gives you well-packed weight tensors for global pruning.
|
|
"""
|
|
mlist = get_modules(model)
|
|
return tuple([(m,'weight') for m in mlist])
|
|
|
|
def _compute_unifplus_amounts(model,amount):
|
|
"""
|
|
Compute # of weights to prune in each layer.
|
|
"""
|
|
amounts = []
|
|
wlist = get_weights(model)
|
|
unmaskeds = _count_unmasked_weights(model)
|
|
totals = _count_total_weights(model)
|
|
|
|
last_layer_minimum = np.round(totals[-1]*0.2) # Minimum number of last-layer weights to keep
|
|
total_to_prune = np.round(unmaskeds.sum()*amount)
|
|
|
|
if wlist[0].dim() == 4:
|
|
amounts.append(0) # Leave the first layer unpruned.
|
|
frac_to_prune = (total_to_prune*1.0)/(unmaskeds[1:].sum())
|
|
if frac_to_prune > 1.0:
|
|
raise ValueError("Cannot be pruned further by the Unif+ scheme! (first layer exception)")
|
|
last_layer_to_surv_planned = np.round((1.0-frac_to_prune)*unmaskeds[-1])
|
|
if last_layer_to_surv_planned < last_layer_minimum:
|
|
last_layer_to_prune = unmaskeds[-1] - last_layer_minimum
|
|
frac_to_prune_middle = ((total_to_prune-last_layer_to_prune)*1.0)/(unmaskeds[1:-1].sum())
|
|
if frac_to_prune_middle > 1.0:
|
|
raise ValueError("Cannot be pruned further by the Unif+ scheme! (first+last layer exception)")
|
|
amounts.extend([frac_to_prune_middle]*(unmaskeds.size(0)-2))
|
|
amounts.append((last_layer_to_prune*1.0)/unmaskeds[-1])
|
|
else:
|
|
amounts.extend([frac_to_prune]*(unmaskeds.size(0)-1))
|
|
else:
|
|
frac_to_prune = (total_to_prune*1.0)/(unmaskeds.sum())
|
|
last_layer_to_surv_planned = np.round((1.0-frac_to_prune)*unmaskeds[-1])
|
|
if last_layer_to_surv_planned < last_layer_minimum:
|
|
last_layer_to_prune = unmaskeds[-1] - last_layer_minimum
|
|
frac_to_prune_middle = ((total_to_prune-last_layer_to_prune)*1.0)/(unmaskeds[:-1].sum())
|
|
if frac_to_prune_middle > 1.0:
|
|
raise ValueError("Cannot be pruned further by the Unif+ scheme! (last layer exception)")
|
|
amounts.extend([frac_to_prune_middle]*(unmaskeds.size(0)-1))
|
|
amounts.append((last_layer_to_prune*1.0)/unmaskeds[-1])
|
|
else:
|
|
amounts.extend([frac_to_prune]*(unmaskeds.size(0)))
|
|
return amounts
|
|
|
|
def _compute_erk_amounts(model,amount):
|
|
unmaskeds = _count_unmasked_weights(model)
|
|
erks = _compute_erks(model)
|
|
|
|
return _amounts_from_eps(unmaskeds,erks,amount)
|
|
|
|
def _amounts_from_eps(unmaskeds,ers,amount):
|
|
num_layers = ers.size(0)
|
|
layers_to_keep_dense = torch.zeros(num_layers)
|
|
total_to_survive = (1.0-amount)*unmaskeds.sum() # Total to keep.
|
|
|
|
# Determine some layers to keep dense.
|
|
is_eps_invalid = True
|
|
while is_eps_invalid:
|
|
unmasked_among_prunables = (unmaskeds*(1-layers_to_keep_dense)).sum()
|
|
to_survive_among_prunables = total_to_survive - (layers_to_keep_dense*unmaskeds).sum()
|
|
|
|
ers_of_prunables = ers*(1.0-layers_to_keep_dense)
|
|
survs_of_prunables = torch.round(to_survive_among_prunables*ers_of_prunables/ers_of_prunables.sum())
|
|
|
|
layer_to_make_dense = -1
|
|
max_ratio = 1.0
|
|
for idx in range(num_layers):
|
|
if layers_to_keep_dense[idx] == 0:
|
|
if survs_of_prunables[idx]/unmaskeds[idx] > max_ratio:
|
|
layer_to_make_dense = idx
|
|
max_ratio = survs_of_prunables[idx]/unmaskeds[idx]
|
|
|
|
if layer_to_make_dense == -1:
|
|
is_eps_invalid = False
|
|
else:
|
|
layers_to_keep_dense[layer_to_make_dense] = 1
|
|
|
|
amounts = torch.zeros(num_layers)
|
|
|
|
for idx in range(num_layers):
|
|
if layers_to_keep_dense[idx] == 1:
|
|
amounts[idx] = 0.0
|
|
else:
|
|
amounts[idx] = 1.0 - (survs_of_prunables[idx]/unmaskeds[idx])
|
|
return amounts
|
|
|
|
def _compute_lamp_amounts(model,amount):
|
|
"""
|
|
Compute normalization schemes.
|
|
"""
|
|
unmaskeds = _count_unmasked_weights(model)
|
|
num_surv = int(np.round(unmaskeds.sum()*(1.0-amount)))
|
|
|
|
flattened_scores = [_normalize_scores(w**2).view(-1) for w in get_weights(model)]
|
|
concat_scores = torch.cat(flattened_scores,dim=0)
|
|
topks,_ = torch.topk(concat_scores,num_surv)
|
|
threshold = topks[-1]
|
|
|
|
# We don't care much about tiebreakers, for now.
|
|
final_survs = [torch.ge(score,threshold*torch.ones(score.size()).to(score.device)).sum() for score in flattened_scores]
|
|
amounts = []
|
|
for idx,final_surv in enumerate(final_survs):
|
|
amounts.append(1.0 - (final_surv/unmaskeds[idx]))
|
|
|
|
return amounts
|
|
|
|
def _compute_erks(model):
|
|
wlist = get_weights(model)
|
|
erks = torch.zeros(len(wlist))
|
|
for idx,w in enumerate(wlist):
|
|
if w.dim() == 4:
|
|
erks[idx] = w.size(0)+w.size(1)+w.size(2)+w.size(3)
|
|
else:
|
|
erks[idx] = w.size(0)+w.size(1)
|
|
return erks
|
|
|
|
def _count_unmasked_weights(model):
|
|
"""
|
|
Return a 1-dimensional tensor of #unmasked weights.
|
|
"""
|
|
mlist = get_modules(model)
|
|
unmaskeds = []
|
|
for m in mlist:
|
|
unmaskeds.append(m.weight_mask.sum())
|
|
return torch.FloatTensor(unmaskeds)
|
|
|
|
def _count_total_weights(model):
|
|
"""
|
|
Return a 1-dimensional tensor of #total weights.
|
|
"""
|
|
wlist = get_weights(model)
|
|
numels = []
|
|
for w in wlist:
|
|
numels.append(w.numel())
|
|
return torch.FloatTensor(numels)
|
|
|
|
def _normalize_scores(scores):
|
|
"""
|
|
Normalizing scheme for LAMP.
|
|
"""
|
|
# sort scores in an ascending order
|
|
sorted_scores,sorted_idx = scores.view(-1).sort(descending=False)
|
|
# compute cumulative sum
|
|
scores_cumsum_temp = sorted_scores.cumsum(dim=0)
|
|
scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device)
|
|
scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1]
|
|
# normalize by cumulative sum
|
|
sorted_scores /= (scores.sum() - scores_cumsum)
|
|
# tidy up and output
|
|
new_scores = torch.zeros(scores_cumsum.shape,device=scores.device)
|
|
new_scores[sorted_idx] = sorted_scores
|
|
|
|
return new_scores.view(scores.shape)
|