add sparse gpt (#499)

init

Co-authored-by: liukai <your_email@abc.example>
pull/510/head
LKJacky 2023-04-11 16:14:35 +08:00 committed by GitHub
parent 6c06849ab7
commit 316977b036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 739 additions and 2 deletions

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import group_fisher
from . import group_fisher, sparse_gpt
__all__ = ['group_fisher']
__all__ = ['group_fisher', 'sparse_gpt']

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mutator import SparseGptMutator
from .ops import SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops
__all__ = [
'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
'SparseGptMutator'
]

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmrazor.utils import print_log
from .ops import SparseGptConv2d, SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops
class SparseGptMutator():
def __init__(self, sparse_model: nn.Module) -> None:
self.model = sparse_model
def start_init_hessian(self):
for module in self.sparse_ops:
module.start_init_hessian()
def end_init_hessian(self):
for module in self.sparse_ops:
module.end_init_hessian()
def prune_24(self):
for name, module in self.named_sparse_ops:
try:
error = module.prune_24()
print_log(f'prune {name} success \t error = {error}')
except Exception as e:
print_log(f'prune {name} failed as {e}')
@property
def sparse_ops(self):
for module in self.model.modules():
if isinstance(module, SparseGptMixIn):
yield module
@property
def named_sparse_ops(self):
for name, module in self.model.named_modules():
if isinstance(module, SparseGptMixIn):
yield name, module
@classmethod
def init_from_a_model(cls, model: nn.Module):
replace_with_dynamic_ops(model, {
nn.Linear: SparseGptLinear,
nn.Conv2d: SparseGptConv2d
})
mutator = cls(model)
return mutator

View File

@ -0,0 +1,216 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Protocol
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
DynamicLinear)
from .utils import ModuleProtocol
class SparseGptMixIn(ModuleProtocol):
# init
def _sparse_gpt_mix_in_init(self):
self.sparse_gpt_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
_hessian = torch.zeros([self.columns, self.columns])
self.register_buffer('_hessian', _hessian)
self._hessian: torch.Tensor
self.hessian_batch = 0
# weight and input adaptive
@property
def weight_matrix(self):
"""Return weight with shape (out in)"""
return self.weight.flatten(1) # out in
@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
self.weight.data = value
def format_input(self, input: torch.Tensor):
"""Return input with shape (B N C)"""
if len(input.shape) == 2: # N C
input = input.unsqueeze(0) # 1 N C
return input
# compute hessian
@property
def hessian(self):
"""hessian always return float."""
self._hessian = self._hessian.float()
return self._hessian
@hessian.setter
def hessian(self, value: torch.Tensor):
with torch.no_grad():
self._hessian = value.float()
@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
input = self.format_input(input).float()
assert len(input.shape) == 3
B = input.shape[0] # B N C
input = input.transpose(0, -1).flatten(1) # C D
H = input @ input.T * 2 # C C
self.hessian = (self.hessian * self.hessian_batch + H) / (
self.hessian_batch + B)
self.hessian_batch = self.hessian_batch + B
def start_init_hessian(self):
@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
assert len(input) == 1
self.update_hessian(input[0])
handle = self.register_forward_pre_hook(forward_pre_hook)
self.sparse_gpt_handles.append(handle)
def end_init_hessian(self):
for h in self.sparse_gpt_handles:
h.remove()
# prune
@torch.no_grad()
def prune_24(self):
# Converted from https://github.com/ist-daslab/sparsegpt
percdamp = 0.01
blocksize = 128
prunem = 4
prunen = 2
sparsity = 0.5
assert self.hessian is not None
W: torch.Tensor = self.weight_matrix.float() # out in
H = self.hessian
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
Losses = torch.zeros(self.rows, device=W.device)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=W.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
mask = None
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
if prunen == 0:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() *
sparsity)]
mask1 = tmp <= thresh
else:
mask1 = torch.zeros_like(W1) == 1
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if prunen != 0 and i % prunem == 0:
tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:(
i + prunem)].reshape((1, -1)))**2
mask1.scatter_(
1,
i + torch.topk(tmp, prunen, dim=1, largest=False)[1],
True)
q = w.clone()
q[mask1[:, i]] = 0
Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
i:].unsqueeze(0))
Err1[:, i] = err1
W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
from .sparse24_utils import is_weight_sparse_24
assert is_weight_sparse_24(
W, -1), f'Weight dose not satisfy 24 with shape {W.shape}'
error = torch.sum(Losses)
if torch.isnan(error).any():
raise Exception('get nan error')
else:
self.weight_matrix = W.data
return error
# SparseGpt Ops for Linear and Conv2d
class SparseGptLinear(DynamicLinear, SparseGptMixIn):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()
@classmethod
def convert_from(cls, module: nn.Linear):
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
return new_module
class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()
@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
return new_module
def format_input(self, input: torch.Tensor):
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,
stride=self.stride) # B C D
return input.transpose(-1, -2)

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
@torch.no_grad()
def is_weight_sparse_24(weight: torch.Tensor, dim=-1):
""""Check if the weight is sparse 24."""
weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4
is_zero = (weight == 0).sum(-1) # N
return (is_zero <= 2).all()

View File

