mmrazor/tests/data/models.py
LKJacky 7acc046678
Add GroupFisher pruning algorithm. (#459)
* init

* support expand dwconv

* add tools

* init

* add import

* add configs

* add ut and fix bug

* update

* update finetune config

* update impl imports

* add deploy configs and result

* add _train_step

* detla_type -> normalization_type

* change img link

* add prune to config

* add json dump when GroupFisherSubModel init

* update prune config

* update finetune config

* update deploy config

* update prune config

* update readme

* mutable_cfg -> fix_subnet

* update readme

* impl -> implementations

* update script.sh

* rm gen_fake_cfg

* add Implementation to readme

* update docstring

* add finetune_lr to config

* update readme

* fix error in config

* update links

* update configs

* refine

* fix spell error

* add test to readme

* update README

* update readme

* update readme

* update cite format

* fix for ci

* update to pass ci

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
2023-02-20 14:29:42 +08:00

1077 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) OpenMMLab. All rights reserved.
# this file includes models for tesing.
from collections import OrderedDict
from typing import Dict
import math
from torch.nn import Module
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmengine.model import BaseModel
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
from mmrazor.models.mutables import MutableChannelUnit
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables import BaseMutable
from mmrazor.models.mutables import OneShotMutableChannelUnit, OneShotMutableChannel
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
from mmrazor.registry import MODELS
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
from mmrazor.models.utils.parse_values import parse_values
from mmrazor.models.architectures.ops.mobilenet_series import MBBlock
from mmcv.cnn import ConvModule
from mmengine.model import Sequential
from mmrazor.models.architectures.utils.mutable_register import (
mutate_conv_module, mutate_mobilenet_layer)
# models to test fx tracer
def untracable_function(x: torch.Tensor):
if x.sum() > 0:
x = x - 1
else:
x = x + 1
return x
class UntracableModule(nn.Module):
def __init__(self, in_channel, out_channel) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = self.conv(x)
if x.sum() > 0:
x = x * 2
else:
x = x * -2
x = self.conv2(x)
return x
class ModuleWithUntracableMethod(nn.Module):
def __init__(self, in_channel, out_channel) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = self.conv(x)
x = self.untracable_method(x)
x = self.conv2(x)
return x
def untracable_method(self, x):
if x.sum() > 0:
x = x * 2
else:
x = x * -2
return x
@MODELS.register_module()
class UntracableBackBone(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, 2)
self.untracable_module = UntracableModule(16, 8)
self.module_with_untracable_method = ModuleWithUntracableMethod(8, 16)
def forward(self, x):
x = self.conv(x)
x = untracable_function(x)
x = self.untracable_module(x)
x = self.module_with_untracable_method(x)
return x
class UntracableModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.backbone = UntracableBackBone()
self.head = LinearHeadForTest(16, 1000)
def forward(self, x):
return self.head(self.backbone(x))
class ConvAttnModel(Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
self.head = LinearHeadForTest(16, 1000)
def forward(self, x):
x1 = self.conv(x)
attn = F.sigmoid(self.pool(x1))
x_attn = x1 * attn
x_last = self.conv2(x_attn)
return self.head(x_last)
@MODELS.register_module()
class LinearHeadForTest(Module):
def __init__(self, in_channel, num_class=1000) -> None:
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(in_channel, num_class)
def forward(self, x):
pool = self.pool(x).flatten(1)
return self.linear(pool)
class MultiConcatModel(Module):
"""
x----------------
|op1 |op2 |op4
x1 x2 x4
| | |
|cat----- |
cat1 |
|op3 |
x3 |
|cat-------------
cat2
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(16, 8, 1)
self.op4 = nn.Conv2d(3, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x4 = self.op4(x)
cat2 = torch.cat([x3, x4], dim=1)
x_pool = self.avg_pool(cat2).flatten(1)
output = self.fc(x_pool)
return output
class MultiConcatModel2(Module):
"""
x---------------
|op1 |op2 |op3
x1 x2 x3
| | |
|cat----- |
cat1 |
|cat-------------
cat2
|op4
x4
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(3, 8, 1)
self.op4 = nn.Conv2d(24, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
cat1 = torch.cat([x1, x2], dim=1)
cat2 = torch.cat([cat1, x3], dim=1)
x4 = self.op4(cat2)
x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1])
output = self.fc(x_pool)
return output
class ConcatModel(Module):
"""
x------------
|op1,bn1 |op2,bn2
x1 x2
|cat--------|
cat1
|op3
x3
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(3, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(16, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x))
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x_pool = self.avg_pool(x3).flatten(1)
output = self.fc(x_pool)
return output
class ResBlock(Module):
"""
x
|op1,bn1
x1-----------
|op2,bn2 |
x2 |
+------------
|op3
x3
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(8, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x1))
x3 = self.op3(x2 + x1)
x_pool = self.avg_pool(x3).flatten(1)
output = self.fc(x_pool)
return output
class SingleLineModel(nn.Module):
"""
x
|net0,net1
|net2
|net3
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
nn.AdaptiveAvgPool2d(1))
self.linear = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
class AddCatModel(Module):
"""
x------------------------
|op1 |op2 |op3 |op4
x1 x2 x3 x4
| | | |
|cat----- |cat-----
cat1 cat2
| |
+----------------
x5
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 2, 3)
self.op2 = nn.Conv2d(3, 6, 3)
self.op3 = nn.Conv2d(3, 4, 3)
self.op4 = nn.Conv2d(3, 4, 3)
self.op5 = nn.Conv2d(8, 16, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
x4 = self.op4(x)
cat1 = torch.cat((x1, x2), dim=1)
cat2 = torch.cat((x3, x4), dim=1)
x5 = self.op5(cat1 + cat2)
x_pool = self.avg_pool(x5).flatten(1)
y = self.fc(x_pool)
return y
class GroupWiseConvModel(nn.Module):
"""
x
|op1,bn1
x1
|op2,bn2
x2
|op3
x3
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2)
self.bn2 = nn.BatchNorm2d(16)
self.op3 = nn.Conv2d(16, 32, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 1000)
def forward(self, x):
x1 = self.op1(x)
x1 = self.bn1(x1)
x2 = self.op2(x1)
x2 = self.bn2(x2)
x3 = self.op3(x2)
x_pool = self.avg_pool(x3).flatten(1)
return self.fc(x_pool)
class Xmodel(nn.Module):
"""
x--------
|op1 |op2
x1 x2
| |
+--------
x12------
|op3 |op4
x3 x4
| |
+--------
x34
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
self.op2 = nn.Conv2d(3, 8, 3, 1, 1)
self.op3 = nn.Conv2d(8, 16, 3, 1, 1)
self.op4 = nn.Conv2d(8, 16, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x12 = x1 * x2
x3 = self.op3(x12)
x4 = self.op4(x12)
x34 = x3 + x4
x_pool = self.avg_pool(x34).flatten(1)
return self.fc(x_pool)
class MultipleUseModel(nn.Module):
"""
x------------------------
|conv0 |conv1 |conv2 |conv3
xs.0 xs.1 xs.2 xs.3
|convm |convm |convm |convm
xs_.0 xs_.1 xs_.2 xs_.3
| | | |
+------------------------
|
x_sum
|conv_last
feature
|avg_pool
pool
|linear
output
"""
def __init__(self) -> None:
super().__init__()
self.conv0 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv3 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1)
self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(32, 1000)
def forward(self, x):
xs = [
conv(x)
for conv in [self.conv0, self.conv1, self.conv2, self.conv3]
]
xs_ = [self.conv_multiple_use(x_) for x_ in xs]
x_cat = torch.cat(xs_, dim=1)
feature = self.conv_last(x_cat)
pool = self.avg_pool(feature).flatten(1)
return self.linear(pool)
class IcepBlock(nn.Module):
"""
x------------------------
|op1 |op2 |op3 |op4
x1 x2 x3 x4
| | | |
cat----------------------
|
y_
"""
def __init__(self, in_c=3, out_c=32) -> None:
super().__init__()
self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1)
# self.op5 = nn.Conv2d(out_c*4, out_c, 3)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
x4 = self.op4(x)
y_ = [x1, x2, x3, x4]
y_ = torch.cat(y_, 1)
return y_
class Icep(nn.Module):
def __init__(self, num_icep_blocks=2) -> None:
super().__init__()
self.icps = nn.Sequential(*[
IcepBlock(32 * 4 if i != 0 else 3, 32)
for i in range(num_icep_blocks)
])
self.op = nn.Conv2d(32 * 4, 32, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 1000)
def forward(self, x):
y_ = self.icps(x)
y = self.op(y_)
pool = self.avg_pool(y).flatten(1)
return self.fc(pool)
class ExpandLineModel(Module):
"""
x
|net0,net1,net2
|net3,net4
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
nn.AdaptiveAvgPool2d(2))
self.linear = nn.Linear(64, 1000)
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
class MultiBindModel(Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv3 = nn.Conv2d(8, 8, 3, 1, 1)
self.head = LinearHeadForTest(8, 1000)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x12 = x1 + x2
x3 = self.conv3(x12)
x123 = x12 + x3
return self.head(x123)
class DwConvModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(),
nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48),
nn.ReLU())
self.head = LinearHeadForTest(48, 1000)
def forward(self, x):
return self.head(self.net(x))
class SelfAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.stem = nn.Conv2d(3, 32, 4, 4, 4)
self.num_head = 4
self.qkv = nn.Linear(32, 32 * 3)
self.proj = nn.Linear(32, 32)
self.head = LinearHeadForTest(32, 1000)
def forward(self, x: torch.Tensor):
x = self.stem(x)
h, w = x.shape[-2:]
x = self._to_token(x)
x = x + self._forward_attention(x)
x = self._to_img(x, h, w)
return self.head(x)
def _to_img(self, x, h, w):
x = x.reshape([x.shape[0], h, w, x.shape[2]])
x = x.permute(0, 3, 1, 2)
return x
def _to_token(self, x):
x = x.flatten(2).transpose(-1, -2)
return x
def _forward_attention(self, x: torch.Tensor):
qkv = self.qkv(x)
qkv = qkv.reshape([
x.shape[0], x.shape[1], 3, self.num_head,
x.shape[2] // self.num_head
]).permute(2, 0, 3, 1, 4).contiguous()
q, k, v = qkv
attn = q @ k.transpose(-1, -2) / math.sqrt(32 // self.num_head)
y = attn @ v # B H N h
y = y.permute(0, 2, 1, 3).flatten(-2)
return self.proj(y)
def MMClsResNet18() -> BaseModel:
model_cfg = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
return MODELS.build(model_cfg)
# models with dynamicop
def register_mutable(module: DynamicChannelMixin,
mutable: MutableChannelUnit,
is_out=True,
start=0,
end=-1):
if end == -1:
end = mutable.num_channels + start
if is_out:
container: MutableChannelContainer = module.get_mutable_attr(
'out_channels')
else:
container: MutableChannelContainer = module.get_mutable_attr(
'in_channels')
container.register_mutable(mutable, start, end)
class SampleExpandDerivedMutable(BaseMutable):
def __init__(self, expand_ratio=1) -> None:
super().__init__()
self.ratio = expand_ratio
def __mul__(self, other):
if isinstance(other, OneShotMutableChannel):
def _expand_mask():
mask = other.current_mask
mask = torch.unsqueeze(
mask,
-1).expand(list(mask.shape) + [self.ratio]).flatten(-2)
return mask
return DerivedMutable(_expand_mask, _expand_mask, [self, other])
else:
raise NotImplementedError()
def dump_chosen(self):
return super().dump_chosen()
def export_chosen(self):
return super().export_chosen()
def fix_chosen(self, chosen):
return super().fix_chosen(chosen)
def num_choices(self) -> int:
return super().num_choices
@property
def current_choice(self):
return super().current_choice
@current_choice.setter
def current_choice(self, choice):
super().current_choice(choice)
class DynamicLinearModel(nn.Module):
"""
x
|net0,net1
|net2
|net3
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(),
DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16),
nn.AdaptiveAvgPool2d(1))
self.linear = DynamicLinear(16, 1000)
MutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self._register_mutable()
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
def _register_mutable(self):
mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8])
mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16])
mutable_value = SampleExpandDerivedMutable(1)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[0], mutable1, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[3], mutable_value * mutable1, False, 0, 8)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[3], mutable2, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[4], mutable2, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.linear, mutable2, False)
class DynamicAttention(nn.Module):
"""
x
|blocks: DynamicSequential(depth)
|(blocks)
x1
|fc (OneShotMutableChannel * OneShotMutableValue)
output
"""
def __init__(self) -> None:
super().__init__()
self.mutable_depth = OneShotMutableValue(
value_list=[1, 2], default_value=2)
self.mutable_embed_dims = OneShotMutableChannel(
num_channels=624, candidate_choices=[576, 624])
self.base_embed_dims = OneShotMutableChannel(
num_channels=64, candidate_choices=[64])
self.mutable_num_heads = [
OneShotMutableValue(value_list=[8, 10], default_value=10)
for _ in range(2)
]
self.mutable_mlp_ratios = [
OneShotMutableValue(value_list=[3.0, 3.5, 4.0], default_value=4.0)
for _ in range(2)
]
self.mutable_q_embed_dims = [
i * self.base_embed_dims for i in self.mutable_num_heads
]
self.patch_embed = DynamicPatchEmbed(
img_size=224,
in_channels=3,
embed_dims=self.mutable_embed_dims.num_channels)
# cls token and pos embed
self.pos_embed = nn.Parameter(
torch.zeros(1, 197, self.mutable_embed_dims.num_channels))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
layers = []
for i in range(self.mutable_depth.max_choice):
layer = TransformerEncoderLayer(
embed_dims=self.mutable_embed_dims.num_channels,
num_heads=self.mutable_num_heads[i].max_choice,
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
layers.append(layer)
self.blocks = DynamicSequential(*layers)
# OneShotMutableChannelUnit
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self.register_mutables()
def register_mutables(self):
# mutablevalue
self.blocks.register_mutable_attr('depth', self.mutable_depth)
# mutablechannel
MutableChannelContainer.register_mutable_channel_to_module(
self.patch_embed, self.mutable_embed_dims, True)
for i in range(self.mutable_depth.max_choice):
layer = self.blocks[i]
layer.register_mutables(
mutable_num_heads=self.mutable_num_heads[i],
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
mutable_head_dims=self.base_embed_dims,
mutable_embed_dims=self.mutable_embed_dims)
def forward(self, x: torch.Tensor):
B = x.shape[0]
x = self.patch_embed(x)
embed_dims = self.mutable_embed_dims.current_choice
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[..., :embed_dims]
x = self.blocks(x)
return torch.mean(x[:, 1:], dim=1)
class DynamicMMBlock(nn.Module):
arch_setting = dict(
kernel_size=[ # [min_kernel_size, max_kernel_size, step]
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
],
num_blocks=[ # [min_num_blocks, max_num_blocks, step]
[1, 2, 1],
[3, 5, 1],
[3, 6, 1],
[3, 6, 1],
[3, 8, 1],
[3, 8, 1],
[1, 2, 1],
],
expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step]
[1, 1, 1],
[4, 6, 1],
[4, 6, 1],
[4, 6, 1],
[4, 6, 1],
[6, 6, 1],
[6, 6, 1],
],
num_out_channels=[ # [min_channel, max_channel, step]
[16, 24, 8],
[24, 32, 8],
[32, 40, 8],
[64, 72, 8],
[112, 128, 8],
[192, 216, 8],
[216, 224, 8],
])
def __init__(
self,
conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'),
norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'),
fine_grained_mode: bool = False,
) -> None:
super().__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_list = ['Swish'] * 7
self.stride_list = [1, 2, 2, 2, 1, 2, 1]
self.with_se_list = [False, False, True, False, True, True, True]
self.kernel_size_list = parse_values(self.arch_setting['kernel_size'])
self.num_blocks_list = parse_values(self.arch_setting['num_blocks'])
self.expand_ratio_list = \
parse_values(self.arch_setting['expand_ratio'])
self.num_channels_list = \
parse_values(self.arch_setting['num_out_channels'])
assert len(self.kernel_size_list) == len(self.num_blocks_list) == \
len(self.expand_ratio_list) == len(self.num_channels_list)
self.fine_grained_mode = fine_grained_mode
self.with_attentive_shortcut = True
self.in_channels = 24
self.first_out_channels_list = [16]
self.first_conv = ConvModule(
in_channels=3,
out_channels=24,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))
self.layers = []
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
self.expand_ratio_list, self.num_channels_list)):
inverted_res_layer = self._make_single_layer(
out_channels=num_channels,
num_blocks=num_blocks,
kernel_sizes=kernel_sizes,
expand_ratios=expand_ratios,
act_cfg=self.act_list[i],
stride=self.stride_list[i],
use_se=self.with_se_list[i])
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(inverted_res_layer)
last_expand_channels = 1344
self.out_channels = 1984
self.last_out_channels_list = [1792, 1984]
self.last_expand_ratio_list = [6]
last_layers = Sequential(
OrderedDict([('final_expand_layer',
ConvModule(
in_channels=self.in_channels,
out_channels=last_expand_channels,
kernel_size=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))),
('pool', nn.AdaptiveAvgPool2d((1, 1))),
('feature_mix_layer',
ConvModule(
in_channels=last_expand_channels,
out_channels=self.out_channels,
kernel_size=1,
padding=0,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=dict(type='Swish')))]))
self.add_module('last_conv', last_layers)
self.layers.append(last_layers)
self.register_mutables()
def _make_single_layer(self, out_channels, num_blocks, kernel_sizes,
expand_ratios, act_cfg, stride, use_se):
_layers = []
for i in range(max(num_blocks)):
if i >= 1:
stride = 1
if use_se:
se_cfg = dict(
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')),
ratio=4,
conv_cfg=self.conv_cfg)
else:
se_cfg = None # type: ignore
mb_layer = MBBlock(
in_channels=self.in_channels,
out_channels=max(out_channels),
kernel_size=max(kernel_sizes),
stride=stride,
expand_ratio=max(expand_ratios),
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act_cfg),
se_cfg=se_cfg,
with_attentive_shortcut=self.with_attentive_shortcut)
_layers.append(mb_layer)
self.in_channels = max(out_channels)
dynamic_seq = DynamicSequential(*_layers)
return dynamic_seq
def register_mutables(self):
"""Mutate the BigNAS-style MobileNetV3."""
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self.first_mutable_channels = OneShotMutableChannel(
alias='backbone.first_channels',
num_channels=max(self.first_out_channels_list),
candidate_choices=self.first_out_channels_list)
mutate_conv_module(
self.first_conv, mutable_out_channels=self.first_mutable_channels)
mid_mutable = self.first_mutable_channels
# mutate the built mobilenet layers
for i, layer in enumerate(self.layers[:-1]):
num_blocks = self.num_blocks_list[i]
kernel_sizes = self.kernel_size_list[i]
expand_ratios = self.expand_ratio_list[i]
out_channels = self.num_channels_list[i]
prefix = 'backbone.layers.' + str(i + 1) + '.'
mutable_out_channels = OneShotMutableChannel(
alias=prefix + 'out_channels',
candidate_choices=out_channels,
num_channels=max(out_channels))
if not self.fine_grained_mode:
mutable_kernel_size = OneShotMutableValue(
alias=prefix + 'kernel_size', value_list=kernel_sizes)
mutable_expand_ratio = OneShotMutableValue(
alias=prefix + 'expand_ratio', value_list=expand_ratios)
mutable_depth = OneShotMutableValue(
alias=prefix + 'depth', value_list=num_blocks)
layer.register_mutable_attr('depth', mutable_depth)
for k in range(max(self.num_blocks_list[i])):
if self.fine_grained_mode:
mutable_kernel_size = OneShotMutableValue(
alias=prefix + str(k) + '.kernel_size',
value_list=kernel_sizes)
mutable_expand_ratio = OneShotMutableValue(
alias=prefix + str(k) + '.expand_ratio',
value_list=expand_ratios)
mutate_mobilenet_layer(layer[k], mid_mutable,
mutable_out_channels,
mutable_expand_ratio,
mutable_kernel_size)
mid_mutable = mutable_out_channels
self.last_mutable_channels = OneShotMutableChannel(
alias='backbone.last_channels',
num_channels=self.out_channels,
candidate_choices=self.last_out_channels_list)
last_mutable_expand_value = OneShotMutableValue(
value_list=self.last_expand_ratio_list,
default_value=max(self.last_expand_ratio_list))
derived_expand_channels = mid_mutable * last_mutable_expand_value
mutate_conv_module(
self.layers[-1].final_expand_layer,
mutable_in_channels=mid_mutable,
mutable_out_channels=derived_expand_channels)
mutate_conv_module(
self.layers[-1].feature_mix_layer,
mutable_in_channels=derived_expand_channels,
mutable_out_channels=self.last_mutable_channels)
def forward(self, x):
x = self.first_conv(x)
for _, layer in enumerate(self.layers):
x = layer(x)
return tuple([x])