mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Dev mobilenetv3
This commit is contained in:
parent
e5c7556d87
commit
03b75789c6
@ -1,4 +1,5 @@
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetv3
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .seresnet import SEResNet
|
||||
@ -8,5 +9,5 @@ from .shufflenet_v2 import ShuffleNetV2
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet', 'SEResNeXt',
|
||||
'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2'
|
||||
'ShuffleNetV1', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetv3'
|
||||
]
|
||||
|
184
mmcls/models/backbones/mobilenet_v3.py
Normal file
184
mmcls/models/backbones/mobilenet_v3.py
Normal file
@ -0,0 +1,184 @@
|
||||
import logging
|
||||
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import InvertedResidual
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class MobileNetv3(BaseBackbone):
|
||||
""" MobileNetv3 backbone
|
||||
|
||||
Args:
|
||||
arch (str): Architechture of mobilnetv3, from {small, big}.
|
||||
Default: small.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
out_indices (None or Sequence[int]): Output from which stages.
|
||||
Default: None, which means output tensors from final stage.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Defualt: -1, which means not freezing any parameters.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
||||
some memory while slowing down the training speed.
|
||||
Defualt: False.
|
||||
"""
|
||||
# Parameters to build each block:
|
||||
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
||||
arch_settings = {
|
||||
'small': [[3, 16, 16, True, 'ReLU', 2],
|
||||
[3, 72, 24, False, 'ReLU', 2],
|
||||
[3, 88, 24, False, 'ReLU', 1],
|
||||
[5, 96, 40, True, 'HSwish', 2],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 240, 40, True, 'HSwish', 1],
|
||||
[5, 120, 48, True, 'HSwish', 1],
|
||||
[5, 144, 48, True, 'HSwish', 1],
|
||||
[5, 288, 96, True, 'HSwish', 2],
|
||||
[5, 576, 96, True, 'HSwish', 1],
|
||||
[5, 576, 96, True, 'HSwish', 1]],
|
||||
'big': [[3, 16, 16, False, 'ReLU', 1],
|
||||
[3, 64, 24, False, 'ReLU', 2],
|
||||
[3, 72, 24, False, 'ReLU', 1],
|
||||
[5, 72, 40, True, 'ReLU', 2],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[5, 120, 40, True, 'ReLU', 1],
|
||||
[3, 240, 80, False, 'HSwish', 2],
|
||||
[3, 200, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 184, 80, False, 'HSwish', 1],
|
||||
[3, 480, 112, True, 'HSwish', 1],
|
||||
[3, 672, 112, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 1],
|
||||
[5, 672, 160, True, 'HSwish', 2],
|
||||
[5, 960, 160, True, 'HSwish', 1]]
|
||||
} # yapf: disable
|
||||
|
||||
def __init__(self,
|
||||
arch='small',
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
out_indices=None,
|
||||
frozen_stages=-1,
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(MobileNetv3, self).__init__()
|
||||
assert arch in self.arch_settings
|
||||
if out_indices is None:
|
||||
out_indices = []
|
||||
assert isinstance(out_indices, (int, tuple, list))
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert frozen_stages <= len(self.arch_settings[arch])
|
||||
if len(out_indices):
|
||||
assert max(out_indices) < len(self.arch_settings[arch])
|
||||
self.arch = arch
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.in_channels = 16
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='HSwish'))
|
||||
|
||||
self.layers = self._make_layer()
|
||||
self.feat_dim = self.arch_settings[arch][-1][2]
|
||||
|
||||
def _make_layer(self):
|
||||
layers = []
|
||||
layer_setting = self.arch_settings[self.arch]
|
||||
for i, params in enumerate(layer_setting):
|
||||
(kernel_size, mid_channels, out_channels, with_se, act,
|
||||
stride) = params
|
||||
if with_se:
|
||||
se_cfg = dict(
|
||||
channels=mid_channels,
|
||||
ratio=4,
|
||||
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
|
||||
else:
|
||||
se_cfg = None
|
||||
|
||||
layer = InvertedResidual(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=out_channels,
|
||||
mid_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
se_cfg=se_cfg,
|
||||
with_expand_conv=True,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=dict(type=act),
|
||||
with_cp=self.with_cp)
|
||||
self.in_channels = out_channels
|
||||
layer_name = 'layer{}'.format(i + 1)
|
||||
self.add_module(layer_name, layer)
|
||||
layers.append(layer_name)
|
||||
return layers
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = logging.getLogger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
||||
outs = []
|
||||
for i, layer_name in enumerate(self.layers):
|
||||
layer = getattr(self, layer_name)
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 0:
|
||||
return x
|
||||
elif len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(MobileNetv3, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
@ -1,4 +1,6 @@
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .inverted_residual import InvertedResidual
|
||||
from .make_divisible import make_divisible
|
||||
from .se_layer import SELayer
|
||||
|
||||
__all__ = ['channel_shuffle', 'make_divisible']
|
||||
__all__ = ['channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer']
|
||||
|
118
mmcls/models/utils/inverted_residual.py
Normal file
118
mmcls/models/utils/inverted_residual.py
Normal file
@ -0,0 +1,118 @@
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from .se_layer import SELayer
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""Inverted Residual Block
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of this Module.
|
||||
out_channels (int): The output channels of this Module.
|
||||
mid_channels (int): The input channels of the depthwise convolution.
|
||||
kernel_size (int): The kernal size of the depthwise convolution.
|
||||
Default: 3.
|
||||
stride (int): The stride of the depthwise convolution. Default: 1.
|
||||
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
|
||||
se layer.
|
||||
with_expand_conv (bool): Use expand conv or not. If set False,
|
||||
mid_channels must be the same with in_channels.
|
||||
Default: True.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
se_cfg=None,
|
||||
with_expand_conv=True,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
|
||||
assert stride in [1, 2]
|
||||
self.with_cp = with_cp
|
||||
self.with_se = se_cfg is not None
|
||||
self.with_expand_conv = with_expand_conv
|
||||
|
||||
if self.with_se:
|
||||
assert isinstance(se_cfg, dict)
|
||||
if not self.with_expand_conv:
|
||||
assert mid_channels == in_channels
|
||||
|
||||
if self.with_expand_conv:
|
||||
self.expand_conv = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.depthwise_conv = ConvModule(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=kernel_size // 2,
|
||||
groups=mid_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if self.with_se:
|
||||
self.se = SELayer(**se_cfg)
|
||||
self.linear_conv = ConvModule(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
out = x
|
||||
|
||||
if self.with_expand_conv:
|
||||
out = self.expand_conv(out)
|
||||
|
||||
out = self.depthwise_conv(out)
|
||||
|
||||
if self.with_se:
|
||||
out = self.se(out)
|
||||
|
||||
out = self.linear_conv(out)
|
||||
|
||||
if self.with_res_shortcut:
|
||||
return x + out
|
||||
else:
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
@ -1,4 +1,6 @@
|
||||
import mmcv
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
@ -8,24 +10,44 @@ class SELayer(nn.Module):
|
||||
channels (int): The input (and output) channels of the SE layer.
|
||||
ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
|
||||
``int(channels/ratio)``. Default: 16.
|
||||
conv_cfg (None or dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
act_cfg (dict or Sequence[dict]): Config dict for activation layer.
|
||||
If act_cfg is a dict, two activation layers will be configurated
|
||||
by this dict. If act_cfg is a sequence of dicts, the first
|
||||
activation layer will be configurated by the first dict and the
|
||||
second activation layer will be configurated by the second dict.
|
||||
Default: (dict(type='ReLU'), dict(type='Sigmoid'))
|
||||
"""
|
||||
|
||||
def __init__(self, channels, ratio=16):
|
||||
def __init__(self,
|
||||
channels,
|
||||
ratio=16,
|
||||
conv_cfg=None,
|
||||
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
|
||||
super(SELayer, self).__init__()
|
||||
if isinstance(act_cfg, dict):
|
||||
act_cfg = (act_cfg, act_cfg)
|
||||
assert len(act_cfg) == 2
|
||||
assert mmcv.is_tuple_of(act_cfg, dict)
|
||||
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1 = nn.Conv2d(
|
||||
channels, int(channels / ratio), kernel_size=1, stride=1)
|
||||
self.conv2 = nn.Conv2d(
|
||||
int(channels / ratio), channels, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=channels,
|
||||
out_channels=int(channels / ratio),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[0])
|
||||
self.conv2 = ConvModule(
|
||||
in_channels=int(channels / ratio),
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=conv_cfg,
|
||||
act_cfg=act_cfg[1])
|
||||
|
||||
def forward(self, x):
|
||||
out = self.global_avgpool(x)
|
||||
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.sigmoid(out)
|
||||
return x * out
|
||||
|
168
tests/test_backbones/test_mobilenet_v3.py
Normal file
168
tests/test_backbones/test_mobilenet_v3.py
Normal file
@ -0,0 +1,168 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import MobileNetv3
|
||||
from mmcls.models.utils import InvertedResidual
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_mobilenetv3_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = MobileNetv3()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# arch must in [small, big]
|
||||
MobileNetv3(arch='others')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# frozen_stages must less than 12 when arch is small
|
||||
MobileNetv3(arch='small', frozen_stages=12)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# frozen_stages must less than 16 when arch is big
|
||||
MobileNetv3(arch='big', frozen_stages=16)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# max out_indices must less than 11 when arch is small
|
||||
MobileNetv3(arch='small', out_indices=(11))
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# max out_indices must less than 15 when arch is big
|
||||
MobileNetv3(arch='big', out_indices=(15))
|
||||
|
||||
# Test MobileNetv3
|
||||
model = MobileNetv3()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
# Test MobileNetv3 with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = MobileNetv3(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
for param in model.conv1.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test MobileNetv3 with norm eval
|
||||
model = MobileNetv3(norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test MobileNetv3 forward with small arch
|
||||
model = MobileNetv3(out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 11
|
||||
assert feat[0].shape == torch.Size([1, 16, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[8].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
||||
|
||||
# Test MobileNetv3 forward with small arch and GroupNorm
|
||||
model = MobileNetv3(
|
||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
|
||||
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, GroupNorm)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 11
|
||||
assert feat[0].shape == torch.Size([1, 16, 56, 56])
|
||||
assert feat[1].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[2].shape == torch.Size([1, 24, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[4].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[5].shape == torch.Size([1, 40, 14, 14])
|
||||
assert feat[6].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[7].shape == torch.Size([1, 48, 14, 14])
|
||||
assert feat[8].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[9].shape == torch.Size([1, 96, 7, 7])
|
||||
assert feat[10].shape == torch.Size([1, 96, 7, 7])
|
||||
|
||||
# Test MobileNetv3 forward with big arch
|
||||
model = MobileNetv3(
|
||||
arch='big',
|
||||
out_indices=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 15
|
||||
assert feat[0].shape == torch.Size([1, 16, 112, 112])
|
||||
assert feat[1].shape == torch.Size([1, 24, 56, 56])
|
||||
assert feat[2].shape == torch.Size([1, 24, 56, 56])
|
||||
assert feat[3].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[4].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[5].shape == torch.Size([1, 40, 28, 28])
|
||||
assert feat[6].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[7].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[8].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[9].shape == torch.Size([1, 80, 14, 14])
|
||||
assert feat[10].shape == torch.Size([1, 112, 14, 14])
|
||||
assert feat[11].shape == torch.Size([1, 112, 14, 14])
|
||||
assert feat[12].shape == torch.Size([1, 160, 14, 14])
|
||||
assert feat[13].shape == torch.Size([1, 160, 7, 7])
|
||||
assert feat[14].shape == torch.Size([1, 160, 7, 7])
|
||||
|
||||
# Test MobileNetv3 forward with big arch
|
||||
model = MobileNetv3(arch='big', out_indices=(0))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 16, 112, 112])
|
||||
|
||||
# Test MobileNetv3 with checkpoint forward
|
||||
model = MobileNetv3(with_cp=True)
|
||||
for m in model.modules():
|
||||
if isinstance(m, InvertedResidual):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat.shape == torch.Size([1, 96, 7, 7])
|
@ -1,7 +1,17 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle, make_divisible
|
||||
from mmcls.models.utils import (InvertedResidual, SELayer, channel_shuffle,
|
||||
make_divisible)
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_make_divisible():
|
||||
@ -35,3 +45,72 @@ def test_channel_shuffle():
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
assert x[b, c, i, j] == out[b, c_out, i, j]
|
||||
|
||||
|
||||
def test_inverted_residual():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# stride must be in [1, 2]
|
||||
InvertedResidual(16, 16, 32, stride=3)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# se_cfg must be None or dict
|
||||
InvertedResidual(16, 16, 32, se_cfg=list())
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_channeld and out_channels must be the same if
|
||||
# with_expand_conv is False
|
||||
InvertedResidual(16, 16, 32, with_expand_conv=False)
|
||||
|
||||
# Test InvertedResidual forward, stride=1
|
||||
block = InvertedResidual(16, 16, 32, stride=1)
|
||||
x = torch.randn(1, 16, 56, 56)
|
||||
x_out = block(x)
|
||||
assert getattr(block, 'se', None) is None
|
||||
assert block.with_res_shortcut
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
||||
# Test InvertedResidual forward, stride=2
|
||||
block = InvertedResidual(16, 16, 32, stride=2)
|
||||
x = torch.randn(1, 16, 56, 56)
|
||||
x_out = block(x)
|
||||
assert not block.with_res_shortcut
|
||||
assert x_out.shape == torch.Size((1, 16, 28, 28))
|
||||
|
||||
# Test InvertedResidual forward with se layer
|
||||
se_cfg = dict(channels=32)
|
||||
block = InvertedResidual(16, 16, 32, stride=1, se_cfg=se_cfg)
|
||||
x = torch.randn(1, 16, 56, 56)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.se, SELayer)
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
||||
# Test InvertedResidual forward, with_expand_conv=False
|
||||
block = InvertedResidual(32, 16, 32, with_expand_conv=False)
|
||||
x = torch.randn(1, 32, 56, 56)
|
||||
x_out = block(x)
|
||||
assert getattr(block, 'expand_conv', None) is None
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
||||
# Test InvertedResidual forward with GroupNorm
|
||||
block = InvertedResidual(
|
||||
16, 16, 32, norm_cfg=dict(type='GN', num_groups=2))
|
||||
x = torch.randn(1, 16, 56, 56)
|
||||
x_out = block(x)
|
||||
for m in block.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, GroupNorm)
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
||||
# Test InvertedResidual forward with HSigmoid
|
||||
block = InvertedResidual(16, 16, 32, act_cfg=dict(type='HSigmoid'))
|
||||
x = torch.randn(1, 16, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
||||
# Test InvertedResidual forward with checkpoint
|
||||
block = InvertedResidual(16, 16, 32, with_cp=True)
|
||||
x = torch.randn(1, 16, 56, 56)
|
||||
x_out = block(x)
|
||||
assert block.with_cp
|
||||
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||
|
Loading…
x
Reference in New Issue
Block a user