[Feature] Add GPTQ and uniform interfaces (#538)

* add gptq implementation

* pre-checkout

* passed resnet example

* passed llama example

* aglin gptq acc

* add activation quantization

* uniform interfaces

* add gptq readme

* update mmrazor_large redame

* add gptq opt example

* fix sparse_gpt example for opt

* fix import Protocol from py37

* fix error function name

* fix bug in test

* fix bug

* fix bug

* limit sparsegpt test with torch>=1.12

* add docstring for gptq and sparse_gpt

* pre-commit

* align acc & add save load ckpt & add ut

* fix ut

* fix ut

* fix ut

* fix ut & add torch2.0 for ci

* del torch2.0 for ci

* fix ut

---------

Co-authored-by: FIRST_NAME LAST_NAME <MY_NAME@example.com>
pull/542/head
humu789 2023-05-24 16:38:51 +08:00 committed by GitHub
parent a578fad2bc
commit dcf7bfa1a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 2419 additions and 140 deletions

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mutator import SparseGptMutator
from .compressor import SparseGptCompressor
from .ops import SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops
__all__ = [
'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
'SparseGptMutator'
'SparseGptCompressor'
]

View File

@ -8,6 +8,7 @@ from .utils import replace_with_dynamic_ops
def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
@ -15,17 +16,17 @@ def to_static_model(model: nn.Module):
return model
class SparseGptMutator():
# init
class SparseGptCompressor():
"""The compressor with SparseGPT."""
def __init__(self) -> None:
self.model: nn.Module = None
def prepare_from_supernet(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
def prepare(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
"""Prepare for compressing model."""
self.model = model
prune_modules: dict = {}
if prune_conv:
@ -36,19 +37,23 @@ class SparseGptMutator():
@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)
# hessian
def start_init_hessian(self):
def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.start_init_hessian()
module.register_hessian_hook()
def end_init_hessian(self):
def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.end_init_hessian()
module.remove_hessian_hook()
def init_hessian(self, device=None):
"""Init hessian."""
for op in self.sparse_ops:
op.init_hessian(device=device)
@ -60,6 +65,7 @@ class SparseGptMutator():
blocksize=128,
percdamp=.01,
device=torch.device('cuda')):
"""Apply the compression algorithm to the model."""
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
@ -78,12 +84,15 @@ class SparseGptMutator():
print_log(f'prune {name} failed as {e}')
def prune_24(self, device=torch.device('cuda:0')):
"""Apply the compression algorithm to the model with the specified
setting."""
self.prune(0.5, prunen=2, prunem=4, device=device)
# ops
@property
def sparse_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, SparseGptMixIn):
@ -91,6 +100,7 @@ class SparseGptMutator():
@property
def named_sparse_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, SparseGptMixIn):
yield name, module

View File

