mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
* add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai <liukai@pjlab.org.cn> * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai <liukai@pjlab.org.cn> * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai <liukai@pjlab.org.cn> * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai <liukai@pjlab.org.cn> * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai <liukai@pjlab.org.cn> * Merge dev-1.x to pruning (#311) * [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> * [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 <humu@pjlab.org.cn> * [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge * [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP * [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option * [Fix] Fix darts metafile (#278) fix darts metafile * fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> * Refine pruning branch (#307) * [feature] CONTRASTIVE REPRESENTATION DISTILLATION with dataset wrapper (#281) * init * TD: CRDLoss * complete UT * fix docstrings * fix ci * update * fix CI * DONE * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * maintain CRD dataset unique funcs as a mixin * add UT: CRD_ClsDataset * init * TODO: UT test formatting. * init * crd dataset wrapper * update docstring Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> * [Improvement] Update estimator with api revision (#277) * update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 <humu@pjlab.org.cn> * [Fix] Fix tracer (#273) * test image_classifier_loss_calculator * fix backward tracer * update SingleStageDetectorPseudoLoss * merge * [Feature] Add Dsnas Algorithm (#226) * [tmp] Update Dsnas * [tmp] refactor arch_loss & flops_loss * Update Dsnas & MMRAZOR_EVALUATOR: 1. finalized compute_loss & handle_grads in algorithm; 2. add MMRAZOR_EVALUATOR; 3. fix bugs. * Update lr scheduler & fix a bug: 1. update param_scheduler & lr_scheduler for dsnas; 2. fix a bug of switching to finetune stage. * remove old evaluators * remove old evaluators * update param_scheduler config * merge dev-1.x into gy/estimator * add flops_loss in Dsnas using ResourcesEstimator * get resources before mutator.prepare_from_supernet * delete unness broadcast api from gml * broadcast spec_modules_resources when estimating * update early fix mechanism for Dsnas * fix merge * update units in estimator * minor change * fix data_preprocessor api * add flops_loss_coef * remove DsnasOptimWrapper * fix bn eps and data_preprocessor * fix bn weight decay bug * add betas for mutator optimizer * set diff_rank_seed=True for dsnas * fix start_factor of lr when warm up * remove .module in non-ddp mode * add GlobalAveragePoolingWithDropout * add UT for dsnas * remove unness channel adjustment for shufflenetv2 * update supernet configs * delete unness dropout * delete unness part with minor change on dsnas * minor change on the flag of search stage * update README and subnet configs * add UT for OneHotMutableOP * [Feature] Update train (#279) * support auto resume * add enable auto_scale_lr in train.py * support '--amp' option * [Fix] Fix darts metafile (#278) fix darts metafile * fix ci (#284) * fix ci for circle ci * fix bug in test_metafiles * add pr_stage_test for github ci * add multiple version * fix ut * fix lint * Temporarily skip dataset UT * update github ci * add github lint ci * install wheel * remove timm from requirements * install wheel when test on windows * fix error * fix bug * remove github windows ci * fix device error of arch_params when DsnasDDP * fix CRD dataset ut * fix scope error * rm test_cuda in workflows of github * [Doc] fix typos in en/usr_guides Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> * fix bug when python=3.6 * fix lint * fix bug when test using cpu only * refine ci * fix error in ci * try ci * update repr of Channel * fix error * mv init_from_predefined_model to MutableChannelUnit * move tests * update SquentialMutableChannel * update l1 mutable channel unit * add OneShotMutableChannel * candidate_mode -> choice_mode * update docstring * change ci Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn> Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com> Co-authored-by: Yang Gao <Gary1546308416AL@gmail.com> Co-authored-by: humu789 <humu@pjlab.org.cn> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: pppppM <gjf_mail@126.com> Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: SheffieldCao <1751899@tongji.edu.cn>
661 lines
17 KiB
Python
661 lines
17 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
from torch.nn import Module
|
||
from torch import Tensor
|
||
import torch.nn as nn
|
||
import torch
|
||
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin
|
||
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
|
||
from mmrazor.models.mutables import MutableChannelUnit
|
||
from mmrazor.models.mutables import DerivedMutable
|
||
from mmrazor.models.mutables import BaseMutable
|
||
from mmrazor.models.mutables import OneShotMutableChannelUnit, SquentialMutableChannel, OneShotMutableChannel
|
||
from mmrazor.registry import MODELS
|
||
from mmengine.model import BaseModel
|
||
# this file includes models for tesing.
|
||
|
||
|
||
class LinearHead(Module):
|
||
|
||
def __init__(self, in_channel, num_class=1000) -> None:
|
||
super().__init__()
|
||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||
self.linear = nn.Linear(in_channel, num_class)
|
||
|
||
def forward(self, x):
|
||
pool = self.pool(x).flatten(1)
|
||
return self.linear(pool)
|
||
|
||
|
||
class MultiConcatModel(Module):
|
||
"""
|
||
x----------------
|
||
|op1 |op2 |op4
|
||
x1 x2 x4
|
||
| | |
|
||
|cat----- |
|
||
cat1 |
|
||
|op3 |
|
||
x3 |
|
||
|cat-------------
|
||
cat2
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.op3 = nn.Conv2d(16, 8, 1)
|
||
self.op4 = nn.Conv2d(3, 8, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
x3 = self.op3(cat1)
|
||
x4 = self.op4(x)
|
||
cat2 = torch.cat([x3, x4], dim=1)
|
||
x_pool = self.avg_pool(cat2).flatten(1)
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class MultiConcatModel2(Module):
|
||
"""
|
||
x---------------
|
||
|op1 |op2 |op3
|
||
x1 x2 x3
|
||
| | |
|
||
|cat----- |
|
||
cat1 |
|
||
|cat-------------
|
||
cat2
|
||
|op4
|
||
x4
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.op3 = nn.Conv2d(3, 8, 1)
|
||
self.op4 = nn.Conv2d(24, 8, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
cat2 = torch.cat([cat1, x3], dim=1)
|
||
x4 = self.op4(cat2)
|
||
|
||
x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1])
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class ConcatModel(Module):
|
||
"""
|
||
x------------
|
||
|op1,bn1 |op2,bn2
|
||
x1 x2
|
||
|cat--------|
|
||
cat1
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.bn2 = nn.BatchNorm2d(8)
|
||
self.op3 = nn.Conv2d(16, 8, 1)
|
||
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.bn1(self.op1(x))
|
||
x2 = self.bn2(self.op2(x))
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
x3 = self.op3(cat1)
|
||
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class ResBlock(Module):
|
||
"""
|
||
x
|
||
|op1,bn1
|
||
x1-----------
|
||
|op2,bn2 |
|
||
x2 |
|
||
+------------
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(8, 8, 1)
|
||
self.bn2 = nn.BatchNorm2d(8)
|
||
self.op3 = nn.Conv2d(8, 8, 1)
|
||
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.bn1(self.op1(x))
|
||
x2 = self.bn2(self.op2(x1))
|
||
x3 = self.op3(x2 + x1)
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
output = self.fc(x_pool)
|
||
return output
|
||
|
||
|
||
class LineModel(BaseModel):
|
||
"""
|
||
x
|
||
|net0,net1
|
||
|net2
|
||
|net3
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
|
||
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(1))
|
||
self.linear = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
|
||
class AddCatModel(Module):
|
||
"""
|
||
x------------------------
|
||
|op1 |op2 |op3 |op4
|
||
x1 x2 x3 x4
|
||
| | | |
|
||
|cat----- |cat-----
|
||
cat1 cat2
|
||
| |
|
||
+----------------
|
||
x5
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 2, 3)
|
||
self.op2 = nn.Conv2d(3, 6, 3)
|
||
self.op3 = nn.Conv2d(3, 4, 3)
|
||
self.op4 = nn.Conv2d(3, 4, 3)
|
||
self.op5 = nn.Conv2d(8, 16, 3)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
x4 = self.op4(x)
|
||
|
||
cat1 = torch.cat((x1, x2), dim=1)
|
||
cat2 = torch.cat((x3, x4), dim=1)
|
||
x5 = self.op5(cat1 + cat2)
|
||
x_pool = self.avg_pool(x5).flatten(1)
|
||
y = self.fc(x_pool)
|
||
return y
|
||
|
||
|
||
class GroupWiseConvModel(nn.Module):
|
||
"""
|
||
x
|
||
|op1,bn1
|
||
x1
|
||
|op2,bn2
|
||
x2
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2)
|
||
self.bn2 = nn.BatchNorm2d(16)
|
||
self.op3 = nn.Conv2d(16, 32, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x1 = self.bn1(x1)
|
||
x2 = self.op2(x1)
|
||
x2 = self.bn2(x2)
|
||
x3 = self.op3(x2)
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
return self.fc(x_pool)
|
||
|
||
|
||
class Xmodel(nn.Module):
|
||
"""
|
||
x--------
|
||
|op1 |op2
|
||
x1 x2
|
||
| |
|
||
+--------
|
||
x12------
|
||
|op3 |op4
|
||
x3 x4
|
||
| |
|
||
+--------
|
||
x34
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.op3 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.op4 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x12 = x1 * x2
|
||
x3 = self.op3(x12)
|
||
x4 = self.op4(x12)
|
||
x34 = x3 + x4
|
||
x_pool = self.avg_pool(x34).flatten(1)
|
||
return self.fc(x_pool)
|
||
|
||
|
||
class MultipleUseModel(nn.Module):
|
||
"""
|
||
x------------------------
|
||
|conv0 |conv1 |conv2 |conv3
|
||
xs.0 xs.1 xs.2 xs.3
|
||
|convm |convm |convm |convm
|
||
xs_.0 xs_.1 xs_.2 xs_.3
|
||
| | | |
|
||
+------------------------
|
||
|
|
||
x_sum
|
||
|conv_last
|
||
feature
|
||
|avg_pool
|
||
pool
|
||
|linear
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv0 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv3 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.linear = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
xs = [
|
||
conv(x)
|
||
for conv in [self.conv0, self.conv1, self.conv2, self.conv3]
|
||
]
|
||
xs_ = [self.conv_multiple_use(x_) for x_ in xs]
|
||
x_cat = torch.cat(xs_, dim=1)
|
||
feature = self.conv_last(x_cat)
|
||
pool = self.avg_pool(feature).flatten(1)
|
||
return self.linear(pool)
|
||
|
||
|
||
class IcepBlock(nn.Module):
|
||
"""
|
||
x------------------------
|
||
|op1 |op2 |op3 |op4
|
||
x1 x2 x3 x4
|
||
| | | |
|
||
cat----------------------
|
||
|
|
||
y_
|
||
"""
|
||
|
||
def __init__(self, in_c=3, out_c=32) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op2 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op3 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
# self.op5 = nn.Conv2d(out_c*4, out_c, 3)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
x4 = self.op4(x)
|
||
y_ = [x1, x2, x3, x4]
|
||
y_ = torch.cat(y_, 1)
|
||
return y_
|
||
|
||
|
||
class Icep(nn.Module):
|
||
|
||
def __init__(self, num_icep_blocks=2) -> None:
|
||
super().__init__()
|
||
self.icps = nn.Sequential(*[
|
||
IcepBlock(32 * 4 if i != 0 else 3, 32)
|
||
for i in range(num_icep_blocks)
|
||
])
|
||
self.op = nn.Conv2d(32 * 4, 32, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
y_ = self.icps(x)
|
||
y = self.op(y_)
|
||
pool = self.avg_pool(y).flatten(1)
|
||
return self.fc(pool)
|
||
|
||
|
||
class ExpandLineModel(Module):
|
||
"""
|
||
x
|
||
|net0,net1,net2
|
||
|net3,net4
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
|
||
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(2))
|
||
self.linear = nn.Linear(64, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
|
||
class MultiBindModel(Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv3 = nn.Conv2d(8, 8, 3, 1, 1)
|
||
self.head = LinearHead(8, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.conv1(x)
|
||
x2 = self.conv2(x)
|
||
x12 = x1 + x2
|
||
x3 = self.conv3(x12)
|
||
x123 = x12 + x3
|
||
return self.head(x123)
|
||
|
||
|
||
class DwConvModel(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(),
|
||
nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48),
|
||
nn.ReLU())
|
||
self.head = LinearHead(48, 1000)
|
||
|
||
def forward(self, x):
|
||
return self.head(self.net(x))
|
||
|
||
|
||
# models with dynamicop
|
||
|
||
|
||
def register_mutable(module: DynamicChannelMixin,
|
||
mutable: OneShotMutableChannelUnit,
|
||
is_out=True,
|
||
start=0,
|
||
end=-1):
|
||
if end == -1:
|
||
end = mutable.num_channels + start
|
||
if is_out:
|
||
container: MutableChannelContainer = module.get_mutable_attr(
|
||
'out_channels')
|
||
else:
|
||
container: MutableChannelContainer = module.get_mutable_attr(
|
||
'in_channels')
|
||
container.register_mutable(mutable, start, end)
|
||
|
||
|
||
class SampleExpandDerivedMutable(BaseMutable):
|
||
|
||
def __init__(self, expand_ratio=1) -> None:
|
||
super().__init__()
|
||
self.ratio = expand_ratio
|
||
|
||
def __mul__(self, other):
|
||
if isinstance(other, OneShotMutableChannel):
|
||
|
||
def _expand_mask():
|
||
mask = other.current_mask
|
||
mask = torch.unsqueeze(
|
||
mask,
|
||
-1).expand(list(mask.shape) + [self.ratio]).flatten(-2)
|
||
return mask
|
||
|
||
return DerivedMutable(_expand_mask, _expand_mask, [self, other])
|
||
else:
|
||
raise NotImplementedError()
|
||
|
||
def dump_chosen(self):
|
||
return super().dump_chosen()
|
||
|
||
def fix_chosen(self, chosen):
|
||
return super().fix_chosen(chosen)
|
||
|
||
def num_choices(self) -> int:
|
||
return super().num_choices
|
||
|
||
|
||
class DynamicLinearModel(nn.Module):
|
||
"""
|
||
x
|
||
|net0,net1
|
||
|net2
|
||
|net3
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(),
|
||
DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(1))
|
||
self.linear = DynamicLinear(16, 1000)
|
||
|
||
MutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
self._register_mutable()
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
def _register_mutable(self):
|
||
mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8])
|
||
mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16])
|
||
mutable_value = SampleExpandDerivedMutable(1)
|
||
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[0], mutable1, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[3], mutable_value * mutable1, False, 0, 8)
|
||
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[3], mutable2, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[4], mutable2, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.linear, mutable2, False)
|
||
|
||
|
||
default_models = [
|
||
LineModel,
|
||
ResBlock,
|
||
AddCatModel,
|
||
ConcatModel,
|
||
MultiConcatModel,
|
||
MultiConcatModel2,
|
||
GroupWiseConvModel,
|
||
Xmodel,
|
||
MultipleUseModel,
|
||
Icep,
|
||
ExpandLineModel,
|
||
DwConvModel,
|
||
]
|
||
|
||
|
||
class ModelLibrary:
|
||
|
||
# includes = [
|
||
# 'alexnet', # pass
|
||
# 'densenet', # pass
|
||
# # 'efficientnet', # pass
|
||
# # 'googlenet', # pass.
|
||
# # googlenet return a tuple when training,
|
||
# # so it should trace in eval mode
|
||
# # 'inception', # failed
|
||
# # 'mnasnet', # pass
|
||
# # 'mobilenet', # pass
|
||
# # 'regnet', # failed
|
||
# # 'resnet', # pass
|
||
# # 'resnext', # failed
|
||
# # 'shufflenet', # failed
|
||
# # 'squeezenet', # pass
|
||
# # 'vgg', # pass
|
||
# # 'wide_resnet', # pass
|
||
# ]
|
||
|
||
def __init__(self, include=[]) -> None:
|
||
|
||
self.include_key = include
|
||
|
||
self.model_creator = self.get_torch_models()
|
||
|
||
def __repr__(self) -> str:
|
||
s = f'model: {len(self.model_creator)}\n'
|
||
for creator in self.model_creator:
|
||
s += creator.__name__ + '\n'
|
||
return s
|
||
|
||
def get_torch_models(self):
|
||
from inspect import isfunction
|
||
|
||
import torchvision
|
||
|
||
attrs = dir(torchvision.models)
|
||
models = []
|
||
for name in attrs:
|
||
module = getattr(torchvision.models, name)
|
||
if isfunction(module):
|
||
models.append(module)
|
||
return models
|
||
|
||
def export_models(self):
|
||
models = []
|
||
for creator in self.model_creator:
|
||
if self.is_include(creator.__name__):
|
||
models.append(creator)
|
||
return models
|
||
|
||
def is_include(self, name):
|
||
for key in self.include_key:
|
||
if key in name:
|
||
return True
|
||
return False
|
||
|
||
def include(self):
|
||
include = []
|
||
for creator in self.model_creator:
|
||
for key in self.include_key:
|
||
if key in creator.__name__:
|
||
include.append(creator)
|
||
return include
|
||
|
||
def uninclude(self):
|
||
include = self.include()
|
||
uninclude = []
|
||
for creator in self.model_creator:
|
||
if creator not in include:
|
||
uninclude.append(creator)
|
||
return uninclude
|