diff --git a/mmrazor/implementations/quantization/gptq/compressor.py b/mmrazor/implementations/quantization/gptq/compressor.py index 5fca4fe8..4a5aadd8 100644 --- a/mmrazor/implementations/quantization/gptq/compressor.py +++ b/mmrazor/implementations/quantization/gptq/compressor.py @@ -108,6 +108,7 @@ class GPTQCompressor(): module: GPTQMixIn = module.to(device) quantizer = Quantizer() quantizer.configure(**qconfig) + # print_log(f'quant {name}...') error = module.quant( quantizer=quantizer, blocksize=blocksize, diff --git a/mmrazor/implementations/quantization/gptq/gptq.py b/mmrazor/implementations/quantization/gptq/gptq.py index 313b0b2b..84cfd3a4 100644 --- a/mmrazor/implementations/quantization/gptq/gptq.py +++ b/mmrazor/implementations/quantization/gptq/gptq.py @@ -173,7 +173,7 @@ class GPTQMixIn(ModuleProtocol): self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) + intweight = intweight.cpu().numpy().astype(np.uint32) qweight = np.zeros( (intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) @@ -189,10 +189,10 @@ class GPTQMixIn(ModuleProtocol): raise NotImplementedError('Only 2,4,8 bits are supported.') qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) + self.qweight = torch.from_numpy(qweight).to(self.weight.device) zeros -= 1 - zeros = zeros.numpy().astype(np.uint32) + zeros = zeros.cpu().numpy().astype(np.uint32) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 @@ -207,7 +207,7 @@ class GPTQMixIn(ModuleProtocol): raise NotImplementedError('Only 2,4,8 bits are supported.') qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) + self.qzeros = torch.from_numpy(qzeros).to(self.weight.device) @torch.no_grad() def quant(self, @@ -298,7 +298,6 @@ class GPTQMixIn(ModuleProtocol): g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if actorder: invperm = torch.argsort(perm) - W = W[:, invperm] Q = Q[:, invperm] g_idx = g_idx[invperm] @@ -307,10 +306,10 @@ class GPTQMixIn(ModuleProtocol): zero.append(quantizer.zero) scale = torch.cat(scale, dim=1) zero = torch.cat(zero, dim=1) - self.weight_matrix = Q.data + self.weight_matrix = Q.data.to(self.weight_matrix.dtype) if self.is_custom_kernel: self.pack(scale, zero, g_idx) - + del self.weight return error def free(self): diff --git a/mmrazor/implementations/quantization/gptq/ops.py b/mmrazor/implementations/quantization/gptq/ops.py index a2e50637..590febc0 100644 --- a/mmrazor/implementations/quantization/gptq/ops.py +++ b/mmrazor/implementations/quantization/gptq/ops.py @@ -479,12 +479,18 @@ class TritonGPTQLinear(nn.Module, GPTQMixIn): def forward(self, x): """Custom forward.""" - out_shape = x.shape[:-1] + (self.out_features, ) - out = QuantLinearFunction.apply( - x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, - self.g_idx, self.bits, self.maxq) - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) + if torch.all(self.qweight == 0): + out = F.linear(x, self.weight, self.bias) + else: + # import pdb;pdb.set_trace() + out_shape = x.shape[:-1] + (self.out_features, ) + out = QuantLinearFunction.apply( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, + self.qzeros, self.g_idx, self.bits, self.maxq) + out = out + self.bias if self.bias is not None else out + out = out.reshape(out_shape) + # import pdb;pdb.set_trace() + return out class GPTQLinear(DynamicLinear, GPTQMixIn): diff --git a/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py b/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py index 0e329e9f..0eae9b4f 100644 --- a/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py +++ b/projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py @@ -7,7 +7,8 @@ from utils import opt_eval, opt_infer from mmrazor.implementations.pruning.sparse_gpt.utils import \ memory_efficient_forward -from mmrazor.implementations.quantization.gptq import GPTQLinear +from mmrazor.implementations.quantization.gptq import (GPTQLinear, + TritonGPTQLinear) from mmrazor.utils import print_log @@ -25,6 +26,13 @@ def disable_observer_linear(model): module.fix_qparams = True +def del_redundant_attr(model): + print_log('Del redundant weight for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, TritonGPTQLinear): + del module.weight + + def get_model(model): def skip(*args, **kwargs): @@ -48,7 +56,7 @@ if __name__ == '__main__': parser.add_argument('model', type=str, help='Llama model to load') parser.add_argument( - 'dataset', + '--dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') @@ -69,6 +77,10 @@ if __name__ == '__main__': help='Batchsize for calibration and evaluation.') parser.add_argument( '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--quant_ckpt', type=str, default='', help='Quantized ckpt to load.') + parser.add_argument( + '--dev', type=str, default='cuda:0', help='Use which device.') parser.add_argument( '-m', type=bool, @@ -77,28 +89,25 @@ if __name__ == '__main__': args = parser.parse_args() - DEV = torch.device('cuda:0') + DEV = args.dev model = get_model(args.model) model.to(DEV) model.eval() print_log('load model over') - dataloader, testloader = get_loaders( - args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) - print_log('load data for infer over') - from mmrazor.implementations.quantization import gptq compressor = gptq.GPTQCompressor() # use_triton_ops is True - compressor.prepare(model.model.layers, - quant_conv=True, - use_triton_ops=True, - quant_linear=True, - bits=4, - groupsize=128) + compressor.prepare( + model.model.layers, + quant_conv=True, + use_triton_ops=True, + quant_linear=True, + bits=4, + groupsize=128) - # # # quantize activation for linear + # # quantize activation for linear # # a_qconfig = dict(bits=4, perchannel=False, sym=False) # compressor.prepare( # model.model.layers, @@ -108,23 +117,38 @@ if __name__ == '__main__': # # a_qconfig=a_qconfig # ) - compressor.init_hessian() - enable_observer_linear(model) - with memory_efficient_forward( - model, wrap_modules=[LlamaDecoderLayer], enabled=args.m): - compressor.register_hessian_hooks() - opt_infer( - model, - testloader, - DEV, - batch_size=args.batch_size, - num_samples=args.nsamples) - compressor.remove_hessian_hooks() - compressor.quant_with_default_qconfig(device=DEV) + if args.quant_ckpt: + del_redundant_attr(model) + model.load_state_dict(torch.load(args.quant_ckpt)) + else: + dataloader, testloader = get_loaders( + args.dataset, + seed=args.seed, + model=args.model, + seqlen=model.seqlen) + print_log('load data for infer over') + + compressor.init_hessian() + enable_observer_linear(model) + with memory_efficient_forward( + model, + wrap_modules=[LlamaDecoderLayer], + enabled=args.m, + device=DEV): + compressor.register_hessian_hooks() + opt_infer( + model, + testloader, + DEV, + batch_size=args.batch_size, + num_samples=args.nsamples) + compressor.remove_hessian_hooks() + compressor.quant_with_default_qconfig(device=DEV) disable_observer_linear(model) with memory_efficient_forward( - model, wrap_modules=[LlamaDecoderLayer], enabled=args.m): + model, wrap_modules=[LlamaDecoderLayer], enabled=args.m, + device=DEV): # for dataset in ['wikitext2', 'ptb', 'c4']: for dataset in ['wikitext2']: @@ -133,8 +157,6 @@ if __name__ == '__main__': print_log(dataset) opt_eval(model, testloader, DEV, batch_size=args.batch_size) - if args.save: - # model = compressor.to_static_model(model) + if args.save and not args.quant_ckpt: print_log(f'save model in {args.save}') - # model.save_pretrained(args.save) torch.save(model.state_dict(), args.save) diff --git a/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py b/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py index 20fb6eab..5cd48e56 100644 --- a/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py +++ b/projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py @@ -8,7 +8,8 @@ from utils import opt_eval, opt_infer from mmrazor.implementations.pruning.sparse_gpt.utils import \ memory_efficient_forward -from mmrazor.implementations.quantization.gptq import GPTQLinear +from mmrazor.implementations.quantization.gptq import (GPTQLinear, + TritonGPTQLinear) from mmrazor.utils import print_log @@ -26,6 +27,13 @@ def disable_observer_linear(model): module.fix_qparams = True +def del_redundant_attr(model): + print_log('Del redundant weight for GPTQLinear!') + for _, module in model.named_modules(): + if isinstance(module, TritonGPTQLinear): + del module.weight + + def get_model(model): def skip(*args, **kwargs): @@ -44,10 +52,9 @@ if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() + parser.add_argument('model', type=str, help='Llama model to load') parser.add_argument( - 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.') - parser.add_argument( - 'dataset', + '--dataset', type=str, choices=['wikitext2', 'ptb', 'c4'], help='Where to extract calibration data from.') @@ -64,10 +71,14 @@ if __name__ == '__main__': parser.add_argument( '--batch_size', type=int, - default=64, + default=16, help='Batchsize for calibration and evaluation.') parser.add_argument( '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--quant_ckpt', type=str, default='', help='Quantized ckpt to load.') + parser.add_argument( + '--dev', type=str, default='cuda:0', help='Use which device.') parser.add_argument( '-m', type=bool, @@ -76,53 +87,63 @@ if __name__ == '__main__': args = parser.parse_args() - DEV = torch.device('cuda:0') + DEV = args.dev model = get_model(args.model) + model.to(DEV) model.eval() print_log('load model over') - dataloader, testloader = get_loaders( - args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen) - print_log('load data for infer over') - from mmrazor.implementations.quantization import gptq compressor = gptq.GPTQCompressor() - # # use_triton_ops is True - # compressor.prepare(model.model.layers, - # quant_conv=True, - # use_triton_ops=True, - # quant_linear=True, - # bits=4, - # groupsize=128) - - # # quantize activation for linear - # a_qconfig = dict(bits=4, perchannel=False, sym=False) + # use_triton_ops is True compressor.prepare( - model.model.decoder, + model.model.layers, quant_conv=True, + use_triton_ops=True, quant_linear=True, - use_triton_ops=False, - # a_qconfig=a_qconfig - ) + bits=4, + groupsize=128) - compressor.init_hessian() - enable_observer_linear(model) - with memory_efficient_forward( - model, wrap_modules=[OPTDecoderLayer], enabled=args.m): - compressor.register_hessian_hooks() - opt_infer( - model, - testloader, - DEV, - batch_size=args.batch_size, - num_samples=args.nsamples) - compressor.remove_hessian_hooks() - compressor.quant_with_default_qconfig(device=DEV) + # # # quantize activation for linear + # # a_qconfig = dict(bits=4, perchannel=False, sym=False) + # compressor.prepare( + # model.model.decoder, + # quant_conv=True, + # quant_linear=True, + # use_triton_ops=False, + # # a_qconfig=a_qconfig + # ) + + if args.quant_ckpt: + del_redundant_attr(model) + model.load_state_dict(torch.load(args.quant_ckpt)) + else: + dataloader, testloader = get_loaders( + args.dataset, + seed=args.seed, + model=args.model, + seqlen=model.seqlen) + print_log('load data for infer over') + + compressor.init_hessian() + enable_observer_linear(model) + with memory_efficient_forward( + model, wrap_modules=[OPTDecoderLayer], enabled=args.m, + device=DEV): + compressor.register_hessian_hooks() + opt_infer( + model, + testloader, + DEV, + batch_size=args.batch_size, + num_samples=args.nsamples) + compressor.remove_hessian_hooks() + compressor.quant_with_default_qconfig(device=DEV) disable_observer_linear(model) with memory_efficient_forward( - model, wrap_modules=[OPTDecoderLayer], enabled=args.m): + model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV): # for dataset in ['wikitext2', 'ptb', 'c4']: for dataset in ['wikitext2']: @@ -131,7 +152,6 @@ if __name__ == '__main__': print_log(dataset) opt_eval(model, testloader, DEV, batch_size=args.batch_size) - if args.save: - model = compressor.to_static_model(model) + if args.save and not args.quant_ckpt: print_log(f'save model in {args.save}') - model.save_pretrained(args.save) + torch.save(model.state_dict(), args.save) diff --git a/tests/test_impl/test_quantization/test_gptq/test_op.py b/tests/test_impl/test_quantization/test_gptq/test_op.py index 5ca4bef0..fd7f411f 100644 --- a/tests/test_impl/test_quantization/test_gptq/test_op.py +++ b/tests/test_impl/test_quantization/test_gptq/test_op.py @@ -7,6 +7,7 @@ import torch.nn as nn from mmrazor import digit_version from mmrazor.implementations.quantization import gptq + class TestGPTQOps(unittest.TestCase): @torch.no_grad() @@ -30,7 +31,7 @@ class TestGPTQOps(unittest.TestCase): linear = nn.Linear(12, 20, bias=False).to(device) gptq_linear = gptq.GPTQLinear( - 12, 20, bias=False).to(device) + in_features=12, out_features=20, bias=False).to(device) gptq_linear.load_state_dict(linear.state_dict(), strict=False) random_data = torch.rand([10, 5, 12]).to( @@ -39,13 +40,13 @@ class TestGPTQOps(unittest.TestCase): self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0) - # prune + # quant gptq_linear.init_hessian() gptq_linear.register_hessian_hook() infer(gptq_linear, random_data) gptq_linear.remove_hessian_hook() - + qconfig = dict(bits=4, perchannel=True, sym=False) quantizer = gptq.Quantizer() quantizer.configure(**qconfig)