Merge branch 'dev_mobilenetv2' into 'master'
add mobilenetv2 See merge request open-mmlab/mmclassification!4pull/2/head
commit
3a5b25162e
|
@ -1,13 +1,10 @@
|
||||||
|
from .mobilenet_v2 import MobileNetV2
|
||||||
from .resnet import ResNet, ResNetV1d
|
from .resnet import ResNet, ResNetV1d
|
||||||
from .resnext import ResNeXt
|
from .resnext import ResNeXt
|
||||||
from .shufflenet_v1 import ShuffleNetv1
|
from .shufflenet_v1 import ShuffleNetV1
|
||||||
from .shufflenet_v2 import ShuffleNetv2
|
from .shufflenet_v2 import ShuffleNetV2
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ResNet',
|
'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'ShuffleNetV1',
|
||||||
'ResNeXt',
|
'ShuffleNetV2', 'MobileNetV2'
|
||||||
'ResNetV1d',
|
|
||||||
'ResNetV1d',
|
|
||||||
'ShuffleNetv1',
|
|
||||||
'ShuffleNetv2',
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,265 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as cp
|
||||||
|
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.utils import make_divisible
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from .base_backbone import BaseBackbone
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
"""InvertedResidual block for MobileNetV2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inplanes (int): The input channels of the InvertedResidual block.
|
||||||
|
planes (int): The output channels of the InvertedResidual block.
|
||||||
|
stride (int): Stride of the middle (first) 3x3 convolution.
|
||||||
|
expand_ratio (int): adjusts number of channels of the hidden layer
|
||||||
|
in InvertedResidual by this amount.
|
||||||
|
conv_cfg (dict): Config dict for convolution layer.
|
||||||
|
Default: None, which means using conv2d.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict): Config dict for activation layer.
|
||||||
|
Default: dict(type='ReLU6').
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: The output tensor
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
inplanes,
|
||||||
|
planes,
|
||||||
|
stride,
|
||||||
|
expand_ratio,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU6'),
|
||||||
|
with_cp=False):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2], f'stride must in [1, 2]. ' \
|
||||||
|
f'But received {stride}.'
|
||||||
|
self.with_cp = with_cp
|
||||||
|
self.use_res_connect = self.stride == 1 and inplanes == planes
|
||||||
|
hidden_dim = int(round(inplanes * expand_ratio))
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
if expand_ratio != 1:
|
||||||
|
layers.append(
|
||||||
|
ConvModule(
|
||||||
|
in_channels=inplanes,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg))
|
||||||
|
layers.extend([
|
||||||
|
ConvModule(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
padding=1,
|
||||||
|
groups=hidden_dim,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=act_cfg),
|
||||||
|
ConvModule(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=planes,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
conv_cfg=conv_cfg,
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
act_cfg=None)
|
||||||
|
])
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
def _inner_forward(x):
|
||||||
|
if self.use_res_connect:
|
||||||
|
return x + self.conv(x)
|
||||||
|
else:
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
if self.with_cp and x.requires_grad:
|
||||||
|
out = cp.checkpoint(_inner_forward, x)
|
||||||
|
else:
|
||||||
|
out = _inner_forward(x)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class MobileNetV2(BaseBackbone):
|
||||||
|
"""MobileNetV2 backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
widen_factor (float): Width multiplier, multiply number of
|
||||||
|
channels in each layer by this amount. Default: 1.0.
|
||||||
|
out_indices (None or Sequence[int]): Output from which stages.
|
||||||
|
Default: None
|
||||||
|
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||||
|
Default: -1, which means not freezing any parameters.
|
||||||
|
conv_cfg (dict): Config dict for convolution layer.
|
||||||
|
Default: None, which means using conv2d.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Default: dict(type='BN').
|
||||||
|
act_cfg (dict): Config dict for activation layer.
|
||||||
|
Default: dict(type='ReLU6').
|
||||||
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||||
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||||
|
and its variants only. Default: False.
|
||||||
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||||
|
memory while slowing down the training speed. Default: False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Parameters to build layers. 4 parameters are needed to construct a
|
||||||
|
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
|
||||||
|
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
|
||||||
|
[6, 320, 1, 1]]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
widen_factor=1.,
|
||||||
|
out_indices=None,
|
||||||
|
frozen_stages=-1,
|
||||||
|
conv_cfg=None,
|
||||||
|
norm_cfg=dict(type='BN'),
|
||||||
|
act_cfg=dict(type='ReLU6'),
|
||||||
|
norm_eval=False,
|
||||||
|
with_cp=False):
|
||||||
|
super(MobileNetV2, self).__init__()
|
||||||
|
self.widen_factor = widen_factor
|
||||||
|
self.out_indices = out_indices
|
||||||
|
if out_indices is not None:
|
||||||
|
assert max(out_indices) < len(self.arch_settings)
|
||||||
|
self.frozen_stages = frozen_stages
|
||||||
|
assert frozen_stages < len(self.arch_settings)
|
||||||
|
self.conv_cfg = conv_cfg
|
||||||
|
self.norm_cfg = norm_cfg
|
||||||
|
self.act_cfg = act_cfg
|
||||||
|
self.norm_eval = norm_eval
|
||||||
|
self.with_cp = with_cp
|
||||||
|
|
||||||
|
self.inplanes = make_divisible(32 * widen_factor, 8)
|
||||||
|
|
||||||
|
self.conv1 = ConvModule(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=self.inplanes,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
|
||||||
|
self.inverted_res_layers = []
|
||||||
|
|
||||||
|
for i, layer_cfg in enumerate(self.arch_settings):
|
||||||
|
expand_ratio, channel, num_blocks, stride = layer_cfg
|
||||||
|
planes = make_divisible(channel * widen_factor, 8)
|
||||||
|
inverted_res_layer = self.make_layer(
|
||||||
|
planes=planes,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
stride=stride,
|
||||||
|
expand_ratio=expand_ratio)
|
||||||
|
layer_name = f'layer{i + 1}'
|
||||||
|
self.add_module(layer_name, inverted_res_layer)
|
||||||
|
self.inverted_res_layers.append(layer_name)
|
||||||
|
|
||||||
|
if widen_factor > 1.0:
|
||||||
|
self.out_channel = int(1280 * widen_factor)
|
||||||
|
else:
|
||||||
|
self.out_channel = 1280
|
||||||
|
|
||||||
|
self.conv2 = ConvModule(
|
||||||
|
in_channels=self.inplanes,
|
||||||
|
out_channels=self.out_channel,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg)
|
||||||
|
|
||||||
|
def make_layer(self, planes, num_blocks, stride, expand_ratio):
|
||||||
|
""" Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
planes (int): planes of block.
|
||||||
|
num_blocks (int): number of blocks.
|
||||||
|
stride (int): stride of the first block. Default: 1
|
||||||
|
expand_ratio (int): Expand the number of channels of the
|
||||||
|
hidden layer in InvertedResidual by this ratio. Default: 6.
|
||||||
|
"""
|
||||||
|
layers = []
|
||||||
|
for i in range(num_blocks):
|
||||||
|
if i >= 1:
|
||||||
|
stride = 1
|
||||||
|
layers.append(
|
||||||
|
InvertedResidual(
|
||||||
|
self.inplanes,
|
||||||
|
planes,
|
||||||
|
stride,
|
||||||
|
expand_ratio=expand_ratio,
|
||||||
|
conv_cfg=self.conv_cfg,
|
||||||
|
norm_cfg=self.norm_cfg,
|
||||||
|
act_cfg=self.act_cfg,
|
||||||
|
with_cp=self.with_cp))
|
||||||
|
self.inplanes = planes
|
||||||
|
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def init_weights(self, pretrained=None):
|
||||||
|
if pretrained is None:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
kaiming_init(m)
|
||||||
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||||
|
constant_init(m, 1)
|
||||||
|
else:
|
||||||
|
raise TypeError('pretrained must be a str or None')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for i, layer_name in enumerate(self.inverted_res_layers):
|
||||||
|
inverted_res_layer = getattr(self, layer_name)
|
||||||
|
x = inverted_res_layer(x)
|
||||||
|
if self.out_indices is not None and i in self.out_indices:
|
||||||
|
outs.append(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
if self.out_indices is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return tuple(outs)
|
||||||
|
|
||||||
|
def _freeze_stages(self):
|
||||||
|
if self.frozen_stages >= 0:
|
||||||
|
for param in self.conv1.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
for i in range(1, self.frozen_stages + 1):
|
||||||
|
layer = getattr(self, f'layer{i}')
|
||||||
|
layer.eval()
|
||||||
|
for param in layer.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def train(self, mode=True):
|
||||||
|
super(MobileNetV2, self).train(mode)
|
||||||
|
self._freeze_stages()
|
||||||
|
if mode and self.norm_eval:
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, _BatchNorm):
|
||||||
|
m.eval()
|
|
@ -6,6 +6,7 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.utils import channel_shuffle, make_divisible
|
from mmcls.models.utils import channel_shuffle, make_divisible
|
||||||
|
from ..builder import BACKBONES
|
||||||
from .base_backbone import BaseBackbone
|
from .base_backbone import BaseBackbone
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,8 +140,9 @@ class ShuffleUnit(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ShuffleNetv1(BaseBackbone):
|
@BACKBONES.register_module()
|
||||||
"""ShuffleNetv1 backbone.
|
class ShuffleNetV1(BaseBackbone):
|
||||||
|
"""ShuffleNetV1 backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
groups (int, optional): The number of groups to be used in grouped 1x1
|
groups (int, optional): The number of groups to be used in grouped 1x1
|
||||||
|
@ -174,7 +176,7 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
super(ShuffleNetv1, self).__init__()
|
super(ShuffleNetV1, self).__init__()
|
||||||
self.stage_blocks = [3, 7, 3]
|
self.stage_blocks = [3, 7, 3]
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
|
|
||||||
|
@ -294,7 +296,7 @@ class ShuffleNetv1(BaseBackbone):
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
super(ShuffleNetv1, self).train(mode)
|
super(ShuffleNetV1, self).train(mode)
|
||||||
self._freeze_stages()
|
self._freeze_stages()
|
||||||
if mode and self.norm_eval:
|
if mode and self.norm_eval:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
|
|
|
@ -5,6 +5,7 @@ from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.utils import channel_shuffle
|
from mmcls.models.utils import channel_shuffle
|
||||||
|
from ..builder import BACKBONES
|
||||||
from .base_backbone import BaseBackbone
|
from .base_backbone import BaseBackbone
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,8 +126,9 @@ class InvertedResidual(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ShuffleNetv2(BaseBackbone):
|
@BACKBONES.register_module()
|
||||||
"""ShuffleNetv2 backbone.
|
class ShuffleNetV2(BaseBackbone):
|
||||||
|
"""ShuffleNetV2 backbone.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
groups (int): The number of groups to be used in grouped 1x1
|
groups (int): The number of groups to be used in grouped 1x1
|
||||||
|
@ -160,7 +162,7 @@ class ShuffleNetv2(BaseBackbone):
|
||||||
act_cfg=dict(type='ReLU'),
|
act_cfg=dict(type='ReLU'),
|
||||||
norm_eval=False,
|
norm_eval=False,
|
||||||
with_cp=False):
|
with_cp=False):
|
||||||
super(ShuffleNetv2, self).__init__()
|
super(ShuffleNetV2, self).__init__()
|
||||||
self.stage_blocks = [4, 8, 4]
|
self.stage_blocks = [4, 8, 4]
|
||||||
self.groups = groups
|
self.groups = groups
|
||||||
self.out_indices = out_indices
|
self.out_indices = out_indices
|
||||||
|
@ -273,7 +275,7 @@ class ShuffleNetv2(BaseBackbone):
|
||||||
return tuple(outs)
|
return tuple(outs)
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
super(ShuffleNetv2, self).train(mode)
|
super(ShuffleNetV2, self).train(mode)
|
||||||
self._freeze_stages()
|
self._freeze_stages()
|
||||||
if mode and self.norm_eval:
|
if mode and self.norm_eval:
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
|
|
|
@ -0,0 +1,255 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules import GroupNorm
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.backbones import MobileNetV2
|
||||||
|
from mmcls.models.backbones.mobilenet_v2 import InvertedResidual
|
||||||
|
|
||||||
|
|
||||||
|
def is_block(modules):
|
||||||
|
"""Check if is ResNet building block."""
|
||||||
|
if isinstance(modules, (InvertedResidual, )):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm(modules):
|
||||||
|
"""Check if is one of the norms."""
|
||||||
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_norm_state(modules, train_state):
|
||||||
|
"""Check if norm layer is in correct train state."""
|
||||||
|
for mod in modules:
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
if mod.training != train_state:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_mobilenetv2_invertedresidual():
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# stride must be in [1, 2]
|
||||||
|
InvertedResidual(16, 24, stride=3, expand_ratio=6)
|
||||||
|
|
||||||
|
# Test InvertedResidual with checkpoint forward, stride=1
|
||||||
|
block = InvertedResidual(16, 24, stride=1, expand_ratio=6)
|
||||||
|
x = torch.randn(1, 16, 56, 56)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size((1, 24, 56, 56))
|
||||||
|
|
||||||
|
# Test InvertedResidual with expand_ratio=1
|
||||||
|
block = InvertedResidual(16, 16, stride=1, expand_ratio=1)
|
||||||
|
assert len(block.conv) == 2
|
||||||
|
|
||||||
|
# Test InvertedResidual with use_res_connect
|
||||||
|
block = InvertedResidual(16, 16, stride=1, expand_ratio=6)
|
||||||
|
x = torch.randn(1, 16, 56, 56)
|
||||||
|
x_out = block(x)
|
||||||
|
assert block.use_res_connect is True
|
||||||
|
assert x_out.shape == torch.Size((1, 16, 56, 56))
|
||||||
|
|
||||||
|
# Test InvertedResidual with checkpoint forward, stride=2
|
||||||
|
block = InvertedResidual(16, 24, stride=2, expand_ratio=6)
|
||||||
|
x = torch.randn(1, 16, 56, 56)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size((1, 24, 28, 28))
|
||||||
|
|
||||||
|
# Test InvertedResidual with checkpoint forward
|
||||||
|
block = InvertedResidual(16, 24, stride=1, expand_ratio=6, with_cp=True)
|
||||||
|
assert block.with_cp
|
||||||
|
x = torch.randn(1, 16, 56, 56)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size((1, 24, 56, 56))
|
||||||
|
|
||||||
|
# Test InvertedResidual with act_cfg=dict(type='ReLU')
|
||||||
|
block = InvertedResidual(
|
||||||
|
16, 24, stride=1, expand_ratio=6, act_cfg=dict(type='ReLU'))
|
||||||
|
x = torch.randn(1, 16, 56, 56)
|
||||||
|
x_out = block(x)
|
||||||
|
assert x_out.shape == torch.Size((1, 24, 56, 56))
|
||||||
|
|
||||||
|
|
||||||
|
def test_mobilenetv2_backbone():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
# pretrained must be a string path
|
||||||
|
model = MobileNetV2()
|
||||||
|
model.init_weights(pretrained=0)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# frozen_stages must less than 7
|
||||||
|
MobileNetV2(frozen_stages=8)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
# the max value in out_indices must less than 7
|
||||||
|
MobileNetV2(out_indices=[8])
|
||||||
|
|
||||||
|
# Test MobileNetV2 with first stage frozen
|
||||||
|
frozen_stages = 1
|
||||||
|
model = MobileNetV2(frozen_stages=frozen_stages)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for mod in model.conv1.modules():
|
||||||
|
for param in mod.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
for i in range(1, frozen_stages + 1):
|
||||||
|
layer = getattr(model, f'layer{i}')
|
||||||
|
for mod in layer.modules():
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
assert mod.training is False
|
||||||
|
for param in layer.parameters():
|
||||||
|
assert param.requires_grad is False
|
||||||
|
|
||||||
|
# Test MobileNetV2 with norm_eval=True
|
||||||
|
model = MobileNetV2(norm_eval=True)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
|
# Test MobileNetV2 forward with widen_factor=1.0
|
||||||
|
model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 7
|
||||||
|
assert feat[0].shape == torch.Size((1, 16, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 24, 56, 56))
|
||||||
|
assert feat[2].shape == torch.Size((1, 32, 28, 28))
|
||||||
|
assert feat[3].shape == torch.Size((1, 64, 14, 14))
|
||||||
|
assert feat[4].shape == torch.Size((1, 96, 14, 14))
|
||||||
|
assert feat[5].shape == torch.Size((1, 160, 7, 7))
|
||||||
|
assert feat[6].shape == torch.Size((1, 320, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 forward with widen_factor=0.5
|
||||||
|
model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 7
|
||||||
|
assert feat[0].shape == torch.Size((1, 8, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 16, 56, 56))
|
||||||
|
assert feat[2].shape == torch.Size((1, 16, 28, 28))
|
||||||
|
assert feat[3].shape == torch.Size((1, 32, 14, 14))
|
||||||
|
assert feat[4].shape == torch.Size((1, 48, 14, 14))
|
||||||
|
assert feat[5].shape == torch.Size((1, 80, 7, 7))
|
||||||
|
assert feat[6].shape == torch.Size((1, 160, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 forward with widen_factor=2.0
|
||||||
|
model = MobileNetV2(widen_factor=2.0, out_indices=None)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat.shape == torch.Size((1, 2560, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 forward with out_indices=None
|
||||||
|
model = MobileNetV2(widen_factor=1.0, out_indices=None)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert feat.shape == torch.Size((1, 1280, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 forward with dict(type='ReLU')
|
||||||
|
model = MobileNetV2(
|
||||||
|
widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 7
|
||||||
|
assert feat[0].shape == torch.Size((1, 16, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 24, 56, 56))
|
||||||
|
assert feat[2].shape == torch.Size((1, 32, 28, 28))
|
||||||
|
assert feat[3].shape == torch.Size((1, 64, 14, 14))
|
||||||
|
assert feat[4].shape == torch.Size((1, 96, 14, 14))
|
||||||
|
assert feat[5].shape == torch.Size((1, 160, 7, 7))
|
||||||
|
assert feat[6].shape == torch.Size((1, 320, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 with GroupNorm forward
|
||||||
|
model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7))
|
||||||
|
for m in model.modules():
|
||||||
|
if is_norm(m):
|
||||||
|
assert isinstance(m, _BatchNorm)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 7
|
||||||
|
assert feat[0].shape == torch.Size((1, 16, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 24, 56, 56))
|
||||||
|
assert feat[2].shape == torch.Size((1, 32, 28, 28))
|
||||||
|
assert feat[3].shape == torch.Size((1, 64, 14, 14))
|
||||||
|
assert feat[4].shape == torch.Size((1, 96, 14, 14))
|
||||||
|
assert feat[5].shape == torch.Size((1, 160, 7, 7))
|
||||||
|
assert feat[6].shape == torch.Size((1, 320, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 with BatchNorm forward
|
||||||
|
model = MobileNetV2(
|
||||||
|
widen_factor=1.0,
|
||||||
|
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True),
|
||||||
|
out_indices=range(0, 7))
|
||||||
|
for m in model.modules():
|
||||||
|
if is_norm(m):
|
||||||
|
assert isinstance(m, GroupNorm)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 7
|
||||||
|
assert feat[0].shape == torch.Size((1, 16, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 24, 56, 56))
|
||||||
|
assert feat[2].shape == torch.Size((1, 32, 28, 28))
|
||||||
|
assert feat[3].shape == torch.Size((1, 64, 14, 14))
|
||||||
|
assert feat[4].shape == torch.Size((1, 96, 14, 14))
|
||||||
|
assert feat[5].shape == torch.Size((1, 160, 7, 7))
|
||||||
|
assert feat[6].shape == torch.Size((1, 320, 7, 7))
|
||||||
|
|
||||||
|
# Test MobileNetV2 with layers 1, 3, 5 out forward
|
||||||
|
model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4))
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 3
|
||||||
|
assert feat[0].shape == torch.Size((1, 16, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 32, 28, 28))
|
||||||
|
assert feat[2].shape == torch.Size((1, 96, 14, 14))
|
||||||
|
|
||||||
|
# Test MobileNetV2 with checkpoint forward
|
||||||
|
model = MobileNetV2(
|
||||||
|
widen_factor=1.0, with_cp=True, out_indices=range(0, 7))
|
||||||
|
for m in model.modules():
|
||||||
|
if is_block(m):
|
||||||
|
assert m.with_cp
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
imgs = torch.randn(1, 3, 224, 224)
|
||||||
|
feat = model(imgs)
|
||||||
|
assert len(feat) == 7
|
||||||
|
assert feat[0].shape == torch.Size((1, 16, 112, 112))
|
||||||
|
assert feat[1].shape == torch.Size((1, 24, 56, 56))
|
||||||
|
assert feat[2].shape == torch.Size((1, 32, 28, 28))
|
||||||
|
assert feat[3].shape == torch.Size((1, 64, 14, 14))
|
||||||
|
assert feat[4].shape == torch.Size((1, 96, 14, 14))
|
||||||
|
assert feat[5].shape == torch.Size((1, 160, 7, 7))
|
||||||
|
assert feat[6].shape == torch.Size((1, 320, 7, 7))
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch.nn.modules import GroupNorm
|
from torch.nn.modules import GroupNorm
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.backbones import ShuffleNetv1
|
from mmcls.models.backbones import ShuffleNetV1
|
||||||
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit
|
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,30 +66,30 @@ def test_shufflenetv1_backbone():
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# frozen_stages must be in range(-1, 4)
|
# frozen_stages must be in range(-1, 4)
|
||||||
ShuffleNetv1(frozen_stages=10)
|
ShuffleNetV1(frozen_stages=10)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# the item in out_indices must be in range(0, 4)
|
# the item in out_indices must be in range(0, 4)
|
||||||
ShuffleNetv1(out_indices=[5])
|
ShuffleNetV1(out_indices=[5])
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# groups must be in [1, 2, 3, 4, 8]
|
# groups must be in [1, 2, 3, 4, 8]
|
||||||
ShuffleNetv1(groups=10)
|
ShuffleNetV1(groups=10)
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# pretrained must be str or None
|
# pretrained must be str or None
|
||||||
model = ShuffleNetv1()
|
model = ShuffleNetV1()
|
||||||
model.init_weights(pretrained=1)
|
model.init_weights(pretrained=1)
|
||||||
|
|
||||||
# Test ShuffleNetv1 norm state
|
# Test ShuffleNetV1 norm state
|
||||||
model = ShuffleNetv1()
|
model = ShuffleNetV1()
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert check_norm_state(model.modules(), True)
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
# Test ShuffleNetv1 with first stage frozen
|
# Test ShuffleNetV1 with first stage frozen
|
||||||
frozen_stages = 1
|
frozen_stages = 1
|
||||||
model = ShuffleNetv1(frozen_stages=frozen_stages)
|
model = ShuffleNetV1(frozen_stages=frozen_stages)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
for param in model.conv1.parameters():
|
for param in model.conv1.parameters():
|
||||||
|
@ -102,8 +102,8 @@ def test_shufflenetv1_backbone():
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
assert param.requires_grad is False
|
assert param.requires_grad is False
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with groups=1
|
# Test ShuffleNetV1 forward with groups=1
|
||||||
model = ShuffleNetv1(groups=1)
|
model = ShuffleNetV1(groups=1)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -118,8 +118,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 288, 14, 14))
|
assert feat[1].shape == torch.Size((1, 288, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 576, 7, 7))
|
assert feat[2].shape == torch.Size((1, 576, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with groups=2
|
# Test ShuffleNetV1 forward with groups=2
|
||||||
model = ShuffleNetv1(groups=2)
|
model = ShuffleNetV1(groups=2)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -134,8 +134,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 400, 14, 14))
|
assert feat[1].shape == torch.Size((1, 400, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 800, 7, 7))
|
assert feat[2].shape == torch.Size((1, 800, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with groups=3
|
# Test ShuffleNetV1 forward with groups=3
|
||||||
model = ShuffleNetv1(groups=3)
|
model = ShuffleNetV1(groups=3)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -150,8 +150,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 480, 14, 14))
|
assert feat[1].shape == torch.Size((1, 480, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 960, 7, 7))
|
assert feat[2].shape == torch.Size((1, 960, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with groups=4
|
# Test ShuffleNetV1 forward with groups=4
|
||||||
model = ShuffleNetv1(groups=4)
|
model = ShuffleNetV1(groups=4)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -166,8 +166,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 544, 14, 14))
|
assert feat[1].shape == torch.Size((1, 544, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 1088, 7, 7))
|
assert feat[2].shape == torch.Size((1, 1088, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with groups=8
|
# Test ShuffleNetV1 forward with groups=8
|
||||||
model = ShuffleNetv1(groups=8)
|
model = ShuffleNetV1(groups=8)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -182,8 +182,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 768, 14, 14))
|
assert feat[1].shape == torch.Size((1, 768, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 1536, 7, 7))
|
assert feat[2].shape == torch.Size((1, 1536, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with GroupNorm forward
|
# Test ShuffleNetV1 forward with GroupNorm forward
|
||||||
model = ShuffleNetv1(
|
model = ShuffleNetV1(
|
||||||
groups=3, norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
|
groups=3, norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -199,8 +199,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 480, 14, 14))
|
assert feat[1].shape == torch.Size((1, 480, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 960, 7, 7))
|
assert feat[2].shape == torch.Size((1, 960, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with layers 1, 2 forward
|
# Test ShuffleNetV1 forward with layers 1, 2 forward
|
||||||
model = ShuffleNetv1(groups=3, out_indices=(1, 2))
|
model = ShuffleNetV1(groups=3, out_indices=(1, 2))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -214,8 +214,8 @@ def test_shufflenetv1_backbone():
|
||||||
assert feat[0].shape == torch.Size((1, 480, 14, 14))
|
assert feat[0].shape == torch.Size((1, 480, 14, 14))
|
||||||
assert feat[1].shape == torch.Size((1, 960, 7, 7))
|
assert feat[1].shape == torch.Size((1, 960, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with layers 2 forward
|
# Test ShuffleNetV1 forward with layers 2 forward
|
||||||
model = ShuffleNetv1(groups=3, out_indices=(2, ))
|
model = ShuffleNetV1(groups=3, out_indices=(2, ))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -228,14 +228,14 @@ def test_shufflenetv1_backbone():
|
||||||
assert isinstance(feat, torch.Tensor)
|
assert isinstance(feat, torch.Tensor)
|
||||||
assert feat.shape == torch.Size((1, 960, 7, 7))
|
assert feat.shape == torch.Size((1, 960, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv1 forward with checkpoint forward
|
# Test ShuffleNetV1 forward with checkpoint forward
|
||||||
model = ShuffleNetv1(groups=3, with_cp=True)
|
model = ShuffleNetV1(groups=3, with_cp=True)
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if is_block(m):
|
if is_block(m):
|
||||||
assert m.with_cp
|
assert m.with_cp
|
||||||
|
|
||||||
# Test ShuffleNetv1 with norm_eval
|
# Test ShuffleNetV1 with norm_eval
|
||||||
model = ShuffleNetv1(norm_eval=True)
|
model = ShuffleNetV1(norm_eval=True)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from torch.nn.modules import GroupNorm
|
from torch.nn.modules import GroupNorm
|
||||||
from torch.nn.modules.batchnorm import _BatchNorm
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
from mmcls.models.backbones import ShuffleNetv2
|
from mmcls.models.backbones import ShuffleNetV2
|
||||||
from mmcls.models.backbones.shufflenet_v2 import InvertedResidual
|
from mmcls.models.backbones.shufflenet_v2 import InvertedResidual
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,26 +59,26 @@ def test_shufflenetv2_backbone():
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
# groups must be in 0.5, 1.0, 1.5, 2.0]
|
# groups must be in 0.5, 1.0, 1.5, 2.0]
|
||||||
ShuffleNetv2(widen_factor=3.0)
|
ShuffleNetV2(widen_factor=3.0)
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
# frozen_stages must be in [0, 1, 2]
|
# frozen_stages must be in [0, 1, 2]
|
||||||
ShuffleNetv2(widen_factor=3.0, frozen_stages=3)
|
ShuffleNetV2(widen_factor=3.0, frozen_stages=3)
|
||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
# pretrained must be str or None
|
# pretrained must be str or None
|
||||||
model = ShuffleNetv2()
|
model = ShuffleNetV2()
|
||||||
model.init_weights(pretrained=1)
|
model.init_weights(pretrained=1)
|
||||||
|
|
||||||
# Test ShuffleNetv2 norm state
|
# Test ShuffleNetV2 norm state
|
||||||
model = ShuffleNetv2()
|
model = ShuffleNetV2()
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
assert check_norm_state(model.modules(), True)
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
# Test ShuffleNetv2 with first stage frozen
|
# Test ShuffleNetV2 with first stage frozen
|
||||||
frozen_stages = 1
|
frozen_stages = 1
|
||||||
model = ShuffleNetv2(frozen_stages=frozen_stages)
|
model = ShuffleNetV2(frozen_stages=frozen_stages)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
for param in model.conv1.parameters():
|
for param in model.conv1.parameters():
|
||||||
|
@ -91,15 +91,15 @@ def test_shufflenetv2_backbone():
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
assert param.requires_grad is False
|
assert param.requires_grad is False
|
||||||
|
|
||||||
# Test ShuffleNetv2 with norm_eval
|
# Test ShuffleNetV2 with norm_eval
|
||||||
model = ShuffleNetv2(norm_eval=True)
|
model = ShuffleNetV2(norm_eval=True)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
assert check_norm_state(model.modules(), False)
|
assert check_norm_state(model.modules(), False)
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with widen_factor=0.5
|
# Test ShuffleNetV2 forward with widen_factor=0.5
|
||||||
model = ShuffleNetv2(widen_factor=0.5)
|
model = ShuffleNetV2(widen_factor=0.5)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -114,8 +114,8 @@ def test_shufflenetv2_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 96, 14, 14))
|
assert feat[1].shape == torch.Size((1, 96, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 192, 7, 7))
|
assert feat[2].shape == torch.Size((1, 192, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with widen_factor=1.0
|
# Test ShuffleNetV2 forward with widen_factor=1.0
|
||||||
model = ShuffleNetv2(widen_factor=1.0)
|
model = ShuffleNetV2(widen_factor=1.0)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -130,8 +130,8 @@ def test_shufflenetv2_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 232, 14, 14))
|
assert feat[1].shape == torch.Size((1, 232, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 464, 7, 7))
|
assert feat[2].shape == torch.Size((1, 464, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with widen_factor=1.5
|
# Test ShuffleNetV2 forward with widen_factor=1.5
|
||||||
model = ShuffleNetv2(widen_factor=1.5)
|
model = ShuffleNetV2(widen_factor=1.5)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -146,8 +146,8 @@ def test_shufflenetv2_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 352, 14, 14))
|
assert feat[1].shape == torch.Size((1, 352, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 704, 7, 7))
|
assert feat[2].shape == torch.Size((1, 704, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with widen_factor=2.0
|
# Test ShuffleNetV2 forward with widen_factor=2.0
|
||||||
model = ShuffleNetv2(widen_factor=2.0)
|
model = ShuffleNetV2(widen_factor=2.0)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -162,8 +162,8 @@ def test_shufflenetv2_backbone():
|
||||||
assert feat[1].shape == torch.Size((1, 488, 14, 14))
|
assert feat[1].shape == torch.Size((1, 488, 14, 14))
|
||||||
assert feat[2].shape == torch.Size((1, 976, 7, 7))
|
assert feat[2].shape == torch.Size((1, 976, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with layers 3 forward
|
# Test ShuffleNetV2 forward with layers 3 forward
|
||||||
model = ShuffleNetv2(widen_factor=1.0, out_indices=(2, ))
|
model = ShuffleNetV2(widen_factor=1.0, out_indices=(2, ))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -176,8 +176,8 @@ def test_shufflenetv2_backbone():
|
||||||
assert isinstance(feat, torch.Tensor)
|
assert isinstance(feat, torch.Tensor)
|
||||||
assert feat.shape == torch.Size((1, 464, 7, 7))
|
assert feat.shape == torch.Size((1, 464, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with layers 1 2 forward
|
# Test ShuffleNetV2 forward with layers 1 2 forward
|
||||||
model = ShuffleNetv2(widen_factor=1.0, out_indices=(1, 2))
|
model = ShuffleNetV2(widen_factor=1.0, out_indices=(1, 2))
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -191,8 +191,8 @@ def test_shufflenetv2_backbone():
|
||||||
assert feat[0].shape == torch.Size((1, 232, 14, 14))
|
assert feat[0].shape == torch.Size((1, 232, 14, 14))
|
||||||
assert feat[1].shape == torch.Size((1, 464, 7, 7))
|
assert feat[1].shape == torch.Size((1, 464, 7, 7))
|
||||||
|
|
||||||
# Test ShuffleNetv2 forward with checkpoint forward
|
# Test ShuffleNetV2 forward with checkpoint forward
|
||||||
model = ShuffleNetv2(widen_factor=1.0, with_cp=True)
|
model = ShuffleNetV2(widen_factor=1.0, with_cp=True)
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if is_block(m):
|
if is_block(m):
|
||||||
assert m.with_cp
|
assert m.with_cp
|
||||||
|
|
Loading…
Reference in New Issue