@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Protocol, Type
import torch
import torch.nn as nn
from mmrazor.models.architectures.dynamic_ops import DynamicMixin
from mmrazor.models.utils import get_module_device
class ModuleProtocol(Protocol):
weight: torch.Tensor
def forward(self, x):
pass
def register_forward_hook(self, hook):
pass
def register_backward_hook(self, hook):
pass
def register_forward_pre_hook(self, hook):
pass
def register_buffer(self, name, tensor):
pass
def replace_with_dynamic_ops(model: nn.Module,
dynamicop_map: Dict[Type[nn.Module],
Type[DynamicMixin]]):
"""Replace torch modules with dynamic-ops."""
def replace_op(model: nn.Module, name: str, module: nn.Module):
names = name.split('.')
for sub_name in names[:-1]:
model = getattr(model, sub_name)
setattr(model, names[-1], module)
for name, module in model.named_modules():
if type(module) in dynamicop_map:
new_module = dynamicop_map[type(module)].convert_from(module).to(
get_module_device(module))
replace_op(model, name, new_module)

View File

@ -0,0 +1,107 @@
import numpy as np
import torch
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
def get_wikitext2(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_ptb(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
split='train')
valdata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, model)
if 'ptb' in name:
return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model)

View File

@ -0,0 +1,129 @@
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
import torch
import torch.nn as nn
from transformers import OPTForCausalLM
has_wandb = False
def get_opt(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = OPTForCausalLM.from_pretrained(
model,
torch_dtype='auto',
mirror='https://mirror.nju.edu.cn/hugging-face-models',
local_files_only=True)
model.seqlen = model.config.max_position_embeddings
return model
@torch.no_grad()
def opt_eval(model: OPTForCausalLM,
testenc,
dev,
dataset: str,
log_wandb: bool = False):
print('Evaluating ...')
testenc: torch.Tensor = testenc.input_ids # type: ignore
nsamples = testenc.numel() // model.seqlen
use_cache = model.config.use_cache
model.config.use_cache = False
nlls = []
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
out = model(batch)[0] # 1
shift_logits = out[:, :-1, :].contiguous() # 1 N C
shift_labels = batch[:, 1:] # 1 N
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
print(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache
@torch.no_grad()
def opt_infer(
model: OPTForCausalLM,
testenc,
dev,
num_samples=128,
):
print('Infer ...')
testenc: torch.Tensor = testenc.input_ids # type: ignore
nsamples = testenc.numel() // model.seqlen
model.config.use_cache = False
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
_ = model(batch)[0] # 1
if i > num_samples:
break
if __name__ == '__main__':
import argparse
from datautils import get_loaders
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
args = parser.parse_args()
model = get_opt(args.model)
model.eval()
model = model.cuda()
print('load model over')
DEV = torch.device('cuda:0')
dataloader, testloader = get_loaders(
'c4', seed=args.seed, model=args.model, seqlen=model.seqlen)
from mmrazor.implementations.pruning import sparse_gpt
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)
mutator.start_init_hessian()
opt_infer(model, testloader, DEV, num_samples=128)
mutator.end_init_hessian()
mutator.prune_24()
for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print(dataset)
opt_eval(model, testloader, DEV, dataset)

View File

@ -0,0 +1,32 @@
import torch
import torch.nn as nn
from mmrazor.implementations.pruning import sparse_gpt
def infer(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_batchs=256):
model.eval()
device = next(model.parameters()).device
with torch.no_grad():
accumulate_batch = 0
for x, _ in dataloader:
x = x.to(device)
model(x)
B = x.shape[0]
accumulate_batch += B
if accumulate_batch > num_batchs:
break
def sparse_model(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_batchs=256):
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)
mutator.start_init_hessian()
infer(model, dataloader, num_batchs)
mutator.end_init_hessian()
mutator.prune_24()
return model

View File

@ -0,0 +1,88 @@
# model settings
import os.path as osp
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from pipe import sparse_model
from torch.utils.data import DataLoader
def get_dataloaders(batch_size, n_workers, path=''):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
osp.join(path, 'train'),
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
test_dataset = datasets.ImageFolder(
osp.join(path, 'val'),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]),
)
dataloader_train = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=True,
)
dataloader_test = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_workers,
pin_memory=True,
)
return dataloader_train, dataloader_test
def eval(model: nn.Module, dataloader_test: DataLoader):
total = 0
correct = 0
device = next(model.parameters()).device
model.eval()
with torch.no_grad():
for x, y in dataloader_test:
x: torch.Tensor # type: ignore
y: torch.Tensor # type: ignore
x = x.to(device)
outputs = model(x)
_, predicted = outputs.max(1)
y = y.to(device)
correct += (y == predicted).long().sum()
total += y.numel()
acc = correct / total
return acc
# sparse_model(model, train_loader, 512)
if __name__ == '__main__':
# sparse_model(model, train_loader, 512)
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(128, 4, 'data/imagenet_torch')
model = model.cuda()
model = sparse_model(model, test_loader, num_batchs=512)
print('start evaluation')
model = model.cuda()
acc = eval(model, test_loader)
print('accuracy:', acc.item())

View File

@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
import torch.nn as nn
from mmrazor.implementations.pruning import sparse_gpt
class TestSparseGptOps(unittest.TestCase):
@torch.no_grad()
def test_op(self):
def get_loss(linear, linear1, data):
y = linear(data)
y1 = linear1(data)
return (y - y1).square().sum()
def infer(model, dataset):
for x in dataset:
model(x)
for device in ['cpu', 'cuda']:
device = torch.device(device)
# prepare
linear = nn.Linear(12, 20, bias=False).to(device)
sparse_linear = sparse_gpt.SparseGptLinear(
12, 20, bias=False).to(device)
sparse_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([100, 5, 12]).to(
device) # [loader_batch,batch,feature]
data_0 = random_data[0]
self.assertTrue(get_loss(linear, sparse_linear, data_0) == 0)
# prune
sparse_linear.start_init_hessian()
infer(sparse_linear, random_data)
sparse_linear.end_init_hessian()
sparse_linear.prune_24()
# compare
print('norm:', linear(data_0).norm(2))
print('distance:', get_loss(linear, sparse_linear, data_0))