mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
* fix bug for autoslim (#511) * fix bug for autoslim * delete resnet50 for dmcp --------- Co-authored-by: liukai <your_email@abc.example> * Add timm (#512) * add timm to optional.txt * fix deit paths * [Feature] Add MMRazor quantization (#513) * [FEATURE] add quant algo `Learned Step Size Quantization` (#346) * update * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai <liukai@pjlab.org.cn> * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * updated * retina loss & predict & tesnor DONE * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao <HIT-cwh> * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * for RFC * Customed FX initialize * add UT init * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai <liukai@pjlab.org.cn> * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix review comments * fix CI * fix UTs * update torch requirements Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: humu789 <humu@pjlab.org.cn> * [Features]Quantize pipeline (#350) * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * update * updated * retina loss & predict & tesnor DONE * for RFC * Customed FX initialize * add UT init * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * TDO: UTs * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * fixed DefaultQconfigs name * fix bugs * add comments and fix typos * delete useless codes * fix bugs and add comments * rename prepare_module_dict * update lsq config Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: FreakieHuang <frank0huang@foxmail.com> Co-authored-by: pppppM <gjf_mail@126.com> * [Feature] Add `prepare_for_mmdeploy` interface (#365) * remove useless code * fix build graph module import bug * refactor general quant * rename GeneralQuant to MMArchitectureQuant * fix some dtype bugs * add prepare_for_mmdeploy interface * update prepare for mmdeploy args * fix some comments Co-authored-by: humu789 <humu@pjlab.org.cn> * CodeCamp #132 add MinMaxFloorObserver (#376) * add minmaxfloor_observer.py * add MinMaxFloorObserver and normative docstring * add test for MinMaxFloorObserver * Quant go (#409) * add torch observer * add torch fakequant * refactor base quantizer * add QConfigHander and QSchemeHander & finish quantizer_refactor_beta * passed ptq_pipeline * tmp-commit * fix loop and algorithm * delete fakequant * refactor code structure * remove lsq * valid ptq pipeline * wip * fix del functions * fix * fix lint and pytest Co-authored-by: HIT-cwh <2892770585@qq.com> * [Refactor & Doc] Refactor graph_utils and add docstring and pytest (#420) * refactor graph_utils and add docstring and pytest * fix del fakequant * delete useless codes * Merge dev-1.x into quantize (#430) * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai <liukai@pjlab.org.cn> * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao <HIT-cwh> * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai <liukai@pjlab.org.cn> * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * [Feature] Add Autoformer algorithm (#315) * update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut * [Feature] Add performance predictor (#306) * add predictor with 4 handlers * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * update metric_predictor: 1. update MetricPredictor; 2. add predictor config for searching; 3. add predictor in evolution_search_loop. * add UT for predictor * add MLPHandler * patch optional.txt for predictors * patch test_evolution_search_loop * refactor apis of predictor and handlers * fix ut and remove predictor_cfg in predictor * adapt new mutable & mutator design * fix ut * remove unness assert after rebase * move predictor-build in __init__ & simplify estimator-build Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn> * [Feature] Add DCFF (#295) * add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai <liukai@pjlab.org.cn> * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai <liukai@pjlab.org.cn> * tmpsave * migrate ut * tmpsave2 * add loss collector * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai <liukai@pjlab.org.cn> * update config * fix md & pytorch support <1.9.0 in batchnorm init * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai <liukai@pjlab.org.cn> * fix ci * ci fix py3.6.x & add mmpose * ci fix py3.6.9 in utils/index_dict.py * fix mmpose * minimum_version_cpu=3.7 * fix ci 3.7.13 * fix pruning &meta ci * support python3.6.9 * fix py3.6 import caused by circular import patch in py3.7 * fix py3.6.9 * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> * support >=3.7 * support py3.6.9 * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai <liukai@pjlab.org.cn> * update new channel config format * update pruning refactor * update merged pruning * update commit * fix dynamic_conv_mixin * update comments: readme&dynamic_conv_mixins.py * update readme * move kl softmax channel pooling to op by comments * fix comments: fix redundant & split README.md * dcff in ItePruneAlgorithm * partial dynamic params for fuseconv * add step_freq & prune_time check * update comments * update comments * update comments * fix ut * fix gpu ut & revise step_freq in ItePruneAlgorithm * update readme * revise ItePruneAlgorithm * fix docs * fix dynamic_conv attr * fix ci Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com> Co-authored-by: jacky <jacky@xx.com> * [Fix] Fix optional requirements (#357) * fix optional requirements * fix dcff ut * fix import with get_placeholder * supplement the previous commit * [Fix] Fix configs of wrn models and ofd. (#361) * 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10 * 1.Add README for vanilla-wrm. * 1.Revise readme of wrn Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn> * [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (#356) fix bug on mmrazor visualization, mismatch argument in define and use. Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> * fix bug in benchmark_test (#364) fix bug in configs Co-authored-by: Your Name <you@example.com> * [FIX] Fix wrn configs (#368) * fix wrn configs * fix wrn configs * update online wrn model weight * [Fix] fix bug on pkd config. Wrong import filename. (#373) * [CI] Update ci to torch1.13 (#380) update ci to torch1.13 * [Feature] Add BigNAS algorithm (#219) * add calibrate-bn-statistics * add test calibrate-bn-statistics * fix mixins * fix mixins * fix mixin tests * remove slimmable channel mutable and refactor dynamic op * refact dynamic batch norm * add progressive dynamic conv2d * add center crop dynamic conv2d * refactor dynamic directory * refactor dynamic sequential * rename length to depth in dynamic sequential * add test for derived mutable * refactor dynamic op * refactor api of dynamic op * add derive mutable mixin * addbignas algorithm * refactor bignas structure * add input resizer * add input resizer to bignas * move input resizer from algorithm into classifier * remove compnents * add attentive mobilenet * delete json file * nearly(less 0.2) align inference accuracy with gml * move mutate seperated in bignas mobilenet backbone * add zero_init_residual * add set_dropout * set dropout in bignas algorithm * fix registry * add subnet yaml and nearly align inference accuracy with gml * add rsb config for bignas * remove base in config * add gml bignas config * convert to iter based * bignas forward and backward fly * fix merge conflict * fix dynamicseq bug * fix bug and refactor bignas * arrange configs of bignas * fix typo * refactor attentive_mobilenet * fix channel mismatch due to registion of DerivedMutable * update bignas & fix se channel mismatch * add AutoAugmentV2 & remove unness configs * fix lint * recover channel assertion in channel unit * fix a group bug * fix comments * add docstring * add norm in dynamic_embed * fix search loop & other minor changes * fix se expansion * minor change * add ut for bignas & attentive_mobilenet * fix ut * update bignas readme * rm unness ut & supplement get_placeholder * fix lint * fix ut * add subnet deployment in downstream tasks. * minor change * update ofa backbone * minor fix * Continued improvements of searchable backbone * minor change * drop ratio in backbone * fix comments * fix ci test * fix test * add dynamic shortcut UT * modify strategy to fit bignas * fix test * fix bug in neck * fix error * fix error * fix yaml * save subnet ckpt * merge autoslim_val/test_loop into subnet_val_loop * move calibrate_bn_mixin to utils * fix bugs and add docstring * clean code * fix register bug * clean code * update Co-authored-by: wangshiguang <wangshiguang@sensetime.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> Co-authored-by: sunyue1 <sunyue1@sensetime.com> * [Bug] Fix ckpt (#372) fix ckpt * [Feature] Add tools to convert distill ckpt to student-only ckpt. (#381) * [Feature] Add tools to convert distill ckpt to student-only ckpt. * fix bug. * add --model-only to only save model. * Make changes accroding to PR review. * Enhance the Abilities of the Tracer for Pruning. (#371) * tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: liukai <your_email@abc.example> * fix bug in placer holder (#395) * fix bug in placer holder * remove redundent comment Co-authored-by: liukai <your_email@abc.example> * Add get_prune_config and a demo config_pruning (#389) * update tools and test * add demo * disable test doc * add switch for test tools and test_doc * fix bug * update doc * update tools name * mv get_channel_units Co-authored-by: liukai <your_email@abc.example> * [Improvement] Adapt OFA series with SearchableMobileNetV3 (#385) * fix mutable bug in AttentiveMobileNetV3 * remove unness code * update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names * unify the sampling usage in sandwich_rule-based NAS * use alias to export subnet * update OFA configs * fix attr bug * fix comments * update convert_supernet2subnet.py * correct the way to dump DerivedMutable * fix convert index bug * update OFA configs & models * fix dynamic2static * generalize convert_ofa_ckpt.py * update input_resizer * update README.md * fix ut * update export_fix_subnet * update _dynamic_to_static * update fix_subnet UT & minor fix bugs * fix ut * add new autoaug compared to attentivenas * clean * fix act * fix act_cfg * update fix_subnet * fix lint * add docstring Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> * [Fix]Dcff Deploy Revision (#383) * dcff deploy revision * tempsave * update fix_subnet * update mutator load * export/load_fix_subnet revision for mutator * update fix_subnet with dev-1.x * update comments * update docs * update registry * [Fix] Fix commands in README to adapt branch 1.x (#400) * update commands in README for 1.x * fix commands Co-authored-by: gaoyang07 <1546308416@qq.com> * Set requires_grad to False if the teacher is not trainable (#398) * add choice and mask of units to checkpoint (#397) * add choice and mask of units to checkpoint * update * fix bug * remove device operation * fix bug * fix circle ci error * fix error in numpy for circle ci * fix bug in requirements * restore * add a note * a new solution * save mutable_channel.mask as float for dist training * refine * mv meta file test Co-authored-by: liukai <your_email@abc.example> Co-authored-by: jacky <jacky@xx.com> * [Bug]Fix fpn teacher distill (#388) fix fpn distill * [CodeCamp #122] Support KD algorithm MGD for detection. (#377) * [Feature] Support KD algorithm MGD for detection. * use connector to beauty mgd. * fix typo, add unitest. * fix mgd loss unitest. * fix mgd connector unitest. * add model pth and log file. * add mAP. * update l1 config (#405) * add l1 config * update l1 config Co-authored-by: jacky <jacky@xx.com> * [Feature] Add greedy search for AutoSlim (#336) * WIP: add greedysearch * fix greedy search and add bn_training_mode to autoslim * fix cfg files * fix autoslim configs * fix bugs when converting dynamic bn to static bn * change to test loop * refactor greedy search * rebase and fix greedysearch * fix lint * fix and delete useless codes * fix pytest * fix pytest and add bn_training_mode * fix lint * add reference to AutoSlimGreedySearchLoop's docstring * sort candidate_choices * fix save subnet * delete useless codes in channel container * change files' name: convert greedy_search_loop to autoslim_greedy_search_loop * [Fix] Fix metafile (#422) * fix ckpt path in metafile and readme * fix darts file path * fix docstring in ConfigurableDistiller * fix darts * fix error * add darts of mmrazor version * delete py36 Co-authored-by: liukai <your_email@abc.example> * update bignas cfg (#412) * check attentivenas training * update ckpt link * update supernet log Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> * Bump version to 1.0.0rc2 (#423) bump version to 1.0.0rc2 Co-authored-by: liukai <your_email@abc.example> * fix lint * fix ci * add tmp docstring for passed ci * add tmp docstring for passed ci * fix ci * add get_placeholder for quant * add skip for unittest * fix package placeholder bug * add version judgement in __init__ * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn> Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com> Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn> Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang <wangshiguang@sensetime.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 <sunyue1@sensetime.com> Co-authored-by: liukai <your_email@abc.example> Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com> Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn> * [Docs] Add docstring and unittest about backendconfig & observer & fakequant (#428) * add ut about backendconfig * add ut about observers and fakequants in torch * fix torch1.13 ci * [Docs] Add docstring for `MMArchitectureQuant` & `NativeQuantizer` (#425) * add docstring on mm_architecture& native_quantizer * add naive openvino r18 qat config & dist_ptq.sh * Added a more accurate description * unitest&doc * checkpoint url * unitest * passed_pre_commit * unitest on native_quantizer& fix bugs * remove dist_ptq * add get_placeholder&skipTest * complete arg descriptions * fix import bugs * fix pre-commit * add get_placeholder * add typehint and doctring * update docstring&typehint * update docstring * pre-commit * fix some problems * fix bug * [Docs] Add docstring and unitest about custom tracer (#427) * rename QConfigHandler and QSchemeHandler * add docstring about custom tracer * add ut about custom tracer * fix torch1.13 ci * fix lint * fix ci * fix ci * [Docs & Refactor] Add docstring and UT of other quantizers (#439) * add quantizer docstring and refactor the interface of AcademicQuantizer * add AcademicQuantizer unittest * add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface * adapt torch113 ci * fix import * fix lint * update some docstring * fix ci * [Feature&Doc]Modify ptq pipeline and support lsq (#435) * modify ptq pipeline and support lsq * use placeholder * fix lsq && quantloop * add lsq pytest * add quant loop pytest * test lsq observer * fix bug under pt13 * fix reset_min_max_vals * fix bugs under pt13 * fix configs * add get_qconfig_mapping * delete is_qat, add doc and fix pytest * delete useless codes in custom_tracer * skip pytest under pt13 * add todo: check freezebn * fix pytest bugs * fix pytest * fix pytest * fix pytest * [Docs] Add customize_quantization_tutorial (#440) * [Docs] Add quantization user guide (#441) * add quantization user guide * fix layout * fix layout * update README * [Bug] Fix del redundant fakequant (#447) fix del redundant fakequant * [Feature] Add onnx exporters (#475) * fix del redundant fakequant * add onnx exporters * fix onnx exporters and add docstring * fix comments * delete useless codes * fix export_onnx in native quantizer --------- Co-authored-by: pppppM <gjf_mail@126.com> * [Feature]Rewrite the origin model during prepare (#488) * add rewriter * add deploy_cfg arg * modify post_process_for_mmdeploy * fix bugs * add det config * [Feature] Using rewriter in mmrazor when building qmodels. (#490) * add rewriter * add deploy_cfg arg * modify post_process_for_mmdeploy * fix bugs * add det config * replace deepcopy * pop detectors' forward * [Feature] Quantization global optimization (#491) * add trtquantizer * unify all fakequant before deploy * move to aide * add yolox config * pre-rebase * add unittest * add a arg of post_process_for_deploy * test trt yolox deploy * opt quantizer interface * fix rebase * add trt r50 config * update trt setting * del redundant code * fix lint * fix ut of quantizers * del redundant file * fix lint * fix some comments * Fix code syntax in UT (#470) Co-authored-by: 王盟 <unicorn@MacBook-Pro.local> * passed lint and pytest * try to fix ci * [Bug] Try to fix CI (#502) fix lint * [Feature] Support lsq (#501) * support deploy_cfg=None * replace fakequant before load ckpt * add _load_from_state_dict to lsq fakequant * fix pre-commit * test lsq load state dict * change github ci: ubuntu 18.04 to ubuntu 20.04 * get_deploy_model order change back * sync before save ckpt * delete strict=False * test context rewriter * fix pre commit config * try to fix ci * [Bug] Try to fix CI (#502) fix lint --------- Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> * [Feature] Add exporter pytest (#504) * add exporter pytest * fix bugs * delete useless codes * handle onnx * delete useless codes * [Bug] Fix ci converage setting (#508) fix ci converage * [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss * [BUG] Fix quantization loop (#507) * fix quantization loop * fix quant loop * fix quant loop * fix qat configs * [Bug] Fix ci converage setting (#508) fix ci converage * [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss * add freeze_bn_begin to lsq * delete useless codes --------- Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> * add test ptq * opt ptq pipeline * refactor quant configs * update config path * add summary analyse tool * fix benchmark_test:detnas_frcnn_shufflenet_subnet_coco_1x.py * update quantization README.md * update quantization metafile, readme, config path * update quantization docs * update git main link in workflow * update benchmark_summary_analyse.py * del dmcp results * [Bug] fix a rebase error (#514) fix a rebase error * [Bug] Fix CI (#515) * fix ci * mmcv2.0 need torch1.8+ * Update CI config and Passed (#516) * test ci * update test.yml based on mmcv2.0.0 * [Docs] Fix cwd test accuary (#517) * test ci * update test.yml based on mmcv2.0.0 * update cwd_logits_pspnet result --------- Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: FreakieHuang <frank0huang@foxmail.com> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com> Co-authored-by: HIT-cwh <2892770585@qq.com> Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn> Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com> Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn> Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang <wangshiguang@sensetime.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 <sunyue1@sensetime.com> Co-authored-by: liukai <your_email@abc.example> Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com> Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn> Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com> Co-authored-by: wm901115nwpu <wmnwpu@gmail.com> Co-authored-by: 王盟 <unicorn@MacBook-Pro.local> * [Docs&Feature] Prepare for checkouting default branch and releasing new version (#518) * prepare for checkout default branch * update README.md and model zoo * update installation.md and update dev-1.x links * update README_zh-CN * add changelog * update ci config * update some links in quantization readme * update quantization user guide * update calibrate_dataloader * add interface pop_rewriter_function_record * Bump version to 1.0.0 (#521) * update release time * bump version to 1.0.0 * [CI] Fix merge stage test (#523) fix merge_stage_test in ci --------- Co-authored-by: liukai <your_email@abc.example> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: FreakieHuang <frank0huang@foxmail.com> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: L-Icarus <30308843+L-Icarus@users.noreply.github.com> Co-authored-by: HIT-cwh <2892770585@qq.com> Co-authored-by: Yue Sun <aptsunny@tongji.edu.cn> Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor <zengyi.vendor@sensetime.com> Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu <zhangzhongyu@pjlab.org.cn> Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang <wangshiguang@sensetime.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 <sunyue1@sensetime.com> Co-authored-by: Ming-Hsuan-Tu <qrnnis2623891@gmail.com> Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun <aptsunny@alumni.tongji.edu.cn> Co-authored-by: Ivan Zhang <51170394+415905716@users.noreply.github.com> Co-authored-by: wm901115nwpu <wmnwpu@gmail.com> Co-authored-by: 王盟 <unicorn@MacBook-Pro.local>
1074 lines
32 KiB
Python
1074 lines
32 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
# this file includes models for tesing.
|
||
from collections import OrderedDict
|
||
from typing import Dict
|
||
import math
|
||
|
||
from torch.nn import Module
|
||
from torch import Tensor
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch
|
||
from mmengine.model import BaseModel
|
||
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
|
||
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
|
||
from mmrazor.models.mutables import MutableChannelUnit
|
||
from mmrazor.models.mutables import DerivedMutable
|
||
from mmrazor.models.mutables import BaseMutable
|
||
from mmrazor.models.mutables import OneShotMutableChannelUnit, OneShotMutableChannel
|
||
|
||
from mmrazor.models.mutables import OneShotMutableValue
|
||
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
|
||
from mmrazor.registry import MODELS
|
||
from mmrazor.models.mutables import OneShotMutableValue
|
||
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
|
||
from mmrazor.models.utils.parse_values import parse_values
|
||
|
||
from mmrazor.models.architectures.ops.mobilenet_series import MBBlock
|
||
from mmcv.cnn import ConvModule
|
||
from mmengine.model import Sequential
|
||
from mmrazor.models.architectures.utils.mutable_register import (
|
||
mutate_conv_module, mutate_mobilenet_layer)
|
||
|
||
# models to test fx tracer
|
||
|
||
|
||
def untracable_function(x: torch.Tensor):
|
||
if x.sum() > 0:
|
||
x = x - 1
|
||
else:
|
||
x = x + 1
|
||
return x
|
||
|
||
|
||
class UntracableModule(nn.Module):
|
||
|
||
def __init__(self, in_channel, out_channel) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
x = self.conv(x)
|
||
if x.sum() > 0:
|
||
x = x * 2
|
||
else:
|
||
x = x * -2
|
||
x = self.conv2(x)
|
||
return x
|
||
|
||
|
||
class ModuleWithUntracableMethod(nn.Module):
|
||
|
||
def __init__(self, in_channel, out_channel) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
x = self.conv(x)
|
||
x = self.untracable_method(x)
|
||
x = self.conv2(x)
|
||
return x
|
||
|
||
def untracable_method(self, x):
|
||
if x.sum() > 0:
|
||
x = x * 2
|
||
else:
|
||
x = x * -2
|
||
return x
|
||
|
||
@MODELS.register_module()
|
||
class UntracableBackBone(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(3, 16, 3, 2)
|
||
self.untracable_module = UntracableModule(16, 8)
|
||
self.module_with_untracable_method = ModuleWithUntracableMethod(8, 16)
|
||
|
||
def forward(self, x):
|
||
x = self.conv(x)
|
||
x = untracable_function(x)
|
||
x = self.untracable_module(x)
|
||
x = self.module_with_untracable_method(x)
|
||
return x
|
||
|
||
|
||
class UntracableModel(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.backbone = UntracableBackBone()
|
||
self.head = LinearHeadForTest(16, 1000)
|
||
|
||
def forward(self, x):
|
||
return self.head(self.backbone(x))
|
||
|
||
|
||
class ConvAttnModel(Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||
self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.head = LinearHeadForTest(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.conv(x)
|
||
attn = F.sigmoid(self.pool(x1))
|
||
x_attn = x1 * attn
|
||
x_last = self.conv2(x_attn)
|
||
return self.head(x_last)
|
||
|
||
@MODELS.register_module()
|
||
class LinearHeadForTest(Module):
|
||
|
||
def __init__(self, in_channel, num_class=1000) -> None:
|
||
super().__init__()
|
||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||
self.linear = nn.Linear(in_channel, num_class)
|
||
|
||
def forward(self, x):
|
||
pool = self.pool(x).flatten(1)
|
||
return self.linear(pool)
|
||
|
||
|
||
class MultiConcatModel(Module):
|
||
"""
|
||
x----------------
|
||
|op1 |op2 |op4
|
||
x1 x2 x4
|
||
| | |
|
||
|cat----- |
|
||
cat1 |
|
||
|op3 |
|
||
x3 |
|
||
|cat-------------
|
||
cat2
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.op3 = nn.Conv2d(16, 8, 1)
|
||
self.op4 = nn.Conv2d(3, 8, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
x3 = self.op3(cat1)
|
||
x4 = self.op4(x)
|
||
cat2 = torch.cat([x3, x4], dim=1)
|
||
x_pool = self.avg_pool(cat2).flatten(1)
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class MultiConcatModel2(Module):
|
||
"""
|
||
x---------------
|
||
|op1 |op2 |op3
|
||
x1 x2 x3
|
||
| | |
|
||
|cat----- |
|
||
cat1 |
|
||
|cat-------------
|
||
cat2
|
||
|op4
|
||
x4
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.op3 = nn.Conv2d(3, 8, 1)
|
||
self.op4 = nn.Conv2d(24, 8, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
cat2 = torch.cat([cat1, x3], dim=1)
|
||
x4 = self.op4(cat2)
|
||
|
||
x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1])
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class ConcatModel(Module):
|
||
"""
|
||
x------------
|
||
|op1,bn1 |op2,bn2
|
||
x1 x2
|
||
|cat--------|
|
||
cat1
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.bn2 = nn.BatchNorm2d(8)
|
||
self.op3 = nn.Conv2d(16, 8, 1)
|
||
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.bn1(self.op1(x))
|
||
x2 = self.bn2(self.op2(x))
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
x3 = self.op3(cat1)
|
||
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class ResBlock(Module):
|
||
"""
|
||
x
|
||
|op1,bn1
|
||
x1-----------
|
||
|op2,bn2 |
|
||
x2 |
|
||
+------------
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(8, 8, 1)
|
||
self.bn2 = nn.BatchNorm2d(8)
|
||
self.op3 = nn.Conv2d(8, 8, 1)
|
||
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.bn1(self.op1(x))
|
||
x2 = self.bn2(self.op2(x1))
|
||
x3 = self.op3(x2 + x1)
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
output = self.fc(x_pool)
|
||
return output
|
||
|
||
|
||
class SingleLineModel(nn.Module):
|
||
"""
|
||
x
|
||
|net0,net1
|
||
|net2
|
||
|net3
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
|
||
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(1))
|
||
self.linear = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
|
||
class AddCatModel(Module):
|
||
"""
|
||
x------------------------
|
||
|op1 |op2 |op3 |op4
|
||
x1 x2 x3 x4
|
||
| | | |
|
||
|cat----- |cat-----
|
||
cat1 cat2
|
||
| |
|
||
+----------------
|
||
x5
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 2, 3)
|
||
self.op2 = nn.Conv2d(3, 6, 3)
|
||
self.op3 = nn.Conv2d(3, 4, 3)
|
||
self.op4 = nn.Conv2d(3, 4, 3)
|
||
self.op5 = nn.Conv2d(8, 16, 3)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
x4 = self.op4(x)
|
||
|
||
cat1 = torch.cat((x1, x2), dim=1)
|
||
cat2 = torch.cat((x3, x4), dim=1)
|
||
x5 = self.op5(cat1 + cat2)
|
||
x_pool = self.avg_pool(x5).flatten(1)
|
||
y = self.fc(x_pool)
|
||
return y
|
||
|
||
|
||
class GroupWiseConvModel(nn.Module):
|
||
"""
|
||
x
|
||
|op1,bn1
|
||
x1
|
||
|op2,bn2
|
||
x2
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2)
|
||
self.bn2 = nn.BatchNorm2d(16)
|
||
self.op3 = nn.Conv2d(16, 32, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x1 = self.bn1(x1)
|
||
x2 = self.op2(x1)
|
||
x2 = self.bn2(x2)
|
||
x3 = self.op3(x2)
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
return self.fc(x_pool)
|
||
|
||
|
||
class Xmodel(nn.Module):
|
||
"""
|
||
x--------
|
||
|op1 |op2
|
||
x1 x2
|
||
| |
|
||
+--------
|
||
x12------
|
||
|op3 |op4
|
||
x3 x4
|
||
| |
|
||
+--------
|
||
x34
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.op3 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.op4 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x12 = x1 * x2
|
||
x3 = self.op3(x12)
|
||
x4 = self.op4(x12)
|
||
x34 = x3 + x4
|
||
x_pool = self.avg_pool(x34).flatten(1)
|
||
return self.fc(x_pool)
|
||
|
||
|
||
class MultipleUseModel(nn.Module):
|
||
"""
|
||
x------------------------
|
||
|conv0 |conv1 |conv2 |conv3
|
||
xs.0 xs.1 xs.2 xs.3
|
||
|convm |convm |convm |convm
|
||
xs_.0 xs_.1 xs_.2 xs_.3
|
||
| | | |
|
||
+------------------------
|
||
|
|
||
x_sum
|
||
|conv_last
|
||
feature
|
||
|avg_pool
|
||
pool
|
||
|linear
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv0 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv3 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.linear = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
xs = [
|
||
conv(x)
|
||
for conv in [self.conv0, self.conv1, self.conv2, self.conv3]
|
||
]
|
||
xs_ = [self.conv_multiple_use(x_) for x_ in xs]
|
||
x_cat = torch.cat(xs_, dim=1)
|
||
feature = self.conv_last(x_cat)
|
||
pool = self.avg_pool(feature).flatten(1)
|
||
return self.linear(pool)
|
||
|
||
|
||
class IcepBlock(nn.Module):
|
||
"""
|
||
x------------------------
|
||
|op1 |op2 |op3 |op4
|
||
x1 x2 x3 x4
|
||
| | | |
|
||
cat----------------------
|
||
|
|
||
y_
|
||
"""
|
||
|
||
def __init__(self, in_c=3, out_c=32) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op2 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op3 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
# self.op5 = nn.Conv2d(out_c*4, out_c, 3)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
x4 = self.op4(x)
|
||
y_ = [x1, x2, x3, x4]
|
||
y_ = torch.cat(y_, 1)
|
||
return y_
|
||
|
||
|
||
class Icep(nn.Module):
|
||
|
||
def __init__(self, num_icep_blocks=2) -> None:
|
||
super().__init__()
|
||
self.icps = nn.Sequential(*[
|
||
IcepBlock(32 * 4 if i != 0 else 3, 32)
|
||
for i in range(num_icep_blocks)
|
||
])
|
||
self.op = nn.Conv2d(32 * 4, 32, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
y_ = self.icps(x)
|
||
y = self.op(y_)
|
||
pool = self.avg_pool(y).flatten(1)
|
||
return self.fc(pool)
|
||
|
||
|
||
class ExpandLineModel(Module):
|
||
"""
|
||
x
|
||
|net0,net1,net2
|
||
|net3,net4
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
|
||
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(2))
|
||
self.linear = nn.Linear(64, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
|
||
class MultiBindModel(Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv3 = nn.Conv2d(8, 8, 3, 1, 1)
|
||
self.head = LinearHeadForTest(8, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.conv1(x)
|
||
x2 = self.conv2(x)
|
||
x12 = x1 + x2
|
||
x3 = self.conv3(x12)
|
||
x123 = x12 + x3
|
||
return self.head(x123)
|
||
|
||
|
||
class DwConvModel(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(),
|
||
nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48),
|
||
nn.ReLU())
|
||
self.head = LinearHeadForTest(48, 1000)
|
||
|
||
def forward(self, x):
|
||
return self.head(self.net(x))
|
||
|
||
|
||
class SelfAttention(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.stem = nn.Conv2d(3, 32, 4, 4, 4)
|
||
|
||
self.num_head = 4
|
||
self.qkv = nn.Linear(32, 32 * 3)
|
||
self.proj = nn.Linear(32, 32)
|
||
|
||
self.head = LinearHeadForTest(32, 1000)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
x = self.stem(x)
|
||
h, w = x.shape[-2:]
|
||
x = self._to_token(x)
|
||
x = x + self._forward_attention(x)
|
||
x = self._to_img(x, h, w)
|
||
return self.head(x)
|
||
|
||
def _to_img(self, x, h, w):
|
||
x = x.reshape([x.shape[0], h, w, x.shape[2]])
|
||
x = x.permute(0, 3, 1, 2)
|
||
return x
|
||
|
||
def _to_token(self, x):
|
||
x = x.flatten(2).transpose(-1, -2)
|
||
return x
|
||
|
||
def _forward_attention(self, x: torch.Tensor):
|
||
qkv = self.qkv(x)
|
||
qkv = qkv.reshape([
|
||
x.shape[0], x.shape[1], 3, self.num_head,
|
||
x.shape[2] // self.num_head
|
||
]).permute(2, 0, 3, 1, 4).contiguous()
|
||
q, k, v = qkv
|
||
attn = q @ k.transpose(-1, -2) / math.sqrt(32 // self.num_head)
|
||
y = attn @ v # B H N h
|
||
y = y.permute(0, 2, 1, 3).flatten(-2)
|
||
return self.proj(y)
|
||
|
||
|
||
def MMClsResNet18() -> BaseModel:
|
||
model_cfg = dict(
|
||
_scope_='mmcls',
|
||
type='ImageClassifier',
|
||
backbone=dict(
|
||
type='ResNet',
|
||
depth=18,
|
||
num_stages=4,
|
||
out_indices=(3, ),
|
||
style='pytorch'),
|
||
neck=dict(type='GlobalAveragePooling'),
|
||
head=dict(
|
||
type='LinearClsHead',
|
||
num_classes=1000,
|
||
in_channels=512,
|
||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||
topk=(1, 5),
|
||
))
|
||
return MODELS.build(model_cfg)
|
||
|
||
|
||
# models with dynamicop
|
||
|
||
|
||
def register_mutable(module: DynamicChannelMixin,
|
||
mutable: MutableChannelUnit,
|
||
is_out=True,
|
||
start=0,
|
||
end=-1):
|
||
if end == -1:
|
||
end = mutable.num_channels + start
|
||
if is_out:
|
||
container: MutableChannelContainer = module.get_mutable_attr(
|
||
'out_channels')
|
||
else:
|
||
container: MutableChannelContainer = module.get_mutable_attr(
|
||
'in_channels')
|
||
container.register_mutable(mutable, start, end)
|
||
|
||
|
||
class SampleExpandDerivedMutable(BaseMutable):
|
||
|
||
def __init__(self, expand_ratio=1) -> None:
|
||
super().__init__()
|
||
self.ratio = expand_ratio
|
||
|
||
def __mul__(self, other):
|
||
if isinstance(other, OneShotMutableChannel):
|
||
|
||
def _expand_mask():
|
||
mask = other.current_mask
|
||
mask = torch.unsqueeze(
|
||
mask,
|
||
-1).expand(list(mask.shape) + [self.ratio]).flatten(-2)
|
||
return mask
|
||
|
||
return DerivedMutable(_expand_mask, _expand_mask, [self, other])
|
||
else:
|
||
raise NotImplementedError()
|
||
|
||
def dump_chosen(self):
|
||
return super().dump_chosen()
|
||
|
||
def export_chosen(self):
|
||
return super().export_chosen()
|
||
|
||
def fix_chosen(self, chosen):
|
||
return super().fix_chosen(chosen)
|
||
|
||
def num_choices(self) -> int:
|
||
return super().num_choices
|
||
|
||
@property
|
||
def current_choice(self):
|
||
return super().current_choice
|
||
|
||
@current_choice.setter
|
||
def current_choice(self, choice):
|
||
super().current_choice(choice)
|
||
|
||
class DynamicLinearModel(nn.Module):
|
||
"""
|
||
x
|
||
|net0,net1
|
||
|net2
|
||
|net3
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(),
|
||
DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(1))
|
||
self.linear = DynamicLinear(16, 1000)
|
||
|
||
MutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
self._register_mutable()
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
def _register_mutable(self):
|
||
mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8])
|
||
mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16])
|
||
mutable_value = SampleExpandDerivedMutable(1)
|
||
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[0], mutable1, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[3], mutable_value * mutable1, False, 0, 8)
|
||
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[3], mutable2, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[4], mutable2, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.linear, mutable2, False)
|
||
|
||
|
||
class DynamicAttention(nn.Module):
|
||
"""
|
||
x
|
||
|blocks: DynamicSequential(depth)
|
||
|(blocks)
|
||
x1
|
||
|fc (OneShotMutableChannel * OneShotMutableValue)
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.mutable_depth = OneShotMutableValue(
|
||
value_list=[1, 2], default_value=2)
|
||
self.mutable_embed_dims = OneShotMutableChannel(
|
||
num_channels=624, candidate_choices=[576, 624])
|
||
self.base_embed_dims = OneShotMutableChannel(
|
||
num_channels=64, candidate_choices=[64])
|
||
self.mutable_num_heads = [
|
||
OneShotMutableValue(value_list=[8, 10], default_value=10)
|
||
for _ in range(2)
|
||
]
|
||
self.mutable_mlp_ratios = [
|
||
OneShotMutableValue(value_list=[3.0, 3.5, 4.0], default_value=4.0)
|
||
for _ in range(2)
|
||
]
|
||
self.mutable_q_embed_dims = [
|
||
i * self.base_embed_dims for i in self.mutable_num_heads
|
||
]
|
||
|
||
self.patch_embed = DynamicPatchEmbed(
|
||
img_size=224,
|
||
in_channels=3,
|
||
embed_dims=self.mutable_embed_dims.num_channels)
|
||
|
||
# cls token and pos embed
|
||
self.pos_embed = nn.Parameter(
|
||
torch.zeros(1, 197, self.mutable_embed_dims.num_channels))
|
||
self.cls_token = nn.Parameter(
|
||
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
|
||
|
||
layers = []
|
||
for i in range(self.mutable_depth.max_choice):
|
||
layer = TransformerEncoderLayer(
|
||
embed_dims=self.mutable_embed_dims.num_channels,
|
||
num_heads=self.mutable_num_heads[i].max_choice,
|
||
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
|
||
layers.append(layer)
|
||
self.blocks = DynamicSequential(*layers)
|
||
|
||
# OneShotMutableChannelUnit
|
||
OneShotMutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
|
||
self.register_mutables()
|
||
|
||
def register_mutables(self):
|
||
# mutablevalue
|
||
self.blocks.register_mutable_attr('depth', self.mutable_depth)
|
||
# mutablechannel
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.patch_embed, self.mutable_embed_dims, True)
|
||
|
||
for i in range(self.mutable_depth.max_choice):
|
||
layer = self.blocks[i]
|
||
layer.register_mutables(
|
||
mutable_num_heads=self.mutable_num_heads[i],
|
||
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
|
||
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
|
||
mutable_head_dims=self.base_embed_dims,
|
||
mutable_embed_dims=self.mutable_embed_dims)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
B = x.shape[0]
|
||
x = self.patch_embed(x)
|
||
embed_dims = self.mutable_embed_dims.current_choice
|
||
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
|
||
x = torch.cat((cls_tokens, x), dim=1)
|
||
x = x + self.pos_embed[..., :embed_dims]
|
||
x = self.blocks(x)
|
||
return torch.mean(x[:, 1:], dim=1)
|
||
|
||
|
||
class DynamicMMBlock(nn.Module):
|
||
|
||
arch_setting = dict(
|
||
kernel_size=[ # [min_kernel_size, max_kernel_size, step]
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
],
|
||
num_blocks=[ # [min_num_blocks, max_num_blocks, step]
|
||
[1, 2, 1],
|
||
[3, 5, 1],
|
||
[3, 6, 1],
|
||
[3, 6, 1],
|
||
[3, 8, 1],
|
||
[3, 8, 1],
|
||
[1, 2, 1],
|
||
],
|
||
expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step]
|
||
[1, 1, 1],
|
||
[4, 6, 1],
|
||
[4, 6, 1],
|
||
[4, 6, 1],
|
||
[4, 6, 1],
|
||
[6, 6, 1],
|
||
[6, 6, 1],
|
||
],
|
||
num_out_channels=[ # [min_channel, max_channel, step]
|
||
[16, 24, 8],
|
||
[24, 32, 8],
|
||
[32, 40, 8],
|
||
[64, 72, 8],
|
||
[112, 128, 8],
|
||
[192, 216, 8],
|
||
[216, 224, 8],
|
||
])
|
||
|
||
def __init__(
|
||
self,
|
||
conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'),
|
||
norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'),
|
||
fine_grained_mode: bool = False,
|
||
) -> None:
|
||
super().__init__()
|
||
|
||
self.conv_cfg = conv_cfg
|
||
self.norm_cfg = norm_cfg
|
||
self.act_list = ['Swish'] * 7
|
||
self.stride_list = [1, 2, 2, 2, 1, 2, 1]
|
||
self.with_se_list = [False, False, True, False, True, True, True]
|
||
self.kernel_size_list = parse_values(self.arch_setting['kernel_size'])
|
||
self.num_blocks_list = parse_values(self.arch_setting['num_blocks'])
|
||
self.expand_ratio_list = \
|
||
parse_values(self.arch_setting['expand_ratio'])
|
||
self.num_channels_list = \
|
||
parse_values(self.arch_setting['num_out_channels'])
|
||
assert len(self.kernel_size_list) == len(self.num_blocks_list) == \
|
||
len(self.expand_ratio_list) == len(self.num_channels_list)
|
||
|
||
self.fine_grained_mode = fine_grained_mode
|
||
self.with_attentive_shortcut = True
|
||
self.in_channels = 24
|
||
|
||
self.first_out_channels_list = [16]
|
||
self.first_conv = ConvModule(
|
||
in_channels=3,
|
||
out_channels=24,
|
||
kernel_size=3,
|
||
stride=2,
|
||
padding=1,
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=self.norm_cfg,
|
||
act_cfg=dict(type='Swish'))
|
||
|
||
self.layers = []
|
||
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
|
||
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
|
||
self.expand_ratio_list, self.num_channels_list)):
|
||
inverted_res_layer = self._make_single_layer(
|
||
out_channels=num_channels,
|
||
num_blocks=num_blocks,
|
||
kernel_sizes=kernel_sizes,
|
||
expand_ratios=expand_ratios,
|
||
act_cfg=self.act_list[i],
|
||
stride=self.stride_list[i],
|
||
use_se=self.with_se_list[i])
|
||
layer_name = f'layer{i + 1}'
|
||
self.add_module(layer_name, inverted_res_layer)
|
||
self.layers.append(inverted_res_layer)
|
||
|
||
last_expand_channels = 1344
|
||
self.out_channels = 1984
|
||
self.last_out_channels_list = [1792, 1984]
|
||
self.last_expand_ratio_list = [6]
|
||
|
||
last_layers = Sequential(
|
||
OrderedDict([('final_expand_layer',
|
||
ConvModule(
|
||
in_channels=self.in_channels,
|
||
out_channels=last_expand_channels,
|
||
kernel_size=1,
|
||
padding=0,
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=self.norm_cfg,
|
||
act_cfg=dict(type='Swish'))),
|
||
('pool', nn.AdaptiveAvgPool2d((1, 1))),
|
||
('feature_mix_layer',
|
||
ConvModule(
|
||
in_channels=last_expand_channels,
|
||
out_channels=self.out_channels,
|
||
kernel_size=1,
|
||
padding=0,
|
||
bias=False,
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=None,
|
||
act_cfg=dict(type='Swish')))]))
|
||
self.add_module('last_conv', last_layers)
|
||
self.layers.append(last_layers)
|
||
|
||
self.register_mutables()
|
||
|
||
def _make_single_layer(self, out_channels, num_blocks, kernel_sizes,
|
||
expand_ratios, act_cfg, stride, use_se):
|
||
_layers = []
|
||
for i in range(max(num_blocks)):
|
||
if i >= 1:
|
||
stride = 1
|
||
if use_se:
|
||
se_cfg = dict(
|
||
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')),
|
||
ratio=4,
|
||
conv_cfg=self.conv_cfg)
|
||
else:
|
||
se_cfg = None # type: ignore
|
||
|
||
mb_layer = MBBlock(
|
||
in_channels=self.in_channels,
|
||
out_channels=max(out_channels),
|
||
kernel_size=max(kernel_sizes),
|
||
stride=stride,
|
||
expand_ratio=max(expand_ratios),
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=self.norm_cfg,
|
||
act_cfg=dict(type=act_cfg),
|
||
se_cfg=se_cfg,
|
||
with_attentive_shortcut=self.with_attentive_shortcut)
|
||
|
||
_layers.append(mb_layer)
|
||
self.in_channels = max(out_channels)
|
||
|
||
dynamic_seq = DynamicSequential(*_layers)
|
||
return dynamic_seq
|
||
|
||
def register_mutables(self):
|
||
"""Mutate the BigNAS-style MobileNetV3."""
|
||
OneShotMutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
|
||
self.first_mutable_channels = OneShotMutableChannel(
|
||
alias='backbone.first_channels',
|
||
num_channels=max(self.first_out_channels_list),
|
||
candidate_choices=self.first_out_channels_list)
|
||
|
||
mutate_conv_module(
|
||
self.first_conv, mutable_out_channels=self.first_mutable_channels)
|
||
|
||
mid_mutable = self.first_mutable_channels
|
||
# mutate the built mobilenet layers
|
||
for i, layer in enumerate(self.layers[:-1]):
|
||
num_blocks = self.num_blocks_list[i]
|
||
kernel_sizes = self.kernel_size_list[i]
|
||
expand_ratios = self.expand_ratio_list[i]
|
||
out_channels = self.num_channels_list[i]
|
||
|
||
prefix = 'backbone.layers.' + str(i + 1) + '.'
|
||
|
||
mutable_out_channels = OneShotMutableChannel(
|
||
alias=prefix + 'out_channels',
|
||
candidate_choices=out_channels,
|
||
num_channels=max(out_channels))
|
||
|
||
if not self.fine_grained_mode:
|
||
mutable_kernel_size = OneShotMutableValue(
|
||
alias=prefix + 'kernel_size', value_list=kernel_sizes)
|
||
|
||
mutable_expand_ratio = OneShotMutableValue(
|
||
alias=prefix + 'expand_ratio', value_list=expand_ratios)
|
||
|
||
mutable_depth = OneShotMutableValue(
|
||
alias=prefix + 'depth', value_list=num_blocks)
|
||
layer.register_mutable_attr('depth', mutable_depth)
|
||
|
||
for k in range(max(self.num_blocks_list[i])):
|
||
|
||
if self.fine_grained_mode:
|
||
mutable_kernel_size = OneShotMutableValue(
|
||
alias=prefix + str(k) + '.kernel_size',
|
||
value_list=kernel_sizes)
|
||
|
||
mutable_expand_ratio = OneShotMutableValue(
|
||
alias=prefix + str(k) + '.expand_ratio',
|
||
value_list=expand_ratios)
|
||
|
||
mutate_mobilenet_layer(layer[k], mid_mutable,
|
||
mutable_out_channels,
|
||
mutable_expand_ratio,
|
||
mutable_kernel_size)
|
||
mid_mutable = mutable_out_channels
|
||
|
||
self.last_mutable_channels = OneShotMutableChannel(
|
||
alias='backbone.last_channels',
|
||
num_channels=self.out_channels,
|
||
candidate_choices=self.last_out_channels_list)
|
||
|
||
last_mutable_expand_value = OneShotMutableValue(
|
||
value_list=self.last_expand_ratio_list,
|
||
default_value=max(self.last_expand_ratio_list))
|
||
|
||
derived_expand_channels = mid_mutable * last_mutable_expand_value
|
||
mutate_conv_module(
|
||
self.layers[-1].final_expand_layer,
|
||
mutable_in_channels=mid_mutable,
|
||
mutable_out_channels=derived_expand_channels)
|
||
mutate_conv_module(
|
||
self.layers[-1].feature_mix_layer,
|
||
mutable_in_channels=derived_expand_channels,
|
||
mutable_out_channels=self.last_mutable_channels)
|
||
|
||
def forward(self, x):
|
||
x = self.first_conv(x)
|
||
for _, layer in enumerate(self.layers):
|
||
x = layer(x)
|
||
|
||
return tuple([x])
|