add unit test

pull/2/head
lixiaojie 2020-06-07 22:40:36 +08:00
parent 2ee95c44ce
commit 98f5b49ffe
3 changed files with 120 additions and 155 deletions

View File

@ -1,5 +1,3 @@
from .mobilenet_v2 import MobileNetv2 from .shufflenet_v1 import ShuffleNetv1
__all__ = [ __all__ = ['ShuffleNetv1']
'MobileNetv2',
]

View File

@ -4,8 +4,8 @@ from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.runner import load_checkpoint
from ..runner import load_checkpoint
from .base_backbone import BaseBackbone from .base_backbone import BaseBackbone
from .weight_init import constant_init, kaiming_init from .weight_init import constant_init, kaiming_init
@ -28,12 +28,7 @@ def conv1x1(inplanes, planes, groups=1):
- Normal pointwise convolution when groups == 1 - Normal pointwise convolution when groups == 1
- Grouped pointwise convolution when groups > 1 - Grouped pointwise convolution when groups > 1
""" """
return nn.Conv2d( return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1)
inplanes,
planes,
kernel_size=1,
groups=groups,
stride=1)
def channel_shuffle(x, groups): def channel_shuffle(x, groups):
@ -65,8 +60,8 @@ def _make_divisible(v, divisor, min_value=None):
return new_v return new_v
# noinspection PyShadowingNames,PyShadowingNames
class ShuffleUnit(nn.Module): class ShuffleUnit(nn.Module):
def __init__(self, def __init__(self,
inplanes, inplanes,
planes, planes,
@ -96,20 +91,24 @@ class ShuffleUnit(nn.Module):
"Only \"add\" and \"concat\" are " "Only \"add\" and \"concat\" are "
"supported".format(self.combine)) "supported".format(self.combine))
if combine == 'add':
assert inplanes == planes, \
'inplanes must be equal to outplanes when combine is add'
self.first_1x1_groups = self.groups if first_block else 1 self.first_1x1_groups = self.groups if first_block else 1
self.g_conv_1x1_compress = self._make_grouped_conv1x1( self.g_conv_1x1_compress = self._make_grouped_conv1x1(
self.inplanes, self.inplanes,
self.bottleneck_channels, self.bottleneck_channels,
self.first_1x1_groups, self.first_1x1_groups,
batch_norm=True, batch_norm=True,
relu=True relu=True)
)
self.depthwise_conv3x3 = conv3x3(self.bottleneck_channels, self.depthwise_conv3x3 = conv3x3(
self.bottleneck_channels,
self.bottleneck_channels, self.bottleneck_channels,
stride=self.depthwise_stride, stride=self.depthwise_stride,
groups=self.bottleneck_channels) groups=self.bottleneck_channels)
self.nn.BatchNorm2d_after_depthwise = \ self.bn_after_depthwise = \
nn.BatchNorm2d(self.bottleneck_channels) nn.BatchNorm2d(self.bottleneck_channels)
self.g_conv_1x1_expand = \ self.g_conv_1x1_expand = \
@ -132,8 +131,11 @@ class ShuffleUnit(nn.Module):
return torch.cat((x, out), 1) return torch.cat((x, out), 1)
@staticmethod @staticmethod
def _make_grouped_conv1x1(inplanes, planes, groups, def _make_grouped_conv1x1(inplanes,
batch_norm=True, relu=False): planes,
groups,
batch_norm=True,
relu=False):
modules = OrderedDict() modules = OrderedDict()
@ -150,6 +152,7 @@ class ShuffleUnit(nn.Module):
return conv return conv
def forward(self, x): def forward(self, x):
def _inner_forward(x): def _inner_forward(x):
residual = x residual = x
@ -159,7 +162,7 @@ class ShuffleUnit(nn.Module):
out = self.g_conv_1x1_compress(x) out = self.g_conv_1x1_compress(x)
out = channel_shuffle(out, self.groups) out = channel_shuffle(out, self.groups)
out = self.depthwise_conv3x3(out) out = self.depthwise_conv3x3(out)
out = self.nn.BatchNorm2d_after_depthwise(out) out = self.bn_after_depthwise(out)
out = self.g_conv_1x1_expand(out) out = self.g_conv_1x1_expand(out)
out = self._combine_func(residual, out) out = self._combine_func(residual, out)
@ -230,10 +233,10 @@ class ShuffleNetv1(BaseBackbone):
self.conv1 = conv3x3(3, self.inplanes, stride=2) self.conv1 = conv3x3(3, self.inplanes, stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer2 = self._make_layer(channels[0], blocks[0], self.layer1 = self._make_layer(
first_block=False, with_cp=with_cp) channels[0], blocks[0], first_block=False, with_cp=with_cp)
self.layer3 = self._make_layer(channels[1], blocks[1], with_cp=with_cp) self.layer2 = self._make_layer(channels[1], blocks[1], with_cp=with_cp)
self.layer4 = self._make_layer(channels[2], blocks[2], with_cp=with_cp) self.layer3 = self._make_layer(channels[2], blocks[2], with_cp=with_cp)
def init_weights(self, pretrained=None): def init_weights(self, pretrained=None):
if isinstance(pretrained, str): if isinstance(pretrained, str):
@ -248,21 +251,23 @@ class ShuffleNetv1(BaseBackbone):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def _make_layer(self, def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False):
outplanes,
blocks,
first_block=True,
with_cp=False):
layers = [] layers = []
for i in range(blocks): for i in range(blocks):
if i == 0: if i == 0:
layers.append(ShuffleUnit(self.inplanes, outplanes, layers.append(
ShuffleUnit(
self.inplanes,
outplanes,
groups=self.groups, groups=self.groups,
first_block=first_block, first_block=first_block,
combine='concat', combine='concat',
with_cp=with_cp)) with_cp=with_cp))
else: else:
layers.append(ShuffleUnit(self.inplanes, outplanes, layers.append(
ShuffleUnit(
self.inplanes,
outplanes,
groups=self.groups, groups=self.groups,
first_block=True, first_block=True,
combine='add', combine='add',
@ -274,7 +279,9 @@ class ShuffleNetv1(BaseBackbone):
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.maxpool(x) x = self.maxpool(x)
outs = [] outs = []
x = self.layer1(x)
if 0 in self.out_indices: if 0 in self.out_indices:
outs.append(x) outs.append(x)
x = self.layer2(x) x = self.layer2(x)
@ -283,8 +290,7 @@ class ShuffleNetv1(BaseBackbone):
x = self.layer3(x) x = self.layer3(x)
if 2 in self.out_indices: if 2 in self.out_indices:
outs.append(x) outs.append(x)
x = self.layer4(x)
if 3 in self.out_indices:
outs.append(x) outs.append(x)
if len(outs) == 1: if len(outs) == 1:

View File

@ -1,16 +1,15 @@
import pytest import pytest
import torch import torch
import torch.nn as nn
from torch.nn.modules import GroupNorm from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import MobileNetv2 from mmcls.models.backbones import ShuffleNetv1
from mmcls.models.backbones.mobilenet_v2 import InvertedResidual from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit
def is_block(modules): def is_block(modules):
"""Check if is ResNet building block.""" """Check if is ResNet building block."""
if isinstance(modules, (InvertedResidual, )): if isinstance(modules, (ShuffleUnit, )):
return True return True
return False return False
@ -31,62 +30,58 @@ def check_norm_state(modules, train_state):
return True return True
def test_mobilenetv2_invertedresidual(): def test_shufflenetv1_shuffleuint():
with pytest.raises(ValueError):
# combine must be in ['add', 'concat']
ShuffleUnit(24, 16, groups=3, first_block=True, combine='test')
with pytest.raises(ValueError):
# in_channels must be divisible by groups
ShuffleUnit(64, 64, groups=3, first_block=True, combine='add')
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
# stride must be in [1, 2] # inplanes must be equal tp = outplanes when combine='add'
InvertedResidual(64, 16, stride=3, expand_ratio=6) ShuffleUnit(64, 24, groups=3, first_block=True, combine='add')
# Test InvertedResidual with checkpoint forward, stride=1 # Test ShuffleUnit with combine='add'
block = InvertedResidual(64, 16, stride=1, expand_ratio=6) block = ShuffleUnit(24, 24, groups=3, first_block=True, combine='add')
x = torch.randn(1, 64, 56, 56) x = torch.randn(1, 24, 56, 56)
x_out = block(x) x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 56, 56]) assert x_out.shape == torch.Size([1, 24, 56, 56])
# Test InvertedResidual with checkpoint forward, stride=2 # Test ShuffleUnit with combine='concat'
block = InvertedResidual(64, 16, stride=2, expand_ratio=6) block = ShuffleUnit(24, 240, groups=3, first_block=True, combine='concat')
x = torch.randn(1, 64, 56, 56) x = torch.randn(1, 24, 56, 56)
x_out = block(x) x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 28, 28]) assert x_out.shape == torch.Size([1, 240, 28, 28])
# Test InvertedResidual with checkpoint forward # Test ShuffleUnit with checkpoint forward
block = InvertedResidual(64, 16, stride=1, expand_ratio=6, with_cp=True) block = ShuffleUnit(
assert block.with_cp 24, 24, groups=3, first_block=True, combine='add', with_cp=True)
x = torch.randn(1, 64, 56, 56) x = torch.randn(1, 24, 56, 56)
x_out = block(x) x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 56, 56]) assert x_out.shape == torch.Size([1, 24, 56, 56])
# Test InvertedResidual with activation=nn.ReLU
block = InvertedResidual(
64, 16, stride=1, expand_ratio=6, activation=nn.ReLU)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 56, 56])
def test_mobilenetv2_backbone(): def test_shufflenetv1_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = MobileNetv2()
model.init_weights(pretrained=0)
with pytest.raises(AssertionError): with pytest.raises(ValueError):
# frozen_stages must less than 7 # groups must in [1, 2, 3, 4, 8]
MobileNetv2(frozen_stages=8) ShuffleNetv1(groups=10)
# Test MobileNetv2 # Test ShuffleNetv1 norm state
model = MobileNetv2() model = ShuffleNetv1()
model.init_weights() model.init_weights()
model.train() model.train()
assert check_norm_state(model.modules(), False) assert check_norm_state(model.modules(), False)
# Test MobileNetv2 with first stage frozen # Test ShuffleNetv1 with first stage frozen
frozen_stages = 1 frozen_stages = 1
model = MobileNetv2(frozen_stages=frozen_stages) model = ShuffleNetv1(frozen_stages=frozen_stages)
model.init_weights() model.init_weights()
model.train() model.train()
assert model.bn1.training is False for layer in [model.conv1]:
for layer in [model.conv1, model.bn1]:
for param in layer.parameters(): for param in layer.parameters():
assert param.requires_grad is False assert param.requires_grad is False
for i in range(1, frozen_stages + 1): for i in range(1, frozen_stages + 1):
@ -97,13 +92,12 @@ def test_mobilenetv2_backbone():
for param in layer.parameters(): for param in layer.parameters():
assert param.requires_grad is False assert param.requires_grad is False
# Test MobileNetv2 with bn frozen # Test ShuffleNetv1 with bn frozen
model = MobileNetv2(bn_frozen=True) model = ShuffleNetv1(bn_frozen=True)
model.init_weights() model.init_weights()
model.train() model.train()
assert model.bn1.training is False
for i in range(1, 8): for i in range(1, 4):
layer = getattr(model, f'layer{i}') layer = getattr(model, f'layer{i}')
for mod in layer.modules(): for mod in layer.modules():
@ -112,85 +106,52 @@ def test_mobilenetv2_backbone():
for params in mod.parameters(): for params in mod.parameters():
params.requires_grad = False params.requires_grad = False
# Test MobileNetv2 forward with widen_factor=1.0 # Test ShuffleNetv1 forward with groups=3
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6) model = ShuffleNetv1(groups=3)
model.init_weights() model.init_weights()
model.train() model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])
# Test MobileNetv2 forward with activation=nn.ReLU
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])
# Test MobileNetv2 with BatchNorm forward
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
for m in model.modules(): for m in model.modules():
if is_norm(m): if is_norm(m):
assert isinstance(m, _BatchNorm) assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])
# Test MobileNetv2 with layers 1, 3, 5 out forward
model = MobileNetv2(
widen_factor=1.0, activation=nn.ReLU6, out_indices=(0, 2, 4))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224) imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs) feat = model(imgs)
assert len(feat) == 4 assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 16, 112, 112]) assert feat[0].shape == torch.Size([1, 240, 28, 28])
assert feat[1].shape == torch.Size([1, 32, 28, 28]) assert feat[1].shape == torch.Size([1, 480, 14, 14])
assert feat[2].shape == torch.Size([1, 96, 14, 14]) assert feat[2].shape == torch.Size([1, 960, 7, 7])
assert feat[3].shape == torch.Size([1, 960, 7, 7])
# Test MobileNetv2 with checkpoint forward # Test ShuffleNetv1 forward with layers 1, 2 forward
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, with_cp=True) model = ShuffleNetv1(groups=3, out_indices=(1, 2))
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights() model.init_weights()
model.train() model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224) imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs) feat = model(imgs)
assert len(feat) == 8 assert len(feat) == 3
assert feat[0].shape == torch.Size([1, 16, 112, 112]) assert feat[0].shape == torch.Size([1, 480, 14, 14])
assert feat[1].shape == torch.Size([1, 24, 56, 56]) assert feat[1].shape == torch.Size([1, 960, 7, 7])
assert feat[2].shape == torch.Size([1, 32, 28, 28]) assert feat[2].shape == torch.Size([1, 960, 7, 7])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14]) # Test ShuffleNetv1 forward with checkpoint forward
assert feat[5].shape == torch.Size([1, 160, 7, 7]) model = ShuffleNetv1(groups=3, with_cp=True)
assert feat[6].shape == torch.Size([1, 320, 7, 7]) model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 240, 28, 28])
assert feat[1].shape == torch.Size([1, 480, 14, 14])
assert feat[2].shape == torch.Size([1, 960, 7, 7])
assert feat[3].shape == torch.Size([1, 960, 7, 7])