@ -1,5 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Protocol
import sys
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
import torch
import torch.distributed as dist
@ -12,10 +17,10 @@ from .utils import ModuleProtocol, torch_setting
class SparseGptMixIn(ModuleProtocol):
# init
"""The core algorithm implementation for SparseGpt."""
def _sparse_gpt_mix_in_init(self):
"""Init mixin."""
self.sparse_gpt_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
@ -32,6 +37,7 @@ class SparseGptMixIn(ModuleProtocol):
@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
"""Set weight."""
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
@ -64,6 +70,7 @@ class SparseGptMixIn(ModuleProtocol):
@hessian.setter
def hessian(self, value: torch.Tensor):
"""Set hessian."""
with torch.no_grad():
if dist.is_initialized():
if dist.get_rank() == 0:
@ -77,6 +84,7 @@ class SparseGptMixIn(ModuleProtocol):
@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
"""Update hessian."""
input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)
@ -94,7 +102,8 @@ class SparseGptMixIn(ModuleProtocol):
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B
def start_init_hessian(self):
def register_hessian_hook(self):
"""Register updating hessian hook."""
@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
@ -104,11 +113,13 @@ class SparseGptMixIn(ModuleProtocol):
handle = self.register_forward_pre_hook(forward_pre_hook)
self.sparse_gpt_handles.append(handle)
def end_init_hessian(self):
def remove_hessian_hook(self):
"""Remove updating hessian hook."""
for h in self.sparse_gpt_handles:
h.remove()
def init_hessian(self, device=None):
"""Init hessian."""
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
@ -125,6 +136,7 @@ class SparseGptMixIn(ModuleProtocol):
@torch.no_grad()
def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
"""The implementation for SparseGPT."""
with torch_setting(dtype=torch.float):
# Converted from https://github.com/ist-daslab/sparsegpt
@ -199,7 +211,8 @@ class SparseGptMixIn(ModuleProtocol):
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
if W.device.type == 'cuda':
torch.cuda.synchronize()
from .sparse24_utils import is_weight_sparse_24
if prunen == 2 and prunem == 4:
assert is_weight_sparse_24(
@ -218,6 +231,7 @@ class SparseGptMixIn(ModuleProtocol):
class SparseGptLinear(DynamicLinear, SparseGptMixIn):
"""Custom Linear for SparseGpt."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -225,6 +239,7 @@ class SparseGptLinear(DynamicLinear, SparseGptMixIn):
@classmethod
def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
if module.out_features < module.in_features:
return module
new_module = super().convert_from(module)
@ -237,6 +252,7 @@ class SparseGptLinear(DynamicLinear, SparseGptMixIn):
class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
"""Custom Conv2d for SparseGpt."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -244,6 +260,7 @@ class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
@ -253,6 +270,7 @@ class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
return new_module
def format_input(self, input: torch.Tensor):
"""Format input shape."""
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,

View File

@ -1,5 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Protocol, Type
import sys
from typing import Dict, Type
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
import torch
import torch.nn as nn
@ -9,21 +15,27 @@ from mmrazor.utils import print_log
class ModuleProtocol(Protocol):
"""Custom module protocol for algorithm mixin."""
weight: torch.Tensor
def forward(self, x):
"""The abstract method."""
pass
def register_forward_hook(self, hook):
"""The abstract method."""
pass
def register_backward_hook(self, hook):
"""The abstract method."""
pass
def register_forward_pre_hook(self, hook):
"""The abstract method."""
pass
def register_buffer(self, name, tensor):
"""The abstract method."""
pass
@ -47,6 +59,7 @@ def replace_with_dynamic_ops(model: nn.Module,
def register_efficient_forward_hook(module: nn.Module,
device=torch.device('cuda:0')):
"""Register efficient forward hook."""
def forward_pre_hook(module: nn.Module, input):
module.to(device)
@ -63,6 +76,7 @@ def register_efficient_forward_hook(module: nn.Module,
def enable_efficient_forward(model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]):
"""Enable efficient forward."""
handles = []
blocks = []
for name, module in model.named_children():
@ -79,6 +93,7 @@ def enable_efficient_forward(model: nn.Module,
class memory_efficient_forward:
"""The class for Memory efficient forward."""
def __init__(self,
model: nn.Module,
@ -95,6 +110,7 @@ class memory_efficient_forward:
model.to(device)
def __enter__(self, ):
"""Enter."""
if self.enabled:
handles, blocks = enable_efficient_forward(self.model, self.device,
self.wrap_modules)
@ -102,19 +118,23 @@ class memory_efficient_forward:
self.handlers = handles
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
for h in self.handlers:
h.remove()
class torch_setting():
"""Set the default torch dtype setting."""
def __init__(self, dtype=None) -> None:
self.origianl_dtype = torch.get_default_dtype()
self.original_dtype = torch.get_default_dtype()
self.dtype = dtype
def __enter__(self):
"""Enter."""
if self.dtype is not None:
torch.set_default_dtype(self.dtype)
def __exit__(self, exc_type, exc_value, exc_traceback):
torch.set_default_dtype(self.origianl_dtype)
"""Exit."""
torch.set_default_dtype(self.original_dtype)

View File

@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compressor import GPTQCompressor
from .gptq import GPTQMixIn
from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
from .quantizer import Quantizer
__all__ = [
'GPTQCompressor',
'GPTQMixIn',
'GPTQConv2d',
'GPTQLinear',
'TritonGPTQLinear',
'Quantizer',
]

View File

@ -0,0 +1,146 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Type
import torch
import torch.nn as nn
from mmrazor.utils import print_log
from .ops import GPTQConv2d, GPTQLinear, GPTQMixIn, TritonGPTQLinear
from .quantizer import Quantizer
def replace_with_dynamic_ops(model: nn.Module,
dynamicop_map: Dict[Type[nn.Module], Type[Any]],
skipped_layers=[],
a_qconfig=None,
**kwargs):
"""Replace torch modules with dynamic-ops."""
def replace_op(model: nn.Module, name: str, module: nn.Module):
names = name.split('.')
for sub_name in names[:-1]:
model = getattr(model, sub_name)
setattr(model, names[-1], module)
for name, module in model.named_modules():
if type(module) in dynamicop_map and name not in skipped_layers:
if isinstance(module, nn.Linear):
if a_qconfig:
a_fakequant = Quantizer()
a_fakequant.configure(**a_qconfig)
kwargs.update({'a_fakequant': a_fakequant})
new_module = dynamicop_map[type(module)].convert_from(
module, **kwargs)
else:
new_module = dynamicop_map[type(module)].convert_from(module)
replace_op(model, name, new_module)
def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
load_fix_subnet(model, fix_subnet)
return model
class GPTQCompressor():
"""The compressor with GPTQ."""
def __init__(self) -> None:
self.model: nn.Module = None
def prepare(self,
model: nn.Module,
quant_conv=True,
quant_linear=True,
use_triton_ops=True,
skipped_layers=[],
a_qconfig=None,
**kwargs) -> None:
"""Prepare for compressing model."""
self.model = model
quant_modules: dict = {}
if quant_conv:
quant_modules[nn.Conv2d] = GPTQConv2d
if quant_linear:
gptq_linear = TritonGPTQLinear if use_triton_ops else GPTQLinear
quant_modules[nn.Linear] = gptq_linear
replace_with_dynamic_ops(model, quant_modules, skipped_layers,
a_qconfig, **kwargs)
@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)
# hessian
def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.quant_ops:
module.register_hessian_hook()
def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.quant_ops:
module.remove_hessian_hook()
def init_hessian(self, device=None):
"""Init hessian."""
for op in self.quant_ops:
op.init_hessian(device=device)
# quant
def quant(self,
blocksize=128,
percdamp=0.01,
groupsize=-1,
actorder=False,
device=torch.device('cuda:0'),
**qconfig):
"""Apply the compression algorithm to the model."""
for name, module in self.named_quant_ops:
try:
original_device = next(module.parameters()).device
module: GPTQMixIn = module.to(device)
quantizer = Quantizer()
quantizer.configure(**qconfig)
# print_log(f'quant {name}...')
error = module.quant(
quantizer=quantizer,
blocksize=blocksize,
percdamp=percdamp,
groupsize=groupsize,
actorder=actorder)
print_log(f'quant {name} success \t error = {error}')
module.to(original_device)
module.free()
except Exception as e:
print_log(f'quant {name} failed as {e}')
def quant_with_default_qconfig(self, groupsize=128, device='cpu'):
"""Apply the compression algorithm to the model with the specified
setting."""
qconfig = dict(bits=4, perchannel=True, sym=False)
self.quant(
groupsize=groupsize, actorder=True, device=device, **qconfig)
# ops
@property
def quant_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, GPTQMixIn):
yield module
@property
def named_quant_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, GPTQMixIn):
yield name, module

View File

@ -0,0 +1,254 @@
# Copyright (c) OpenMMLab. All rights reserved.
# https://github.com/fpgaminer/GPTQ-triton
"""Mostly the same as the autotuner in Triton, but with a few changes like
using 40 runs instead of 100."""
import builtins
import math
import time
from typing import Dict
try:
import triton
except ImportError:
from mmrazor.utils import get_package_placeholder
triton = get_package_placeholder('triton >= 2.0.0')
class Autotuner(triton.KernelInterface):
"""Autotuner."""
def __init__(self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False):
'''prune_configs_by: a dict of functions that are used to prune
configs, fields:
'perf_model': performance model used to predicate running time
with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune
num_stages. It take configs:List[Config] as its input, and
returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments
to the nearest power of two when caching tuning results.'''
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache: Dict = {}
# hook to reset all required tensor to zeros before relaunching
# a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by[
'perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
"""Check for conflicts, i.e. meta-parameters both provided as kwargs
and by the autotuner."""
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current)
try:
# In testings using only 40 reps seems to be close enough and it
# appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any
# speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
except triton.compiler.OutOfResources:
return (float('inf'), float('inf'), float('inf'))
def run(self, *args, **kwargs):
"""Run."""
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to
# the nearest power of two
# In my testing this gives decent results, and greatly reduces
# the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2**int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs)
def prune_configs(self, kwargs):
"""Prune configs."""
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(
est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
"""Warm up."""
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(configs,
key,
prune_configs_by=None,
reset_to_zero=None,
nearest_power_of_two=False):
"""Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated
# anytime the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run
multiple time.This means that whatever value the kernel updates will
be updated multiple times.To avoid this undesired behavior, you can
use the `reset_to_zero` argument, which reset the value of the
provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger
the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune
configs, fields:
'perf_model': performance model used to predicate running time with
different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune
(eg, num_stages). It take configs:List[Config] as its input, and
returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset
to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero,
prune_configs_by, nearest_power_of_two)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""The main purpose of this function is to shrink BLOCK_SIZE_* when the
corresponding dimension is smaller."""
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
group_size_m = config.kwargs['GROUP_SIZE_M']
if (block_size_m, block_size_n, block_size_k, group_size_m,
config.num_stages, config.num_warps) in used:
continue
used.add((block_size_m, block_size_n, block_size_k, group_size_m,
config.num_stages, config.num_warps))
yield triton.Config(
{
'BLOCK_SIZE_M': block_size_m,
'BLOCK_SIZE_N': block_size_n,
'BLOCK_SIZE_K': block_size_k,
'GROUP_SIZE_M': group_size_m
},
num_stages=config.num_stages,
num_warps=config.num_warps)

View File

@ -0,0 +1,318 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
import numpy as np
import torch
import torch.distributed as dist
from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting
class ModuleProtocol(Protocol):
"""Custom module protocol for algorithm mixin."""
weight: torch.Tensor
def forward(self, x):
"""The abstract method."""
pass
def register_forward_hook(self, hook):
"""The abstract method."""
pass
def register_backward_hook(self, hook):
"""The abstract method."""
pass
def register_forward_pre_hook(self, hook):
"""The abstract method."""
pass
def register_buffer(self, name, tensor):
"""The abstract method."""
pass
class GPTQMixIn(ModuleProtocol):
"""The core algorithm implementation for GPTQ."""
def _gptq_mix_in_init(self):
"""Init mixin."""
self.gptq_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
self._hessian: torch.Tensor = None
self.hessian_batch = 0
# weight and input adaptive
@property
def weight_matrix(self):
"""Return weight with shape (out in)"""
return self.weight.flatten(1) # out in
@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
"""Set weight."""
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
self.weight.data.copy_(value)
def format_input(self, input: torch.Tensor):
"""Return input with shape (B N C)"""
if len(input.shape) == 2: # N C
input = input.unsqueeze(0) # 1 N C
return input
# compute hessian
@property
def hessian(self):
"""hessian always return float."""
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.'
hessian = self._hessian.to(self.weight_matrix.device)
else:
hessian = torch.zeros(
self.columns,
self.columns,
device=self.weight_matrix.device)
dist.broadcast(hessian, 0)
return hessian
else:
return self._hessian
@hessian.setter
def hessian(self, value: torch.Tensor):
"""Set hessian."""
with torch.no_grad():
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.' # noqa
self._hessian.data.copy_(
value.data.to(self._hessian.device))
else:
self._hessian = None
else:
self._hessian.data.copy_(value.data.to(self._hessian.device))
@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
"""Update hessian."""
input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)
assert len(input.shape) == 3
B = input.shape[0] # B N C
input = input.transpose(0, -1).flatten(1) # C D
H = input @ input.T * 2 # C C
if dist.is_initialized():
dist.all_reduce(H)
B *= dist.get_world_size()
H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B
def register_hessian_hook(self):
"""Register updating hessian hook."""
@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
assert len(input) == 1
self.update_hessian(input[0])
handle = self.register_forward_pre_hook(forward_pre_hook)
self.gptq_handles.append(handle)
def remove_hessian_hook(self):
"""Remove updating hessian hook."""
for h in self.gptq_handles:
h.remove()
def init_hessian(self, device=None):
"""Init hessian."""
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
else:
self._hessian = None
else:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
def pack(self, scales, zeros, g_idx=None):
"""Pack and update qparams with groupsize_idx."""
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if self.bias is not None:
self.bias.half()
intweight = []
for idx in range(self.in_features):
intweight.append(
torch.round(
(self.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) /
self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.cpu().numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]),
dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError('Only 2,4,8 bits are supported.')
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight).to(self.weight.device)
zeros -= 1
zeros = zeros.cpu().numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits),
dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError('Only 2,4,8 bits are supported.')
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros).to(self.weight.device)
@torch.no_grad()
def quant(self,
quantizer,
blocksize=128,
percdamp=0.01,
groupsize=-1,
actorder=False):
"""The implementation for GPTQ."""
with torch_setting(dtype=torch.float):
assert self.hessian is not None
W: torch.Tensor = self.weight_matrix.float() # out in
if not quantizer.ready():
quantizer.find_params(W, weight=True)
H = self.hessian.float().to(W.device)
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=W.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if (i1 + i) % groupsize == 0:
quantizer.find_params(
W[:, (i1 + i):(i1 + i + groupsize)],
weight=True)
if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(quantizer.scale)
zero.append(quantizer.zero)
now_idx += 1
q = quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2
err1 = (w - q) / d
W1[:,
i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
error = torch.sum(Losses).item()
groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if scale == []:
scale.append(quantizer.scale)
zero.append(quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
self.weight_matrix = Q.data.to(self.weight_matrix.dtype)
if self.is_custom_kernel:
self.pack(scale, zero, g_idx)
del self.weight
return error
def free(self):
"""Free some cache and memory."""
self._hessian = None
torch.cuda.empty_cache()

View File

@ -0,0 +1,566 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
DynamicLinear)
# from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting
from .gptq import GPTQMixIn
try:
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune':
custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M,
N, K, bits, maxq, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, stride_scales,
stride_zeros, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8
# times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk +
offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit
# word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused
# in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] *
stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs +
g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(
a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit
# values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[
None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True)
@triton.jit
def transpose_matmul_248_kernel(
a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits,
maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,
stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8
# times
b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk +
offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit
# word from B
scales_ptrs = scales_ptr + offs_n[
None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits
) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for n in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused
# in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(
a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit
# values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[
None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except: # noqa: E722
print('triton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
"""matmul248 function with matmul_248_kernel."""
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]),
device=input.device,
dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv( # noqa: E731
input.shape[0], META['BLOCK_SIZE_M']) * triton. # noqa: E731
cdiv( # noqa: E731
qweight.shape[1], META['BLOCK_SIZE_N']), ) # noqa: E731
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx,
input.shape[0], qweight.shape[1],
input.shape[1], bits, maxq, input.stride(0),
input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0),
output.stride(1), scales.stride(0),
qzeros.stride(0))
return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
"""transpose_matmul248 function with transpose_matmul_248_kernel."""
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim),
device=input.device,
dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) # noqa: E731
* triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) # noqa: E731
transpose_matmul_248_kernel[grid](input, qweight, output, scales,
qzeros, g_idx, input.shape[0],
qweight.shape[1], output_dim,
bits, maxq, input.stride(0),
input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0),
output.stride(1), scales.stride(0),
qzeros.stride(0))
return output
class QuantLinearFunction(torch.autograd.Function):
"""Custom QuantLinearFunction."""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
"""Custom forward."""
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
"""Custom backward."""
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales,
qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
class TritonGPTQLinear(nn.Module, GPTQMixIn):
"""Custom Linear for GPTQ with custom triton kernel."""
def __init__(self, bits, groupsize, weight, in_features, out_features,
bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError('Only 2,4,8 bits are supported.')
self.weight = weight
self.bias = bias
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else in_features
self.register_buffer(
'qweight',
torch.zeros((in_features // 32 * self.bits, out_features),
dtype=torch.int32))
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(
in_features / self.groupsize), out_features // 32 * self.bits),
dtype=torch.int32))
self.register_buffer(
'scales',
torch.zeros(
(math.ceil(in_features / self.groupsize), out_features),
dtype=torch.float16))
self.register_buffer(
'g_idx',
torch.tensor([i // self.groupsize for i in range(in_features)],
dtype=torch.int32))
self._gptq_mix_in_init()
@property
def is_custom_kernel(self):
"""Whether use custom kernel."""
return True
@classmethod
def convert_from(cls, module: nn.Linear, bits, groupsize):
"""Convert to cls from torch's module."""
new_module = cls(
bits,
groupsize,
weight=module.weight,
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias)
return new_module
def forward(self, x):
"""Custom forward."""
if torch.all(self.qweight == 0):
out = F.linear(x, self.weight, self.bias)
else:
# import pdb;pdb.set_trace()
out_shape = x.shape[:-1] + (self.out_features, )
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales,
self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
out = out.reshape(out_shape)
# import pdb;pdb.set_trace()
return out
class GPTQLinear(DynamicLinear, GPTQMixIn):
"""Custom Linear for GPTQ without custom triton kernel."""
def __init__(self, a_fakequant=None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._gptq_mix_in_init()
self.a_fakequant = a_fakequant
self.fix_qparams = False
@property
def is_custom_kernel(self):
"""Whether use custom kernel."""
return False
@classmethod
def convert_from(cls,
module: nn.Linear,
a_fakequant=None) -> 'DynamicLinear':
"""Convert to cls from torch's module."""
new_module = cls(
a_fakequant=a_fakequant,
in_features=module.in_features,
out_features=module.out_features,
bias=True if module.bias is not None else False)
new_module.load_state_dict(module.state_dict(), strict=False)
dtype = next(module.parameters()).dtype
new_module = new_module.to(dtype)
return new_module
def forward(self, input: Tensor) -> Tensor:
"""Custom forward."""
if self.a_fakequant:
dtype = self.weight.dtype
if not self.fix_qparams:
self.a_fakequant.find_params(input)
input = self.a_fakequant.quantize(input).to(dtype)
return super().forward(input)
class GPTQConv2d(DynamicConv2d, GPTQMixIn):
"""Custom Conv2d for GPTQ without custom triton kernel."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._gptq_mix_in_init()
@property
def is_custom_kernel(self):
"""Whether use custom kernel."""
return False
@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
dtype = next(module.parameters()).dtype
new_module = new_module.to(dtype)
return new_module
def format_input(self, input: torch.Tensor):
"""Format input shape."""
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,
stride=self.stride) # B C D
return input.transpose(-1, -2)

View File

@ -0,0 +1,144 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class Quantizer(nn.Module):
"""Quantizer for some basic quantization functions."""
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=.8,
trits=False):
"""Configure qconfig."""
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
self.scale = torch.zeros_like(self.scale)
def _quantize(self, x, scale, zero, maxq):
"""Fakequant."""
if maxq < 0:
return (x > scale / 2).float() * scale + (x <
zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
def find_params(self, x, weight=False):
"""Observe the specified data and calculate the qparams."""
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 /
scale1) if not self.sym else self.zero
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1),
self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
"""Fakequant."""
if self.ready():
return self._quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
"""Whether is enabled."""
return self.maxq > 0
def ready(self):
"""Whether is ready."""
return torch.all(self.scale != 0)

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py # noqa: E501
def torch_snr_error(y_pred: torch.Tensor,
y_real: torch.Tensor,
reduction: str = 'mean') -> torch.Tensor:
"""Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calculted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of
SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
y_pred = y_pred.type(torch.float32)
y_real = y_real.type(torch.float32)
if y_pred.shape != y_real.shape:
raise ValueError(
f'Can not compute snr loss for tensors with different shape. '
f'({y_pred.shape} and {y_real.shape})')
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == 'mean':
return torch.mean(snr)
elif reduction == 'sum':
return torch.sum(snr)
elif reduction == 'none':
return snr
else:
raise ValueError('Unsupported reduction method.')

View File

@ -2,33 +2,41 @@
<img src="../../resources/mmrazor-logo.png" width="600"/>
</div>
# MMRazor Examples for Large Models
# MMRazor for Large Models
## Introduction
MMRazor is dedicated to the development of general-purpose model compression tools. Now, MMRazor not only supports conventional CV model compression but also extends to support large models. This project will provide examples of MMRazor's compression for various large models, including LLama, stable diffusion, and more.
MMRazor is dedicated to the development of general-purpose model compression tools. Now, MMRazor not only supports conventional CV model compression but also extends to support large models. This project will provide examples of MMRazor's compression for various large models, including LLaMA, stable diffusion, and more.
## Installation
Code structure overview about large models.
```shell
pip install openmim
mim install mmcv
mim install mmengine
pip install git+https://github.com/open-mmlab/mmrazor.git
git clone github.com/open-mmlab/mmrazor-examples.git
```
mmrazor
├── implementations # core algorithm components
├── pruning
└── quantization
projects
└── mmrazor_large
├── algorithms # algorithms usage introduction
└── examples # examples for various models about algorithms
├── language_models
│ ├── LLaMA
│ └── OPT
└── ResNet
```
## Model-Algorithm Example Matrix
| | ResNet | OPT | LLama | Stable diffusion |
| ------------------------------------ | ---------------------------------------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------- | ---------------- |
| [SparseGPT](algorithms/SparseGPT.md) | [:white_check_mark:](examples/ResNet/sparse_gpt/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/Llama/README.md) | |
| | ResNet | OPT | LLama | Stable diffusion |
| ------------------------------------ | ----------------------------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------- | ---------------- |
| [SparseGPT](algorithms/SparseGPT.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | |
| [GPTQ](algorithms/GPTQ.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | |
## PaperList
We provide a paperlist for researchers in the field of model compression for large models. If you want to add your paper to this list, please submit a PR.
| Paper | Title | Type | MMRazor |
| --------- | --------------------------------------------------------------------------------------------------------------------- | ------- | --------------------------------------------- |
| SparseGPT | [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774) | Pruning | [:white_check_mark:](algorithms/SparseGPT.md) |
| GPTQ | [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) | Quant | |
| Paper | Title | Type | MMRazor |
| --------- | --------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------- |
| SparseGPT | [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774) | Pruning | [:white_check_mark:](algorithms/SparseGPT.md) |
| GPTQ | [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) | Quantization | [:white_check_mark:](algorithms/GPTQ.md) |

View File

@ -0,0 +1,56 @@
# GPTQ
> [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323)
<!-- [ALGORITHM] -->
## Abstract
Generative Pre-trained Transformer models, known as GPT or OPT, set themselves apart through breakthrough performance across complex language modelling tasks, but also by their extremely high computational and storage costs. Specifically, due to their massive size, even inference for large, highly-accurate GPT models may require multiple performant GPUs, which limits the usability of such models. While there is emerging work on relieving this pressure via model compression, the applicability and performance of existing compression techniques is limited by the scale and complexity of GPT models. In this paper, we address this challenge, and propose GPTQ, a new one-shot weight quantization method based on approximate second-order information, that is both highlyaccurate and highly-efficient. Specifically, GPTQ can quantize GPT models with 175 billion parameters in approximately four GPU hours, reducing the bitwidth down to 3 or 4 bits per weight, with negligible accuracy degradation relative to the uncompressed baseline. Our method more than doubles the compression gains relative to previously-proposed one-shot quantization methods, preserving accuracy, allowing us for the first time to execute an 175 billion-parameter model inside a single GPU for generative inference. Moreover, we also show that our method can still provide reasonable accuracy in the extreme quantization regime, in which weights are quantized to 2-bit or even ternary quantization levels. We show experimentally that these improvements can be leveraged for end-to-end inference speedups over FP16, of around 3.25x when using high-end GPUs (NVIDIA A100) and 4.5x when using more cost-effective ones (NVIDIA A6000). The implementation is available at https://github.com/IST-DASLab/gptq.
## Usage
GPTQ is easy to use in mmrazor. You can use it like this:
```python
from mmrazor.implementations.quantization import gptq
# initial model, dataloaders
model
train_loader, test_loader
## init gptq compressor and prepare for quantization
compressor = gptq.GPTQCompressor()
compressor.prepare(model)
## get hessian matrix
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples)
compressor.remove_hessian_hooks()
## quant
compressor.quant_with_default_qconfig()
## to a normal torch model
model = compressor.to_static_model(model)
```
## Full Examples
- [ResNet](../examples/ResNet/README.md)
- [LLaMA](../examples/language_models/LLaMA/README.md)
## Cite
```latex
@misc{
Frantar_Ashkboos_Hoefler_Alistarh_2022,
title={GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers},
author={Frantar, Elias and Ashkboos, Saleh and Hoefler, Torsten and Alistarh, Dan},
year={2022},
month={Oct},
language={en-US}
}
```

View File

@ -1,6 +1,10 @@
# SparseGPT
## abstract
> [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774)
<!-- [ALGORITHM] -->
## Abstract
We show for the first time that large-scale generative pretrained transformer (GPT) family models can be pruned to at least 50% sparsity in one-shot, without any retraining, at minimal loss of accuracy. This is achieved via a new pruning method called SparseGPT, specifically designed to work efficiently and accurately on massive GPT-family models. We can execute SparseGPT on the largest available open-source models, OPT-175B and BLOOM-176B, in under 4.5 hours, and can reach 60% unstructured sparsity with negligible increase in perplexity: remarkably, more than 100 billion weights from these models can be ignored at inference time. SparseGPT generalizes to semi-structured (2:4 and 4:8) patterns, and is compatible with weight quantization approaches.
@ -15,28 +19,29 @@ from mmrazor.implementations.pruning import sparse_gpt
model
train_loader, test_loader
## init sparse gpt mutator and prepare for pruning
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model)
## init sparse gpt compressor and prepare for pruning
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model)
## init hessian matrix
mutator.start_init_hessian()
## get hessian matrix
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples)
mutator.end_init_hessian()
compressor.remove_hessian_hooks()
## prune
mutator.prune_24()
compressor.prune_24()
## to a normal torch model
model = mutator.to_static_model(model)
model = compressor.to_static_model(model)
```
## Full Examples
- [ResNet](../examples/ResNet/sparse_gpt/README.md)
- [ResNet](../examples/ResNet/README.md)
- [OPT](../examples/language_models/OPT/README.md)
- [Llama](../examples/language_models/Llama/README.md)
- [LLaMA](../examples/language_models/LLaMA/README.md)
## Cite

View File

@ -0,0 +1,25 @@
# Examples for ResNet
## SparseGPT
For more details about SparseGPT, please refer to [SparseGPT](../../algorithms/SparseGPT.md)
### Usage
```shell
python projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512
```
**Note**: this imagenet folder follows torch format.
## GPTQ
For more details about GPTQ, please refer to [GPTQ](../../algorithms/GPTQ.md)
### Usage
```shell
python projects/mmrazor_large/examples/ResNet/resnet18_gptq.py --data {imagenet_path} --batchsize 128 --num_samples 512
```
**Note**: this imagenet folder follows torch format.

View File

@ -0,0 +1,187 @@
# Copyright (c) OpenMMLab. All rights reserved.
# model settings
import os.path as osp
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from mmrazor.implementations.quantization.gptq import (GPTQCompressor,
GPTQLinear)
from mmrazor.utils import print_log
def enable_observer_linear(model):
print_log('Enable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = False
def disable_observer_linear(model):
print_log('Disable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = True
def get_dataloaders(batch_size, n_workers, path=''):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
osp.join(path, 'train'),
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
test_dataset = datasets.ImageFolder(
osp.join(path, 'val'),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]),
)
dataloader_train = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=True,
)
dataloader_test = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_workers,
pin_memory=True,
)
return dataloader_train, dataloader_test
@torch.no_grad()
def eval(model: nn.Module,
dataloader_test: DataLoader,
device=torch.device('cuda:0'),
is_half=True):
total = 0
correct = 0
model.eval()
with torch.no_grad():
for x, y in dataloader_test:
x: torch.Tensor # type: ignore
y: torch.Tensor # type: ignore
x = x.to(device)
y = y.to(device)
if is_half:
x = x.half()
y = y.half()
outputs = model(x)
_, predicted = outputs.max(1)
correct += (y == predicted).long().sum()
total += y.numel()
acc = correct / total
return acc
@torch.no_grad()
def infer(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_samples=256,
device=torch.device('cuda:0'),
is_half=True):
model.eval()
with torch.no_grad():
accumulate_batch = 0
for x, _ in dataloader:
x = x.to(device)
if is_half:
x = x.half()
model(x)
B = x.shape[0]
accumulate_batch += B
if accumulate_batch > num_samples:
break
if __name__ == '__main__':
import argparse
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
'--data',
type=str,
default='data/imagenet_torch',
help='path to imagenet in torch folder format')
arg_parser.add_argument(
'--num_samples',
type=int,
default=512,
help='number of samples to estimate hessian matrix')
arg_parser.add_argument(
'--batch_size',
type=int,
default=128,
help='batch size for evaluation and inference')
arg_parser.add_argument(
'--fp16',
type=bool,
default=False,
help='whether to use fp16 for evaluation and inference')
args = arg_parser.parse_args()
data_path = args.data
num_samples = args.num_samples
batch_size = args.batch_size
model = torchvision.models.resnet18(pretrained=True)
if args.fp16:
model = model.half()
train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)
compressor = GPTQCompressor()
# # use_triton_ops is True
# compressor.prepare(model,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# skipped_layers=['conv1'],
# bits=4,
# groupsize=128)
# # quantize activation for linear
# a_qconfig = dict(bits=4, perchannel=True, sym=False)
compressor.prepare(
model,
quant_conv=True,
quant_linear=True,
use_triton_ops=False,
skipped_layers=['conv1'],
# a_qconfig=a_qconfig
)
model.cuda()
enable_observer_linear(model)
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples, is_half=args.fp16)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig()
print('start evaluation')
disable_observer_linear(model)
model = model.cuda()
acc = eval(model, test_loader, is_half=args.fp16)
print('accuracy:', acc.item())

