mmrazor/tests/data/models.py
qiufeng b0b3fbdb49
[Feature] Add BigNAS algorithm (#219)
* add calibrate-bn-statistics

* add test calibrate-bn-statistics

* fix mixins

* fix mixins

* fix mixin tests

* remove slimmable channel mutable and refactor dynamic op

* refact dynamic batch norm

* add progressive dynamic conv2d

* add center crop dynamic conv2d

* refactor dynamic directory

* refactor dynamic sequential

* rename length to depth in dynamic sequential

* add test for derived mutable

* refactor dynamic op

* refactor api of dynamic op

* add derive mutable mixin

* addbignas algorithm

* refactor bignas structure

* add input resizer

* add input resizer to bignas

* move input resizer from algorithm into classifier

* remove compnents

* add attentive mobilenet

* delete json file

* nearly(less 0.2) align inference accuracy with gml

* move mutate seperated in bignas mobilenet backbone

* add zero_init_residual

* add set_dropout

* set dropout in bignas algorithm

* fix registry

* add subnet yaml and nearly align inference accuracy with gml

* add rsb config for bignas

* remove base in config

* add gml bignas config

* convert to iter based

* bignas forward and backward fly

* fix merge conflict

* fix dynamicseq bug

* fix bug and refactor bignas

* arrange configs of bignas

* fix typo

* refactor attentive_mobilenet

* fix channel mismatch due to registion of DerivedMutable

* update bignas & fix se channel mismatch

* add AutoAugmentV2 & remove unness configs

* fix lint

* recover channel assertion in channel unit

* fix a group bug

* fix comments

* add docstring

* add norm in dynamic_embed

* fix search loop & other minor changes

* fix se expansion

* minor change

* add ut for bignas & attentive_mobilenet

* fix ut

* update bignas readme

* rm unness ut & supplement get_placeholder

* fix lint

* fix ut

* add subnet deployment in downstream tasks.

* minor change

* update ofa backbone

* minor fix

* Continued improvements of searchable backbone

* minor change

* drop ratio in backbone

* fix comments

* fix ci test

* fix test

* add dynamic shortcut UT

* modify strategy to fit bignas

* fix test

* fix bug in neck

* fix error

* fix error

* fix yaml

* save subnet ckpt

* merge autoslim_val/test_loop into subnet_val_loop

* move calibrate_bn_mixin to utils

* fix bugs and add docstring

* clean code

* fix register bug

* clean code

* update

Co-authored-by: wangshiguang <wangshiguang@sensetime.com>
Co-authored-by: gaoyang07 <1546308416@qq.com>
Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
2022-12-07 11:28:10 +08:00

988 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict
from torch.nn import Module
from torch import Tensor
import torch.nn as nn
import torch
from mmengine.model import BaseModel
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
from mmrazor.models.mutables import MutableChannelUnit
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables import BaseMutable
from mmrazor.models.mutables import OneShotMutableChannelUnit, OneShotMutableChannel
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
from mmrazor.models.utils.parse_values import parse_values
from mmrazor.models.architectures.ops.mobilenet_series import MBBlock
from mmcv.cnn import ConvModule
from mmengine.model import Sequential
from mmrazor.models.architectures.utils.mutable_register import (
mutate_conv_module, mutate_mobilenet_layer)
class LinearHead(Module):
def __init__(self, in_channel, num_class=1000) -> None:
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(in_channel, num_class)
def forward(self, x):
pool = self.pool(x).flatten(1)
return self.linear(pool)
class MultiConcatModel(Module):
"""
x----------------
|op1 |op2 |op4
x1 x2 x4
| | |
|cat----- |
cat1 |
|op3 |
x3 |
|cat-------------
cat2
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(16, 8, 1)
self.op4 = nn.Conv2d(3, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x4 = self.op4(x)
cat2 = torch.cat([x3, x4], dim=1)
x_pool = self.avg_pool(cat2).flatten(1)
output = self.fc(x_pool)
return output
class MultiConcatModel2(Module):
"""
x---------------
|op1 |op2 |op3
x1 x2 x3
| | |
|cat----- |
cat1 |
|cat-------------
cat2
|op4
x4
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(3, 8, 1)
self.op4 = nn.Conv2d(24, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
cat1 = torch.cat([x1, x2], dim=1)
cat2 = torch.cat([cat1, x3], dim=1)
x4 = self.op4(cat2)
x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1])
output = self.fc(x_pool)
return output
class ConcatModel(Module):
"""
x------------
|op1,bn1 |op2,bn2
x1 x2
|cat--------|
cat1
|op3
x3
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(3, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(16, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x))
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x_pool = self.avg_pool(x3).flatten(1)
output = self.fc(x_pool)
return output
class ResBlock(Module):
"""
x
|op1,bn1
x1-----------
|op2,bn2 |
x2 |
+------------
|op3
x3
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(8, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x1))
x3 = self.op3(x2 + x1)
x_pool = self.avg_pool(x3).flatten(1)
output = self.fc(x_pool)
return output
class LineModel(BaseModel):
"""
x
|net0,net1
|net2
|net3
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
nn.AdaptiveAvgPool2d(1))
self.linear = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
class AddCatModel(Module):
"""
x------------------------
|op1 |op2 |op3 |op4
x1 x2 x3 x4
| | | |
|cat----- |cat-----
cat1 cat2
| |
+----------------
x5
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 2, 3)
self.op2 = nn.Conv2d(3, 6, 3)
self.op3 = nn.Conv2d(3, 4, 3)
self.op4 = nn.Conv2d(3, 4, 3)
self.op5 = nn.Conv2d(8, 16, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
x4 = self.op4(x)
cat1 = torch.cat((x1, x2), dim=1)
cat2 = torch.cat((x3, x4), dim=1)
x5 = self.op5(cat1 + cat2)
x_pool = self.avg_pool(x5).flatten(1)
y = self.fc(x_pool)
return y
class GroupWiseConvModel(nn.Module):
"""
x
|op1,bn1
x1
|op2,bn2
x2
|op3
x3
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2)
self.bn2 = nn.BatchNorm2d(16)
self.op3 = nn.Conv2d(16, 32, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 1000)
def forward(self, x):
x1 = self.op1(x)
x1 = self.bn1(x1)
x2 = self.op2(x1)
x2 = self.bn2(x2)
x3 = self.op3(x2)
x_pool = self.avg_pool(x3).flatten(1)
return self.fc(x_pool)
class Xmodel(nn.Module):
"""
x--------
|op1 |op2
x1 x2
| |
+--------
x12------
|op3 |op4
x3 x4
| |
+--------
x34
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
self.op2 = nn.Conv2d(3, 8, 3, 1, 1)
self.op3 = nn.Conv2d(8, 16, 3, 1, 1)
self.op4 = nn.Conv2d(8, 16, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x12 = x1 * x2
x3 = self.op3(x12)
x4 = self.op4(x12)
x34 = x3 + x4
x_pool = self.avg_pool(x34).flatten(1)
return self.fc(x_pool)
class MultipleUseModel(nn.Module):
"""
x------------------------
|conv0 |conv1 |conv2 |conv3
xs.0 xs.1 xs.2 xs.3
|convm |convm |convm |convm
xs_.0 xs_.1 xs_.2 xs_.3
| | | |
+------------------------
|
x_sum
|conv_last
feature
|avg_pool
pool
|linear
output
"""
def __init__(self) -> None:
super().__init__()
self.conv0 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv3 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1)
self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(32, 1000)
def forward(self, x):
xs = [
conv(x)
for conv in [self.conv0, self.conv1, self.conv2, self.conv3]
]
xs_ = [self.conv_multiple_use(x_) for x_ in xs]
x_cat = torch.cat(xs_, dim=1)
feature = self.conv_last(x_cat)
pool = self.avg_pool(feature).flatten(1)
return self.linear(pool)
class IcepBlock(nn.Module):
"""
x------------------------
|op1 |op2 |op3 |op4
x1 x2 x3 x4
| | | |
cat----------------------
|
y_
"""
def __init__(self, in_c=3, out_c=32) -> None:
super().__init__()
self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1)
# self.op5 = nn.Conv2d(out_c*4, out_c, 3)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
x4 = self.op4(x)
y_ = [x1, x2, x3, x4]
y_ = torch.cat(y_, 1)
return y_
class Icep(nn.Module):
def __init__(self, num_icep_blocks=2) -> None:
super().__init__()
self.icps = nn.Sequential(*[
IcepBlock(32 * 4 if i != 0 else 3, 32)
for i in range(num_icep_blocks)
])
self.op = nn.Conv2d(32 * 4, 32, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 1000)
def forward(self, x):
y_ = self.icps(x)
y = self.op(y_)
pool = self.avg_pool(y).flatten(1)
return self.fc(pool)
class ExpandLineModel(Module):
"""
x
|net0,net1,net2
|net3,net4
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
nn.AdaptiveAvgPool2d(2))
self.linear = nn.Linear(64, 1000)
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
class MultiBindModel(Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv3 = nn.Conv2d(8, 8, 3, 1, 1)
self.head = LinearHead(8, 1000)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x12 = x1 + x2
x3 = self.conv3(x12)
x123 = x12 + x3
return self.head(x123)
class DwConvModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(),
nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48),
nn.ReLU())
self.head = LinearHead(48, 1000)
def forward(self, x):
return self.head(self.net(x))
# models with dynamicop
def register_mutable(module: DynamicChannelMixin,
mutable: MutableChannelUnit,
is_out=True,
start=0,
end=-1):
if end == -1:
end = mutable.num_channels + start
if is_out:
container: MutableChannelContainer = module.get_mutable_attr(
'out_channels')
else:
container: MutableChannelContainer = module.get_mutable_attr(
'in_channels')
container.register_mutable(mutable, start, end)
class SampleExpandDerivedMutable(BaseMutable):
def __init__(self, expand_ratio=1) -> None:
super().__init__()
self.ratio = expand_ratio
def __mul__(self, other):
if isinstance(other, OneShotMutableChannel):
def _expand_mask():
mask = other.current_mask
mask = torch.unsqueeze(
mask,
-1).expand(list(mask.shape) + [self.ratio]).flatten(-2)
return mask
return DerivedMutable(_expand_mask, _expand_mask, [self, other])
else:
raise NotImplementedError()
def dump_chosen(self):
return super().dump_chosen()
def export_chosen(self):
return super().export_chosen()
def fix_chosen(self, chosen):
return super().fix_chosen(chosen)
def num_choices(self) -> int:
return super().num_choices
@property
def current_choice(self):
return super().current_choice
@current_choice.setter
def current_choice(self, choice):
super().current_choice(choice)
class DynamicLinearModel(nn.Module):
"""
x
|net0,net1
|net2
|net3
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(),
DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16),
nn.AdaptiveAvgPool2d(1))
self.linear = DynamicLinear(16, 1000)
MutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self._register_mutable()
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
def _register_mutable(self):
mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8])
mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16])
mutable_value = SampleExpandDerivedMutable(1)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[0], mutable1, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[3], mutable_value * mutable1, False, 0, 8)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[3], mutable2, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[4], mutable2, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.linear, mutable2, False)
class DynamicAttention(nn.Module):
"""
x
|blocks: DynamicSequential(depth)
|(blocks)
x1
|fc (OneShotMutableChannel * OneShotMutableValue)
output
"""
def __init__(self) -> None:
super().__init__()
self.mutable_depth = OneShotMutableValue(
value_list=[1, 2], default_value=2)
self.mutable_embed_dims = OneShotMutableChannel(
num_channels=624, candidate_choices=[576, 624])
self.base_embed_dims = OneShotMutableChannel(
num_channels=64, candidate_choices=[64])
self.mutable_num_heads = [
OneShotMutableValue(
value_list=[8, 10],
default_value=10)
for _ in range(2)
]
self.mutable_mlp_ratios = [
OneShotMutableValue(
value_list=[3.0, 3.5, 4.0],
default_value=4.0)
for _ in range(2)
]
self.mutable_q_embed_dims = [
i * self.base_embed_dims for i in self.mutable_num_heads
]
self.patch_embed = DynamicPatchEmbed(
img_size=224,
in_channels=3,
embed_dims=self.mutable_embed_dims.num_channels)
# cls token and pos embed
self.pos_embed = nn.Parameter(
torch.zeros(1, 197,
self.mutable_embed_dims.num_channels))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
layers = []
for i in range(self.mutable_depth.max_choice):
layer = TransformerEncoderLayer(
embed_dims=self.mutable_embed_dims.num_channels,
num_heads=self.mutable_num_heads[i].max_choice,
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
layers.append(layer)
self.blocks = DynamicSequential(*layers)
# OneShotMutableChannelUnit
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self.register_mutables()
def register_mutables(self):
# mutablevalue
self.blocks.register_mutable_attr('depth', self.mutable_depth)
# mutablechannel
MutableChannelContainer.register_mutable_channel_to_module(
self.patch_embed, self.mutable_embed_dims, True)
for i in range(self.mutable_depth.max_choice):
layer = self.blocks[i]
layer.register_mutables(
mutable_num_heads=self.mutable_num_heads[i],
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
mutable_head_dims=self.base_embed_dims,
mutable_embed_dims=self.mutable_embed_dims)
def forward(self, x: torch.Tensor):
B = x.shape[0]
x = self.patch_embed(x)
embed_dims = self.mutable_embed_dims.current_choice
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[..., :embed_dims]
x = self.blocks(x)
return torch.mean(x[:, 1:], dim=1)
class DynamicMMBlock(nn.Module):
arch_setting = dict(
kernel_size=[ # [min_kernel_size, max_kernel_size, step]
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
],
num_blocks=[ # [min_num_blocks, max_num_blocks, step]
[1, 2, 1],
[3, 5, 1],
[3, 6, 1],
[3, 6, 1],
[3, 8, 1],
[3, 8, 1],
[1, 2, 1],
],
expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step]
[1, 1, 1],
[4, 6, 1],
[4, 6, 1],
[4, 6, 1],
[4, 6, 1],
[6, 6, 1],
[6, 6, 1]
],
num_out_channels=[ # [min_channel, max_channel, step]
[16, 24, 8],
[24, 32, 8],
[32, 40, 8],
[64, 72, 8],
[112, 128, 8],
[192, 216, 8],
[216, 224, 8]
])
def __init__(
self,
conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'),
norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'),
) -> None:
super().__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_list = ['Swish'] * 7
self.stride_list = [1, 2, 2, 2, 1, 2, 1]
self.with_se_list = [False, False, True, False, True, True, True]
self.kernel_size_list = parse_values(self.arch_setting['kernel_size'])
self.num_blocks_list = parse_values(self.arch_setting['num_blocks'])
self.expand_ratio_list = \
parse_values(self.arch_setting['expand_ratio'])
self.num_channels_list = \
parse_values(self.arch_setting['num_out_channels'])
assert len(self.kernel_size_list) == len(self.num_blocks_list) == \
len(self.expand_ratio_list) == len(self.num_channels_list)
self.with_attentive_shortcut = True
self.in_channels = 24
self.first_conv = ConvModule(
in_channels=3,
out_channels=24,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))
self.last_mutable = OneShotMutableChannel(num_channels=24, candidate_choices=[16, 24])
self.layers = []
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
self.expand_ratio_list, self.num_channels_list)):
inverted_res_layer = self._make_single_layer(
out_channels=num_channels,
num_blocks=num_blocks,
kernel_sizes=kernel_sizes,
expand_ratios=expand_ratios,
act_cfg=self.act_list[i],
stride=self.stride_list[i],
use_se=self.with_se_list[i])
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(inverted_res_layer)
last_expand_channels = 1344
self.out_channels = 1984
self.last_out_channels_list = [1792, 1984]
self.last_expand_ratio_list = [6]
last_layers = Sequential(
OrderedDict([('final_expand_layer',
ConvModule(
in_channels=self.in_channels,
out_channels=last_expand_channels,
kernel_size=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))),
('pool', nn.AdaptiveAvgPool2d((1, 1))),
('feature_mix_layer',
ConvModule(
in_channels=last_expand_channels,
out_channels=self.out_channels,
kernel_size=1,
padding=0,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=dict(type='Swish')))]))
self.add_module('last_conv', last_layers)
self.layers.append(last_layers)
self.register_mutables()
def _make_single_layer(self, out_channels, num_blocks,
kernel_sizes, expand_ratios,
act_cfg, stride, use_se):
_layers = []
for i in range(max(num_blocks)):
if i >= 1:
stride = 1
if use_se:
se_cfg = dict(
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')),
ratio=4,
conv_cfg=self.conv_cfg)
else:
se_cfg = None # type: ignore
mb_layer = MBBlock(
in_channels=self.in_channels,
out_channels=max(out_channels),
kernel_size=max(kernel_sizes),
stride=stride,
expand_ratio=max(expand_ratios),
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act_cfg),
se_cfg=se_cfg,
with_attentive_shortcut=self.with_attentive_shortcut)
_layers.append(mb_layer)
self.in_channels = max(out_channels)
dynamic_seq = DynamicSequential(*_layers)
return dynamic_seq
def register_mutables(self):
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
# mutate the first conv
mutate_conv_module(
self.first_conv, mutable_out_channels=self.last_mutable)
# mutate the built mobilenet layers
for i, layer in enumerate(self.layers[:-1]):
num_blocks = self.num_blocks_list[i]
kernel_sizes = self.kernel_size_list[i]
expand_ratios = self.expand_ratio_list[i]
out_channels = self.num_channels_list[i]
mutable_kernel_size = OneShotMutableValue(
value_list=kernel_sizes, default_value=max(kernel_sizes))
mutable_expand_value = OneShotMutableValue(
value_list=expand_ratios, default_value=max(expand_ratios))
mutable_out_channels = OneShotMutableChannel(
num_channels=max(out_channels), candidate_choices=out_channels)
se_ratios = [i / 4 for i in expand_ratios]
mutable_se_channels = OneShotMutableValue(
value_list=se_ratios, default_value=max(se_ratios))
for k in range(max(self.num_blocks_list[i])):
mutate_mobilenet_layer(layer[k], self.last_mutable,
mutable_out_channels,
mutable_se_channels,
mutable_expand_value,
mutable_kernel_size)
self.last_mutable = mutable_out_channels
mutable_depth = OneShotMutableValue(
value_list=num_blocks, default_value=max(num_blocks))
layer.register_mutable_attr('depth', mutable_depth)
mutable_out_channels = OneShotMutableChannel(
num_channels=self.out_channels,
candidate_choices=self.last_out_channels_list)
last_mutable_expand_value = OneShotMutableValue(
value_list=self.last_expand_ratio_list,
default_value=max(self.last_expand_ratio_list))
derived_expand_channels = self.last_mutable * last_mutable_expand_value
mutate_conv_module(
self.layers[-1].final_expand_layer,
mutable_in_channels=self.last_mutable,
mutable_out_channels=derived_expand_channels)
mutate_conv_module(
self.layers[-1].feature_mix_layer,
mutable_in_channels=derived_expand_channels,
mutable_out_channels=mutable_out_channels)
self.last_mutable = mutable_out_channels
def forward(self, x):
x = self.first_conv(x)
for _, layer in enumerate(self.layers):
x = layer(x)
return tuple([x])
default_models = [
LineModel,
ResBlock,
AddCatModel,
ConcatModel,
MultiConcatModel,
MultiConcatModel2,
GroupWiseConvModel,
Xmodel,
MultipleUseModel,
Icep,
ExpandLineModel,
DwConvModel,
]
class ModelLibrary:
# includes = [
# 'alexnet', # pass
# 'densenet', # pass
# # 'efficientnet', # pass
# # 'googlenet', # pass.
# # googlenet return a tuple when training,
# # so it should trace in eval mode
# # 'inception', # failed
# # 'mnasnet', # pass
# # 'mobilenet', # pass
# # 'regnet', # failed
# # 'resnet', # pass
# # 'resnext', # failed
# # 'shufflenet', # failed
# # 'squeezenet', # pass
# # 'vgg', # pass
# # 'wide_resnet', # pass
# ]
def __init__(self, include=[]) -> None:
self.include_key = include
self.model_creator = self.get_torch_models()
def __repr__(self) -> str:
s = f'model: {len(self.model_creator)}\n'
for creator in self.model_creator:
s += creator.__name__ + '\n'
return s
def get_torch_models(self):
from inspect import isfunction
import torchvision
attrs = dir(torchvision.models)
models = []
for name in attrs:
module = getattr(torchvision.models, name)
if isfunction(module):
models.append(module)
return models
def export_models(self):
models = []
for creator in self.model_creator:
if self.is_include(creator.__name__):
models.append(creator)
return models
def is_include(self, name):
for key in self.include_key:
if key in name:
return True
return False
def include(self):
include = []
for creator in self.model_creator:
for key in self.include_key:
if key in creator.__name__:
include.append(creator)
return include
def uninclude(self):
include = self.include()
uninclude = []
for creator in self.model_creator:
if creator not in include:
uninclude.append(creator)
return uninclude