1077 lines
32 KiB
Python
1077 lines
32 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
# this file includes models for tesing.
|
||
from collections import OrderedDict
|
||
from typing import Dict
|
||
import math
|
||
|
||
from torch.nn import Module
|
||
from torch import Tensor
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch
|
||
from mmengine.model import BaseModel
|
||
from mmrazor.models.architectures.dynamic_ops import DynamicBatchNorm2d, DynamicConv2d, DynamicLinear, DynamicChannelMixin, DynamicPatchEmbed, DynamicSequential
|
||
from mmrazor.models.mutables.mutable_channel import MutableChannelContainer
|
||
from mmrazor.models.mutables import MutableChannelUnit
|
||
from mmrazor.models.mutables import DerivedMutable
|
||
from mmrazor.models.mutables import BaseMutable
|
||
from mmrazor.models.mutables import OneShotMutableChannelUnit, OneShotMutableChannel
|
||
|
||
from mmrazor.models.mutables import OneShotMutableValue
|
||
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
|
||
from mmrazor.registry import MODELS
|
||
from mmrazor.models.mutables import OneShotMutableValue
|
||
from mmrazor.models.architectures.backbones.searchable_autoformer import TransformerEncoderLayer
|
||
from mmrazor.models.utils.parse_values import parse_values
|
||
|
||
from mmrazor.models.architectures.ops.mobilenet_series import MBBlock
|
||
from mmcv.cnn import ConvModule
|
||
from mmengine.model import Sequential
|
||
from mmrazor.models.architectures.utils.mutable_register import (
|
||
mutate_conv_module, mutate_mobilenet_layer)
|
||
|
||
# models to test fx tracer
|
||
|
||
|
||
def untracable_function(x: torch.Tensor):
|
||
if x.sum() > 0:
|
||
x = x - 1
|
||
else:
|
||
x = x + 1
|
||
return x
|
||
|
||
|
||
class UntracableModule(nn.Module):
|
||
|
||
def __init__(self, in_channel, out_channel) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
x = self.conv(x)
|
||
if x.sum() > 0:
|
||
x = x * 2
|
||
else:
|
||
x = x * -2
|
||
x = self.conv2(x)
|
||
return x
|
||
|
||
|
||
class ModuleWithUntracableMethod(nn.Module):
|
||
|
||
def __init__(self, in_channel, out_channel) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(in_channel, out_channel, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(out_channel, out_channel, 3, 1, 1)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
x = self.conv(x)
|
||
x = self.untracable_method(x)
|
||
x = self.conv2(x)
|
||
return x
|
||
|
||
def untracable_method(self, x):
|
||
if x.sum() > 0:
|
||
x = x * 2
|
||
else:
|
||
x = x * -2
|
||
return x
|
||
|
||
|
||
@MODELS.register_module()
|
||
class UntracableBackBone(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(3, 16, 3, 2)
|
||
self.untracable_module = UntracableModule(16, 8)
|
||
self.module_with_untracable_method = ModuleWithUntracableMethod(8, 16)
|
||
|
||
def forward(self, x):
|
||
x = self.conv(x)
|
||
x = untracable_function(x)
|
||
x = self.untracable_module(x)
|
||
x = self.module_with_untracable_method(x)
|
||
return x
|
||
|
||
|
||
class UntracableModel(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.backbone = UntracableBackBone()
|
||
self.head = LinearHeadForTest(16, 1000)
|
||
|
||
def forward(self, x):
|
||
return self.head(self.backbone(x))
|
||
|
||
|
||
class ConvAttnModel(Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||
self.conv2 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.head = LinearHeadForTest(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.conv(x)
|
||
attn = F.sigmoid(self.pool(x1))
|
||
x_attn = x1 * attn
|
||
x_last = self.conv2(x_attn)
|
||
return self.head(x_last)
|
||
|
||
|
||
@MODELS.register_module()
|
||
class LinearHeadForTest(Module):
|
||
|
||
def __init__(self, in_channel, num_class=1000) -> None:
|
||
super().__init__()
|
||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||
self.linear = nn.Linear(in_channel, num_class)
|
||
|
||
def forward(self, x):
|
||
pool = self.pool(x).flatten(1)
|
||
return self.linear(pool)
|
||
|
||
|
||
class MultiConcatModel(Module):
|
||
"""
|
||
x----------------
|
||
|op1 |op2 |op4
|
||
x1 x2 x4
|
||
| | |
|
||
|cat----- |
|
||
cat1 |
|
||
|op3 |
|
||
x3 |
|
||
|cat-------------
|
||
cat2
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.op3 = nn.Conv2d(16, 8, 1)
|
||
self.op4 = nn.Conv2d(3, 8, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
x3 = self.op3(cat1)
|
||
x4 = self.op4(x)
|
||
cat2 = torch.cat([x3, x4], dim=1)
|
||
x_pool = self.avg_pool(cat2).flatten(1)
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class MultiConcatModel2(Module):
|
||
"""
|
||
x---------------
|
||
|op1 |op2 |op3
|
||
x1 x2 x3
|
||
| | |
|
||
|cat----- |
|
||
cat1 |
|
||
|cat-------------
|
||
cat2
|
||
|op4
|
||
x4
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.op3 = nn.Conv2d(3, 8, 1)
|
||
self.op4 = nn.Conv2d(24, 8, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
cat2 = torch.cat([cat1, x3], dim=1)
|
||
x4 = self.op4(cat2)
|
||
|
||
x_pool = self.avg_pool(x4).reshape([x4.shape[0], -1])
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class ConcatModel(Module):
|
||
"""
|
||
x------------
|
||
|op1,bn1 |op2,bn2
|
||
x1 x2
|
||
|cat--------|
|
||
cat1
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(3, 8, 1)
|
||
self.bn2 = nn.BatchNorm2d(8)
|
||
self.op3 = nn.Conv2d(16, 8, 1)
|
||
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.bn1(self.op1(x))
|
||
x2 = self.bn2(self.op2(x))
|
||
cat1 = torch.cat([x1, x2], dim=1)
|
||
x3 = self.op3(cat1)
|
||
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
output = self.fc(x_pool)
|
||
|
||
return output
|
||
|
||
|
||
class ResBlock(Module):
|
||
"""
|
||
x
|
||
|op1,bn1
|
||
x1-----------
|
||
|op2,bn2 |
|
||
x2 |
|
||
+------------
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.op1 = nn.Conv2d(3, 8, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(8, 8, 1)
|
||
self.bn2 = nn.BatchNorm2d(8)
|
||
self.op3 = nn.Conv2d(8, 8, 1)
|
||
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(8, 1000)
|
||
|
||
def forward(self, x: Tensor) -> Tensor:
|
||
x1 = self.bn1(self.op1(x))
|
||
x2 = self.bn2(self.op2(x1))
|
||
x3 = self.op3(x2 + x1)
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
output = self.fc(x_pool)
|
||
return output
|
||
|
||
|
||
class SingleLineModel(nn.Module):
|
||
"""
|
||
x
|
||
|net0,net1
|
||
|net2
|
||
|net3
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
|
||
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(1))
|
||
self.linear = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
|
||
class AddCatModel(Module):
|
||
"""
|
||
x------------------------
|
||
|op1 |op2 |op3 |op4
|
||
x1 x2 x3 x4
|
||
| | | |
|
||
|cat----- |cat-----
|
||
cat1 cat2
|
||
| |
|
||
+----------------
|
||
x5
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 2, 3)
|
||
self.op2 = nn.Conv2d(3, 6, 3)
|
||
self.op3 = nn.Conv2d(3, 4, 3)
|
||
self.op4 = nn.Conv2d(3, 4, 3)
|
||
self.op5 = nn.Conv2d(8, 16, 3)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
x4 = self.op4(x)
|
||
|
||
cat1 = torch.cat((x1, x2), dim=1)
|
||
cat2 = torch.cat((x3, x4), dim=1)
|
||
x5 = self.op5(cat1 + cat2)
|
||
x_pool = self.avg_pool(x5).flatten(1)
|
||
y = self.fc(x_pool)
|
||
return y
|
||
|
||
|
||
class GroupWiseConvModel(nn.Module):
|
||
"""
|
||
x
|
||
|op1,bn1
|
||
x1
|
||
|op2,bn2
|
||
x2
|
||
|op3
|
||
x3
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.bn1 = nn.BatchNorm2d(8)
|
||
self.op2 = nn.Conv2d(8, 16, 3, 1, 1, groups=2)
|
||
self.bn2 = nn.BatchNorm2d(16)
|
||
self.op3 = nn.Conv2d(16, 32, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x1 = self.bn1(x1)
|
||
x2 = self.op2(x1)
|
||
x2 = self.bn2(x2)
|
||
x3 = self.op3(x2)
|
||
x_pool = self.avg_pool(x3).flatten(1)
|
||
return self.fc(x_pool)
|
||
|
||
|
||
class Xmodel(nn.Module):
|
||
"""
|
||
x--------
|
||
|op1 |op2
|
||
x1 x2
|
||
| |
|
||
+--------
|
||
x12------
|
||
|op3 |op4
|
||
x3 x4
|
||
| |
|
||
+--------
|
||
x34
|
||
|avg_pool
|
||
x_pool
|
||
|fc
|
||
y
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.op2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.op3 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.op4 = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(16, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x12 = x1 * x2
|
||
x3 = self.op3(x12)
|
||
x4 = self.op4(x12)
|
||
x34 = x3 + x4
|
||
x_pool = self.avg_pool(x34).flatten(1)
|
||
return self.fc(x_pool)
|
||
|
||
|
||
class MultipleUseModel(nn.Module):
|
||
"""
|
||
x------------------------
|
||
|conv0 |conv1 |conv2 |conv3
|
||
xs.0 xs.1 xs.2 xs.3
|
||
|convm |convm |convm |convm
|
||
xs_.0 xs_.1 xs_.2 xs_.3
|
||
| | | |
|
||
+------------------------
|
||
|
|
||
x_sum
|
||
|conv_last
|
||
feature
|
||
|avg_pool
|
||
pool
|
||
|linear
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv0 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv3 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv_multiple_use = nn.Conv2d(8, 16, 3, 1, 1)
|
||
self.conv_last = nn.Conv2d(16 * 4, 32, 3, 1, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.linear = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
xs = [
|
||
conv(x)
|
||
for conv in [self.conv0, self.conv1, self.conv2, self.conv3]
|
||
]
|
||
xs_ = [self.conv_multiple_use(x_) for x_ in xs]
|
||
x_cat = torch.cat(xs_, dim=1)
|
||
feature = self.conv_last(x_cat)
|
||
pool = self.avg_pool(feature).flatten(1)
|
||
return self.linear(pool)
|
||
|
||
|
||
class IcepBlock(nn.Module):
|
||
"""
|
||
x------------------------
|
||
|op1 |op2 |op3 |op4
|
||
x1 x2 x3 x4
|
||
| | | |
|
||
cat----------------------
|
||
|
|
||
y_
|
||
"""
|
||
|
||
def __init__(self, in_c=3, out_c=32) -> None:
|
||
super().__init__()
|
||
self.op1 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op2 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op3 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
self.op4 = nn.Conv2d(in_c, out_c, 3, 1, 1)
|
||
# self.op5 = nn.Conv2d(out_c*4, out_c, 3)
|
||
|
||
def forward(self, x):
|
||
x1 = self.op1(x)
|
||
x2 = self.op2(x)
|
||
x3 = self.op3(x)
|
||
x4 = self.op4(x)
|
||
y_ = [x1, x2, x3, x4]
|
||
y_ = torch.cat(y_, 1)
|
||
return y_
|
||
|
||
|
||
class Icep(nn.Module):
|
||
|
||
def __init__(self, num_icep_blocks=2) -> None:
|
||
super().__init__()
|
||
self.icps = nn.Sequential(*[
|
||
IcepBlock(32 * 4 if i != 0 else 3, 32)
|
||
for i in range(num_icep_blocks)
|
||
])
|
||
self.op = nn.Conv2d(32 * 4, 32, 1)
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.fc = nn.Linear(32, 1000)
|
||
|
||
def forward(self, x):
|
||
y_ = self.icps(x)
|
||
y = self.op(y_)
|
||
pool = self.avg_pool(y).flatten(1)
|
||
return self.fc(pool)
|
||
|
||
|
||
class ExpandLineModel(Module):
|
||
"""
|
||
x
|
||
|net0,net1,net2
|
||
|net3,net4
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(),
|
||
nn.Conv2d(8, 16, 3, 1, 1), nn.BatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(2))
|
||
self.linear = nn.Linear(64, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
|
||
class MultiBindModel(Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.conv1 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv2 = nn.Conv2d(3, 8, 3, 1, 1)
|
||
self.conv3 = nn.Conv2d(8, 8, 3, 1, 1)
|
||
self.head = LinearHeadForTest(8, 1000)
|
||
|
||
def forward(self, x):
|
||
x1 = self.conv1(x)
|
||
x2 = self.conv2(x)
|
||
x12 = x1 + x2
|
||
x3 = self.conv3(x12)
|
||
x123 = x12 + x3
|
||
return self.head(x123)
|
||
|
||
|
||
class DwConvModel(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
nn.Conv2d(3, 48, 3, 1, 1), nn.BatchNorm2d(48), nn.ReLU(),
|
||
nn.Conv2d(48, 48, 3, 1, 1, groups=48), nn.BatchNorm2d(48),
|
||
nn.ReLU())
|
||
self.head = LinearHeadForTest(48, 1000)
|
||
|
||
def forward(self, x):
|
||
return self.head(self.net(x))
|
||
|
||
|
||
class SelfAttention(nn.Module):
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.stem = nn.Conv2d(3, 32, 4, 4, 4)
|
||
|
||
self.num_head = 4
|
||
self.qkv = nn.Linear(32, 32 * 3)
|
||
self.proj = nn.Linear(32, 32)
|
||
|
||
self.head = LinearHeadForTest(32, 1000)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
x = self.stem(x)
|
||
h, w = x.shape[-2:]
|
||
x = self._to_token(x)
|
||
x = x + self._forward_attention(x)
|
||
x = self._to_img(x, h, w)
|
||
return self.head(x)
|
||
|
||
def _to_img(self, x, h, w):
|
||
x = x.reshape([x.shape[0], h, w, x.shape[2]])
|
||
x = x.permute(0, 3, 1, 2)
|
||
return x
|
||
|
||
def _to_token(self, x):
|
||
x = x.flatten(2).transpose(-1, -2)
|
||
return x
|
||
|
||
def _forward_attention(self, x: torch.Tensor):
|
||
qkv = self.qkv(x)
|
||
qkv = qkv.reshape([
|
||
x.shape[0], x.shape[1], 3, self.num_head,
|
||
x.shape[2] // self.num_head
|
||
]).permute(2, 0, 3, 1, 4).contiguous()
|
||
q, k, v = qkv
|
||
attn = q @ k.transpose(-1, -2) / math.sqrt(32 // self.num_head)
|
||
y = attn @ v # B H N h
|
||
y = y.permute(0, 2, 1, 3).flatten(-2)
|
||
return self.proj(y)
|
||
|
||
|
||
def MMClsResNet18() -> BaseModel:
|
||
model_cfg = dict(
|
||
_scope_='mmcls',
|
||
type='ImageClassifier',
|
||
backbone=dict(
|
||
type='ResNet',
|
||
depth=18,
|
||
num_stages=4,
|
||
out_indices=(3, ),
|
||
style='pytorch'),
|
||
neck=dict(type='GlobalAveragePooling'),
|
||
head=dict(
|
||
type='LinearClsHead',
|
||
num_classes=1000,
|
||
in_channels=512,
|
||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||
topk=(1, 5),
|
||
))
|
||
return MODELS.build(model_cfg)
|
||
|
||
|
||
# models with dynamicop
|
||
|
||
|
||
def register_mutable(module: DynamicChannelMixin,
|
||
mutable: MutableChannelUnit,
|
||
is_out=True,
|
||
start=0,
|
||
end=-1):
|
||
if end == -1:
|
||
end = mutable.num_channels + start
|
||
if is_out:
|
||
container: MutableChannelContainer = module.get_mutable_attr(
|
||
'out_channels')
|
||
else:
|
||
container: MutableChannelContainer = module.get_mutable_attr(
|
||
'in_channels')
|
||
container.register_mutable(mutable, start, end)
|
||
|
||
|
||
class SampleExpandDerivedMutable(BaseMutable):
|
||
|
||
def __init__(self, expand_ratio=1) -> None:
|
||
super().__init__()
|
||
self.ratio = expand_ratio
|
||
|
||
def __mul__(self, other):
|
||
if isinstance(other, OneShotMutableChannel):
|
||
|
||
def _expand_mask():
|
||
mask = other.current_mask
|
||
mask = torch.unsqueeze(
|
||
mask,
|
||
-1).expand(list(mask.shape) + [self.ratio]).flatten(-2)
|
||
return mask
|
||
|
||
return DerivedMutable(_expand_mask, _expand_mask, [self, other])
|
||
else:
|
||
raise NotImplementedError()
|
||
|
||
def dump_chosen(self):
|
||
return super().dump_chosen()
|
||
|
||
def export_chosen(self):
|
||
return super().export_chosen()
|
||
|
||
def fix_chosen(self, chosen):
|
||
return super().fix_chosen(chosen)
|
||
|
||
def num_choices(self) -> int:
|
||
return super().num_choices
|
||
|
||
@property
|
||
def current_choice(self):
|
||
return super().current_choice
|
||
|
||
@current_choice.setter
|
||
def current_choice(self, choice):
|
||
super().current_choice(choice)
|
||
|
||
|
||
class DynamicLinearModel(nn.Module):
|
||
"""
|
||
x
|
||
|net0,net1
|
||
|net2
|
||
|net3
|
||
x1
|
||
|fc
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
self.net = nn.Sequential(
|
||
DynamicConv2d(3, 8, 3, 1, 1), DynamicBatchNorm2d(8), nn.ReLU(),
|
||
DynamicConv2d(8, 16, 3, 1, 1), DynamicBatchNorm2d(16),
|
||
nn.AdaptiveAvgPool2d(1))
|
||
self.linear = DynamicLinear(16, 1000)
|
||
|
||
MutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
self._register_mutable()
|
||
|
||
def forward(self, x):
|
||
x1 = self.net(x)
|
||
x1 = x1.reshape([x1.shape[0], -1])
|
||
return self.linear(x1)
|
||
|
||
def _register_mutable(self):
|
||
mutable1 = OneShotMutableChannel(8, candidate_choices=[1, 4, 8])
|
||
mutable2 = OneShotMutableChannel(16, candidate_choices=[2, 8, 16])
|
||
mutable_value = SampleExpandDerivedMutable(1)
|
||
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[0], mutable1, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[1], mutable1.expand_mutable_channel(1), True, 0, 8)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[3], mutable_value * mutable1, False, 0, 8)
|
||
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[3], mutable2, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.net[4], mutable2, True)
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.linear, mutable2, False)
|
||
|
||
|
||
class DynamicAttention(nn.Module):
|
||
"""
|
||
x
|
||
|blocks: DynamicSequential(depth)
|
||
|(blocks)
|
||
x1
|
||
|fc (OneShotMutableChannel * OneShotMutableValue)
|
||
output
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
self.mutable_depth = OneShotMutableValue(
|
||
value_list=[1, 2], default_value=2)
|
||
self.mutable_embed_dims = OneShotMutableChannel(
|
||
num_channels=624, candidate_choices=[576, 624])
|
||
self.base_embed_dims = OneShotMutableChannel(
|
||
num_channels=64, candidate_choices=[64])
|
||
self.mutable_num_heads = [
|
||
OneShotMutableValue(value_list=[8, 10], default_value=10)
|
||
for _ in range(2)
|
||
]
|
||
self.mutable_mlp_ratios = [
|
||
OneShotMutableValue(value_list=[3.0, 3.5, 4.0], default_value=4.0)
|
||
for _ in range(2)
|
||
]
|
||
self.mutable_q_embed_dims = [
|
||
i * self.base_embed_dims for i in self.mutable_num_heads
|
||
]
|
||
|
||
self.patch_embed = DynamicPatchEmbed(
|
||
img_size=224,
|
||
in_channels=3,
|
||
embed_dims=self.mutable_embed_dims.num_channels)
|
||
|
||
# cls token and pos embed
|
||
self.pos_embed = nn.Parameter(
|
||
torch.zeros(1, 197, self.mutable_embed_dims.num_channels))
|
||
self.cls_token = nn.Parameter(
|
||
torch.zeros(1, 1, self.mutable_embed_dims.num_channels))
|
||
|
||
layers = []
|
||
for i in range(self.mutable_depth.max_choice):
|
||
layer = TransformerEncoderLayer(
|
||
embed_dims=self.mutable_embed_dims.num_channels,
|
||
num_heads=self.mutable_num_heads[i].max_choice,
|
||
mlp_ratio=self.mutable_mlp_ratios[i].max_choice)
|
||
layers.append(layer)
|
||
self.blocks = DynamicSequential(*layers)
|
||
|
||
# OneShotMutableChannelUnit
|
||
OneShotMutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
|
||
self.register_mutables()
|
||
|
||
def register_mutables(self):
|
||
# mutablevalue
|
||
self.blocks.register_mutable_attr('depth', self.mutable_depth)
|
||
# mutablechannel
|
||
MutableChannelContainer.register_mutable_channel_to_module(
|
||
self.patch_embed, self.mutable_embed_dims, True)
|
||
|
||
for i in range(self.mutable_depth.max_choice):
|
||
layer = self.blocks[i]
|
||
layer.register_mutables(
|
||
mutable_num_heads=self.mutable_num_heads[i],
|
||
mutable_mlp_ratios=self.mutable_mlp_ratios[i],
|
||
mutable_q_embed_dims=self.mutable_q_embed_dims[i],
|
||
mutable_head_dims=self.base_embed_dims,
|
||
mutable_embed_dims=self.mutable_embed_dims)
|
||
|
||
def forward(self, x: torch.Tensor):
|
||
B = x.shape[0]
|
||
x = self.patch_embed(x)
|
||
embed_dims = self.mutable_embed_dims.current_choice
|
||
cls_tokens = self.cls_token[..., :embed_dims].expand(B, -1, -1)
|
||
x = torch.cat((cls_tokens, x), dim=1)
|
||
x = x + self.pos_embed[..., :embed_dims]
|
||
x = self.blocks(x)
|
||
return torch.mean(x[:, 1:], dim=1)
|
||
|
||
|
||
class DynamicMMBlock(nn.Module):
|
||
|
||
arch_setting = dict(
|
||
kernel_size=[ # [min_kernel_size, max_kernel_size, step]
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
[3, 5, 2],
|
||
],
|
||
num_blocks=[ # [min_num_blocks, max_num_blocks, step]
|
||
[1, 2, 1],
|
||
[3, 5, 1],
|
||
[3, 6, 1],
|
||
[3, 6, 1],
|
||
[3, 8, 1],
|
||
[3, 8, 1],
|
||
[1, 2, 1],
|
||
],
|
||
expand_ratio=[ # [min_expand_ratio, max_expand_ratio, step]
|
||
[1, 1, 1],
|
||
[4, 6, 1],
|
||
[4, 6, 1],
|
||
[4, 6, 1],
|
||
[4, 6, 1],
|
||
[6, 6, 1],
|
||
[6, 6, 1],
|
||
],
|
||
num_out_channels=[ # [min_channel, max_channel, step]
|
||
[16, 24, 8],
|
||
[24, 32, 8],
|
||
[32, 40, 8],
|
||
[64, 72, 8],
|
||
[112, 128, 8],
|
||
[192, 216, 8],
|
||
[216, 224, 8],
|
||
])
|
||
|
||
def __init__(
|
||
self,
|
||
conv_cfg: Dict = dict(type='mmrazor.BigNasConv2d'),
|
||
norm_cfg: Dict = dict(type='mmrazor.DynamicBatchNorm2d'),
|
||
fine_grained_mode: bool = False,
|
||
) -> None:
|
||
super().__init__()
|
||
|
||
self.conv_cfg = conv_cfg
|
||
self.norm_cfg = norm_cfg
|
||
self.act_list = ['Swish'] * 7
|
||
self.stride_list = [1, 2, 2, 2, 1, 2, 1]
|
||
self.with_se_list = [False, False, True, False, True, True, True]
|
||
self.kernel_size_list = parse_values(self.arch_setting['kernel_size'])
|
||
self.num_blocks_list = parse_values(self.arch_setting['num_blocks'])
|
||
self.expand_ratio_list = \
|
||
parse_values(self.arch_setting['expand_ratio'])
|
||
self.num_channels_list = \
|
||
parse_values(self.arch_setting['num_out_channels'])
|
||
assert len(self.kernel_size_list) == len(self.num_blocks_list) == \
|
||
len(self.expand_ratio_list) == len(self.num_channels_list)
|
||
|
||
self.fine_grained_mode = fine_grained_mode
|
||
self.with_attentive_shortcut = True
|
||
self.in_channels = 24
|
||
|
||
self.first_out_channels_list = [16]
|
||
self.first_conv = ConvModule(
|
||
in_channels=3,
|
||
out_channels=24,
|
||
kernel_size=3,
|
||
stride=2,
|
||
padding=1,
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=self.norm_cfg,
|
||
act_cfg=dict(type='Swish'))
|
||
|
||
self.layers = []
|
||
for i, (num_blocks, kernel_sizes, expand_ratios, num_channels) in \
|
||
enumerate(zip(self.num_blocks_list, self.kernel_size_list,
|
||
self.expand_ratio_list, self.num_channels_list)):
|
||
inverted_res_layer = self._make_single_layer(
|
||
out_channels=num_channels,
|
||
num_blocks=num_blocks,
|
||
kernel_sizes=kernel_sizes,
|
||
expand_ratios=expand_ratios,
|
||
act_cfg=self.act_list[i],
|
||
stride=self.stride_list[i],
|
||
use_se=self.with_se_list[i])
|
||
layer_name = f'layer{i + 1}'
|
||
self.add_module(layer_name, inverted_res_layer)
|
||
self.layers.append(inverted_res_layer)
|
||
|
||
last_expand_channels = 1344
|
||
self.out_channels = 1984
|
||
self.last_out_channels_list = [1792, 1984]
|
||
self.last_expand_ratio_list = [6]
|
||
|
||
last_layers = Sequential(
|
||
OrderedDict([('final_expand_layer',
|
||
ConvModule(
|
||
in_channels=self.in_channels,
|
||
out_channels=last_expand_channels,
|
||
kernel_size=1,
|
||
padding=0,
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=self.norm_cfg,
|
||
act_cfg=dict(type='Swish'))),
|
||
('pool', nn.AdaptiveAvgPool2d((1, 1))),
|
||
('feature_mix_layer',
|
||
ConvModule(
|
||
in_channels=last_expand_channels,
|
||
out_channels=self.out_channels,
|
||
kernel_size=1,
|
||
padding=0,
|
||
bias=False,
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=None,
|
||
act_cfg=dict(type='Swish')))]))
|
||
self.add_module('last_conv', last_layers)
|
||
self.layers.append(last_layers)
|
||
|
||
self.register_mutables()
|
||
|
||
def _make_single_layer(self, out_channels, num_blocks, kernel_sizes,
|
||
expand_ratios, act_cfg, stride, use_se):
|
||
_layers = []
|
||
for i in range(max(num_blocks)):
|
||
if i >= 1:
|
||
stride = 1
|
||
if use_se:
|
||
se_cfg = dict(
|
||
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')),
|
||
ratio=4,
|
||
conv_cfg=self.conv_cfg)
|
||
else:
|
||
se_cfg = None # type: ignore
|
||
|
||
mb_layer = MBBlock(
|
||
in_channels=self.in_channels,
|
||
out_channels=max(out_channels),
|
||
kernel_size=max(kernel_sizes),
|
||
stride=stride,
|
||
expand_ratio=max(expand_ratios),
|
||
conv_cfg=self.conv_cfg,
|
||
norm_cfg=self.norm_cfg,
|
||
act_cfg=dict(type=act_cfg),
|
||
se_cfg=se_cfg,
|
||
with_attentive_shortcut=self.with_attentive_shortcut)
|
||
|
||
_layers.append(mb_layer)
|
||
self.in_channels = max(out_channels)
|
||
|
||
dynamic_seq = DynamicSequential(*_layers)
|
||
return dynamic_seq
|
||
|
||
def register_mutables(self):
|
||
"""Mutate the BigNAS-style MobileNetV3."""
|
||
OneShotMutableChannelUnit._register_channel_container(
|
||
self, MutableChannelContainer)
|
||
|
||
self.first_mutable_channels = OneShotMutableChannel(
|
||
alias='backbone.first_channels',
|
||
num_channels=max(self.first_out_channels_list),
|
||
candidate_choices=self.first_out_channels_list)
|
||
|
||
mutate_conv_module(
|
||
self.first_conv, mutable_out_channels=self.first_mutable_channels)
|
||
|
||
mid_mutable = self.first_mutable_channels
|
||
# mutate the built mobilenet layers
|
||
for i, layer in enumerate(self.layers[:-1]):
|
||
num_blocks = self.num_blocks_list[i]
|
||
kernel_sizes = self.kernel_size_list[i]
|
||
expand_ratios = self.expand_ratio_list[i]
|
||
out_channels = self.num_channels_list[i]
|
||
|
||
prefix = 'backbone.layers.' + str(i + 1) + '.'
|
||
|
||
mutable_out_channels = OneShotMutableChannel(
|
||
alias=prefix + 'out_channels',
|
||
candidate_choices=out_channels,
|
||
num_channels=max(out_channels))
|
||
|
||
if not self.fine_grained_mode:
|
||
mutable_kernel_size = OneShotMutableValue(
|
||
alias=prefix + 'kernel_size', value_list=kernel_sizes)
|
||
|
||
mutable_expand_ratio = OneShotMutableValue(
|
||
alias=prefix + 'expand_ratio', value_list=expand_ratios)
|
||
|
||
mutable_depth = OneShotMutableValue(
|
||
alias=prefix + 'depth', value_list=num_blocks)
|
||
layer.register_mutable_attr('depth', mutable_depth)
|
||
|
||
for k in range(max(self.num_blocks_list[i])):
|
||
|
||
if self.fine_grained_mode:
|
||
mutable_kernel_size = OneShotMutableValue(
|
||
alias=prefix + str(k) + '.kernel_size',
|
||
value_list=kernel_sizes)
|
||
|
||
mutable_expand_ratio = OneShotMutableValue(
|
||
alias=prefix + str(k) + '.expand_ratio',
|
||
value_list=expand_ratios)
|
||
|
||
mutate_mobilenet_layer(layer[k], mid_mutable,
|
||
mutable_out_channels,
|
||
mutable_expand_ratio,
|
||
mutable_kernel_size)
|
||
mid_mutable = mutable_out_channels
|
||
|
||
self.last_mutable_channels = OneShotMutableChannel(
|
||
alias='backbone.last_channels',
|
||
num_channels=self.out_channels,
|
||
candidate_choices=self.last_out_channels_list)
|
||
|
||
last_mutable_expand_value = OneShotMutableValue(
|
||
value_list=self.last_expand_ratio_list,
|
||
default_value=max(self.last_expand_ratio_list))
|
||
|
||
derived_expand_channels = mid_mutable * last_mutable_expand_value
|
||
mutate_conv_module(
|
||
self.layers[-1].final_expand_layer,
|
||
mutable_in_channels=mid_mutable,
|
||
mutable_out_channels=derived_expand_channels)
|
||
mutate_conv_module(
|
||
self.layers[-1].feature_mix_layer,
|
||
mutable_in_channels=derived_expand_channels,
|
||
mutable_out_channels=self.last_mutable_channels)
|
||
|
||
def forward(self, x):
|
||
x = self.first_conv(x)
|
||
for _, layer in enumerate(self.layers):
|
||
x = layer(x)
|
||
|
||
return tuple([x])
|