View File

@ -119,17 +119,17 @@ if __name__ == '__main__':
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model)
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model)
model.cuda()
mutator.init_hessian()
mutator.start_init_hessian()
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples)
mutator.end_init_hessian()
mutator.prune_24()
model = mutator.to_static_model(model)
compressor.remove_hessian_hooks()
compressor.prune_24()
model = compressor.to_static_model(model)
print('start evaluation')
model = model.cuda()

View File

@ -1,11 +0,0 @@
# SparseGPT for ResNet
For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
## Usage
```shell
python examples/model_examples/ResNet/sparse_gpt/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512
```
**Note**: this imagenet folder follows torch format.

View File

@ -0,0 +1,55 @@
# Examples for LLaMA
## SparseGPT
For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
### Usage
```shell
# example for decapoda-research/llama-7b-hf
python projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py decapoda-research/llama-7b-hf c4
# help
usage: llama_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model Llama model to load
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```
## GPTQ
For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md)
### Usage
```shell
# example for decapoda-research/llama-7b-hf
python projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py decapoda-research/llama-7b-hf c4
# help
usage: llama_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model Llama model to load
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```

View File

@ -0,0 +1,162 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datautils import get_loaders
from transformers.models.llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log
def enable_observer_linear(model):
print_log('Enable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = False
def disable_observer_linear(model):
print_log('Disable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = True
def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight
def get_model(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
model,
torch_dtype='auto',
)
model.seqlen = 2048
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
default=False,
help='Whether to enable memory efficient forward')
args = parser.parse_args()
DEV = args.dev
model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')
from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# use_triton_ops is True
compressor.prepare(
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)
# # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.layers,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# # a_qconfig=a_qconfig
# )
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model,
wrap_modules=[LlamaDecoderLayer],
enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m,
device=DEV):
# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
torch.save(model.state_dict(), args.save)

