Merge branch 'dev_mobilenetv2' into dev_shufflenetv1

pull/2/head
lixiaojie 2020-06-07 21:10:42 +08:00
commit 2ee95c44ce
2 changed files with 224 additions and 46 deletions

View File

@ -2,8 +2,8 @@ import logging
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.runner import load_checkpoint
from ..runner import load_checkpoint
from .base_backbone import BaseBackbone
from .weight_init import constant_init, kaiming_init
@ -22,13 +22,12 @@ def conv3x3(in_planes, out_planes, stride=1, dilation=1):
def conv_1x1_bn(inp, oup, activation=nn.ReLU6):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
activation(inplace=True)
)
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup),
activation(inplace=True))
class ConvBNReLU(nn.Sequential):
def __init__(self,
in_planes,
out_planes,
@ -39,16 +38,15 @@ class ConvBNReLU(nn.Sequential):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False),
nn.BatchNorm2d(out_planes),
activation(inplace=True)
)
nn.Conv2d(
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups=groups,
bias=False), nn.BatchNorm2d(out_planes),
activation(inplace=True))
def _make_divisible(v, divisor, min_value=None):
@ -62,6 +60,7 @@ def _make_divisible(v, divisor, min_value=None):
class InvertedResidual(nn.Module):
def __init__(self,
inplanes,
outplanes,
@ -79,17 +78,18 @@ class InvertedResidual(nn.Module):
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inplanes,
hidden_dim,
kernel_size=1,
activation=activation))
layers.append(
ConvBNReLU(
inplanes, hidden_dim, kernel_size=1,
activation=activation))
layers.extend([
# dw
ConvBNReLU(hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
activation=activation),
ConvBNReLU(
hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
activation=activation),
# pw-linear
nn.Conv2d(hidden_dim, outplanes, 1, 1, 0, bias=False),
nn.BatchNorm2d(outplanes),
@ -97,6 +97,7 @@ class InvertedResidual(nn.Module):
self.conv = nn.Sequential(*layers)
def forward(self, x):
def _inner_forward(x):
if self.use_res_connect:
return x + self.conv(x)
@ -122,15 +123,23 @@ def make_inverted_res_layer(block,
layers = []
for i in range(num_blocks):
if i == 0:
layers.append(block(inplanes, planes, stride,
expand_ratio=expand_ratio,
activation=activation,
with_cp=with_cp))
layers.append(
block(
inplanes,
planes,
stride,
expand_ratio=expand_ratio,
activation=activation,
with_cp=with_cp))
else:
layers.append(block(inplanes, planes, 1,
expand_ratio=expand_ratio,
activation=activation,
with_cp=with_cp))
layers.append(
block(
inplanes,
planes,
1,
expand_ratio=expand_ratio,
activation=activation,
with_cp=with_cp))
inplanes = planes
return nn.Sequential(*layers)
@ -154,7 +163,7 @@ class MobileNetv2(BaseBackbone):
def __init__(self,
widen_factor=1.,
activation=nn.ReLU6,
out_indices=(0, 1, 2, 3, 4, 5, 6),
out_indices=(0, 1, 2, 3, 4, 5, 6, 7),
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
@ -162,21 +171,17 @@ class MobileNetv2(BaseBackbone):
super(MobileNetv2, self).__init__()
block = InvertedResidual
# expand_ratio, out_channel, n, stride
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1]
]
inverted_residual_setting = [[1, 16, 1, 1], [6, 24, 2,
2], [6, 32, 3, 2],
[6, 64, 4, 2], [6, 96, 3, 1],
[6, 160, 3, 2], [6, 320, 1, 1]]
self.widen_factor = widen_factor
if isinstance(activation, str):
activation = eval(activation)
self.activation = activation(inplace=True)
self.out_indices = out_indices
assert frozen_stages <= 7
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
@ -210,9 +215,8 @@ class MobileNetv2(BaseBackbone):
self.out_channel = int(self.out_channel * widen_factor) \
if widen_factor > 1.0 else self.out_channel
self.conv_last = nn.Conv2d(self.inplanes,
self.out_channel,
1, 1, 0, bias=False)
self.conv_last = nn.Conv2d(
self.inplanes, self.out_channel, 1, 1, 0, bias=False)
self.bn_last = nn.BatchNorm2d(self.out_channel)
self.feat_dim = self.out_channel

View File

@ -1,11 +1,118 @@
import pytest
import torch
import torch.nn as nn
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import MobileNetv2
from mmcls.models.backbones.mobilenet_v2 import InvertedResidual
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (InvertedResidual, )):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_mobilenetv2_invertedresidual():
with pytest.raises(AssertionError):
# stride must be in [1, 2]
InvertedResidual(64, 16, stride=3, expand_ratio=6)
# Test InvertedResidual with checkpoint forward, stride=1
block = InvertedResidual(64, 16, stride=1, expand_ratio=6)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 56, 56])
# Test InvertedResidual with checkpoint forward, stride=2
block = InvertedResidual(64, 16, stride=2, expand_ratio=6)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 28, 28])
# Test InvertedResidual with checkpoint forward
block = InvertedResidual(64, 16, stride=1, expand_ratio=6, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 56, 56])
# Test InvertedResidual with activation=nn.ReLU
block = InvertedResidual(
64, 16, stride=1, expand_ratio=6, activation=nn.ReLU)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 56, 56])
def test_mobilenetv2_backbone():
# Test MobileNetv2 with widen_factor 1.0, activation nn.ReLU6
with pytest.raises(TypeError):
# pretrained must be a string path
model = MobileNetv2()
model.init_weights(pretrained=0)
with pytest.raises(AssertionError):
# frozen_stages must less than 7
MobileNetv2(frozen_stages=8)
# Test MobileNetv2
model = MobileNetv2()
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test MobileNetv2 with first stage frozen
frozen_stages = 1
model = MobileNetv2(frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.bn1.training is False
for layer in [model.conv1, model.bn1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test MobileNetv2 with bn frozen
model = MobileNetv2(bn_frozen=True)
model.init_weights()
model.train()
assert model.bn1.training is False
for i in range(1, 8):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for params in mod.parameters():
params.requires_grad = False
# Test MobileNetv2 forward with widen_factor=1.0
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
model.init_weights()
model.train()
@ -20,3 +127,70 @@ def test_mobilenetv2_backbone():
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])
# Test MobileNetv2 forward with activation=nn.ReLU
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])
# Test MobileNetv2 with BatchNorm forward
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6)
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])
# Test MobileNetv2 with layers 1, 3, 5 out forward
model = MobileNetv2(
widen_factor=1.0, activation=nn.ReLU6, out_indices=(0, 2, 4))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 32, 28, 28])
assert feat[2].shape == torch.Size([1, 96, 14, 14])
# Test MobileNetv2 with checkpoint forward
model = MobileNetv2(widen_factor=1.0, activation=nn.ReLU6, with_cp=True)
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 8
assert feat[0].shape == torch.Size([1, 16, 112, 112])
assert feat[1].shape == torch.Size([1, 24, 56, 56])
assert feat[2].shape == torch.Size([1, 32, 28, 28])
assert feat[3].shape == torch.Size([1, 64, 14, 14])
assert feat[4].shape == torch.Size([1, 96, 14, 14])
assert feat[5].shape == torch.Size([1, 160, 7, 7])
assert feat[6].shape == torch.Size([1, 320, 7, 7])