mmrazor/tests/data/models.py

1077 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) OpenMMLab. All rights reserved.
# this file includes models for tesing.
from collections import OrderedDict
from typing import Dict
import math
from torch.nn import Module
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmengine.model import BaseModel
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
from mmrazor.models.mutables import MutableChannelUnit
from mmrazor.models.mutables import DerivedMutable
from mmrazor.models.mutables import BaseMutable
from mmrazor.models.mutables import OneShotMutableChannelUnit, OneShotMutableChannel
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
from mmrazor.registry import MODELS
from mmrazor.models.mutables import OneShotMutableValue
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
from mmrazor.models.utils.parse_values import parse_values
from mmrazor.models.architectures.ops.mobilenet_series import MBBlock
from mmcv.cnn import ConvModule
from mmengine.model import Sequential
from mmrazor.models.architectures.utils.mutable_register import (
mutate_conv_module, mutate_mobilenet_layer)
# models to test fx tracer
def untracable_function(x: torch.Tensor):
if x.sum() > 0:
x = x - 1
else:
x = x + 1
return x
class UntracableModule(nn.Module):
def __init__(self, in_channel, out_channel) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = self.conv(x)
if x.sum() > 0:
x = x * 2
else:
x = x * -2
x = self.conv2(x)
return x
class ModuleWithUntracableMethod(nn.Module):
def __init__(self, in_channel, out_channel) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
def forward(self, x: torch.Tensor):
x = self.conv(x)
x = self.untracable_method(x)
x = self.conv2(x)
return x
def untracable_method(self, x):
if x.sum() > 0:
x = x * 2
else:
x = x * -2
return x
@MODELS.register_module()
class UntracableBackBone(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, 3, 2)
self.untracable_module = UntracableModule(16, 8)
self.module_with_untracable_method = ModuleWithUntracableMethod(8, 16)
def forward(self, x):
x = self.conv(x)
x = untracable_function(x)
x = self.untracable_module(x)
x = self.module_with_untracable_method(x)
return x
class UntracableModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.backbone = UntracableBackBone()
self.head = LinearHeadForTest(16, 1000)
def forward(self, x):
return self.head(self.backbone(x))
class ConvAttnModel(Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
self.head = LinearHeadForTest(16, 1000)
def forward(self, x):
x1 = self.conv(x)
attn = F.sigmoid(self.pool(x1))
x_attn = x1 * attn
x_last = self.conv2(x_attn)
return self.head(x_last)
@MODELS.register_module()
class LinearHeadForTest(Module):
def __init__(self, in_channel, num_class=1000) -> None:
super().__init__()
self.pool = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(in_channel, num_class)
def forward(self, x):
pool = self.pool(x).flatten(1)
return self.linear(pool)
class MultiConcatModel(Module):
"""
x----------------
|op1 |op2 |op4
x1 x2 x4
| | |
|cat----- |
cat1 |
|op3 |
x3 |
|cat-------------
cat2
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(16, 8, 1)
self.op4 = nn.Conv2d(3, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x4 = self.op4(x)
cat2 = torch.cat([x3, x4], dim=1)
x_pool = self.avg_pool(cat2).flatten(1)
output = self.fc(x_pool)
return output
class MultiConcatModel2(Module):
"""
x---------------
|op1 |op2 |op3
x1 x2 x3
| | |
|cat----- |
cat1 |
|cat-------------
cat2
|op4
x4
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.op2 = nn.Conv2d(3, 8, 1)
self.op3 = nn.Conv2d(3, 8, 1)
self.op4 = nn.Conv2d(24, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
cat1 = torch.cat([x1, x2], dim=1)
cat2 = torch.cat([cat1, x3], dim=1)
x4 = self.op4(cat2)
x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1])
output = self.fc(x_pool)
return output
class ConcatModel(Module):
"""
x------------
|op1,bn1 |op2,bn2
x1 x2
|cat--------|
cat1
|op3
x3
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(3, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(16, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x))
cat1 = torch.cat([x1, x2], dim=1)
x3 = self.op3(cat1)
x_pool = self.avg_pool(x3).flatten(1)
output = self.fc(x_pool)
return output
class ResBlock(Module):
"""
x
|op1,bn1
x1-----------
|op2,bn2 |
x2 |
+------------
|op3
x3
|avg_pool
x_pool
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 8, 1)
self.bn2 = nn.BatchNorm2d(8)
self.op3 = nn.Conv2d(8, 8, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(8, 1000)
def forward(self, x: Tensor) -> Tensor:
x1 = self.bn1(self.op1(x))
x2 = self.bn2(self.op2(x1))
x3 = self.op3(x2 + x1)
x_pool = self.avg_pool(x3).flatten(1)
output = self.fc(x_pool)
return output
class SingleLineModel(nn.Module):
"""
x
|net0,net1
|net2
|net3
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
nn.AdaptiveAvgPool2d(1))
self.linear = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
class AddCatModel(Module):
"""
x------------------------
|op1 |op2 |op3 |op4
x1 x2 x3 x4
| | | |
|cat----- |cat-----
cat1 cat2
| |
+----------------
x5
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 2, 3)
self.op2 = nn.Conv2d(3, 6, 3)
self.op3 = nn.Conv2d(3, 4, 3)
self.op4 = nn.Conv2d(3, 4, 3)
self.op5 = nn.Conv2d(8, 16, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
x4 = self.op4(x)
cat1 = torch.cat((x1, x2), dim=1)
cat2 = torch.cat((x3, x4), dim=1)
x5 = self.op5(cat1 + cat2)
x_pool = self.avg_pool(x5).flatten(1)
y = self.fc(x_pool)
return y
class GroupWiseConvModel(nn.Module):
"""
x
|op1,bn1
x1
|op2,bn2
x2
|op3
x3
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(8)
self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2)
self.bn2 = nn.BatchNorm2d(16)
self.op3 = nn.Conv2d(16, 32, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 1000)
def forward(self, x):
x1 = self.op1(x)
x1 = self.bn1(x1)
x2 = self.op2(x1)
x2 = self.bn2(x2)
x3 = self.op3(x2)
x_pool = self.avg_pool(x3).flatten(1)
return self.fc(x_pool)
class Xmodel(nn.Module):
"""
x--------
|op1 |op2
x1 x2
| |
+--------
x12------
|op3 |op4
x3 x4
| |
+--------
x34
|avg_pool
x_pool
|fc
y
"""
def __init__(self) -> None:
super().__init__()
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
self.op2 = nn.Conv2d(3, 8, 3, 1, 1)
self.op3 = nn.Conv2d(8, 16, 3, 1, 1)
self.op4 = nn.Conv2d(8, 16, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(16, 1000)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x12 = x1 * x2
x3 = self.op3(x12)
x4 = self.op4(x12)
x34 = x3 + x4
x_pool = self.avg_pool(x34).flatten(1)
return self.fc(x_pool)
class MultipleUseModel(nn.Module):
"""
x------------------------
|conv0 |conv1 |conv2 |conv3
xs.0 xs.1 xs.2 xs.3
|convm |convm |convm |convm
xs_.0 xs_.1 xs_.2 xs_.3
| | | |
+------------------------
|
x_sum
|conv_last
feature
|avg_pool
pool
|linear
output
"""
def __init__(self) -> None:
super().__init__()
self.conv0 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv3 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1)
self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(32, 1000)
def forward(self, x):
xs = [
conv(x)
for conv in [self.conv0, self.conv1, self.conv2, self.conv3]
]
xs_ = [self.conv_multiple_use(x_) for x_ in xs]
x_cat = torch.cat(xs_, dim=1)
feature = self.conv_last(x_cat)
pool = self.avg_pool(feature).flatten(1)
return self.linear(pool)
class IcepBlock(nn.Module):
"""
x------------------------
|op1 |op2 |op3 |op4
x1 x2 x3 x4
| | | |
cat----------------------
|
y_
"""
def __init__(self, in_c=3, out_c=32) -> None:
super().__init__()
self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op = nn.Conv2d(in_c, out_c, 3, 1, 1)
self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1)
# self.op5 = nn.Conv2d(out_c*4, out_c, 3)
def forward(self, x):
x1 = self.op1(x)
x2 = self.op2(x)
x3 = self.op3(x)
x4 = self.op4(x)
y_ = [x1, x2, x3, x4]
y_ = torch.cat(y_, 1)
return y_
class Icep(nn.Module):
def __init__(self, num_icep_blocks=2) -> None:
super().__init__()
self.icps = nn.Sequential(*[
IcepBlock(32 * 4 if i != 0 else 3, 32)
for i in range(num_icep_blocks)
])
self.op = nn.Conv2d(32 * 4, 32, 1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(32, 1000)
def forward(self, x):
y_ = self.icps(x)
y = self.op(y_)
pool = self.avg_pool(y).flatten(1)
return self.fc(pool)
class ExpandLineModel(Module):
"""
x
|net0,net1,net2
|net3,net4
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
nn.AdaptiveAvgPool2d(2))
self.linear = nn.Linear(64, 1000)
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
class MultiBindModel(Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
self.conv3 = nn.Conv2d(8, 8, 3, 1, 1)
self.head = LinearHeadForTest(8, 1000)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x12 = x1 + x2
x3 = self.conv3(x12)
x123 = x12 + x3
return self.head(x123)
class DwConvModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(),
nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48),
nn.ReLU())
self.head = LinearHeadForTest(48, 1000)
def forward(self, x):
return self.head(self.net(x))
class SelfAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.stem = nn.Conv2d(3, 32, 4, 4, 4)
self.num_head = 4
self.qkv = nn.Linear(32, 32 * 3)
self.proj = nn.Linear(32, 32)
self.head = LinearHeadForTest(32, 1000)
def forward(self, x: torch.Tensor):
x = self.stem(x)
h, w = x.shape[-2:]
x = self._to_token(x)
x = x + self._forward_attention(x)
x = self._to_img(x, h, w)
return self.head(x)
def _to_img(self, x, h, w):
x = x.reshape([x.shape[0], h, w, x.shape[2]])
x = x.permute(0, 3, 1, 2)
return x
def _to_token(self, x):
x = x.flatten(2).transpose(-1, -2)
return x
def _forward_attention(self, x: torch.Tensor):
qkv = self.qkv(x)
qkv = qkv.reshape([
x.shape[0], x.shape[1], 3, self.num_head,
x.shape[2] // self.num_head
]).permute(2, 0, 3, 1, 4).contiguous()
q, k, v = qkv
attn = q @ k.transpose(-1, -2) / math.sqrt(32 // self.num_head)
y = attn @ v # B H N h
y = y.permute(0, 2, 1, 3).flatten(-2)
return self.proj(y)
def MMClsResNet18() -> BaseModel:
model_cfg = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
return MODELS.build(model_cfg)
# models with dynamicop
def register_mutable(module: DynamicChannelMixin,
mutable: MutableChannelUnit,
is_out=True,
start=0,
end=-1):
if end == -1:
end = mutable.num_channels + start
if is_out:
container: MutableChannelContainer = module.get_mutable_attr(
'out_channels')
else:
container: MutableChannelContainer = module.get_mutable_attr(
'in_channels')
container.register_mutable(mutable, start, end)
class SampleExpandDerivedMutable(BaseMutable):
def __init__(self, expand_ratio=1) -> None:
super().__init__()
self.ratio = expand_ratio
def __mul__(self, other):
if isinstance(other, OneShotMutableChannel):
def _expand_mask():
mask = other.current_mask
mask = torch.unsqueeze(
mask,
-1).expand(list(mask.shape) + [self.ratio]).flatten(-2)
return mask
return DerivedMutable(_expand_mask, _expand_mask, [self, other])
else:
raise NotImplementedError()
def dump_chosen(self):
return super().dump_chosen()
def export_chosen(self):
return super().export_chosen()
def fix_chosen(self, chosen):
return super().fix_chosen(chosen)
def num_choices(self) -> int:
return super().num_choices
@property
def current_choice(self):
return super().current_choice
@current_choice.setter
def current_choice(self, choice):
super().current_choice(choice)
class DynamicLinearModel(nn.Module):
"""
x
|net0,net1
|net2
|net3
x1
|fc
output
"""
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(),
DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16),
nn.AdaptiveAvgPool2d(1))
self.linear = DynamicLinear(16, 1000)
MutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self._register_mutable()
def forward(self, x):
x1 = self.net(x)
x1 = x1.reshape([x1.shape[0], -1])
return self.linear(x1)
def _register_mutable(self):
mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8])
mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16])
mutable_value = SampleExpandDerivedMutable(1)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[0], mutable1, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[3], mutable_value * mutable1, False, 0, 8)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[3], mutable2, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.net[4], mutable2, True)
MutableChannelContainer.register_mutable_channel_to_module(
self.linear, mutable2, False)
class DynamicAttention(nn.Module):
"""
x
|blocks: DynamicSequential(depth)
|(blocks)
x1
|fc (OneShotMutableChannel * OneShotMutableValue)
output
"""
def __init__(self) -> None:
super().__init__()
self.mutable_depth = OneShotMutableValue(
value_list=[1, 2], default_value=2)
self.mutable_embed_dims = OneShotMutableChannel(
num_channels=624, candidate_choices=[576, 624])
self.base_embed_dims = OneShotMutableChannel(
num_channels=64, candidate_choices=[64])
self.mutable_num_heads = [
OneShotMutableValue(value_list=[8, 10], default_value=10)
for _ in range(2)
]
self.mutable_mlp_ratios = [
OneShotMutableValue(value_list=[3.0, 3.5, 4.0], default_value=4.0)
for _ in range(2)
]
self.mutable_q_embed_dims = [
i * self.base_embed_dims for i in self.mutable_num_heads
]
self.patch_embed = DynamicPatchEmbed(
img_size=224,
in_channels=3,
embed_dims=self.mutable_embed_dims.num_channels)
# cls token and pos embed
self.pos_embed = nn.Parameter(
torch.zeros(1, 197, self.mutable_embed_dims.num_channels))
self.cls_token = nn.Parameter(
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
layers = []
for i in range(self.mutable_depth.max_choice):
layer = TransformerEncoderLayer(
embed_dims=self.mutable_embed_dims.num_channels,
num_heads=self.mutable_num_heads[i].max_choice,
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
layers.append(layer)
self.blocks = DynamicSequential(*layers)
# OneShotMutableChannelUnit
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self.register_mutables()
def register_mutables(self):
# mutablevalue
self.blocks.register_mutable_attr('depth', self.mutable_depth)
# mutablechannel
MutableChannelContainer.register_mutable_channel_to_module(
self.patch_embed, self.mutable_embed_dims, True)
for i in range(self.mutable_depth.max_choice):
layer = self.blocks[i]
layer.register_mutables(
mutable_num_heads=self.mutable_num_heads[i],
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
mutable_head_dims=self.base_embed_dims,
mutable_embed_dims=self.mutable_embed_dims)
def forward(self, x: torch.Tensor):
B = x.shape[0]
x = self.patch_embed(x)
embed_dims = self.mutable_embed_dims.current_choice
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed[..., :embed_dims]
x = self.blocks(x)
return torch.mean(x[:, 1:], dim=1)
class DynamicMMBlock(nn.Module):
arch_setting = dict(
kernel_size=[ # [min_kernel_size, max_kernel_size, step]
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
[3, 5, 2],
],
num_blocks=[ # [min_num_blocks, max_num_blocks, step]
[1, 2, 1],
[3, 5, 1],
[3, 6, 1],
[3, 6, 1],
[3, 8, 1],
[3, 8, 1],
[1, 2, 1],
],
expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step]
[1, 1, 1],
[4, 6, 1],
[4, 6, 1],
[4, 6, 1],
[4, 6, 1],
[6, 6, 1],
[6, 6, 1],
],
num_out_channels=[ # [min_channel, max_channel, step]
[16, 24, 8],
[24, 32, 8],
[32, 40, 8],
[64, 72, 8],
[112, 128, 8],
[192, 216, 8],
[216, 224, 8],
])
def __init__(
self,
conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'),
norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'),
fine_grained_mode: bool = False,
) -> None:
super().__init__()
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_list = ['Swish'] * 7
self.stride_list = [1, 2, 2, 2, 1, 2, 1]
self.with_se_list = [False, False, True, False, True, True, True]
self.kernel_size_list = parse_values(self.arch_setting['kernel_size'])
self.num_blocks_list = parse_values(self.arch_setting['num_blocks'])
self.expand_ratio_list = \
parse_values(self.arch_setting['expand_ratio'])
self.num_channels_list = \
parse_values(self.arch_setting['num_out_channels'])
assert len(self.kernel_size_list) == len(self.num_blocks_list) == \
len(self.expand_ratio_list) == len(self.num_channels_list)
self.fine_grained_mode = fine_grained_mode
self.with_attentive_shortcut = True
self.in_channels = 24
self.first_out_channels_list = [16]
self.first_conv = ConvModule(
in_channels=3,
out_channels=24,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))
self.layers = []
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
self.expand_ratio_list, self.num_channels_list)):
inverted_res_layer = self._make_single_layer(
out_channels=num_channels,
num_blocks=num_blocks,
kernel_sizes=kernel_sizes,
expand_ratios=expand_ratios,
act_cfg=self.act_list[i],
stride=self.stride_list[i],
use_se=self.with_se_list[i])
layer_name = f'layer{i + 1}'
self.add_module(layer_name, inverted_res_layer)
self.layers.append(inverted_res_layer)
last_expand_channels = 1344
self.out_channels = 1984
self.last_out_channels_list = [1792, 1984]
self.last_expand_ratio_list = [6]
last_layers = Sequential(
OrderedDict([('final_expand_layer',
ConvModule(
in_channels=self.in_channels,
out_channels=last_expand_channels,
kernel_size=1,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type='Swish'))),
('pool', nn.AdaptiveAvgPool2d((1, 1))),
('feature_mix_layer',
ConvModule(
in_channels=last_expand_channels,
out_channels=self.out_channels,
kernel_size=1,
padding=0,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=dict(type='Swish')))]))
self.add_module('last_conv', last_layers)
self.layers.append(last_layers)
self.register_mutables()
def _make_single_layer(self, out_channels, num_blocks, kernel_sizes,
expand_ratios, act_cfg, stride, use_se):
_layers = []
for i in range(max(num_blocks)):
if i >= 1:
stride = 1
if use_se:
se_cfg = dict(
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')),
ratio=4,
conv_cfg=self.conv_cfg)
else:
se_cfg = None # type: ignore
mb_layer = MBBlock(
in_channels=self.in_channels,
out_channels=max(out_channels),
kernel_size=max(kernel_sizes),
stride=stride,
expand_ratio=max(expand_ratios),
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=dict(type=act_cfg),
se_cfg=se_cfg,
with_attentive_shortcut=self.with_attentive_shortcut)
_layers.append(mb_layer)
self.in_channels = max(out_channels)
dynamic_seq = DynamicSequential(*_layers)
return dynamic_seq
def register_mutables(self):
"""Mutate the BigNAS-style MobileNetV3."""
OneShotMutableChannelUnit._register_channel_container(
self, MutableChannelContainer)
self.first_mutable_channels = OneShotMutableChannel(
alias='backbone.first_channels',
num_channels=max(self.first_out_channels_list),
candidate_choices=self.first_out_channels_list)
mutate_conv_module(
self.first_conv, mutable_out_channels=self.first_mutable_channels)
mid_mutable = self.first_mutable_channels
# mutate the built mobilenet layers
for i, layer in enumerate(self.layers[:-1]):
num_blocks = self.num_blocks_list[i]
kernel_sizes = self.kernel_size_list[i]
expand_ratios = self.expand_ratio_list[i]
out_channels = self.num_channels_list[i]
prefix = 'backbone.layers.' + str(i + 1) + '.'
mutable_out_channels = OneShotMutableChannel(
alias=prefix + 'out_channels',
candidate_choices=out_channels,
num_channels=max(out_channels))
if not self.fine_grained_mode:
mutable_kernel_size = OneShotMutableValue(
alias=prefix + 'kernel_size', value_list=kernel_sizes)
mutable_expand_ratio = OneShotMutableValue(
alias=prefix + 'expand_ratio', value_list=expand_ratios)
mutable_depth = OneShotMutableValue(
alias=prefix + 'depth', value_list=num_blocks)
layer.register_mutable_attr('depth', mutable_depth)
for k in range(max(self.num_blocks_list[i])):
if self.fine_grained_mode:
mutable_kernel_size = OneShotMutableValue(
alias=prefix + str(k) + '.kernel_size',
value_list=kernel_sizes)
mutable_expand_ratio = OneShotMutableValue(
alias=prefix + str(k) + '.expand_ratio',
value_list=expand_ratios)
mutate_mobilenet_layer(layer[k], mid_mutable,
mutable_out_channels,
mutable_expand_ratio,
mutable_kernel_size)
mid_mutable = mutable_out_channels
self.last_mutable_channels = OneShotMutableChannel(
alias='backbone.last_channels',
num_channels=self.out_channels,
candidate_choices=self.last_out_channels_list)
last_mutable_expand_value = OneShotMutableValue(
value_list=self.last_expand_ratio_list,
default_value=max(self.last_expand_ratio_list))
derived_expand_channels = mid_mutable * last_mutable_expand_value
mutate_conv_module(
self.layers[-1].final_expand_layer,
mutable_in_channels=mid_mutable,
mutable_out_channels=derived_expand_channels)
mutate_conv_module(
self.layers[-1].feature_mix_layer,
mutable_in_channels=derived_expand_channels,
mutable_out_channels=self.last_mutable_channels)
def forward(self, x):
x = self.first_conv(x)
for _, layer in enumerate(self.layers):
x = layer(x)
return tuple([x])