[Feature] Merge dev-large into main (#543)

* add sparse gpt (#499)

init

Co-authored-by: liukai <your_email@abc.example>

* enhence sparsegpt (#505)

* update

* fix bug

* fix bug

* update opt

* add memory efficient forward for opt

* support to set device for pruning

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>

* Lk large (#510)

* update

* update

---------

Co-authored-by: liukai <your_email@abc.example>

* refine sparse gpt, support multiple gpus with fsdp (#520)

* add mmrazor large

* update readme

* add fsdp for opt

* update

* update

* rename

* update args

* support fsdp

* refine

* refine

* refine

* refine

* fix out of memorry bug

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>

* refine sparse gpt (#526)

* save cpu memory

* update

* update

* update

* update

* refine

* update

* update

---------

Co-authored-by: Your Name <you@example.com>

* merge main (#527)

* fix bug for autoslim (#511)

* fix bug for autoslim

* delete resnet50 for dmcp

---------

Co-authored-by: liukai <your_email@abc.example>

* Add timm (#512)

* add timm to optional.txt

* fix deit paths

* [Feature] Add MMRazor quantization (#513)

* [FEATURE] add quant algo `Learned Step Size Quantization` (#346)

* update

* Fix a bug in make_divisible. (#333)

fix bug in make_divisible

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Fix counter mapping bug (#331)

* fix counter mapping bug

* move judgment into get_counter_type & update UT

* [Docs]Add MMYOLO projects link (#334)

* [Doc] fix typos in en/usr_guides (#299)

* Update README.md

* Update README_zh-CN.md

Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>

* [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320)

* support MethodInputsRecorder and FunctionInputsRecorder

* fix bugs that the model can not be pickled

* WIP: add pytest for ema model

* fix bugs in recorder and delivery when ema_hook is used

* don't register the DummyDataset

* fix pytest

* updated

* retina loss & predict & tesnor DONE

* [Feature] Add deit-base (#332)

* WIP: support deit

* WIP: add deithead

* WIP: fix checkpoint hook

* fix data preprocessor

* fix cfg

* WIP: add readme

* reset single_teacher_distill

* add metafile

* add model to model-index

* fix configs and readme

* [Feature]Feature map visualization (#293)

* WIP: vis

* WIP: add visualization

* WIP: add visualization hook

* WIP: support razor visualizer

* WIP

* WIP: wrap draw_featmap

* support feature map visualization

* add a demo image for visualization

* fix typos

* change eps to 1e-6

* add pytest for visualization

* fix vis hook

* fix arguments' name

* fix img path

* support draw inference results

* add visualization doc

* fix figure url

* move files

Co-authored-by: weihan cao <HIT-cwh>

* [Feature] Add kd examples (#305)

* support kd for mbv2 and shufflenetv2

* WIP: fix ckpt path

* WIP: fix kd r34-r18

* add metafile

* fix metafile

* delete

* [Doc] add documents about pruning. (#313)

* init

* update user guide

* update images

* update

* update How to prune your model

* update how_to_use_config_tool_of_pruning.md

* update doc

* move location

* update

* update

* update

* add mutablechannels.md

* add references

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304)

* add pkd

* add pytest for pkd

* fix cfg

* WIP: support fcos3d

* WIP: support fcos3d pkd

* support mmdet3d

* fix cfgs

* change eps to 1e-6 and add some comments

* fix docstring

* fix cfg

* add assert

* add type hint

* WIP: add readme and metafile

* fix readme

* update metafiles and readme

* fix metafile

* fix pipeline figure

* for RFC

* Customed FX initialize

* add UT init

* [Refactor] Refactor Mutables and Mutators (#324)

* refactor mutables

* update load fix subnet

* add DumpChosen Typehint

* adapt UTs

* fix lint

* Add GroupMixin to ChannelMutator (temporarily)

* fix type hints

* add GroupMixin doc-string

* modified by comments

* fix type hits

* update subnet format

* fix channel group bugs and add UTs

* fix doc string

* fix comments

* refactor diff module forward

* fix error in channel mutator doc

* fix comments

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Update readme (#341)

* update kl readme

* update dsnas readme

* fix url

* Bump version to 1.0.0rc1 (#338)

update version

* init demo

* add customer_tracer

* add quantizer

* add fake_quant, loop, config

* remove CPatcher in custome_tracer

* demo_try

* init version

* modified base.py

* pre-rebase

* wip of adaround series

* adaround experiment

* trasfer to s2

* update api

* point at sub_reconstruction

* pre-checkout

* export onnx

* add customtracer

* fix lint

* move custom tracer

* fix import

* TDO: UTs

* Successfully RUN

* update loop

* update loop docstrings

* update quantizer docstrings

* update qscheme docstrings

* update qobserver docstrings

* update tracer docstrings

* update UTs init

* update UTs init

* fix review comments

* fix CI

* fix UTs

* update torch requirements

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: humu789 <humu@pjlab.org.cn>

* [Features]Quantize pipeline (#350)

* init demo

* add customer_tracer

* add quantizer

* add fake_quant, loop, config

* remove CPatcher in custome_tracer

* demo_try

* init version

* modified base.py

* pre-rebase

* wip of adaround series

* adaround experiment

* trasfer to s2

* update api

* point at sub_reconstruction

* pre-checkout

* export onnx

* add customtracer

* fix lint

* move custom tracer

* fix import

* update

* updated

* retina loss & predict & tesnor DONE

* for RFC

* Customed FX initialize

* add UT init

* TDO: UTs

* Successfully RUN

* update loop

* update loop docstrings

* update quantizer docstrings

* update qscheme docstrings

* update qobserver docstrings

* update tracer docstrings

* update UTs init

* update UTs init

* fix bugs

* fix lsq

* refactor quantize pipeline

* fix quant

* WIP: debug qat

* fix lsq bugs

* fix qat, docstring in progress

* TDO: UTs

* fix bugs

* fix lsq

* refactor quantize pipeline

* fix quant

* WIP: debug qat

* fix lsq bugs

* fix qat, docstring in progress

* fixed DefaultQconfigs name

* fix bugs

* add comments and fix typos

* delete useless codes

* fix bugs and add comments

* rename prepare_module_dict

* update lsq config

Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>

* [Feature] Add `prepare_for_mmdeploy` interface  (#365)

* remove useless code

* fix build graph module import bug

* refactor general quant

* rename GeneralQuant to MMArchitectureQuant

* fix some dtype bugs

* add prepare_for_mmdeploy interface

* update prepare for mmdeploy args

* fix some comments

Co-authored-by: humu789 <humu@pjlab.org.cn>

* CodeCamp #132 add MinMaxFloorObserver (#376)

* add minmaxfloor_observer.py

* add MinMaxFloorObserver and normative docstring

* add test for MinMaxFloorObserver

* Quant go (#409)

* add torch observer

* add torch fakequant

* refactor base quantizer

* add QConfigHander and QSchemeHander & finish quantizer_refactor_beta

* passed ptq_pipeline

* tmp-commit

* fix loop and algorithm

* delete fakequant

* refactor code structure

* remove lsq

* valid ptq pipeline

* wip

* fix del functions

* fix

* fix lint and pytest

Co-authored-by: HIT-cwh <2892770585@qq.com>

* [Refactor & Doc] Refactor graph_utils and add docstring and pytest (#420)

* refactor graph_utils and add docstring and pytest

* fix del fakequant

* delete useless codes

* Merge dev-1.x into quantize (#430)

* Fix a bug in make_divisible. (#333)

fix bug in make_divisible

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Fix counter mapping bug (#331)

* fix counter mapping bug

* move judgment into get_counter_type & update UT

* [Docs]Add MMYOLO projects link (#334)

* [Doc] fix typos in en/usr_guides (#299)

* Update README.md

* Update README_zh-CN.md

Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>

* [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320)

* support MethodInputsRecorder and FunctionInputsRecorder

* fix bugs that the model can not be pickled

* WIP: add pytest for ema model

* fix bugs in recorder and delivery when ema_hook is used

* don't register the DummyDataset

* fix pytest

* [Feature] Add deit-base (#332)

* WIP: support deit

* WIP: add deithead

* WIP: fix checkpoint hook

* fix data preprocessor

* fix cfg

* WIP: add readme

* reset single_teacher_distill

* add metafile

* add model to model-index

* fix configs and readme

* [Feature]Feature map visualization (#293)

* WIP: vis

* WIP: add visualization

* WIP: add visualization hook

* WIP: support razor visualizer

* WIP

* WIP: wrap draw_featmap

* support feature map visualization

* add a demo image for visualization

* fix typos

* change eps to 1e-6

* add pytest for visualization

* fix vis hook

* fix arguments' name

* fix img path

* support draw inference results

* add visualization doc

* fix figure url

* move files

Co-authored-by: weihan cao <HIT-cwh>

* [Feature] Add kd examples (#305)

* support kd for mbv2 and shufflenetv2

* WIP: fix ckpt path

* WIP: fix kd r34-r18

* add metafile

* fix metafile

* delete

* [Doc] add documents about pruning. (#313)

* init

* update user guide

* update images

* update

* update How to prune your model

* update how_to_use_config_tool_of_pruning.md

* update doc

* move location

* update

* update

* update

* add mutablechannels.md

* add references

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304)

* add pkd

* add pytest for pkd

* fix cfg

* WIP: support fcos3d

* WIP: support fcos3d pkd

* support mmdet3d

* fix cfgs

* change eps to 1e-6 and add some comments

* fix docstring

* fix cfg

* add assert

* add type hint

* WIP: add readme and metafile

* fix readme

* update metafiles and readme

* fix metafile

* fix pipeline figure

* [Refactor] Refactor Mutables and Mutators (#324)

* refactor mutables

* update load fix subnet

* add DumpChosen Typehint

* adapt UTs

* fix lint

* Add GroupMixin to ChannelMutator (temporarily)

* fix type hints

* add GroupMixin doc-string

* modified by comments

* fix type hits

* update subnet format

* fix channel group bugs and add UTs

* fix doc string

* fix comments

* refactor diff module forward

* fix error in channel mutator doc

* fix comments

Co-authored-by: liukai <liukai@pjlab.org.cn>

* [Fix] Update readme (#341)

* update kl readme

* update dsnas readme

* fix url

* Bump version to 1.0.0rc1 (#338)

update version

* [Feature] Add Autoformer algorithm (#315)

* update candidates

* update subnet_sampler_loop

* update candidate

* add readme

* rename variable

* rename variable

* clean

* update

* add doc string

* Revert "[Improvement] Support for candidate multiple dimensional search constraints."

* [Improvement] Update Candidate with multi-dim search constraints. (#322)

* update doc

* add support type

* clean code

* update candidates

* clean

* xx

* set_resource -> set_score

* fix ci bug

* py36 lint

* fix bug

* fix check constrain

* py36 ci

* redesign candidate

* fix pre-commit

* update cfg

* add build_resource_estimator

* fix ci bug

* remove runner.epoch in testcase

* [Feature] Autoformer architecture and dynamicOPs (#327)

* add DynamicSequential

* dynamiclayernorm

* add dynamic_pathchembed

* add DynamicMultiheadAttention and DynamicRelativePosition2D

* add channel-level dynamicOP

* add autoformer algo

* clean notes

* adapt channel_mutator

* vit fly

* fix import

* mutable init

* remove annotation

* add DynamicInputResizer

* add unittest for mutables

* add OneShotMutableChannelUnit_VIT

* clean code

* reset unit for vit

* remove attr

* add autoformer backbone UT

* add valuemutator UT

* clean code

* add autoformer algo UT

* update classifier UT

* fix test error

* ignore

* make lint

* update

* fix lint

* mutable_attrs

* fix test

* fix error

* remove DynamicInputResizer

* fix test ci

* remove InputResizer

* rename variables

* modify type

* Continued improvements of ChannelUnit

* fix lint

* fix lint

* remove OneShotMutableChannelUnit

* adjust derived type

* combination mixins

* clean code

* fix sample subnet

* search loop fly

* more annotations

* avoid counter warning and modify batch_augment cfg by gy

* restore

* source_value_mutables restriction

* simply arch_setting api

* update

* clean

* fix ut

* [Feature] Add performance predictor (#306)

* add predictor with 4 handlers

* [Improvement] Update Candidate with multi-dim search constraints. (#322)

* update doc

* add support type

* clean code

* update candidates

* clean

* xx

* set_resource -> set_score

* fix ci bug

* py36 lint

* fix bug

* fix check constrain

* py36 ci

* redesign candidate

* fix pre-commit

* update cfg

* add build_resource_estimator

* fix ci bug

* remove runner.epoch in testcase

* update metric_predictor:
1. update MetricPredictor;
2. add predictor config for searching;
3. add predictor in evolution_search_loop.

* add UT for predictor

* add MLPHandler

* patch optional.txt for predictors

* patch test_evolution_search_loop

* refactor apis of predictor and handlers

* fix ut and remove predictor_cfg in predictor

* adapt new mutable & mutator design

* fix ut

* remove unness assert after rebase

* move predictor-build in __init__ & simplify estimator-build

Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>

* [Feature] Add DCFF (#295)

* add ChannelGroup (#250)

* rebase new dev-1.x

* modification for adding config_template

* add docstring to channel_group.py

* add docstring to mutable_channel_group.py

* rm channel_group_cfg from Graph2ChannelGroups

* change choice type of SequentialChannelGroup from float to int

* add a warning about group-wise conv

* restore __init__ of dynamic op

* in_channel_mutable  ->  mutable_in_channel

* rm abstractproperty

* add a comment about VT

* rm registry for ChannelGroup

* MUTABLECHANNELGROUP -> ChannelGroupType

* refine docstring of IndexDict

* update docstring

* update docstring

* is_prunable -> is_mutable

* update docstring

* fix error in pre-commit

* update unittest

* add return type

* unify init_xxx apit

* add unitest about init of MutableChannelGroup

* update according to reviews

* sequential_channel_group -> sequential_mutable_channel_group

Co-authored-by: liukai <liukai@pjlab.org.cn>

* Add BaseChannelMutator and refactor Autoslim (#289)

* add BaseChannelMutator

* add autoslim

* tmp

* make SequentialMutableChannelGroup accpeted both of num and ratio  as choice. and supports divisior

* update OneShotMutableChannelGroup

* pass supernet training of autoslim

* refine autoslim

* fix bug in OneShotMutableChannelGroup

* refactor make_divisible

* fix spell error:  channl -> channel

* init_using_backward_tracer -> init_from_backward_tracer
init_from_fx_tracer -> init_from_fx_tracer

* refine SequentialMutableChannelGroup

* let mutator support models with dynamicop

* support define search space in model

* tracer_cfg -> parse_cfg

* refine

* using -> from

* update docstring

* update docstring

Co-authored-by: liukai <liukai@pjlab.org.cn>

* tmpsave

* migrate ut

* tmpsave2

* add loss collector

* refactor slimmable and add l1-norm (#291)

* refactor slimmable and add l1-norm

* make l1-norm support convnd

* update get_channel_groups

* add  l1-norm_resnet34_8xb32_in1k.py

* add pretrained to resnet34-l1

* remove old channel mutator

* BaseChannelMutator -> ChannelMutator

* update according to reviews

* add readme to l1-norm

* MBV2_slimmable -> MBV2_slimmable_config

Co-authored-by: liukai <liukai@pjlab.org.cn>

* update config

* fix md & pytorch support <1.9.0 in batchnorm init

* Clean old codes. (#296)

* remove old dynamic ops

* move dynamic ops

* clean old mutable_channels

* rm OneShotMutableChannel

* rm MutableChannel

* refine

* refine

* use SquentialMutableChannel to replace OneshotMutableChannel

* refactor dynamicops folder

* let SquentialMutableChannel support float

Co-authored-by: liukai <liukai@pjlab.org.cn>

* fix ci

* ci fix py3.6.x & add mmpose

* ci fix py3.6.9 in utils/index_dict.py

* fix mmpose

* minimum_version_cpu=3.7

* fix ci 3.7.13

* fix pruning &meta ci

* support python3.6.9

* fix py3.6 import caused by circular import patch in py3.7

* fix py3.6.9

* Add channel-flow (#301)

* base_channel_mutator -> channel_mutator

* init

* update docstring

* allow omitting redundant configs for channel

* add register_mutable_channel_to_a_module to MutableChannelContainer

* update according to reviews 1

* update according to reviews 2

* update according to reviews 3

* remove old docstring

* fix error

* using->from

* update according to reviews

* support self-define input channel number

* update docstring

* chanenl -> channel_elem

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>

* support >=3.7

* support py3.6.9

* Rename: ChannelGroup -> ChannelUnit (#302)

* refine repr of MutableChannelGroup

* rename folder name

* ChannelGroup -> ChannelUnit

* filename in units folder

* channel_group -> channel_unit

* groups -> units

* group -> unit

* update

* get_mutable_channel_groups -> get_mutable_channel_units

* fix bug

* refine docstring

* fix ci

* fix bug in tracer

Co-authored-by: liukai <liukai@pjlab.org.cn>

* update new channel config format

* update pruning refactor

* update merged pruning

* update commit

* fix dynamic_conv_mixin

* update comments: readme&dynamic_conv_mixins.py

* update readme

* move kl softmax channel pooling to op by comments

* fix comments: fix redundant & split README.md

* dcff in ItePruneAlgorithm

* partial dynamic params for fuseconv

* add step_freq & prune_time check

* update comments

* update comments

* update comments

* fix ut

* fix gpu ut & revise step_freq in ItePruneAlgorithm

* update readme

* revise ItePruneAlgorithm

* fix docs

* fix dynamic_conv attr

* fix ci

Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: jacky <jacky@xx.com>

* [Fix] Fix optional requirements (#357)

* fix optional requirements

* fix dcff ut

* fix import with get_placeholder

* supplement the previous commit

* [Fix] Fix configs of wrn models and ofd. (#361)

* 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10

* 1.Add README for vanilla-wrm.

* 1.Revise readme of wrn

Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>

* [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (#356)

fix bug on mmrazor visualization, mismatch argument in define and use.

Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>

* fix bug in benchmark_test (#364)

fix bug in configs

Co-authored-by: Your Name <you@example.com>

* [FIX] Fix wrn configs (#368)

* fix wrn configs

* fix wrn configs

* update online wrn model weight

* [Fix] fix bug on pkd config. Wrong import filename. (#373)

* [CI] Update ci to torch1.13 (#380)

update ci to torch1.13

* [Feature] Add BigNAS algorithm (#219)

* add calibrate-bn-statistics

* add test calibrate-bn-statistics

* fix mixins

* fix mixins

* fix mixin tests

* remove slimmable channel mutable and refactor dynamic op

* refact dynamic batch norm

* add progressive dynamic conv2d

* add center crop dynamic conv2d

* refactor dynamic directory

* refactor dynamic sequential

* rename length to depth in dynamic sequential

* add test for derived mutable

* refactor dynamic op

* refactor api of dynamic op

* add derive mutable mixin

* addbignas algorithm

* refactor bignas structure

* add input resizer

* add input resizer to bignas

* move input resizer from algorithm into classifier

* remove compnents

* add attentive mobilenet

* delete json file

* nearly(less 0.2) align inference accuracy with gml

* move mutate seperated in bignas mobilenet backbone

* add zero_init_residual

* add set_dropout

* set dropout in bignas algorithm

* fix registry

* add subnet yaml and nearly align inference accuracy with gml

* add rsb config for bignas

* remove base in config

* add gml bignas config

* convert to iter based

* bignas forward and backward fly

* fix merge conflict

* fix dynamicseq bug

* fix bug and refactor bignas

* arrange configs of bignas

* fix typo

* refactor attentive_mobilenet

* fix channel mismatch due to registion of DerivedMutable

* update bignas & fix se channel mismatch

* add AutoAugmentV2 & remove unness configs

* fix lint

* recover channel assertion in channel unit

* fix a group bug

* fix comments

* add docstring

* add norm in dynamic_embed

* fix search loop & other minor changes

* fix se expansion

* minor change

* add ut for bignas & attentive_mobilenet

* fix ut

* update bignas readme

* rm unness ut & supplement get_placeholder

* fix lint

* fix ut

* add subnet deployment in downstream tasks.

* minor change

* update ofa backbone

* minor fix

* Continued improvements of searchable backbone

* minor change

* drop ratio in backbone

* fix comments

* fix ci test

* fix test

* add dynamic shortcut UT

* modify strategy to fit bignas

* fix test

* fix bug in neck

* fix error

* fix error

* fix yaml

* save subnet ckpt

* merge autoslim_val/test_loop into subnet_val_loop

* move calibrate_bn_mixin to utils

* fix bugs and add docstring

* clean code

* fix register bug

* clean code

* update

Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>

* [Bug] Fix ckpt (#372)

fix ckpt

* [Feature] Add tools to convert distill ckpt to student-only ckpt. (#381)

* [Feature] Add tools to convert distill ckpt to student-only ckpt.

* fix bug.

* add --model-only to only save model.

* Make changes accroding to PR review.

* Enhance the Abilities of the Tracer for Pruning. (#371)

* tmp

* add new mmdet models

* add docstring

* pass test and pre-commit

* rm razor tracer

* update fx tracer, now it can automatically wrap methods and functions.

* update tracer passed models

* add warning for torch <1.12.0

fix bug for python3.6

update placeholder to support placeholder.XXX

* fix bug

* update docs

* fix lint

* fix parse_cfg in configs

* restore mutablechannel

* test ite prune algorithm when using dist

* add get_model_from_path to MMModelLibrrary

* add mm models to DefaultModelLibrary

* add uts

* fix bug

* fix bug

* add uts

* add uts

* add uts

* add uts

* fix bug

* restore ite_prune_algorithm

* update doc

* PruneTracer -> ChannelAnalyzer

* prune_tracer -> channel_analyzer

* add test for fxtracer

* fix bug

* fix bug

* PruneTracer -> ChannelAnalyzer

refine

* CustomFxTracer -> MMFxTracer

* fix bug when test with torch<1.12

* update print log

* fix lint

* rm unuseful code

Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: liukai <your_email@abc.example>

* fix bug in placer holder (#395)

* fix bug in placer holder

* remove redundent comment

Co-authored-by: liukai <your_email@abc.example>

* Add get_prune_config and a demo config_pruning (#389)

* update tools and test

* add demo

* disable test doc

* add switch for test tools and test_doc

* fix bug

* update doc

* update tools name

* mv get_channel_units

Co-authored-by: liukai <your_email@abc.example>

* [Improvement] Adapt OFA series with SearchableMobileNetV3 (#385)

* fix mutable bug in AttentiveMobileNetV3

* remove unness code

* update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names

* unify the sampling usage in sandwich_rule-based NAS

* use alias to export subnet

* update OFA configs

* fix attr bug

* fix comments

* update convert_supernet2subnet.py

* correct the way to dump DerivedMutable

* fix convert index bug

* update OFA configs & models

* fix dynamic2static

* generalize convert_ofa_ckpt.py

* update input_resizer

* update README.md

* fix ut

* update export_fix_subnet

* update _dynamic_to_static

* update fix_subnet UT & minor fix bugs

* fix ut

* add new autoaug compared to attentivenas

* clean

* fix act

* fix act_cfg

* update fix_subnet

* fix lint

* add docstring

Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>

* [Fix]Dcff Deploy Revision (#383)

* dcff deploy revision

* tempsave

* update fix_subnet

* update mutator load

* export/load_fix_subnet revision for mutator

* update fix_subnet with dev-1.x

* update comments

* update docs

* update registry

* [Fix] Fix commands in README to adapt branch 1.x (#400)

* update commands in README for 1.x

* fix commands

Co-authored-by: gaoyang07 <1546308416@qq.com>

* Set requires_grad to False if the teacher is not trainable (#398)

* add choice and mask of units to checkpoint (#397)

* add choice and mask of units to checkpoint

* update

* fix bug

* remove device operation

* fix bug

* fix circle ci error

* fix error in numpy for circle ci

* fix bug in requirements

* restore

* add a note

* a new solution

* save mutable_channel.mask as float for dist training

* refine

* mv meta file test

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: jacky <jacky@xx.com>

* [Bug]Fix fpn teacher distill (#388)

fix fpn distill

* [CodeCamp #122] Support KD algorithm MGD for detection. (#377)

* [Feature] Support KD algorithm MGD for detection.

* use connector to beauty mgd.

* fix typo, add unitest.

* fix mgd loss unitest.

* fix mgd connector unitest.

* add model pth and log file.

* add mAP.

* update l1 config (#405)

* add l1 config

* update l1 config

Co-authored-by: jacky <jacky@xx.com>

* [Feature] Add greedy search for AutoSlim (#336)

* WIP: add greedysearch

* fix greedy search and add bn_training_mode to autoslim

* fix cfg files

* fix autoslim configs

* fix bugs when converting dynamic bn to static bn

* change to test loop

* refactor greedy search

* rebase and fix greedysearch

* fix lint

* fix and delete useless codes

* fix pytest

* fix pytest and add bn_training_mode

* fix lint

* add reference to AutoSlimGreedySearchLoop's docstring

* sort candidate_choices

* fix save subnet

* delete useless codes in channel container

* change files' name: convert greedy_search_loop to autoslim_greedy_search_loop

* [Fix] Fix metafile (#422)

* fix ckpt path in metafile and readme

* fix darts file path

* fix docstring in ConfigurableDistiller

* fix darts

* fix error

* add darts of mmrazor version

* delete py36

Co-authored-by: liukai <your_email@abc.example>

* update bignas cfg (#412)

* check attentivenas training

* update ckpt link

* update supernet log

Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>

* Bump version to 1.0.0rc2 (#423)

bump version to 1.0.0rc2

Co-authored-by: liukai <your_email@abc.example>

* fix lint

* fix ci

* add tmp docstring for passed ci

* add tmp docstring for passed ci

* fix ci

* add get_placeholder for quant

* add skip for unittest

* fix package placeholder bug

* add version judgement in __init__

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

* update prev commit

Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>

* [Docs] Add docstring and unittest about backendconfig & observer & fakequant (#428)

* add ut about backendconfig

* add ut about observers and fakequants in torch

* fix torch1.13 ci

* [Docs] Add docstring for `MMArchitectureQuant` & `NativeQuantizer` (#425)

* add docstring on mm_architecture& native_quantizer

* add naive openvino r18 qat config & dist_ptq.sh

* Added a more accurate description

* unitest&doc

* checkpoint url

* unitest

* passed_pre_commit

* unitest on native_quantizer& fix bugs

* remove dist_ptq

* add get_placeholder&skipTest

* complete arg descriptions

* fix import bugs

* fix pre-commit

* add get_placeholder

* add typehint and doctring

* update docstring&typehint

* update docstring

* pre-commit

* fix some problems

* fix bug

* [Docs] Add docstring and unitest about custom tracer (#427)

* rename QConfigHandler and QSchemeHandler

* add docstring about custom tracer

* add ut about custom tracer

* fix torch1.13 ci

* fix lint

* fix ci

* fix ci

* [Docs & Refactor] Add docstring and UT of other quantizers (#439)

* add quantizer docstring and refactor the interface of AcademicQuantizer

* add AcademicQuantizer unittest

* add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface

* adapt torch113 ci

* fix import

* fix lint

* update some docstring

* fix ci

* [Feature&Doc]Modify ptq pipeline and support lsq (#435)

* modify ptq pipeline and support lsq

* use placeholder

* fix lsq && quantloop

* add lsq pytest

* add quant loop pytest

* test lsq observer

* fix bug under pt13

* fix reset_min_max_vals

* fix bugs under pt13

* fix configs

* add get_qconfig_mapping

* delete is_qat, add doc and fix pytest

* delete useless codes in custom_tracer

* skip pytest under pt13

* add todo: check freezebn

* fix pytest bugs

* fix pytest

* fix pytest

* fix pytest

* [Docs] Add customize_quantization_tutorial (#440)

* [Docs] Add quantization user guide (#441)

* add quantization user guide

* fix layout

* fix layout

* update README

* [Bug] Fix del redundant fakequant (#447)

fix del redundant fakequant

* [Feature] Add onnx exporters (#475)

* fix del redundant fakequant

* add onnx exporters

* fix onnx exporters and add docstring

* fix comments

* delete useless codes

* fix export_onnx in native quantizer

---------

Co-authored-by: pppppM <gjf_mail@126.com>

* [Feature]Rewrite the origin model during prepare (#488)

* add rewriter

* add deploy_cfg arg

* modify post_process_for_mmdeploy

* fix bugs

* add det config

* [Feature] Using rewriter in mmrazor when building qmodels. (#490)

* add rewriter

* add deploy_cfg arg

* modify post_process_for_mmdeploy

* fix bugs

* add det config

* replace deepcopy

* pop detectors' forward

* [Feature] Quantization global optimization (#491)

* add trtquantizer

* unify all fakequant before deploy

* move to aide

* add yolox config

* pre-rebase

* add unittest

* add a arg of post_process_for_deploy

* test trt yolox deploy

* opt quantizer interface

* fix rebase

* add trt r50 config

* update trt setting

* del redundant code

* fix lint

* fix ut of quantizers

* del redundant file

* fix lint

* fix some comments

* Fix code syntax in UT (#470)

Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>

* passed lint and pytest

* try to fix ci

* [Bug] Try to fix CI (#502)

fix lint

* [Feature] Support lsq (#501)

* support deploy_cfg=None

* replace fakequant before load ckpt

* add _load_from_state_dict to lsq fakequant

* fix pre-commit

* test lsq load state dict

* change github ci: ubuntu 18.04 to ubuntu 20.04

* get_deploy_model order change back

* sync before save ckpt

* delete strict=False

* test context rewriter

* fix pre commit config

* try to fix ci

* [Bug] Try to fix CI (#502)

fix lint

---------

Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>

* [Feature] Add exporter pytest (#504)

* add exporter pytest

* fix bugs

* delete useless codes

* handle onnx

* delete useless codes

* [Bug] Fix ci converage setting (#508)

fix ci converage

* [Bug] Fix codecov (#509)

* remove codecov in requirements

* try to fix ci

* del adaround loss

* [BUG] Fix quantization loop (#507)

* fix quantization loop

* fix quant loop

* fix quant loop

* fix qat configs

* [Bug] Fix ci converage setting (#508)

fix ci converage

* [Bug] Fix codecov (#509)

* remove codecov in requirements

* try to fix ci

* del adaround loss

* add freeze_bn_begin to lsq

* delete useless codes

---------

Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>

* add test ptq

* opt ptq pipeline

* refactor quant configs

* update config path

* add summary analyse tool

* fix benchmark_test:detnas_frcnn_shufflenet_subnet_coco_1x.py

* update quantization README.md

* update quantization metafile, readme, config path

* update quantization docs

* update git main link in workflow

* update benchmark_summary_analyse.py

* del dmcp results

* [Bug] fix a rebase error (#514)

fix a rebase error

* [Bug] Fix CI (#515)

* fix ci

* mmcv2.0 need torch1.8+

* Update CI config and Passed (#516)

* test ci

* update test.yml based on mmcv2.0.0

* [Docs] Fix cwd test accuary (#517)

* test ci

* update test.yml based on mmcv2.0.0

* update cwd_logits_pspnet result

---------

Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com>
Co-authored-by: HIT-cwh <2892770585@qq.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com>
Co-authored-by: wm901115nwpu <wmnwpu@gmail.com>
Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>

* [Docs&Feature] Prepare for checkouting default branch and releasing new version (#518)

* prepare for checkout default branch

* update README.md and model zoo

* update installation.md and update dev-1.x links

* update README_zh-CN

* add changelog

* update ci config

* update some links in quantization readme

* update quantization user guide

* update calibrate_dataloader

* add interface pop_rewriter_function_record

* Bump version to 1.0.0 (#521)

* update release time

* bump version to 1.0.0

* [CI] Fix merge stage test (#523)

fix merge_stage_test in ci

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com>
Co-authored-by: HIT-cwh <2892770585@qq.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com>
Co-authored-by: wm901115nwpu <wmnwpu@gmail.com>
Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>

* move folders and update readme (#528)

* move folders

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>

* [Bug] Fix torch2 error (#536)

fix torch2 error

* [Feature] Add GPTQ and uniform interfaces (#538)

* add gptq implementation

* pre-checkout

* passed resnet example

* passed llama example

* aglin gptq acc

* add activation quantization

* uniform interfaces

* add gptq readme

* update mmrazor_large redame

* add gptq opt example

* fix sparse_gpt example for opt

* fix import Protocol from py37

* fix error function name

* fix bug in test

* fix bug

* fix bug

* limit sparsegpt test with torch>=1.12

* add docstring for gptq and sparse_gpt

* pre-commit

* align acc & add save load ckpt & add ut

* fix ut

* fix ut

* fix ut

* fix ut & add torch2.0 for ci

* del torch2.0 for ci

* fix ut

---------

Co-authored-by: FIRST_NAME LAST_NAME <MY_NAME@example.com>

---------

Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com>
Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com>
Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com>
Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
Co-authored-by: liukai <liukai@pjlab.org.cn>
Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com>
Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com>
Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com>
Co-authored-by: jacky <jacky@xx.com>
Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com>
Co-authored-by: FreakieHuang <frank0huang@foxmail.com>
Co-authored-by: pppppM <gjf_mail@126.com>
Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com>
Co-authored-by: HIT-cwh <2892770585@qq.com>
Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn>
Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com>
Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com>
Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com>
Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn>
Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com>
Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com>
Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com>
Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com>
Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com>
Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn>
Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com>
Co-authored-by: wm901115nwpu <wmnwpu@gmail.com>
Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>
Co-authored-by: FIRST_NAME LAST_NAME <MY_NAME@example.com>
pull/547/head
humu789 2023-05-25 16:50:09 +08:00 committed by GitHub
parent d3cd028f4b
commit 454f39781d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 4404 additions and 6 deletions

View File

@ -69,4 +69,5 @@ repos:
| ^docs
| ^configs
| ^.*/configs*
| ^projects
)

View File

@ -61,6 +61,8 @@ English | [简体中文](README_zh-CN.md)
</div>
**:star: MMRazor for Large Models** is Available Now! Please refer to [MMRazorLarge](projects/mmrazor_large/README.md)
## Introduction
MMRazor is a model compression toolkit for model slimming and AutoML, which includes 4 mainstream technologies:

View File

@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import group_fisher
from . import group_fisher, sparse_gpt
__all__ = ['group_fisher']
__all__ = ['group_fisher', 'sparse_gpt']

View File

@ -0,0 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compressor import SparseGptCompressor
from .ops import SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops
__all__ = [
'SparseGptLinear', 'SparseGptMixIn', 'replace_with_dynamic_ops',
'SparseGptCompressor'
]

View File

@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmrazor.utils import print_log
from .ops import SparseGptConv2d, SparseGptLinear, SparseGptMixIn
from .utils import replace_with_dynamic_ops
def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
load_fix_subnet(model, fix_subnet)
return model
class SparseGptCompressor():
"""The compressor with SparseGPT."""
def __init__(self) -> None:
self.model: nn.Module = None
def prepare(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
"""Prepare for compressing model."""
self.model = model
prune_modules: dict = {}
if prune_conv:
prune_modules[nn.Conv2d] = SparseGptConv2d
if prune_linear:
prune_modules[nn.Linear] = SparseGptLinear
replace_with_dynamic_ops(model, prune_modules)
@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)
# hessian
def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.register_hessian_hook()
def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.remove_hessian_hook()
def init_hessian(self, device=None):
"""Init hessian."""
for op in self.sparse_ops:
op.init_hessian(device=device)
# prune
def prune(self,
sparsity,
prunen=0,
prunem=0,
blocksize=128,
percdamp=.01,
device=torch.device('cuda')):
"""Apply the compression algorithm to the model."""
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
module: SparseGptMixIn = module.to(device)
error = module.prune(
sparsity=sparsity,
prunen=prunen,
prunem=prunem,
blocksize=blocksize,
percdamp=percdamp,
)
print_log(f'prune {name} success \t error = {error}')
module.to(original_device)
torch.cuda.empty_cache()
except Exception as e:
print_log(f'prune {name} failed as {e}')
def prune_24(self, device=torch.device('cuda:0')):
"""Apply the compression algorithm to the model with the specified
setting."""
self.prune(0.5, prunen=2, prunem=4, device=device)
# ops
@property
def sparse_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, SparseGptMixIn):
yield module
@property
def named_sparse_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, SparseGptMixIn):
yield name, module

View File

@ -0,0 +1,278 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
DynamicLinear)
from .utils import ModuleProtocol, torch_setting
class SparseGptMixIn(ModuleProtocol):
"""The core algorithm implementation for SparseGpt."""
def _sparse_gpt_mix_in_init(self):
"""Init mixin."""
self.sparse_gpt_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
self._hessian: torch.Tensor = None
self.hessian_batch = 0
# weight and input adaptive
@property
def weight_matrix(self):
"""Return weight with shape (out in)"""
return self.weight.flatten(1) # out in
@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
"""Set weight."""
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
self.weight.data.copy_(value)
def format_input(self, input: torch.Tensor):
"""Return input with shape (B N C)"""
if len(input.shape) == 2: # N C
input = input.unsqueeze(0) # 1 N C
return input
# compute hessian
@property
def hessian(self):
"""hessian always return float."""
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.'
hessian = self._hessian.to(self.weight_matrix.device)
else:
hessian = torch.zeros(
self.columns,
self.columns,
device=self.weight_matrix.device)
dist.broadcast(hessian, 0)
return hessian
else:
return self._hessian
@hessian.setter
def hessian(self, value: torch.Tensor):
"""Set hessian."""
with torch.no_grad():
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.' # noqa
self._hessian.data.copy_(
value.data.to(self._hessian.device))
else:
self._hessian = None
else:
self._hessian.data.copy_(value.data.to(self._hessian.device))
@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
"""Update hessian."""
input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)
assert len(input.shape) == 3
B = input.shape[0] # B N C
input = input.transpose(0, -1).flatten(1) # C D
H = input @ input.T * 2 # C C
if dist.is_initialized():
dist.all_reduce(H)
B *= dist.get_world_size()
H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B
def register_hessian_hook(self):
"""Register updating hessian hook."""
@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
assert len(input) == 1
self.update_hessian(input[0])
handle = self.register_forward_pre_hook(forward_pre_hook)
self.sparse_gpt_handles.append(handle)
def remove_hessian_hook(self):
"""Remove updating hessian hook."""
for h in self.sparse_gpt_handles:
h.remove()
def init_hessian(self, device=None):
"""Init hessian."""
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
else:
self._hessian = None
else:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
# prune
@torch.no_grad()
def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
"""The implementation for SparseGPT."""
with torch_setting(dtype=torch.float):
# Converted from https://github.com/ist-daslab/sparsegpt
assert self.hessian is not None
W: torch.Tensor = self.weight_matrix.float() # out in
H = self.hessian.float().to(W.device)
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
Losses = torch.zeros(self.rows, device=W.device)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=W.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
mask = None
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
if prunen == 0:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() *
sparsity)]
mask1 = tmp <= thresh
else:
mask1 = torch.zeros_like(W1) == 1
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if prunen != 0 and i % prunem == 0:
tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:(
i + prunem)].reshape((1, -1)))**2
mask1.scatter_(
1, i +
torch.topk(tmp, prunen, dim=1, largest=False)[1],
True)
q = w.clone()
q[mask1[:, i]] = 0
Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2
err1 = (w - q) / d
W1[:,
i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
i:].unsqueeze(0))
Err1[:, i] = err1
W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if W.device.type == 'cuda':
torch.cuda.synchronize()
from .sparse24_utils import is_weight_sparse_24
if prunen == 2 and prunem == 4:
assert is_weight_sparse_24(
W, -1), f'Weight dose not satisfy 24 with shape {W.shape}'
error = torch.sum(Losses)
if torch.isnan(error).any():
raise Exception('get nan error')
else:
self.weight_matrix = W.data
return error.item()
# SparseGpt Ops for Linear and Conv2d
class SparseGptLinear(DynamicLinear, SparseGptMixIn):
"""Custom Linear for SparseGpt."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()
@classmethod
def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
if module.out_features < module.in_features:
return module
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
dtype = next(module.parameters()).dtype
new_module = new_module.to(dtype)
return new_module
class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
"""Custom Conv2d for SparseGpt."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()
@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
dtype = next(module.parameters()).dtype
new_module = new_module.to(dtype)
return new_module
def format_input(self, input: torch.Tensor):
"""Format input shape."""
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,
stride=self.stride) # B C D
return input.transpose(-1, -2)

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
@torch.no_grad()
def is_weight_sparse_24(weight: torch.Tensor, dim=-1):
""""Check if the weight is sparse 24."""
weight = weight.transpose(-1, dim).reshape(-1, 4) # N 4
is_zero = (weight == 0).sum(-1) # N
return (is_zero >= 2).all()

View File

@ -0,0 +1,140 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from typing import Dict, Type
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
import torch
import torch.nn as nn
from mmrazor.models.architectures.dynamic_ops import DynamicMixin
from mmrazor.utils import print_log
class ModuleProtocol(Protocol):
"""Custom module protocol for algorithm mixin."""
weight: torch.Tensor
def forward(self, x):
"""The abstract method."""
pass
def register_forward_hook(self, hook):
"""The abstract method."""
pass
def register_backward_hook(self, hook):
"""The abstract method."""
pass
def register_forward_pre_hook(self, hook):
"""The abstract method."""
pass
def register_buffer(self, name, tensor):
"""The abstract method."""
pass
def replace_with_dynamic_ops(model: nn.Module,
dynamicop_map: Dict[Type[nn.Module],
Type[DynamicMixin]]):
"""Replace torch modules with dynamic-ops."""
def replace_op(model: nn.Module, name: str, module: nn.Module):
names = name.split('.')
for sub_name in names[:-1]:
model = getattr(model, sub_name)
setattr(model, names[-1], module)
for name, module in model.named_modules():
if type(module) in dynamicop_map:
new_module = dynamicop_map[type(module)].convert_from(module)
replace_op(model, name, new_module)
def register_efficient_forward_hook(module: nn.Module,
device=torch.device('cuda:0')):
"""Register efficient forward hook."""
def forward_pre_hook(module: nn.Module, input):
module.to(device)
def forward_hook(module: nn.Module, input, output):
module.to('cpu')
torch.cuda.empty_cache()
h1 = module.register_forward_pre_hook(forward_pre_hook)
h2 = module.register_forward_hook(forward_hook)
return [h1, h2]
def enable_efficient_forward(model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]):
"""Enable efficient forward."""
handles = []
blocks = []
for name, module in model.named_children():
if type(module) in wrap_modules or len(module._parameters) != 0 or len(
module._buffers) != 0:
handles_ = register_efficient_forward_hook(module, device)
blocks_ = [name]
else:
handles_, blocks_ = enable_efficient_forward(
module, device, wrap_modules)
handles += handles_
blocks += blocks_
return handles, blocks
class memory_efficient_forward:
"""The class for Memory efficient forward."""
def __init__(self,
model: nn.Module,
enabled=True,
device=torch.device('cuda:0'),
wrap_modules=[]) -> None:
self.model = model
self.device = device
self.wrap_modules = wrap_modules
self.enabled = enabled
self.handlers: list = []
if not enabled:
model.to(device)
def __enter__(self, ):
"""Enter."""
if self.enabled:
handles, blocks = enable_efficient_forward(self.model, self.device,
self.wrap_modules)
print_log(f'enable memory efficient forward for {blocks}')
self.handlers = handles
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
for h in self.handlers:
h.remove()
class torch_setting():
"""Set the default torch dtype setting."""
def __init__(self, dtype=None) -> None:
self.original_dtype = torch.get_default_dtype()
self.dtype = dtype
def __enter__(self):
"""Enter."""
if self.dtype is not None:
torch.set_default_dtype(self.dtype)
def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
torch.set_default_dtype(self.original_dtype)

View File

@ -0,0 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .compressor import GPTQCompressor
from .gptq import GPTQMixIn
from .ops import GPTQConv2d, GPTQLinear, TritonGPTQLinear
from .quantizer import Quantizer
__all__ = [
'GPTQCompressor',
'GPTQMixIn',
'GPTQConv2d',
'GPTQLinear',
'TritonGPTQLinear',
'Quantizer',
]

View File

@ -0,0 +1,146 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Type
import torch
import torch.nn as nn
from mmrazor.utils import print_log
from .ops import GPTQConv2d, GPTQLinear, GPTQMixIn, TritonGPTQLinear
from .quantizer import Quantizer
def replace_with_dynamic_ops(model: nn.Module,
dynamicop_map: Dict[Type[nn.Module], Type[Any]],
skipped_layers=[],
a_qconfig=None,
**kwargs):
"""Replace torch modules with dynamic-ops."""
def replace_op(model: nn.Module, name: str, module: nn.Module):
names = name.split('.')
for sub_name in names[:-1]:
model = getattr(model, sub_name)
setattr(model, names[-1], module)
for name, module in model.named_modules():
if type(module) in dynamicop_map and name not in skipped_layers:
if isinstance(module, nn.Linear):
if a_qconfig:
a_fakequant = Quantizer()
a_fakequant.configure(**a_qconfig)
kwargs.update({'a_fakequant': a_fakequant})
new_module = dynamicop_map[type(module)].convert_from(
module, **kwargs)
else:
new_module = dynamicop_map[type(module)].convert_from(module)
replace_op(model, name, new_module)
def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
load_fix_subnet(model, fix_subnet)
return model
class GPTQCompressor():
"""The compressor with GPTQ."""
def __init__(self) -> None:
self.model: nn.Module = None
def prepare(self,
model: nn.Module,
quant_conv=True,
quant_linear=True,
use_triton_ops=True,
skipped_layers=[],
a_qconfig=None,
**kwargs) -> None:
"""Prepare for compressing model."""
self.model = model
quant_modules: dict = {}
if quant_conv:
quant_modules[nn.Conv2d] = GPTQConv2d
if quant_linear:
gptq_linear = TritonGPTQLinear if use_triton_ops else GPTQLinear
quant_modules[nn.Linear] = gptq_linear
replace_with_dynamic_ops(model, quant_modules, skipped_layers,
a_qconfig, **kwargs)
@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)
# hessian
def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.quant_ops:
module.register_hessian_hook()
def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.quant_ops:
module.remove_hessian_hook()
def init_hessian(self, device=None):
"""Init hessian."""
for op in self.quant_ops:
op.init_hessian(device=device)
# quant
def quant(self,
blocksize=128,
percdamp=0.01,
groupsize=-1,
actorder=False,
device=torch.device('cuda:0'),
**qconfig):
"""Apply the compression algorithm to the model."""
for name, module in self.named_quant_ops:
try:
original_device = next(module.parameters()).device
module: GPTQMixIn = module.to(device)
quantizer = Quantizer()
quantizer.configure(**qconfig)
# print_log(f'quant {name}...')
error = module.quant(
quantizer=quantizer,
blocksize=blocksize,
percdamp=percdamp,
groupsize=groupsize,
actorder=actorder)
print_log(f'quant {name} success \t error = {error}')
module.to(original_device)
module.free()
except Exception as e:
print_log(f'quant {name} failed as {e}')
def quant_with_default_qconfig(self, groupsize=128, device='cpu'):
"""Apply the compression algorithm to the model with the specified
setting."""
qconfig = dict(bits=4, perchannel=True, sym=False)
self.quant(
groupsize=groupsize, actorder=True, device=device, **qconfig)
# ops
@property
def quant_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, GPTQMixIn):
yield module
@property
def named_quant_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, GPTQMixIn):
yield name, module

View File

@ -0,0 +1,254 @@
# Copyright (c) OpenMMLab. All rights reserved.
# https://github.com/fpgaminer/GPTQ-triton
"""Mostly the same as the autotuner in Triton, but with a few changes like
using 40 runs instead of 100."""
import builtins
import math
import time
from typing import Dict
try:
import triton
except ImportError:
from mmrazor.utils import get_package_placeholder
triton = get_package_placeholder('triton >= 2.0.0')
class Autotuner(triton.KernelInterface):
"""Autotuner."""
def __init__(self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False):
'''prune_configs_by: a dict of functions that are used to prune
configs, fields:
'perf_model': performance model used to predicate running time
with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune
num_stages. It take configs:List[Config] as its input, and
returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments
to the nearest power of two when caching tuning results.'''
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache: Dict = {}
# hook to reset all required tensor to zeros before relaunching
# a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = prune_configs_by[
'perf_model'], prune_configs_by['top_k']
if 'early_config_prune' in prune_configs_by:
early_config_prune = prune_configs_by['early_config_prune']
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
"""Check for conflicts, i.e. meta-parameters both provided as kwargs
and by the autotuner."""
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols.")
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current)
try:
# In testings using only 40 reps seems to be close enough and it
# appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any
# speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
except triton.compiler.OutOfResources:
return (float('inf'), float('inf'), float('inf'))
def run(self, *args, **kwargs):
"""Run."""
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to
# the nearest power of two
# In my testing this gives decent results, and greatly reduces
# the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2**int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs)
def prune_configs(self, kwargs):
"""Prune configs."""
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps)
for config in pruned_configs
}
pruned_configs = sorted(
est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
"""Warm up."""
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(configs,
key,
prune_configs_by=None,
reset_to_zero=None,
nearest_power_of_two=False):
"""Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated
# anytime the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run
multiple time.This means that whatever value the kernel updates will
be updated multiple times.To avoid this undesired behavior, you can
use the `reset_to_zero` argument, which reset the value of the
provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger
the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune
configs, fields:
'perf_model': performance model used to predicate running time with
different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune
(eg, num_stages). It take configs:List[Config] as its input, and
returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset
to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero,
prune_configs_by, nearest_power_of_two)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""The main purpose of this function is to shrink BLOCK_SIZE_* when the
corresponding dimension is smaller."""
m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
group_size_m = config.kwargs['GROUP_SIZE_M']
if (block_size_m, block_size_n, block_size_k, group_size_m,
config.num_stages, config.num_warps) in used:
continue
used.add((block_size_m, block_size_n, block_size_k, group_size_m,
config.num_stages, config.num_warps))
yield triton.Config(
{
'BLOCK_SIZE_M': block_size_m,
'BLOCK_SIZE_N': block_size_n,
'BLOCK_SIZE_K': block_size_k,
'GROUP_SIZE_M': group_size_m
},
num_stages=config.num_stages,
num_warps=config.num_warps)

View File

@ -0,0 +1,318 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
import numpy as np
import torch
import torch.distributed as dist
from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting
class ModuleProtocol(Protocol):
"""Custom module protocol for algorithm mixin."""
weight: torch.Tensor
def forward(self, x):
"""The abstract method."""
pass
def register_forward_hook(self, hook):
"""The abstract method."""
pass
def register_backward_hook(self, hook):
"""The abstract method."""
pass
def register_forward_pre_hook(self, hook):
"""The abstract method."""
pass
def register_buffer(self, name, tensor):
"""The abstract method."""
pass
class GPTQMixIn(ModuleProtocol):
"""The core algorithm implementation for GPTQ."""
def _gptq_mix_in_init(self):
"""Init mixin."""
self.gptq_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
self._hessian: torch.Tensor = None
self.hessian_batch = 0
# weight and input adaptive
@property
def weight_matrix(self):
"""Return weight with shape (out in)"""
return self.weight.flatten(1) # out in
@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
"""Set weight."""
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
self.weight.data.copy_(value)
def format_input(self, input: torch.Tensor):
"""Return input with shape (B N C)"""
if len(input.shape) == 2: # N C
input = input.unsqueeze(0) # 1 N C
return input
# compute hessian
@property
def hessian(self):
"""hessian always return float."""
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.'
hessian = self._hessian.to(self.weight_matrix.device)
else:
hessian = torch.zeros(
self.columns,
self.columns,
device=self.weight_matrix.device)
dist.broadcast(hessian, 0)
return hessian
else:
return self._hessian
@hessian.setter
def hessian(self, value: torch.Tensor):
"""Set hessian."""
with torch.no_grad():
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.' # noqa
self._hessian.data.copy_(
value.data.to(self._hessian.device))
else:
self._hessian = None
else:
self._hessian.data.copy_(value.data.to(self._hessian.device))
@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
"""Update hessian."""
input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)
assert len(input.shape) == 3
B = input.shape[0] # B N C
input = input.transpose(0, -1).flatten(1) # C D
H = input @ input.T * 2 # C C
if dist.is_initialized():
dist.all_reduce(H)
B *= dist.get_world_size()
H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B
def register_hessian_hook(self):
"""Register updating hessian hook."""
@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
assert len(input) == 1
self.update_hessian(input[0])
handle = self.register_forward_pre_hook(forward_pre_hook)
self.gptq_handles.append(handle)
def remove_hessian_hook(self):
"""Remove updating hessian hook."""
for h in self.gptq_handles:
h.remove()
def init_hessian(self, device=None):
"""Init hessian."""
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
else:
self._hessian = None
else:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
def pack(self, scales, zeros, g_idx=None):
"""Pack and update qparams with groupsize_idx."""
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if self.bias is not None:
self.bias.half()
intweight = []
for idx in range(self.in_features):
intweight.append(
torch.round(
(self.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) /
self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.cpu().numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]),
dtype=np.uint32)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError('Only 2,4,8 bits are supported.')
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight).to(self.weight.device)
zeros -= 1
zeros = zeros.cpu().numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits),
dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError('Only 2,4,8 bits are supported.')
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros).to(self.weight.device)
@torch.no_grad()
def quant(self,
quantizer,
blocksize=128,
percdamp=0.01,
groupsize=-1,
actorder=False):
"""The implementation for GPTQ."""
with torch_setting(dtype=torch.float):
assert self.hessian is not None
W: torch.Tensor = self.weight_matrix.float() # out in
if not quantizer.ready():
quantizer.find_params(W, weight=True)
H = self.hessian.float().to(W.device)
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=W.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if (i1 + i) % groupsize == 0:
quantizer.find_params(
W[:, (i1 + i):(i1 + i + groupsize)],
weight=True)
if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(quantizer.scale)
zero.append(quantizer.zero)
now_idx += 1
q = quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2
err1 = (w - q) / d
W1[:,
i:] -= err1.unsqueeze(1).matmul(Hinv1[i,
i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
torch.cuda.synchronize()
error = torch.sum(Losses).item()
groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if scale == []:
scale.append(quantizer.scale)
zero.append(quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
self.weight_matrix = Q.data.to(self.weight_matrix.dtype)
if self.is_custom_kernel:
self.pack(scale, zero, g_idx)
del self.weight
return error
def free(self):
"""Free some cache and memory."""
self._hessian = None
torch.cuda.empty_cache()

View File

@ -0,0 +1,566 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d,
DynamicLinear)
# from mmrazor.implementations.pruning.sparse_gpt.utils import torch_setting
from .gptq import GPTQMixIn
try:
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True,
prune_configs_by={
'early_config_prune':
custom_autotune.matmul248_kernel_config_pruner,
'perf_model': None,
'top_k': None,
},
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M,
N, K, bits, maxq, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, stride_scales,
stride_zeros, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8
# times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk +
offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit
# word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused
# in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] *
stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs +
g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(
a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit
# values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[
None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 256,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 8
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
},
num_stages=2,
num_warps=4),
],
key=['M', 'N', 'K'],
nearest_power_of_two=True)
@triton.jit
def transpose_matmul_248_kernel(
a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits,
maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm,
stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = (offs_am[:, None] < M)
# b_ptrs is set up such that it repeats elements along the K axis 8
# times
b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk +
offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit
# word from B
scales_ptrs = scales_ptr + offs_n[
None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits
) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for n in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused
# in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1)
a = tl.load(
a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit
# values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[
None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
except: # noqa: E722
print('triton not installed.')
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
"""matmul248 function with matmul_248_kernel."""
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]),
device=input.device,
dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv( # noqa: E731
input.shape[0], META['BLOCK_SIZE_M']) * triton. # noqa: E731
cdiv( # noqa: E731
qweight.shape[1], META['BLOCK_SIZE_N']), ) # noqa: E731
matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx,
input.shape[0], qweight.shape[1],
input.shape[1], bits, maxq, input.stride(0),
input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0),
output.stride(1), scales.stride(0),
qzeros.stride(0))
return output
def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
"""transpose_matmul248 function with transpose_matmul_248_kernel."""
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim),
device=input.device,
dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) # noqa: E731
* triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) # noqa: E731
transpose_matmul_248_kernel[grid](input, qweight, output, scales,
qzeros, g_idx, input.shape[0],
qweight.shape[1], output_dim,
bits, maxq, input.stride(0),
input.stride(1), qweight.stride(0),
qweight.stride(1), output.stride(0),
output.stride(1), scales.stride(0),
qzeros.stride(0))
return output
class QuantLinearFunction(torch.autograd.Function):
"""Custom QuantLinearFunction."""
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
"""Custom forward."""
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
"""Custom backward."""
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_matmul248(grad_output, qweight, scales,
qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
class TritonGPTQLinear(nn.Module, GPTQMixIn):
"""Custom Linear for GPTQ with custom triton kernel."""
def __init__(self, bits, groupsize, weight, in_features, out_features,
bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError('Only 2,4,8 bits are supported.')
self.weight = weight
self.bias = bias
self.in_features = in_features
self.out_features = out_features
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else in_features
self.register_buffer(
'qweight',
torch.zeros((in_features // 32 * self.bits, out_features),
dtype=torch.int32))
self.register_buffer(
'qzeros',
torch.zeros((math.ceil(
in_features / self.groupsize), out_features // 32 * self.bits),
dtype=torch.int32))
self.register_buffer(
'scales',
torch.zeros(
(math.ceil(in_features / self.groupsize), out_features),
dtype=torch.float16))
self.register_buffer(
'g_idx',
torch.tensor([i // self.groupsize for i in range(in_features)],
dtype=torch.int32))
self._gptq_mix_in_init()
@property
def is_custom_kernel(self):
"""Whether use custom kernel."""
return True
@classmethod
def convert_from(cls, module: nn.Linear, bits, groupsize):
"""Convert to cls from torch's module."""
new_module = cls(
bits,
groupsize,
weight=module.weight,
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias)
return new_module
def forward(self, x):
"""Custom forward."""
if torch.all(self.qweight == 0):
out = F.linear(x, self.weight, self.bias)
else:
# import pdb;pdb.set_trace()
out_shape = x.shape[:-1] + (self.out_features, )
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales,
self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
out = out.reshape(out_shape)
# import pdb;pdb.set_trace()
return out
class GPTQLinear(DynamicLinear, GPTQMixIn):
"""Custom Linear for GPTQ without custom triton kernel."""
def __init__(self, a_fakequant=None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._gptq_mix_in_init()
self.a_fakequant = a_fakequant
self.fix_qparams = False
@property
def is_custom_kernel(self):
"""Whether use custom kernel."""
return False
@classmethod
def convert_from(cls,
module: nn.Linear,
a_fakequant=None) -> 'DynamicLinear':
"""Convert to cls from torch's module."""
new_module = cls(
a_fakequant=a_fakequant,
in_features=module.in_features,
out_features=module.out_features,
bias=True if module.bias is not None else False)
new_module.load_state_dict(module.state_dict(), strict=False)
dtype = next(module.parameters()).dtype
new_module = new_module.to(dtype)
return new_module
def forward(self, input: Tensor) -> Tensor:
"""Custom forward."""
if self.a_fakequant:
dtype = self.weight.dtype
if not self.fix_qparams:
self.a_fakequant.find_params(input)
input = self.a_fakequant.quantize(input).to(dtype)
return super().forward(input)
class GPTQConv2d(DynamicConv2d, GPTQMixIn):
"""Custom Conv2d for GPTQ without custom triton kernel."""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._gptq_mix_in_init()
@property
def is_custom_kernel(self):
"""Whether use custom kernel."""
return False
@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)
dtype = next(module.parameters()).dtype
new_module = new_module.to(dtype)
return new_module
def format_input(self, input: torch.Tensor):
"""Format input shape."""
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,
stride=self.stride) # B C D
return input.transpose(-1, -2)

View File

@ -0,0 +1,144 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
class Quantizer(nn.Module):
"""Quantizer for some basic quantization functions."""
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer('maxq', torch.tensor(0))
self.register_buffer('scale', torch.zeros(shape))
self.register_buffer('zero', torch.zeros(shape))
def configure(self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=.8,
trits=False):
"""Configure qconfig."""
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
self.scale = torch.zeros_like(self.scale)
def _quantize(self, x, scale, zero, maxq):
"""Fakequant."""
if maxq < 0:
return (x > scale / 2).float() * scale + (x <
zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
def find_params(self, x, weight=False):
"""Observe the specified data and calculate the qparams."""
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float('inf'), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 /
scale1) if not self.sym else self.zero
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1),
self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
"""Fakequant."""
if self.ready():
return self._quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
"""Whether is enabled."""
return self.maxq > 0
def ready(self):
"""Whether is ready."""
return torch.all(self.scale != 0)

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py # noqa: E501
def torch_snr_error(y_pred: torch.Tensor,
y_real: torch.Tensor,
reduction: str = 'mean') -> torch.Tensor:
"""Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calculted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of
SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
y_pred = y_pred.type(torch.float32)
y_real = y_real.type(torch.float32)
if y_pred.shape != y_real.shape:
raise ValueError(
f'Can not compute snr loss for tensors with different shape. '
f'({y_pred.shape} and {y_real.shape})')
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == 'mean':
return torch.mean(snr)
elif reduction == 'sum':
return torch.sum(snr)
elif reduction == 'none':
return snr
else:
raise ValueError('Unsupported reduction method.')

View File

@ -49,7 +49,9 @@ _ConvMetadata = namedtuple('_ConvMetadata', [
'relu_qat', 'bn_qat', 'bn_relu_qat', 'func'
])
if digit_version(torch.__version__) >= digit_version('1.13.0'):
if digit_version(
torch.__version__) >= digit_version('1.13.0') and digit_version(
torch.__version__) <= digit_version('1.13.1'):
_Conv1dMetadata = _ConvMetadata(
nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d,
nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d,

View File

@ -7,7 +7,9 @@ from .native import get_native_backend_config
from .openvino import get_openvino_backend_config
from .tensorrt import get_tensorrt_backend_config
if digit_version(torch.__version__) >= digit_version('1.13.0'):
if digit_version(
torch.__version__) >= digit_version('1.13.0') and digit_version(
torch.__version__) <= digit_version('1.13.1'):
BackendConfigs = {
'academic': get_academic_backend_config(),
'native': get_native_backend_config(),

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import torch.distributed as dist
from mmengine import MMLogger
from mmengine import print_log as engine_print_log
@ -17,8 +18,15 @@ def get_level(level='info'):
return level
def print_log(msg, logger='current', level='info'):
engine_print_log(msg, logger, get_level(level))
def print_log(msg, logger='current', level='info', only_rank0=True):
if only_rank0 and dist.is_initialized():
if dist.get_rank() == 0:
engine_print_log(msg, logger, get_level(level))
else:
pass
else:
engine_print_log(msg, logger, get_level(level))
def set_log_level(level='debug'):

View File

@ -0,0 +1,42 @@
<div align="center">
<img src="../../resources/mmrazor-logo.png" width="600"/>
</div>
# MMRazor for Large Models
## Introduction
MMRazor is dedicated to the development of general-purpose model compression tools. Now, MMRazor not only supports conventional CV model compression but also extends to support large models. This project will provide examples of MMRazor's compression for various large models, including LLaMA, stable diffusion, and more.
Code structure overview about large models.
```
mmrazor
├── implementations # core algorithm components
├── pruning
└── quantization
projects
└── mmrazor_large
├── algorithms # algorithms usage introduction
└── examples # examples for various models about algorithms
├── language_models
│ ├── LLaMA
│ └── OPT
└── ResNet
```
## Model-Algorithm Example Matrix
| | ResNet | OPT | LLama | Stable diffusion |
| ------------------------------------ | ----------------------------------------------- | ------------------------------------------------------------ | -------------------------------------------------------------- | ---------------- |
| [SparseGPT](algorithms/SparseGPT.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | |
| [GPTQ](algorithms/GPTQ.md) | [:white_check_mark:](examples/ResNet/README.md) | [:white_check_mark:](examples/language_models/OPT/README.md) | [:white_check_mark:](examples/language_models/LLaMA/README.md) | |
## PaperList
We provide a paperlist for researchers in the field of model compression for large models. If you want to add your paper to this list, please submit a PR.
| Paper | Title | Type | MMRazor |
| --------- | --------------------------------------------------------------------------------------------------------------------- | ------------ | --------------------------------------------- |
| SparseGPT | [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774) | Pruning | [:white_check_mark:](algorithms/SparseGPT.md) |
| GPTQ | [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) | Quantization | [:white_check_mark:](algorithms/GPTQ.md) |

View File

@ -0,0 +1,56 @@
# GPTQ
> [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323)
<!-- [ALGORITHM] -->
## Abstract
Generative Pre-trained Transformer models, known as GPT or OPT, set themselves apart through breakthrough performance across complex language modelling tasks, but also by their extremely high computational and storage costs. Specifically, due to their massive size, even inference for large, highly-accurate GPT models may require multiple performant GPUs, which limits the usability of such models. While there is emerging work on relieving this pressure via model compression, the applicability and performance of existing compression techniques is limited by the scale and complexity of GPT models. In this paper, we address this challenge, and propose GPTQ, a new one-shot weight quantization method based on approximate second-order information, that is both highlyaccurate and highly-efficient. Specifically, GPTQ can quantize GPT models with 175 billion parameters in approximately four GPU hours, reducing the bitwidth down to 3 or 4 bits per weight, with negligible accuracy degradation relative to the uncompressed baseline. Our method more than doubles the compression gains relative to previously-proposed one-shot quantization methods, preserving accuracy, allowing us for the first time to execute an 175 billion-parameter model inside a single GPU for generative inference. Moreover, we also show that our method can still provide reasonable accuracy in the extreme quantization regime, in which weights are quantized to 2-bit or even ternary quantization levels. We show experimentally that these improvements can be leveraged for end-to-end inference speedups over FP16, of around 3.25x when using high-end GPUs (NVIDIA A100) and 4.5x when using more cost-effective ones (NVIDIA A6000). The implementation is available at https://github.com/IST-DASLab/gptq.
## Usage
GPTQ is easy to use in mmrazor. You can use it like this:
```python
from mmrazor.implementations.quantization import gptq
# initial model, dataloaders
model
train_loader, test_loader
## init gptq compressor and prepare for quantization
compressor = gptq.GPTQCompressor()
compressor.prepare(model)
## get hessian matrix
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples)
compressor.remove_hessian_hooks()
## quant
compressor.quant_with_default_qconfig()
## to a normal torch model
model = compressor.to_static_model(model)
```
## Full Examples
- [ResNet](../examples/ResNet/README.md)
- [LLaMA](../examples/language_models/LLaMA/README.md)
## Cite
```latex
@misc{
Frantar_Ashkboos_Hoefler_Alistarh_2022,
title={GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers},
author={Frantar, Elias and Ashkboos, Saleh and Hoefler, Torsten and Alistarh, Dan},
year={2022},
month={Oct},
language={en-US}
}
```

View File

@ -0,0 +1,55 @@
# SparseGPT
> [SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot](https://arxiv.org/abs/2301.00774)
<!-- [ALGORITHM] -->
## Abstract
We show for the first time that large-scale generative pretrained transformer (GPT) family models can be pruned to at least 50% sparsity in one-shot, without any retraining, at minimal loss of accuracy. This is achieved via a new pruning method called SparseGPT, specifically designed to work efficiently and accurately on massive GPT-family models. We can execute SparseGPT on the largest available open-source models, OPT-175B and BLOOM-176B, in under 4.5 hours, and can reach 60% unstructured sparsity with negligible increase in perplexity: remarkably, more than 100 billion weights from these models can be ignored at inference time. SparseGPT generalizes to semi-structured (2:4 and 4:8) patterns, and is compatible with weight quantization approaches.
## Usage
SparseGPT is easy to use in mmrazor. You can use it like this:
```python
from mmrazor.implementations.pruning import sparse_gpt
# initial model, dataloaders
model
train_loader, test_loader
## init sparse gpt compressor and prepare for pruning
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model)
## get hessian matrix
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples)
compressor.remove_hessian_hooks()
## prune
compressor.prune_24()
## to a normal torch model
model = compressor.to_static_model(model)
```
## Full Examples
- [ResNet](../examples/ResNet/README.md)
- [OPT](../examples/language_models/OPT/README.md)
- [LLaMA](../examples/language_models/LLaMA/README.md)
## Cite
```latex
@article{frantar2023massive,
title={Massive Language Models Can Be Accurately Pruned in One-Shot},
author={Frantar, Elias and Alistarh, Dan},
journal={arXiv preprint arXiv:2301.00774},
year={2023}
}
```

View File

@ -0,0 +1,25 @@
# Examples for ResNet
## SparseGPT
For more details about SparseGPT, please refer to [SparseGPT](../../algorithms/SparseGPT.md)
### Usage
```shell
python projects/mmrazor_large/examples/ResNet/resnet18_sparse_gpt.py --data {imagenet_path} --batchsize 128 --num_samples 512
```
**Note**: this imagenet folder follows torch format.
## GPTQ
For more details about GPTQ, please refer to [GPTQ](../../algorithms/GPTQ.md)
### Usage
```shell
python projects/mmrazor_large/examples/ResNet/resnet18_gptq.py --data {imagenet_path} --batchsize 128 --num_samples 512
```
**Note**: this imagenet folder follows torch format.

View File

@ -0,0 +1,187 @@
# Copyright (c) OpenMMLab. All rights reserved.
# model settings
import os.path as osp
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from mmrazor.implementations.quantization.gptq import (GPTQCompressor,
GPTQLinear)
from mmrazor.utils import print_log
def enable_observer_linear(model):
print_log('Enable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = False
def disable_observer_linear(model):
print_log('Disable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = True
def get_dataloaders(batch_size, n_workers, path=''):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
osp.join(path, 'train'),
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
test_dataset = datasets.ImageFolder(
osp.join(path, 'val'),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]),
)
dataloader_train = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=True,
)
dataloader_test = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_workers,
pin_memory=True,
)
return dataloader_train, dataloader_test
@torch.no_grad()
def eval(model: nn.Module,
dataloader_test: DataLoader,
device=torch.device('cuda:0'),
is_half=True):
total = 0
correct = 0
model.eval()
with torch.no_grad():
for x, y in dataloader_test:
x: torch.Tensor # type: ignore
y: torch.Tensor # type: ignore
x = x.to(device)
y = y.to(device)
if is_half:
x = x.half()
y = y.half()
outputs = model(x)
_, predicted = outputs.max(1)
correct += (y == predicted).long().sum()
total += y.numel()
acc = correct / total
return acc
@torch.no_grad()
def infer(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_samples=256,
device=torch.device('cuda:0'),
is_half=True):
model.eval()
with torch.no_grad():
accumulate_batch = 0
for x, _ in dataloader:
x = x.to(device)
if is_half:
x = x.half()
model(x)
B = x.shape[0]
accumulate_batch += B
if accumulate_batch > num_samples:
break
if __name__ == '__main__':
import argparse
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
'--data',
type=str,
default='data/imagenet_torch',
help='path to imagenet in torch folder format')
arg_parser.add_argument(
'--num_samples',
type=int,
default=512,
help='number of samples to estimate hessian matrix')
arg_parser.add_argument(
'--batch_size',
type=int,
default=128,
help='batch size for evaluation and inference')
arg_parser.add_argument(
'--fp16',
type=bool,
default=False,
help='whether to use fp16 for evaluation and inference')
args = arg_parser.parse_args()
data_path = args.data
num_samples = args.num_samples
batch_size = args.batch_size
model = torchvision.models.resnet18(pretrained=True)
if args.fp16:
model = model.half()
train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)
compressor = GPTQCompressor()
# # use_triton_ops is True
# compressor.prepare(model,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# skipped_layers=['conv1'],
# bits=4,
# groupsize=128)
# # quantize activation for linear
# a_qconfig = dict(bits=4, perchannel=True, sym=False)
compressor.prepare(
model,
quant_conv=True,
quant_linear=True,
use_triton_ops=False,
skipped_layers=['conv1'],
# a_qconfig=a_qconfig
)
model.cuda()
enable_observer_linear(model)
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples, is_half=args.fp16)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig()
print('start evaluation')
disable_observer_linear(model)
model = model.cuda()
acc = eval(model, test_loader, is_half=args.fp16)
print('accuracy:', acc.item())

View File

@ -0,0 +1,137 @@
# Copyright (c) OpenMMLab. All rights reserved.
# model settings
import os.path as osp
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from mmrazor.implementations.pruning import sparse_gpt
def get_dataloaders(batch_size, n_workers, path=''):
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
osp.join(path, 'train'),
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
)
test_dataset = datasets.ImageFolder(
osp.join(path, 'val'),
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]),
)
dataloader_train = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=n_workers,
pin_memory=True,
)
dataloader_test = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=n_workers,
pin_memory=True,
)
return dataloader_train, dataloader_test
@torch.no_grad()
def eval(model: nn.Module,
dataloader_test: DataLoader,
device=torch.device('cuda:0')):
total = 0
correct = 0
model.eval()
with torch.no_grad():
for x, y in dataloader_test:
x: torch.Tensor # type: ignore
y: torch.Tensor # type: ignore
x = x.to(device)
outputs = model(x)
_, predicted = outputs.max(1)
y = y.to(device)
correct += (y == predicted).long().sum()
total += y.numel()
acc = correct / total
return acc
@torch.no_grad()
def infer(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_samples=256,
device=torch.device('cuda:0')):
model.eval()
with torch.no_grad():
accumulate_batch = 0
for x, _ in dataloader:
x = x.to(device)
model(x)
B = x.shape[0]
accumulate_batch += B
if accumulate_batch > num_samples:
break
if __name__ == '__main__':
import argparse
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
'--data',
type=str,
default='data/imagenet_torch',
help='path to imagenet in torch folder format')
arg_parser.add_argument(
'--num_samples',
type=int,
default=512,
help='number of samples to estimate hessian matrix')
arg_parser.add_argument(
'--batch_size',
type=int,
default=128,
help='batch size for evaluation and inference')
args = arg_parser.parse_args()
data_path = args.data
num_samples = args.num_samples
batch_size = args.batch_size
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model)
model.cuda()
compressor.init_hessian()
compressor.register_hessian_hooks()
infer(model, test_loader, num_samples=num_samples)
compressor.remove_hessian_hooks()
compressor.prune_24()
model = compressor.to_static_model(model)
print('start evaluation')
model = model.cuda()
acc = eval(model, test_loader)
print('accuracy:', acc.item())

View File

@ -0,0 +1,55 @@
# Examples for LLaMA
## SparseGPT
For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
### Usage
```shell
# example for decapoda-research/llama-7b-hf
python projects/mmrazor_large/examples/language_models/LLaMA/llama_sparse_gpt.py decapoda-research/llama-7b-hf c4
# help
usage: llama_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model Llama model to load
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```
## GPTQ
For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md)
### Usage
```shell
# example for decapoda-research/llama-7b-hf
python projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py decapoda-research/llama-7b-hf c4
# help
usage: llama_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model Llama model to load
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```

View File

@ -0,0 +1,152 @@
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
def get_wikitext2(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_ptb(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
split='train')
valdata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, model)
if 'ptb' in name:
return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model)
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
class LanguageDataset(TorchDataset):
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len
self.seq = fold_tokens(seq) # B N
def __len__(self) -> int:
return self.seq.shape[0]
def __getitem__(self, index):
return self.seq[index]
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader

View File

@ -0,0 +1,162 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datautils import get_loaders
from transformers.models.llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log
def enable_observer_linear(model):
print_log('Enable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = False
def disable_observer_linear(model):
print_log('Disable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = True
def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight
def get_model(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
model,
torch_dtype='auto',
)
model.seqlen = 2048
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
default=False,
help='Whether to enable memory efficient forward')
args = parser.parse_args()
DEV = args.dev
model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')
from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# use_triton_ops is True
compressor.prepare(
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)
# # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.layers,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# # a_qconfig=a_qconfig
# )
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model,
wrap_modules=[LlamaDecoderLayer],
enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m,
device=DEV):
# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
torch.save(model.state_dict(), args.save)

View File

@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datautils import get_loaders
from transformers.models.llama import LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.utils import print_log
def get_model(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
model,
torch_dtype='auto',
)
model.seqlen = 2048
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'-m',
type=bool,
default=False,
help='Whether to enable memory efficient forward')
args = parser.parse_args()
torch.set_default_dtype(torch.half)
DEV = torch.device('cuda:0')
model = get_model(args.model)
model.eval()
print_log('load model over')
dataloader, testloader = get_loaders(
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')
from mmrazor.implementations.pruning import sparse_gpt
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model.model.layers)
compressor.init_hessian()
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.prune_24()
model = compressor.to_static_model(model)
if args.save:
print_log(f'save model in {args.save}')
model.save_pretrained(args.save)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)

View File

@ -0,0 +1,198 @@
import functools
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from datautils import build_language_loader, get_loaders
from llama_sparse_gpt import get_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp
from mmrazor.implementations.pruning import sparse_gpt
from mmrazor.utils import print_log
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12356'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
print_log(f'init {rank}/{world_size}', only_rank0=False)
def init_fn_wrapper(model: nn.Module, model_copy: nn.Module):
def find_module_in_model_copy(module: nn.Module):
name2module = dict(model.named_modules())
module2name = dict([(v, k) for k, v in name2module.items()])
name = module2name[module]
return dict(model_copy.named_modules())[name]
def _materialize_meta_module(module: nn.Module, ):
def meta_to_empty(p: torch.Tensor):
if p.device == torch.device('meta'):
return p.new_empty(p.shape, device='cpu')
else:
return p
module._apply(meta_to_empty)
if dist.get_rank() == 0:
assert model_copy is not None
module_copy = find_module_in_model_copy(module)
name2p = dict(module_copy.named_parameters(remove_duplicate=False))
for n, p in module.named_parameters():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass
name2p = dict(module_copy.named_buffers(remove_duplicate=False))
for n, p in module.named_buffers():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass
return _materialize_meta_module
def main(rank, world_size=8, args=None):
setup(rank, world_size)
model_name = args.model
batch_size = args.batch_size
def build():
model = get_model(model_name)
# init compressor
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model.model.layers)
return model, compressor
with init_on_meta(enable=True):
model, compressor = build()
if rank == 0:
model_copy, _ = build() # init on cpu
else:
model_copy = None
# init fsdp
size_based_auto_wrap_policy_x = functools.partial(
size_based_auto_wrap_policy, min_num_params=int(1e8))
model = FSDP(
model,
auto_wrap_policy=size_based_auto_wrap_policy_x,
cpu_offload=CPUOffload(True),
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=rank,
param_init_fn=init_fn_wrapper(model, model_copy),
sync_module_states=True)
print_log(model)
# init hessian
compressor.init_hessian(device='cuda')
compressor.register_hessian_hooks()
_, testloader = get_loaders(
args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
opt_infer_fsdp(model, testloader)
compressor.remove_hessian_hooks()
# prune
name2module = dict(model.named_modules())
module2name = {}
module2name = dict([(v, k) for k, v in name2module.items()])
with torch.no_grad():
for fsdp in FSDP.fsdp_modules(model):
fsdp._reset_lazy_init()
with FSDP.summon_full_params(fsdp, recurse=False):
fsdp_name = module2name[fsdp]
for name, op in fsdp.named_modules():
if name.count('_fsdp_wrapped_module') <= 1:
if isinstance(op, sparse_gpt.SparseGptMixIn):
try:
op.prune(0.5, prunen=2, prunem=4)
print_log(
f'prune {fsdp_name}.{name} successfully.', # noqa
only_rank0=True)
except Exception as e:
print_log(
f'prune {fsdp_name}.{name} failed, as {e}', # noqa
only_rank0=True)
fsdp._reset_lazy_init()
# save
if args.save:
print_log(f'save model in {args.save}')
model._reset_lazy_init()
with FSDP.summon_full_params(model, rank0_only=True, writeback=False):
if dist.get_rank() == 0:
model.save_pretrained(args.save)
# val
torch.cuda.empty_cache()
model._reset_lazy_init()
for dataset in ['wikitext2', 'ptb', 'c4']:
_, testloader = get_loaders(
dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
print_log(dataset)
opt_eval_fsdp(model, testloader, torch.device('cuda'))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=64,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--world_size', type=int, default=1, help='Number of GPUs to use.')
args = parser.parse_args()
WORLD_SIZE = args.world_size
mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

View File

@ -0,0 +1,173 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.utils.data import DataLoader
from transformers import OPTForCausalLM
from mmrazor.utils import print_log
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
@torch.no_grad()
def opt_eval(model: OPTForCausalLM,
testenc,
dev=torch.device('cuda:0'),
batch_size=16):
print_log('Evaluating ...')
seqlen = model.seqlen
testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N
use_cache = model.config.use_cache
model.config.use_cache = False
nlls = []
for i, batch in enumerate(torch.split(testenc, batch_size)):
B = batch.shape[0]
batch = batch.to(dev)
out: torch.Tensor = model(batch)[0] # 1
shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
shift_labels = batch[:, 1:].flatten() # (B N)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
neg_log_likelihood = loss.float() * seqlen * B
nlls.append(neg_log_likelihood)
print_log(f'{(i+1)*batch_size} / {len(testenc)}')
ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel()))
print_log(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache
@torch.no_grad()
def opt_infer(
model: OPTForCausalLM,
testenc,
dev,
batch_size=16,
num_samples=128,
):
print_log('Infer ...')
seqlen = model.seqlen
testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N
use_cache = model.config.use_cache
model.config.use_cache = False
for i, batch in enumerate(torch.split(testenc, batch_size)):
batch = batch.to(dev)
_ = model(batch)[0] # 1
print_log(f'{(i+1)*batch_size} / {num_samples}')
if (i + 1) * batch_size >= num_samples:
break
model.config.use_cache = use_cache
class init_on_meta:
def __init__(self, enable=True) -> None:
self.enable = enable
self.default_device = torch.ones([]).device
def __enter__(self):
if self.enable:
torch.set_default_device('meta')
def __exit__(self, exc_type, exc_value, traceback):
if self.enable:
torch.set_default_device(self.default_device)
@torch.no_grad()
def opt_eval_fsdp(
model: nn.Module,
dataloader: DataLoader,
dev=torch.device('cuda:0'),
):
print_log('Evaluating ...')
use_cache = model.config.use_cache
model.config.use_cache = False
loss_sum = torch.zeros([1], device=dev)
total_seq_len = torch.zeros([1], device=dev, dtype=torch.long)
for i, batch in enumerate(dataloader):
B, seq_len = batch.shape[:2]
batch = batch.to(dev)
out: torch.Tensor = model(batch)[0] # 1
shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
shift_labels = batch[:, 1:].flatten() # (B N)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
neg_log_likelihood = loss.float() * seq_len * B
total_seq_len += seq_len * B
loss_sum += neg_log_likelihood
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
infered_batch = (i + 1) * B * world_size
print_log(f'{infered_batch} / {len(dataloader.dataset)}')
if dist.is_initialized():
dist.all_reduce(loss_sum)
dist.all_reduce(total_seq_len)
ppl = torch.exp(loss_sum / total_seq_len)
print_log(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache
@torch.no_grad()
def opt_infer_fsdp(
model: nn.Module,
dataloader: DataLoader,
dev=torch.device('cuda:0'),
num_samples=128,
):
print_log('Infering ...')
model.config.use_cache = False
for i, batch in enumerate(dataloader):
B = batch.shape[0]
batch = batch.to(dev)
model(batch)[0] # 1
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
infered_batch = (i + 1) * B * world_size
print_log(f'{infered_batch} / {len(dataloader.dataset)}')
if infered_batch >= num_samples:
break

View File

@ -0,0 +1,55 @@
# Examples for OPT
## SparseGPT
For more details about SparseGPT, please refer to [SparseGPT](../../../algorithms/SparseGPT.md)
### Usage
```shell
# example for facebook/opt-125m
python projects/mmrazor_large/examples/language_models/OPT/opt_sparse_gpt.py facebook/opt-125m c4
# help
usage: opt_sparse_gpt.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model OPT model to load; pass `facebook/opt-X`.
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```
## GPTQ
For more details about GPTQ, please refer to [GPTQ](../../../algorithms/GPTQ.md)
### Usage
```shell
# example for facebook/opt-125m
python projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py facebook/opt-125m c4
# help
usage: opt_gptq.py [-h] [--seed SEED] [--nsamples NSAMPLES] [--batch_size BATCH_SIZE] [--save SAVE] [-m M] model {wikitext2,ptb,c4}
positional arguments:
model OPT model to load; pass `facebook/opt-X`.
{wikitext2,ptb,c4} Where to extract calibration data from.
optional arguments:
-h, --help show this help message and exit
--seed SEED Seed for sampling the calibration data.
--nsamples NSAMPLES Number of calibration data samples.
--batch_size BATCH_SIZE
Batchsize for calibration and evaluation.
--save SAVE Path to saved model.
-m M Whether to enable memory efficient forward
```

View File

@ -0,0 +1,152 @@
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
def get_wikitext2(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(' '.join(traindata['text']), return_tensors='pt')
testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_ptb(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc
def get_c4(nsamples, seed, seqlen, model):
from datasets import load_dataset
traindata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
split='train')
valdata = load_dataset(
'allenai/c4',
'allenai--c4',
data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
split='validation')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] >= seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids
valenc = TokenizerWrapper(valenc)
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, model)
if 'ptb' in name:
return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model)
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
class LanguageDataset(TorchDataset):
def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len
self.seq = fold_tokens(seq) # B N
def __len__(self) -> int:
return self.seq.shape[0]
def __getitem__(self, index):
return self.seq[index]
def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader

View File

@ -0,0 +1,157 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
import torch
from datautils import get_loaders
from transformers import OPTForCausalLM
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log
def enable_observer_linear(model):
print_log('Enable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = False
def disable_observer_linear(model):
print_log('Disable updating qparams for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, GPTQLinear):
module.fix_qparams = True
def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight
def get_model(model):
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = model.config.max_position_embeddings
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
default=False,
help='Whether to enable memory efficient forward')
args = parser.parse_args()
DEV = args.dev
model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')
from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# use_triton_ops is True
compressor.prepare(
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)
# # # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.decoder,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# # a_qconfig=a_qconfig
# )
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')
compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV):
# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
torch.save(model.state_dict(), args.save)

View File

@ -0,0 +1,105 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
import torch
from datautils import get_loaders
from transformers import OPTForCausalLM
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from utils import opt_eval, opt_infer
from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.utils import print_log
def get_model(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = model.config.max_position_embeddings
return model
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=64,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'-m',
type=bool,
default=False,
help='Whether to enable memory efficient forward')
args = parser.parse_args()
DEV = torch.device('cuda:0')
model = get_model(args.model)
model.eval()
print_log('load model over')
dataloader, testloader = get_loaders(
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')
from mmrazor.implementations.pruning import sparse_gpt
compressor = sparse_gpt.SparseGptCompressor()
compressor.prepare(model.model.decoder)
compressor.init_hessian()
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.prune_24()
model = compressor.to_static_model(model)
if args.save:
print_log(f'save model in {args.save}')
model.save_pretrained(args.save)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)

View File

@ -0,0 +1,198 @@
import functools
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from datautils import build_language_loader, get_loaders
from opt_sparse_gpt import get_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp
from mmrazor.implementations.pruning import sparse_gpt
from mmrazor.utils import print_log
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12356'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
print_log(f'init {rank}/{world_size}', only_rank0=False)
def init_fn_wrapper(model: nn.Module, model_copy: nn.Module):
def find_module_in_model_copy(module: nn.Module):
name2module = dict(model.named_modules())
module2name = dict([(v, k) for k, v in name2module.items()])
name = module2name[module]
return dict(model_copy.named_modules())[name]
def _materialize_meta_module(module: nn.Module, ):
def meta_to_empty(p: torch.Tensor):
if p.device == torch.device('meta'):
return p.new_empty(p.shape, device='cpu')
else:
return p
module._apply(meta_to_empty)
if dist.get_rank() == 0:
assert model_copy is not None
module_copy = find_module_in_model_copy(module)
name2p = dict(module_copy.named_parameters(remove_duplicate=False))
for n, p in module.named_parameters():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass
name2p = dict(module_copy.named_buffers(remove_duplicate=False))
for n, p in module.named_buffers():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass
return _materialize_meta_module
def main(rank, world_size=8, args=None):
setup(rank, world_size)
model_name = args.model
batch_size = args.batch_size
def build():
model = get_model(model_name)
# init mutator
mutator = sparse_gpt.SparseGptCompressor()
mutator.prepare(model.model.decoder)
return model, mutator
with init_on_meta(enable=True):
model, mutator = build()
if rank == 0:
model_copy, _ = build() # init on cpu
else:
model_copy = None
# init fsdp
size_based_auto_wrap_policy_x = functools.partial(
size_based_auto_wrap_policy, min_num_params=int(1e8))
model = FSDP(
model,
auto_wrap_policy=size_based_auto_wrap_policy_x,
cpu_offload=CPUOffload(True),
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=rank,
param_init_fn=init_fn_wrapper(model, model_copy),
sync_module_states=True)
print_log(model)
# init hessian
mutator.init_hessian(device='cuda')
mutator.register_hessian_hooks(model)
_, testloader = get_loaders(
args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
opt_infer_fsdp(model, testloader)
mutator.remove_hessian_hooks()
# prune
name2module = dict(model.named_modules())
module2name = {}
module2name = dict([(v, k) for k, v in name2module.items()])
with torch.no_grad():
for fsdp in FSDP.fsdp_modules(model):
fsdp._reset_lazy_init()
with FSDP.summon_full_params(fsdp, recurse=False):
fsdp_name = module2name[fsdp]
for name, op in fsdp.named_modules():
if name.count('_fsdp_wrapped_module') <= 1:
if isinstance(op, sparse_gpt.SparseGptMixIn):
try:
op.prune(0.5, prunen=2, prunem=4)
print_log(
f'prune {fsdp_name}.{name} successfully.', # noqa
only_rank0=True)
except Exception as e:
print_log(
f'prune {fsdp_name}.{name} failed, as {e}', # noqa
only_rank0=True)
fsdp._reset_lazy_init()
# save
if args.save:
print_log(f'save model in {args.save}')
model._reset_lazy_init()
with FSDP.summon_full_params(model, rank0_only=True, writeback=False):
if dist.get_rank() == 0:
model.save_pretrained(args.save)
# val
torch.cuda.empty_cache()
model._reset_lazy_init()
for dataset in ['wikitext2', 'ptb', 'c4']:
_, testloader = get_loaders(
dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
print_log(dataset)
opt_eval_fsdp(model, testloader, torch.device('cuda'))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=64,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--world_size', type=int, default=1, help='Number of GPUs to use.')
args = parser.parse_args()
WORLD_SIZE = args.world_size
mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

View File

@ -0,0 +1,171 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Example for opt is converted from https://github.com/ist-daslab/sparsegpt
import torch
import torch.nn as nn
from torch import distributed as dist
from torch.utils.data import DataLoader
from transformers import OPTForCausalLM
from mmrazor.utils import print_log
def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens
@torch.no_grad()
def opt_eval(model: OPTForCausalLM,
testenc,
dev=torch.device('cuda:0'),
batch_size=16):
print_log('Evaluating ...')
seqlen = model.seqlen
testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N
use_cache = model.config.use_cache
model.config.use_cache = False
nlls = []
for i, batch in enumerate(torch.split(testenc, batch_size)):
B = batch.shape[0]
batch = batch.to(dev)
out: torch.Tensor = model(batch)[0] # 1
shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
shift_labels = batch[:, 1:].flatten() # (B N)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
neg_log_likelihood = loss.float() * seqlen * B
nlls.append(neg_log_likelihood)
print_log(f'{(i+1)*batch_size} / {len(testenc)}')
ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel()))
print_log(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache
@torch.no_grad()
def opt_infer(
model: OPTForCausalLM,
testenc,
dev,
batch_size=16,
num_samples=128,
):
print_log('Infer ...')
seqlen = model.seqlen
testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N
model.config.use_cache = False
for i, batch in enumerate(torch.split(testenc, batch_size)):
batch = batch.to(dev)
_ = model(batch)[0] # 1
print_log(f'{(i+1)*batch_size} / {num_samples}')
if (i + 1) * batch_size >= num_samples:
break
class init_on_meta:
def __init__(self, enable=True) -> None:
self.enable = enable
self.default_device = torch.ones([]).device
def __enter__(self):
if self.enable:
torch.set_default_device('meta')
def __exit__(self, exc_type, exc_value, traceback):
if self.enable:
torch.set_default_device(self.default_device)
@torch.no_grad()
def opt_eval_fsdp(
model: nn.Module,
dataloader: DataLoader,
dev=torch.device('cuda:0'),
):
print_log('Evaluating ...')
use_cache = model.config.use_cache
model.config.use_cache = False
loss_sum = torch.zeros([1], device=dev)
total_seq_len = torch.zeros([1], device=dev, dtype=torch.long)
for i, batch in enumerate(dataloader):
B, seq_len = batch.shape[:2]
batch = batch.to(dev)
out: torch.Tensor = model(batch)[0] # 1
shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
shift_labels = batch[:, 1:].flatten() # (B N)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
neg_log_likelihood = loss.float() * seq_len * B
total_seq_len += seq_len * B
loss_sum += neg_log_likelihood
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
infered_batch = (i + 1) * B * world_size
print_log(f'{infered_batch} / {len(dataloader.dataset)}')
if dist.is_initialized():
dist.all_reduce(loss_sum)
dist.all_reduce(total_seq_len)
ppl = torch.exp(loss_sum / total_seq_len)
print_log(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache
@torch.no_grad()
def opt_infer_fsdp(
model: nn.Module,
dataloader: DataLoader,
dev=torch.device('cuda:0'),
num_samples=128,
):
print_log('Infering ...')
model.config.use_cache = False
for i, batch in enumerate(dataloader):
B = batch.shape[0]
batch = batch.to(dev)
model(batch)[0] # 1
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
infered_batch = (i + 1) * B * world_size
print_log(f'{infered_batch} / {len(dataloader.dataset)}')
if infered_batch >= num_samples:
break

View File

@ -7,5 +7,6 @@ nbformat
numpy < 1.24.0 # A temporary solution for tests with mmdet.
onnx
pytest
triton==2.0.0
xdoctest >= 0.10.0
yapf

View File

@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
import torch.nn as nn
from mmrazor import digit_version
from mmrazor.implementations.pruning import sparse_gpt
class TestSparseGptOps(unittest.TestCase):
@torch.no_grad()
def test_op(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('torch<1.12.0')
def get_loss(linear, linear1, data):
y = linear(data)
y1 = linear1(data)
return (y - y1).square().sum()
def infer(model, dataset):
for x in dataset:
model(x)
for device in ['cpu']:
device = torch.device(device)
# prepare
linear = nn.Linear(12, 20, bias=False).to(device)
sparse_linear = sparse_gpt.SparseGptLinear(
12, 20, bias=False).to(device)
sparse_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([10, 5, 12]).to(
device) # [loader_batch,batch,feature]
data_0 = random_data[0]
self.assertTrue(get_loss(linear, sparse_linear, data_0) == 0)
# prune
sparse_linear.init_hessian()
sparse_linear.register_hessian_hook()
infer(sparse_linear, random_data)
sparse_linear.remove_hessian_hook()
sparse_linear.prune(0.5)
# compare
print('norm:', linear(data_0).norm(2))
print('distance:', get_loss(linear, sparse_linear, data_0))
@torch.no_grad()
def test_model(self):
if digit_version(torch.__version__) < digit_version('1.12.0'):
self.skipTest('torch<1.12.0')
import torchvision
model = torchvision.models.resnet18()
mutator = sparse_gpt.SparseGptCompressor()
mutator.prepare(model)
x = torch.rand(10, 3, 224, 224)
mutator.init_hessian()
mutator.register_hessian_hooks()
model(x)
mutator.remove_hessian_hooks()
mutator.prune_24()
model = mutator.to_static_model(model)
assert type(model.conv1) is nn.Conv2d

View File

@ -0,0 +1,80 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
import torch.nn as nn
from mmrazor import digit_version
from mmrazor.implementations.quantization import gptq
class TestGPTQOps(unittest.TestCase):
@torch.no_grad()
def test_op(self):
if digit_version(torch.__version__) < digit_version(
'1.12.0') or not torch.cuda.is_available():
self.skipTest('torch<1.12.0')
def get_loss(linear, linear1, data):
y = linear(data)
y1 = linear1(data)
return (y - y1).square().sum()
def infer(model, dataset):
for x in dataset:
model(x)
for device in ['cpu']:
device = torch.device(device)
# prepare
linear = nn.Linear(12, 20, bias=False).to(device)
gptq_linear = gptq.GPTQLinear(
in_features=12, out_features=20, bias=False).to(device)
gptq_linear.load_state_dict(linear.state_dict(), strict=False)
random_data = torch.rand([10, 5, 12]).to(
device) # [loader_batch,batch,feature]
data_0 = random_data[0]
self.assertTrue(get_loss(linear, gptq_linear, data_0) == 0)
# quant
gptq_linear.init_hessian()
gptq_linear.register_hessian_hook()
infer(gptq_linear, random_data)
gptq_linear.remove_hessian_hook()
qconfig = dict(bits=4, perchannel=True, sym=False)
quantizer = gptq.Quantizer()
quantizer.configure(**qconfig)
gptq_linear.quant(quantizer=quantizer)
# compare
print('norm:', linear(data_0).norm(2))
print('distance:', get_loss(linear, gptq_linear, data_0))
@torch.no_grad()
def test_model(self):
if digit_version(torch.__version__) < digit_version(
'1.12.0') or not torch.cuda.is_available():
self.skipTest('torch<1.12.0')
import torchvision
model = torchvision.models.resnet18()
compressor = gptq.GPTQCompressor()
compressor.prepare(model, use_triton_ops=False)
x = torch.rand(10, 3, 224, 224)
compressor.init_hessian()
compressor.register_hessian_hooks()
model(x)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig()
model = compressor.to_static_model(model)
assert type(model.conv1) is nn.Conv2d