View File

@ -51,7 +51,7 @@ if __name__ == '__main__':
parser.add_argument(
'--batch_size',
type=int,
default=64,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
@ -71,26 +71,27 @@ if __name__ == '__main__':
print_log('load model over')
dataloader, testloader = get_loaders(
'c4', seed=args.seed, model=args.model, seqlen=model.seqlen)
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')
from mmrazor.implementations.pruning import sparse_gpt
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model.model.layers)
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model.model.layers)
compressor.init_hessian()
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
mutator.start_init_hessian()
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
mutator.end_init_hessian()
mutator.prune_24()
compressor.remove_hessian_hooks()
compressor.prune_24()
model = mutator.to_static_model(model)
model = compressor.to_static_model(model)
if args.save:
print_log(f'save model in {args.save}')
model.save_pretrained(args.save)

View File

@ -77,13 +77,13 @@ def main(rank, world_size=8, args=None):
def build():
model = get_model(model_name)
# init mutator
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model.model.layers)
return model, mutator
# init compressor
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model.model.layers)
return model, compressor
with init_on_meta(enable=True):
model, mutator = build()
model, compressor = build()
if rank == 0:
model_copy, _ = build() # init on cpu
@ -106,8 +106,8 @@ def main(rank, world_size=8, args=None):
# init hessian
mutator.init_hessian(device='cuda')
mutator.start_init_hessian()
compressor.init_hessian(device='cuda')
compressor.register_hessian_hooks()
_, testloader = get_loaders(
args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
@ -115,7 +115,7 @@ def main(rank, world_size=8, args=None):
testloader, world_size, rank, model, batch_size=batch_size)
opt_infer_fsdp(model, testloader)
mutator.end_init_hessian()
compressor.remove_hessian_hooks()
# prune
name2module = dict(model.named_modules())

