parent
6c06849ab7
commit
316977b036
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from . import group_fisher
|
||||
from . import group_fisher, sparse_gpt
|
||||
|
||||
__all__ = ['group_fisher']
|
||||
__all__ = ['group_fisher', 'sparse_gpt']
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .mutator import SparseGptMutator
|
||||
from .ops import SparseGptLinear, SparseGptMixIn
|
||||
from .utils import replace_with_dynamic_ops
|
||||
|
||||
__all__ = [
|
||||
'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
|
||||
'SparseGptMutator'
|
||||
]
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.utils import print_log
|
||||
from .ops import SparseGptConv2d, SparseGptLinear, SparseGptMixIn
|
||||
from .utils import replace_with_dynamic_ops
|
||||
|
||||
|
||||
class SparseGptMutator():
|
||||
|
||||
def __init__(self, sparse_model: nn.Module) -> None:
|
||||
self.model = sparse_model
|
||||
|
||||
def start_init_hessian(self):
|
||||
for module in self.sparse_ops:
|
||||
module.start_init_hessian()
|
||||
|
||||
def end_init_hessian(self):
|
||||
for module in self.sparse_ops:
|
||||
module.end_init_hessian()
|
||||
|
||||
def prune_24(self):
|
||||
for name, module in self.named_sparse_ops:
|
||||
try:
|
||||
error = module.prune_24()
|
||||
print_log(f'prune {name} success \t error = {error}')
|
||||
except Exception as e:
|
||||
print_log(f'prune {name} failed as {e}')
|
||||
|
||||
@property
|
||||
def sparse_ops(self):
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, SparseGptMixIn):
|
||||
yield module
|
||||
|
||||
@property
|
||||
def named_sparse_ops(self):
|
||||
for name, module in self.model.named_modules():
|
||||
if isinstance(module, SparseGptMixIn):
|
||||
yield name, module
|
||||
|
||||
@classmethod
|
||||
def init_from_a_model(cls, model: nn.Module):
|
||||
replace_with_dynamic_ops(model, {
|
||||
nn.Linear: SparseGptLinear,
|
||||
nn.Conv2d: SparseGptConv2d
|
||||
})
|
||||
mutator = cls(model)
|
||||
return mutator
|
|
@ -0,0 +1,216 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Protocol
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
|
||||
DynamicLinear)
|
||||
from .utils import ModuleProtocol
|
||||
|
||||
|
||||
class SparseGptMixIn(ModuleProtocol):
|
||||
|
||||
# init
|
||||
|
||||
def _sparse_gpt_mix_in_init(self):
|
||||
self.sparse_gpt_handles = []
|
||||
self.rows = self.weight_matrix.shape[0]
|
||||
self.columns = self.weight_matrix.shape[1]
|
||||
|
||||
_hessian = torch.zeros([self.columns, self.columns])
|
||||
self.register_buffer('_hessian', _hessian)
|
||||
self._hessian: torch.Tensor
|
||||
self.hessian_batch = 0
|
||||
|
||||
# weight and input adaptive
|
||||
|
||||
@property
|
||||
def weight_matrix(self):
|
||||
"""Return weight with shape (out in)"""
|
||||
return self.weight.flatten(1) # out in
|
||||
|
||||
@weight_matrix.setter
|
||||
def weight_matrix(self, value: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
value = value.reshape(self.weight.shape).to(self.weight.device).to(
|
||||
self.weight.dtype)
|
||||
self.weight.data = value
|
||||
|
||||
def format_input(self, input: torch.Tensor):
|
||||
"""Return input with shape (B N C)"""
|
||||
if len(input.shape) == 2: # N C
|
||||
input = input.unsqueeze(0) # 1 N C
|
||||
return input
|
||||
|
||||
# compute hessian
|
||||
|
||||
@property
|
||||
def hessian(self):
|
||||
"""hessian always return float."""
|
||||
self._hessian = self._hessian.float()
|
||||
return self._hessian
|
||||
|
||||
@hessian.setter
|
||||
def hessian(self, value: torch.Tensor):
|
||||
with torch.no_grad():
|
||||
self._hessian = value.float()
|
||||
|
||||
@torch.no_grad()
|
||||
def update_hessian(self, input: torch.Tensor):
|
||||
|
||||
input = self.format_input(input).float()
|
||||
|
||||
assert len(input.shape) == 3
|
||||
B = input.shape[0] # B N C
|
||||
input = input.transpose(0, -1).flatten(1) # C D
|
||||
|
||||
H = input @ input.T * 2 # C C
|
||||
self.hessian = (self.hessian * self.hessian_batch + H) / (
|
||||
self.hessian_batch + B)
|
||||
self.hessian_batch = self.hessian_batch + B
|
||||
|
||||
def start_init_hessian(self):
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_pre_hook(module: Protocol, input: tuple):
|
||||
assert len(input) == 1
|
||||
self.update_hessian(input[0])
|
||||
|
||||
handle = self.register_forward_pre_hook(forward_pre_hook)
|
||||
self.sparse_gpt_handles.append(handle)
|
||||
|
||||
def end_init_hessian(self):
|
||||
for h in self.sparse_gpt_handles:
|
||||
h.remove()
|
||||
|
||||
# prune
|
||||
|
||||
@torch.no_grad()
|
||||
def prune_24(self):
|
||||
# Converted from https://github.com/ist-daslab/sparsegpt
|
||||
percdamp = 0.01
|
||||
blocksize = 128
|
||||
prunem = 4
|
||||
prunen = 2
|
||||
sparsity = 0.5
|
||||
|
||||
assert self.hessian is not None
|
||||
W: torch.Tensor = self.weight_matrix.float() # out in
|
||||
|
||||
H = self.hessian
|
||||
|
||||
dead = torch.diag(H) == 0
|
||||
H[dead, dead] = 1
|
||||
W[:, dead] = 0
|
||||
|
||||
Losses = torch.zeros(self.rows, device=W.device)
|
||||
|
||||
damp = percdamp * torch.mean(torch.diag(H))
|
||||
diag = torch.arange(self.columns, device=W.device)
|
||||
H[diag, diag] += damp
|
||||
H = torch.linalg.cholesky(H)
|
||||
H = torch.cholesky_inverse(H)
|
||||
H = torch.linalg.cholesky(H, upper=True)
|
||||
Hinv = H
|
||||
|
||||
mask = None
|
||||
|
||||
for i1 in range(0, self.columns, blocksize):
|
||||
i2 = min(i1 + blocksize, self.columns)
|
||||
count = i2 - i1
|
||||
|
||||
W1 = W[:, i1:i2].clone()
|
||||
Q1 = torch.zeros_like(W1)
|
||||
Err1 = torch.zeros_like(W1)
|
||||
Losses1 = torch.zeros_like(W1)
|
||||
Hinv1 = Hinv[i1:i2, i1:i2]
|
||||
|
||||
if prunen == 0:
|
||||
if mask is not None:
|
||||
mask1 = mask[:, i1:i2]
|
||||
else:
|
||||
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2
|
||||
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() *
|
||||
sparsity)]
|
||||
mask1 = tmp <= thresh
|
||||
else:
|
||||
mask1 = torch.zeros_like(W1) == 1
|
||||
|
||||
for i in range(count):
|
||||
w = W1[:, i]
|
||||
d = Hinv1[i, i]
|
||||
|
||||
if prunen != 0 and i % prunem == 0:
|
||||
tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:(
|
||||
i + prunem)].reshape((1, -1)))**2
|
||||
mask1.scatter_(
|
||||
1,
|
||||
i + torch.topk(tmp, prunen, dim=1, largest=False)[1],
|
||||
True)
|
||||
|
||||
q = w.clone()
|
||||
q[mask1[:, i]] = 0
|
||||
|
||||
Q1[:, i] = q
|
||||
Losses1[:, i] = (w - q)**2 / d**2
|
||||
|
||||
err1 = (w - q) / d
|
||||
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
|
||||
i:].unsqueeze(0))
|
||||
Err1[:, i] = err1
|
||||
|
||||
W[:, i1:i2] = Q1
|
||||
Losses += torch.sum(Losses1, 1) / 2
|
||||
|
||||
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
||||
|
||||
torch.cuda.synchronize()
|
||||
from .sparse24_utils import is_weight_sparse_24
|
||||
assert is_weight_sparse_24(
|
||||
W, -1), f'Weight dose not satisfy 24 with shape {W.shape}'
|
||||
error = torch.sum(Losses)
|
||||
|
||||
if torch.isnan(error).any():
|
||||
raise Exception('get nan error')
|
||||
else:
|
||||
self.weight_matrix = W.data
|
||||
|
||||
return error
|
||||
|
||||
|
||||
# SparseGpt Ops for Linear and Conv2d
|
||||
|
||||
|
||||
class SparseGptLinear(DynamicLinear, SparseGptMixIn):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._sparse_gpt_mix_in_init()
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module: nn.Linear):
|
||||
new_module = super().convert_from(module)
|
||||
new_module.load_state_dict(module.state_dict(), strict=False)
|
||||
return new_module
|
||||
|
||||
|
||||
class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._sparse_gpt_mix_in_init()
|
||||
|
||||
@classmethod
|
||||
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
|
||||
new_module = super().convert_from(module)
|
||||
new_module.load_state_dict(module.state_dict(), strict=False)
|
||||
return new_module
|
||||
|
||||
def format_input(self, input: torch.Tensor):
|
||||
# input B C H W
|
||||
input = F.unfold(
|
||||
input, self.kernel_size, padding=self.padding,
|
||||
stride=self.stride) # B C D
|
||||
return input.transpose(-1, -2)
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def is_weight_sparse_24(weight: torch.Tensor, dim=-1):
|
||||
""""Check if the weight is sparse 24."""
|
||||
weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4
|
||||
is_zero = (weight == 0).sum(-1) # N
|
||||
return (is_zero <= 2).all()
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Protocol, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.models.architectures.dynamic_ops import DynamicMixin
|
||||
from mmrazor.models.utils import get_module_device
|
||||
|
||||
|
||||
class ModuleProtocol(Protocol):
|
||||
weight: torch.Tensor
|
||||
|
||||
def forward(self, x):
|
||||
pass
|
||||
|
||||
def register_forward_hook(self, hook):
|
||||
pass
|
||||
|
||||
def register_backward_hook(self, hook):
|
||||
pass
|
||||
|
||||
def register_forward_pre_hook(self, hook):
|
||||
pass
|
||||
|
||||
def register_buffer(self, name, tensor):
|
||||
pass
|
||||
|
||||
|
||||
def replace_with_dynamic_ops(model: nn.Module,
|
||||
dynamicop_map: Dict[Type[nn.Module],
|
||||
Type[DynamicMixin]]):
|
||||
"""Replace torch modules with dynamic-ops."""
|
||||
|
||||
def replace_op(model: nn.Module, name: str, module: nn.Module):
|
||||
names = name.split('.')
|
||||
for sub_name in names[:-1]:
|
||||
model = getattr(model, sub_name)
|
||||
|
||||
setattr(model, names[-1], module)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if type(module) in dynamicop_map:
|
||||
new_module = dynamicop_map[type(module)].convert_from(module).to(
|
||||
get_module_device(module))
|
||||
replace_op(model, name, new_module)
|
|
@ -0,0 +1,107 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
|
||||
def get_wikitext2(nsamples, seed, seqlen, model):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
|
||||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
||||
trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
|
||||
testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_ptb(nsamples, seed, seqlen, model):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
|
||||
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
||||
trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
|
||||
testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
return trainloader, testenc
|
||||
|
||||
|
||||
def get_c4(nsamples, seed, seqlen, model):
|
||||
from datasets import load_dataset
|
||||
traindata = load_dataset(
|
||||
'allenai/c4',
|
||||
'allenai--c4',
|
||||
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
|
||||
split='train')
|
||||
valdata = load_dataset(
|
||||
'allenai/c4',
|
||||
'allenai--c4',
|
||||
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
|
||||
split='validation')
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
|
||||
|
||||
import random
|
||||
random.seed(seed)
|
||||
trainloader = []
|
||||
for _ in range(nsamples):
|
||||
while True:
|
||||
i = random.randint(0, len(traindata) - 1)
|
||||
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
|
||||
if trainenc.input_ids.shape[1] >= seqlen:
|
||||
break
|
||||
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
|
||||
j = i + seqlen
|
||||
inp = trainenc.input_ids[:, i:j]
|
||||
tar = inp.clone()
|
||||
tar[:, :-1] = -100
|
||||
trainloader.append((inp, tar))
|
||||
|
||||
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
|
||||
valenc = valenc.input_ids[:, :(256 * seqlen)]
|
||||
|
||||
class TokenizerWrapper:
|
||||
|
||||
def __init__(self, input_ids):
|
||||
self.input_ids = input_ids
|
||||
|
||||
valenc = TokenizerWrapper(valenc)
|
||||
|
||||
return trainloader, valenc
|
||||
|
||||
|
||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
|
||||
if 'wikitext2' in name:
|
||||
return get_wikitext2(nsamples, seed, seqlen, model)
|
||||
if 'ptb' in name:
|
||||
return get_ptb(nsamples, seed, seqlen, model)
|
||||
if 'c4' in name:
|
||||
return get_c4(nsamples, seed, seqlen, model)
|
|
@ -0,0 +1,129 @@
|
|||
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import OPTForCausalLM
|
||||
|
||||
has_wandb = False
|
||||
|
||||
|
||||
def get_opt(model):
|
||||
import torch
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
model = OPTForCausalLM.from_pretrained(
|
||||
model,
|
||||
torch_dtype='auto',
|
||||
mirror='https://mirror.nju.edu.cn/hugging-face-models',
|
||||
local_files_only=True)
|
||||
model.seqlen = model.config.max_position_embeddings
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def opt_eval(model: OPTForCausalLM,
|
||||
testenc,
|
||||
dev,
|
||||
dataset: str,
|
||||
log_wandb: bool = False):
|
||||
print('Evaluating ...')
|
||||
|
||||
testenc: torch.Tensor = testenc.input_ids # type: ignore
|
||||
nsamples = testenc.numel() // model.seqlen
|
||||
|
||||
use_cache = model.config.use_cache
|
||||
model.config.use_cache = False
|
||||
nlls = []
|
||||
|
||||
for i in range(nsamples):
|
||||
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
|
||||
out = model(batch)[0] # 1
|
||||
|
||||
shift_logits = out[:, :-1, :].contiguous() # 1 N C
|
||||
shift_labels = batch[:, 1:] # 1 N
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
neg_log_likelihood = loss.float() * model.seqlen
|
||||
nlls.append(neg_log_likelihood)
|
||||
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
|
||||
print(f'Perplexity: {ppl.item():3f}')
|
||||
model.config.use_cache = use_cache
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def opt_infer(
|
||||
model: OPTForCausalLM,
|
||||
testenc,
|
||||
dev,
|
||||
num_samples=128,
|
||||
):
|
||||
print('Infer ...')
|
||||
|
||||
testenc: torch.Tensor = testenc.input_ids # type: ignore
|
||||
nsamples = testenc.numel() // model.seqlen
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
for i in range(nsamples):
|
||||
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
|
||||
_ = model(batch)[0] # 1
|
||||
|
||||
if i > num_samples:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
from datautils import get_loaders
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
|
||||
parser.add_argument(
|
||||
'dataset',
|
||||
type=str,
|
||||
choices=['wikitext2', 'ptb', 'c4'],
|
||||
help='Where to extract calibration data from.')
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Seed for sampling the calibration data.')
|
||||
parser.add_argument(
|
||||
'--nsamples',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Number of calibration data samples.')
|
||||
args = parser.parse_args()
|
||||
|
||||
model = get_opt(args.model)
|
||||
model.eval()
|
||||
model = model.cuda()
|
||||
print('load model over')
|
||||
DEV = torch.device('cuda:0')
|
||||
|
||||
dataloader, testloader = get_loaders(
|
||||
'c4', seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
|
||||
from mmrazor.implementations.pruning import sparse_gpt
|
||||
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)
|
||||
|
||||
mutator.start_init_hessian()
|
||||
opt_infer(model, testloader, DEV, num_samples=128)
|
||||
mutator.end_init_hessian()
|
||||
mutator.prune_24()
|
||||
|
||||
for dataset in ['wikitext2', 'ptb', 'c4']:
|
||||
dataloader, testloader = get_loaders(
|
||||
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
print(dataset)
|
||||
opt_eval(model, testloader, DEV, dataset)
|
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.implementations.pruning import sparse_gpt
|
||||
|
||||
|
||||
def infer(model: nn.Module,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
num_batchs=256):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
with torch.no_grad():
|
||||
accumulate_batch = 0
|
||||
for x, _ in dataloader:
|
||||
x = x.to(device)
|
||||
model(x)
|
||||
B = x.shape[0]
|
||||
accumulate_batch += B
|
||||
if accumulate_batch > num_batchs:
|
||||
break
|
||||
|
||||
|
||||
def sparse_model(model: nn.Module,
|
||||
dataloader: torch.utils.data.DataLoader,
|
||||
num_batchs=256):
|
||||
|
||||
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)
|
||||
mutator.start_init_hessian()
|
||||
infer(model, dataloader, num_batchs)
|
||||
mutator.end_init_hessian()
|
||||
mutator.prune_24()
|
||||
return model
|
|
@ -0,0 +1,88 @@
|
|||
# model settings
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
from pipe import sparse_model
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_dataloaders(batch_size, n_workers, path=''):
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
train_dataset = datasets.ImageFolder(
|
||||
osp.join(path, 'train'),
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(224),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]),
|
||||
)
|
||||
|
||||
test_dataset = datasets.ImageFolder(
|
||||
osp.join(path, 'val'),
|
||||
transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
]),
|
||||
)
|
||||
|
||||
dataloader_train = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=n_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
dataloader_test = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=n_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return dataloader_train, dataloader_test
|
||||
|
||||
|
||||
def eval(model: nn.Module, dataloader_test: DataLoader):
|
||||
|
||||
total = 0
|
||||
correct = 0
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for x, y in dataloader_test:
|
||||
x: torch.Tensor # type: ignore
|
||||
y: torch.Tensor # type: ignore
|
||||
x = x.to(device)
|
||||
outputs = model(x)
|
||||
_, predicted = outputs.max(1)
|
||||
y = y.to(device)
|
||||
correct += (y == predicted).long().sum()
|
||||
total += y.numel()
|
||||
acc = correct / total
|
||||
return acc
|
||||
|
||||
|
||||
# sparse_model(model, train_loader, 512)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# sparse_model(model, train_loader, 512)
|
||||
model = torchvision.models.resnet18(pretrained=True)
|
||||
train_loader, test_loader = get_dataloaders(128, 4, 'data/imagenet_torch')
|
||||
|
||||
model = model.cuda()
|
||||
model = sparse_model(model, test_loader, num_batchs=512)
|
||||
|
||||
print('start evaluation')
|
||||
model = model.cuda()
|
||||
acc = eval(model, test_loader)
|
||||
print('accuracy:', acc.item())
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.implementations.pruning import sparse_gpt
|
||||
|
||||
|
||||
class TestSparseGptOps(unittest.TestCase):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_op(self):
|
||||
|
||||
def get_loss(linear, linear1, data):
|
||||
y = linear(data)
|
||||
y1 = linear1(data)
|
||||
return (y - y1).square().sum()
|
||||
|
||||
def infer(model, dataset):
|
||||
for x in dataset:
|
||||
model(x)
|
||||
|
||||
for device in ['cpu', 'cuda']:
|
||||
device = torch.device(device)
|
||||
|
||||
# prepare
|
||||
|
||||
linear = nn.Linear(12, 20, bias=False).to(device)
|
||||
sparse_linear = sparse_gpt.SparseGptLinear(
|
||||
12, 20, bias=False).to(device)
|
||||
sparse_linear.load_state_dict(linear.state_dict(), strict=False)
|
||||
|
||||
random_data = torch.rand([100, 5, 12]).to(
|
||||
device) # [loader_batch,batch,feature]
|
||||
data_0 = random_data[0]
|
||||
|
||||
self.assertTrue(get_loss(linear, sparse_linear, data_0) == 0)
|
||||
|
||||
# prune
|
||||
|
||||
sparse_linear.start_init_hessian()
|
||||
infer(sparse_linear, random_data)
|
||||
sparse_linear.end_init_hessian()
|
||||
|
||||
sparse_linear.prune_24()
|
||||
|
||||
# compare
|
||||
|
||||
print('norm:', linear(data_0).norm(2))
|
||||
print('distance:', get_loss(linear, sparse_linear, data_0))
|
Loading…
Reference in New Issue