align acc & add save load ckpt & add ut

pull/538/head
humu789 2023-05-23 20:17:55 +08:00
parent daacadc3dc
commit f71040e13d
6 changed files with 137 additions and 88 deletions

View File

@ -108,6 +108,7 @@ class GPTQCompressor():
module: GPTQMixIn = module.to(device)
quantizer = Quantizer()
quantizer.configure(**qconfig)
# print_log(f'quant {name}...')
error = module.quant(
quantizer=quantizer,
blocksize=blocksize,

View File

@ -173,7 +173,7 @@ class GPTQMixIn(ModuleProtocol):
self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
intweight = intweight.cpu().numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]),
dtype=np.uint32)
@ -189,10 +189,10 @@ class GPTQMixIn(ModuleProtocol):
raise NotImplementedError('Only 2,4,8 bits are supported.')
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
self.qweight = torch.from_numpy(qweight).to(self.weight.device)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
zeros = zeros.cpu().numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits),
dtype=np.uint32)
i = 0
@ -207,7 +207,7 @@ class GPTQMixIn(ModuleProtocol):
raise NotImplementedError('Only 2,4,8 bits are supported.')
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
self.qzeros = torch.from_numpy(qzeros).to(self.weight.device)
@torch.no_grad()
def quant(self,
@ -298,7 +298,6 @@ class GPTQMixIn(ModuleProtocol):
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
W = W[:, invperm]
Q = Q[:, invperm]
g_idx = g_idx[invperm]
@ -307,10 +306,10 @@ class GPTQMixIn(ModuleProtocol):
zero.append(quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
self.weight_matrix = Q.data
self.weight_matrix = Q.data.to(self.weight_matrix.dtype)
if self.is_custom_kernel:
self.pack(scale, zero, g_idx)
del self.weight
return error
def free(self):

View File

@ -479,12 +479,18 @@ class TritonGPTQLinear(nn.Module, GPTQMixIn):
def forward(self, x):
"""Custom forward."""
out_shape = x.shape[:-1] + (self.out_features, )
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros,
self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
if torch.all(self.qweight == 0):
out = F.linear(x, self.weight, self.bias)
else:
# import pdb;pdb.set_trace()
out_shape = x.shape[:-1] + (self.out_features, )
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales,
self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
out = out.reshape(out_shape)
# import pdb;pdb.set_trace()
return out
class GPTQLinear(DynamicLinear, GPTQMixIn):

View File

@ -7,7 +7,8 @@ from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import GPTQLinear
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log
@ -25,6 +26,13 @@ def disable_observer_linear(model):
module.fix_qparams = True
def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight
def get_model(model):
def skip(*args, **kwargs):
@ -48,7 +56,7 @@ if __name__ == '__main__':
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'dataset',
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
@ -69,6 +77,10 @@ if __name__ == '__main__':
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
@ -77,28 +89,25 @@ if __name__ == '__main__':
args = parser.parse_args()
DEV = torch.device('cuda:0')
DEV = args.dev
model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')
dataloader, testloader = get_loaders(
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')
from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# use_triton_ops is True
compressor.prepare(model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)
compressor.prepare(
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)
# # # quantize activation for linear
# # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.layers,
@ -108,23 +117,38 @@ if __name__ == '__main__':
# # a_qconfig=a_qconfig
# )
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model,
wrap_modules=[LlamaDecoderLayer],
enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m,
device=DEV):
# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
@ -133,8 +157,6 @@ if __name__ == '__main__':
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
if args.save:
# model = compressor.to_static_model(model)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
# model.save_pretrained(args.save)
torch.save(model.state_dict(), args.save)

View File

@ -8,7 +8,8 @@ from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import GPTQLinear
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log
@ -26,6 +27,13 @@ def disable_observer_linear(model):
module.fix_qparams = True
def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight
def get_model(model):
def skip(*args, **kwargs):
@ -44,10 +52,9 @@ if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
@ -64,10 +71,14 @@ if __name__ == '__main__':
parser.add_argument(
'--batch_size',
type=int,
default=64,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
@ -76,53 +87,63 @@ if __name__ == '__main__':
args = parser.parse_args()
DEV = torch.device('cuda:0')
DEV = args.dev
model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')
dataloader, testloader = get_loaders(
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')
from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# # use_triton_ops is True
# compressor.prepare(model.model.layers,
# quant_conv=True,
# use_triton_ops=True,
# quant_linear=True,
# bits=4,
# groupsize=128)
# # quantize activation for linear
# a_qconfig = dict(bits=4, perchannel=False, sym=False)
# use_triton_ops is True
compressor.prepare(
model.model.decoder,
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
use_triton_ops=False,
# a_qconfig=a_qconfig
)
bits=4,
groupsize=128)
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
# # # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.decoder,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# # a_qconfig=a_qconfig
# )
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV):
# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
@ -131,7 +152,6 @@ if __name__ == '__main__':
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
if args.save:
model = compressor.to_static_model(model)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
model.save_pretrained(args.save)
torch.save(model.state_dict(), args.save)

View File

@ -7,6 +7,7 @@ import torch.nn as nn
from mmrazor import digit_version
from mmrazor.implementations.quantization import gptq
class TestGPTQOps(unittest.TestCase):
@torch.no_grad()
@ -30,7 +31,7 @@ class TestGPTQOps(unittest.TestCase):
linear = nn.Linear(12, 20, bias=False).to(device)
gptq_linear = gptq.GPTQLinear(
12, 20, bias=False).to(device)
in_features=12, out_features=20, bias=False).to(device)
gptq_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([10, 5, 12]).to(
@ -39,13 +40,13 @@ class TestGPTQOps(unittest.TestCase):
self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0)
# prune
# quant
gptq_linear.init_hessian()
gptq_linear.register_hessian_hook()
infer(gptq_linear, random_data)
gptq_linear.remove_hessian_hook()
qconfig = dict(bits=4, perchannel=True, sym=False)
quantizer = gptq.Quantizer()
quantizer.configure(**qconfig)