View File

@ -71,6 +71,7 @@ def opt_infer(
testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N
use_cache = model.config.use_cache
model.config.use_cache = False
for i, batch in enumerate(torch.split(testenc, batch_size)):
@ -80,6 +81,7 @@ def opt_infer(
if (i + 1) * batch_size >= num_samples:
break
model.config.use_cache = use_cache
class init_on_meta:

View File

@ -1,29 +0,0 @@
# Llama
## SparseGPT for LL
For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
### Usage
```shell
python examples/model_examples/language_models/Llama/llama_sparse_gpt.py -h
usage: llama_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M]
model {wikitext2,ptb,c4}
positional arguments:
model Llama model to load
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
# For example, prune decapoda-research/llama-7b-hf
python examples/model_examples/language_models/Llama/llama_sparse_gpt.py decapoda-research/llama-7b-hf c4
```

View File

@ -1,15 +1,44 @@
# OPT
# Examples for OPT
## SparseGPT for OPT
## SparseGPT
For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
### Usage
```shell
python examples/model_examples/language_models/OPT/opt_sparse_gpt.py -h
usage: opt_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M]
model {wikitext2,ptb,c4}
# example for facebook/opt-125m
python projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py facebook/opt-125m c4
# help
usage: opt_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model OPT model to load; pass `facebook/opt-X`.
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```
## GPTQ
For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md)
### Usage
```shell
# example for facebook/opt-125m
python projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py facebook/opt-125m c4
# help
usage: opt_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model OPT model to load; pass `facebook/opt-X`.
@ -23,7 +52,4 @@ optional arguments:
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
# For example, prune facebook/opt-125m
python examples/model_examples/language_models/OPT/opt_sparse_gpt.py facebook/opt-125m c4
```

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
import torch
from datautils import get_loaders
from transformers import OPTForCausalLM
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log
def enable_observer_linear(model):
print_log('Enable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = False
def disable_observer_linear(model):
print_log('Disable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = True
def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight
def get_model(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = model.config.max_position_embeddings
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
default=False,
help='Whether to enable memory efficient forward')
args = parser.parse_args()
DEV = args.dev
model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')
from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# use_triton_ops is True
compressor.prepare(
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)
# # # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.decoder,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# # a_qconfig=a_qconfig
# )
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV):
# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
torch.save(model.state_dict(), args.save)

View File

@ -69,27 +69,28 @@ if __name__ == '__main__':
print_log('load model over')
dataloader, testloader = get_loaders(
'c4', seed=args.seed, model=args.model, seqlen=model.seqlen)
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')
from mmrazor.implementations.pruning import sparse_gpt
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model.model.decoder)
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model.model.decoder)
compressor.init_hessian()
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
mutator.start_init_hessian()
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
mutator.end_init_hessian()
mutator.prune_24()
compressor.remove_hessian_hooks()
compressor.prune_24()
model = mutator.to_static_model(model)
model = compressor.to_static_model(model)
if args.save:
print_log(f'save model in {args.save}')
model.save_pretrained(args.save)

