align acc & add save load ckpt & add ut
parent
daacadc3dc
commit
f71040e13d
|
@ -108,6 +108,7 @@ class GPTQCompressor():
|
|||
module: GPTQMixIn = module.to(device)
|
||||
quantizer = Quantizer()
|
||||
quantizer.configure(**qconfig)
|
||||
# print_log(f'quant {name}...')
|
||||
error = module.quant(
|
||||
quantizer=quantizer,
|
||||
blocksize=blocksize,
|
||||
|
|
|
@ -173,7 +173,7 @@ class GPTQMixIn(ModuleProtocol):
|
|||
self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
intweight = intweight.cpu().numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]),
|
||||
dtype=np.uint32)
|
||||
|
@ -189,10 +189,10 @@ class GPTQMixIn(ModuleProtocol):
|
|||
raise NotImplementedError('Only 2,4,8 bits are supported.')
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
self.qweight = torch.from_numpy(qweight).to(self.weight.device)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
zeros = zeros.cpu().numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits),
|
||||
dtype=np.uint32)
|
||||
i = 0
|
||||
|
@ -207,7 +207,7 @@ class GPTQMixIn(ModuleProtocol):
|
|||
raise NotImplementedError('Only 2,4,8 bits are supported.')
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
self.qzeros = torch.from_numpy(qzeros).to(self.weight.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def quant(self,
|
||||
|
@ -298,7 +298,6 @@ class GPTQMixIn(ModuleProtocol):
|
|||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||
if actorder:
|
||||
invperm = torch.argsort(perm)
|
||||
W = W[:, invperm]
|
||||
Q = Q[:, invperm]
|
||||
g_idx = g_idx[invperm]
|
||||
|
||||
|
@ -307,10 +306,10 @@ class GPTQMixIn(ModuleProtocol):
|
|||
zero.append(quantizer.zero)
|
||||
scale = torch.cat(scale, dim=1)
|
||||
zero = torch.cat(zero, dim=1)
|
||||
self.weight_matrix = Q.data
|
||||
self.weight_matrix = Q.data.to(self.weight_matrix.dtype)
|
||||
if self.is_custom_kernel:
|
||||
self.pack(scale, zero, g_idx)
|
||||
|
||||
del self.weight
|
||||
return error
|
||||
|
||||
def free(self):
|
||||
|
|
|
@ -479,12 +479,18 @@ class TritonGPTQLinear(nn.Module, GPTQMixIn):
|
|||
|
||||
def forward(self, x):
|
||||
"""Custom forward."""
|
||||
out_shape = x.shape[:-1] + (self.out_features, )
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros,
|
||||
self.g_idx, self.bits, self.maxq)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
if torch.all(self.qweight == 0):
|
||||
out = F.linear(x, self.weight, self.bias)
|
||||
else:
|
||||
# import pdb;pdb.set_trace()
|
||||
out_shape = x.shape[:-1] + (self.out_features, )
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]), self.qweight, self.scales,
|
||||
self.qzeros, self.g_idx, self.bits, self.maxq)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
out = out.reshape(out_shape)
|
||||
# import pdb;pdb.set_trace()
|
||||
return out
|
||||
|
||||
|
||||
class GPTQLinear(DynamicLinear, GPTQMixIn):
|
||||
|
|
|
@ -7,7 +7,8 @@ from utils import opt_eval, opt_infer
|
|||
|
||||
from mmrazor.implementations.pruning.sparse_gpt.utils import \
|
||||
memory_efficient_forward
|
||||
from mmrazor.implementations.quantization.gptq import GPTQLinear
|
||||
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
|
||||
TritonGPTQLinear)
|
||||
from mmrazor.utils import print_log
|
||||
|
||||
|
||||
|
@ -25,6 +26,13 @@ def disable_observer_linear(model):
|
|||
module.fix_qparams = True
|
||||
|
||||
|
||||
def del_redundant_attr(model):
|
||||
print_log('Del redundant weight for GPTQLinear!')
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, TritonGPTQLinear):
|
||||
del module.weight
|
||||
|
||||
|
||||
def get_model(model):
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
|
@ -48,7 +56,7 @@ if __name__ == '__main__':
|
|||
|
||||
parser.add_argument('model', type=str, help='Llama model to load')
|
||||
parser.add_argument(
|
||||
'dataset',
|
||||
'--dataset',
|
||||
type=str,
|
||||
choices=['wikitext2', 'ptb', 'c4'],
|
||||
help='Where to extract calibration data from.')
|
||||
|
@ -69,6 +77,10 @@ if __name__ == '__main__':
|
|||
help='Batchsize for calibration and evaluation.')
|
||||
parser.add_argument(
|
||||
'--save', type=str, default='', help='Path to saved model.')
|
||||
parser.add_argument(
|
||||
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
|
||||
parser.add_argument(
|
||||
'--dev', type=str, default='cuda:0', help='Use which device.')
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
type=bool,
|
||||
|
@ -77,28 +89,25 @@ if __name__ == '__main__':
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
DEV = torch.device('cuda:0')
|
||||
DEV = args.dev
|
||||
|
||||
model = get_model(args.model)
|
||||
model.to(DEV)
|
||||
model.eval()
|
||||
print_log('load model over')
|
||||
|
||||
dataloader, testloader = get_loaders(
|
||||
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
print_log('load data for infer over')
|
||||
|
||||
from mmrazor.implementations.quantization import gptq
|
||||
compressor = gptq.GPTQCompressor()
|
||||
# use_triton_ops is True
|
||||
compressor.prepare(model.model.layers,
|
||||
quant_conv=True,
|
||||
use_triton_ops=True,
|
||||
quant_linear=True,
|
||||
bits=4,
|
||||
groupsize=128)
|
||||
compressor.prepare(
|
||||
model.model.layers,
|
||||
quant_conv=True,
|
||||
use_triton_ops=True,
|
||||
quant_linear=True,
|
||||
bits=4,
|
||||
groupsize=128)
|
||||
|
||||
# # # quantize activation for linear
|
||||
# # quantize activation for linear
|
||||
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
|
||||
# compressor.prepare(
|
||||
# model.model.layers,
|
||||
|
@ -108,23 +117,38 @@ if __name__ == '__main__':
|
|||
# # a_qconfig=a_qconfig
|
||||
# )
|
||||
|
||||
compressor.init_hessian()
|
||||
enable_observer_linear(model)
|
||||
with memory_efficient_forward(
|
||||
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
|
||||
compressor.register_hessian_hooks()
|
||||
opt_infer(
|
||||
model,
|
||||
testloader,
|
||||
DEV,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.nsamples)
|
||||
compressor.remove_hessian_hooks()
|
||||
compressor.quant_with_default_qconfig(device=DEV)
|
||||
if args.quant_ckpt:
|
||||
del_redundant_attr(model)
|
||||
model.load_state_dict(torch.load(args.quant_ckpt))
|
||||
else:
|
||||
dataloader, testloader = get_loaders(
|
||||
args.dataset,
|
||||
seed=args.seed,
|
||||
model=args.model,
|
||||
seqlen=model.seqlen)
|
||||
print_log('load data for infer over')
|
||||
|
||||
compressor.init_hessian()
|
||||
enable_observer_linear(model)
|
||||
with memory_efficient_forward(
|
||||
model,
|
||||
wrap_modules=[LlamaDecoderLayer],
|
||||
enabled=args.m,
|
||||
device=DEV):
|
||||
compressor.register_hessian_hooks()
|
||||
opt_infer(
|
||||
model,
|
||||
testloader,
|
||||
DEV,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.nsamples)
|
||||
compressor.remove_hessian_hooks()
|
||||
compressor.quant_with_default_qconfig(device=DEV)
|
||||
|
||||
disable_observer_linear(model)
|
||||
with memory_efficient_forward(
|
||||
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
|
||||
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m,
|
||||
device=DEV):
|
||||
|
||||
# for dataset in ['wikitext2', 'ptb', 'c4']:
|
||||
for dataset in ['wikitext2']:
|
||||
|
@ -133,8 +157,6 @@ if __name__ == '__main__':
|
|||
print_log(dataset)
|
||||
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
|
||||
|
||||
if args.save:
|
||||
# model = compressor.to_static_model(model)
|
||||
if args.save and not args.quant_ckpt:
|
||||
print_log(f'save model in {args.save}')
|
||||
# model.save_pretrained(args.save)
|
||||
torch.save(model.state_dict(), args.save)
|
||||
|
|
|
@ -8,7 +8,8 @@ from utils import opt_eval, opt_infer
|
|||
|
||||
from mmrazor.implementations.pruning.sparse_gpt.utils import \
|
||||
memory_efficient_forward
|
||||
from mmrazor.implementations.quantization.gptq import GPTQLinear
|
||||
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
|
||||
TritonGPTQLinear)
|
||||
from mmrazor.utils import print_log
|
||||
|
||||
|
||||
|
@ -26,6 +27,13 @@ def disable_observer_linear(model):
|
|||
module.fix_qparams = True
|
||||
|
||||
|
||||
def del_redundant_attr(model):
|
||||
print_log('Del redundant weight for GPTQLinear!')
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, TritonGPTQLinear):
|
||||
del module.weight
|
||||
|
||||
|
||||
def get_model(model):
|
||||
|
||||
def skip(*args, **kwargs):
|
||||
|
@ -44,10 +52,9 @@ if __name__ == '__main__':
|
|||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('model', type=str, help='Llama model to load')
|
||||
parser.add_argument(
|
||||
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
|
||||
parser.add_argument(
|
||||
'dataset',
|
||||
'--dataset',
|
||||
type=str,
|
||||
choices=['wikitext2', 'ptb', 'c4'],
|
||||
help='Where to extract calibration data from.')
|
||||
|
@ -64,10 +71,14 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=64,
|
||||
default=16,
|
||||
help='Batchsize for calibration and evaluation.')
|
||||
parser.add_argument(
|
||||
'--save', type=str, default='', help='Path to saved model.')
|
||||
parser.add_argument(
|
||||
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
|
||||
parser.add_argument(
|
||||
'--dev', type=str, default='cuda:0', help='Use which device.')
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
type=bool,
|
||||
|
@ -76,53 +87,63 @@ if __name__ == '__main__':
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
DEV = torch.device('cuda:0')
|
||||
DEV = args.dev
|
||||
|
||||
model = get_model(args.model)
|
||||
model.to(DEV)
|
||||
model.eval()
|
||||
print_log('load model over')
|
||||
|
||||
dataloader, testloader = get_loaders(
|
||||
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
|
||||
print_log('load data for infer over')
|
||||
|
||||
from mmrazor.implementations.quantization import gptq
|
||||
compressor = gptq.GPTQCompressor()
|
||||
# # use_triton_ops is True
|
||||
# compressor.prepare(model.model.layers,
|
||||
# quant_conv=True,
|
||||
# use_triton_ops=True,
|
||||
# quant_linear=True,
|
||||
# bits=4,
|
||||
# groupsize=128)
|
||||
|
||||
# # quantize activation for linear
|
||||
# a_qconfig = dict(bits=4, perchannel=False, sym=False)
|
||||
# use_triton_ops is True
|
||||
compressor.prepare(
|
||||
model.model.decoder,
|
||||
model.model.layers,
|
||||
quant_conv=True,
|
||||
use_triton_ops=True,
|
||||
quant_linear=True,
|
||||
use_triton_ops=False,
|
||||
# a_qconfig=a_qconfig
|
||||
)
|
||||
bits=4,
|
||||
groupsize=128)
|
||||
|
||||
compressor.init_hessian()
|
||||
enable_observer_linear(model)
|
||||
with memory_efficient_forward(
|
||||
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
|
||||
compressor.register_hessian_hooks()
|
||||
opt_infer(
|
||||
model,
|
||||
testloader,
|
||||
DEV,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.nsamples)
|
||||
compressor.remove_hessian_hooks()
|
||||
compressor.quant_with_default_qconfig(device=DEV)
|
||||
# # # quantize activation for linear
|
||||
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
|
||||
# compressor.prepare(
|
||||
# model.model.decoder,
|
||||
# quant_conv=True,
|
||||
# quant_linear=True,
|
||||
# use_triton_ops=False,
|
||||
# # a_qconfig=a_qconfig
|
||||
# )
|
||||
|
||||
if args.quant_ckpt:
|
||||
del_redundant_attr(model)
|
||||
model.load_state_dict(torch.load(args.quant_ckpt))
|
||||
else:
|
||||
dataloader, testloader = get_loaders(
|
||||
args.dataset,
|
||||
seed=args.seed,
|
||||
model=args.model,
|
||||
seqlen=model.seqlen)
|
||||
print_log('load data for infer over')
|
||||
|
||||
compressor.init_hessian()
|
||||
enable_observer_linear(model)
|
||||
with memory_efficient_forward(
|
||||
model, wrap_modules=[OPTDecoderLayer], enabled=args.m,
|
||||
device=DEV):
|
||||
compressor.register_hessian_hooks()
|
||||
opt_infer(
|
||||
model,
|
||||
testloader,
|
||||
DEV,
|
||||
batch_size=args.batch_size,
|
||||
num_samples=args.nsamples)
|
||||
compressor.remove_hessian_hooks()
|
||||
compressor.quant_with_default_qconfig(device=DEV)
|
||||
|
||||
disable_observer_linear(model)
|
||||
with memory_efficient_forward(
|
||||
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
|
||||
model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV):
|
||||
|
||||
# for dataset in ['wikitext2', 'ptb', 'c4']:
|
||||
for dataset in ['wikitext2']:
|
||||
|
@ -131,7 +152,6 @@ if __name__ == '__main__':
|
|||
print_log(dataset)
|
||||
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
|
||||
|
||||
if args.save:
|
||||
model = compressor.to_static_model(model)
|
||||
if args.save and not args.quant_ckpt:
|
||||
print_log(f'save model in {args.save}')
|
||||
model.save_pretrained(args.save)
|
||||
torch.save(model.state_dict(), args.save)
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||
from mmrazor import digit_version
|
||||
from mmrazor.implementations.quantization import gptq
|
||||
|
||||
|
||||
class TestGPTQOps(unittest.TestCase):
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -30,7 +31,7 @@ class TestGPTQOps(unittest.TestCase):
|
|||
|
||||
linear = nn.Linear(12, 20, bias=False).to(device)
|
||||
gptq_linear = gptq.GPTQLinear(
|
||||
12, 20, bias=False).to(device)
|
||||
in_features=12, out_features=20, bias=False).to(device)
|
||||
gptq_linear.load_state_dict(linear.state_dict(), strict=False)
|
||||
|
||||
random_data = torch.rand([10, 5, 12]).to(
|
||||
|
@ -39,13 +40,13 @@ class TestGPTQOps(unittest.TestCase):
|
|||
|
||||
self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0)
|
||||
|
||||
# prune
|
||||
# quant
|
||||
|
||||
gptq_linear.init_hessian()
|
||||
gptq_linear.register_hessian_hook()
|
||||
infer(gptq_linear, random_data)
|
||||
gptq_linear.remove_hessian_hook()
|
||||
|
||||
|
||||
qconfig = dict(bits=4, perchannel=True, sym=False)
|
||||
quantizer = gptq.Quantizer()
|
||||
quantizer.configure(**qconfig)
|
||||
|
|
Loading…
Reference in New Issue