View File

@ -78,8 +78,8 @@ def main(rank, world_size=8, args=None):
model = get_model(model_name)
# init mutator
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model.model.decoder)
mutator = sparse_gpt.SparseGptCompressor()
mutator.prepare(model.model.decoder)
return model, mutator
with init_on_meta(enable=True):
@ -107,7 +107,7 @@ def main(rank, world_size=8, args=None):
# init hessian
mutator.init_hessian(device='cuda')
mutator.start_init_hessian()
mutator.register_hessian_hooks(model)
_, testloader = get_loaders(
args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
@ -115,7 +115,7 @@ def main(rank, world_size=8, args=None):
testloader, world_size, rank, model, batch_size=batch_size)
opt_infer_fsdp(model, testloader)
mutator.end_init_hessian()
mutator.remove_hessian_hooks()
# prune
name2module = dict(model.named_modules())

View File

@ -7,5 +7,6 @@ nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
onnx
pytest
triton==2.0.0
xdoctest >= 0.10.0
yapf

View File

@ -4,6 +4,7 @@ import unittest
import torch
import torch.nn as nn
from mmrazor import digit_version
from mmrazor.implementations.pruning import sparse_gpt
@ -11,6 +12,8 @@ class TestSparseGptOps(unittest.TestCase):
@torch.no_grad()
def test_op(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('torch<1.12.0')
def get_loss(linear, linear1, data):
y = linear(data)
@ -21,7 +24,7 @@ class TestSparseGptOps(unittest.TestCase):
for x in dataset:
model(x)
for device in ['cpu', 'cuda']:
for device in ['cpu']:
device = torch.device(device)
# prepare
@ -31,7 +34,7 @@ class TestSparseGptOps(unittest.TestCase):
12, 20, bias=False).to(device)
sparse_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([100, 5, 12]).to(
random_data = torch.rand([10, 5, 12]).to(
device) # [loader_batch,batch,feature]
data_0 = random_data[0]
@ -39,11 +42,12 @@ class TestSparseGptOps(unittest.TestCase):
# prune
sparse_linear.start_init_hessian()
sparse_linear.init_hessian()
sparse_linear.register_hessian_hook()
infer(sparse_linear, random_data)
sparse_linear.end_init_hessian()
sparse_linear.remove_hessian_hook()
sparse_linear.prune()
sparse_linear.prune(0.5)
# compare
@ -52,16 +56,19 @@ class TestSparseGptOps(unittest.TestCase):
@torch.no_grad()
def test_model(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('torch<1.12.0')
import torchvision
model = torchvision.models.resnet18()
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model)
mutator = sparse_gpt.SparseGptCompressor()
mutator.prepare(model)
x = torch.rand(10, 3, 224, 224)
mutator.start_init_hessian()
mutator.init_hessian()
mutator.register_hessian_hooks()
model(x)
mutator.end_init_hessian()
mutator.remove_hessian_hooks()
mutator.prune_24()
model = mutator.to_static_model(model)

View File

@ -0,0 +1,80 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
import torch.nn as nn
from mmrazor import digit_version
from mmrazor.implementations.quantization import gptq
class TestGPTQOps(unittest.TestCase):
@torch.no_grad()
def test_op(self):
if digit_version(torch.__version__) < digit_version(
'1.12.0') or not torch.cuda.is_available():
self.skipTest('torch<1.12.0')
def get_loss(linear, linear1, data):
y = linear(data)
y1 = linear1(data)
return (y - y1).square().sum()
def infer(model, dataset):
for x in dataset:
model(x)
for device in ['cpu']:
device = torch.device(device)
# prepare
linear = nn.Linear(12, 20, bias=False).to(device)
gptq_linear = gptq.GPTQLinear(
in_features=12, out_features=20, bias=False).to(device)
gptq_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([10, 5, 12]).to(
device) # [loader_batch,batch,feature]
data_0 = random_data[0]
self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0)
# quant
gptq_linear.init_hessian()
gptq_linear.register_hessian_hook()
infer(gptq_linear, random_data)
gptq_linear.remove_hessian_hook()
qconfig = dict(bits=4, perchannel=True, sym=False)
quantizer = gptq.Quantizer()
quantizer.configure(**qconfig)
gptq_linear.quant(quantizer=quantizer)
# compare
print('norm:', linear(data_0).norm(2))
print('distance:', get_loss(linear, gptq_linear, data_0))
@torch.no_grad()
def test_model(self):
if digit_version(torch.__version__) < digit_version(
'1.12.0') or not torch.cuda.is_available():
self.skipTest('torch<1.12.0')
import torchvision
model = torchvision.models.resnet18()
compressor = gptq.GPTQCompressor()
compressor.prepare(model, use_triton_ops=False)
x = torch.rand(10, 3, 224, 224)
compressor.init_hessian()
compressor.register_hessian_hooks()
model(x)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig()
model = compressor.to_static_model(model)
assert type(model.conv1) is nn.Conv2d