Dev se resnet

This commit is contained in:
yangmingmin 2020-06-17 14:20:20 +08:00 committed by chenkai
parent c168aa786e
commit f729a60f87
4 changed files with 406 additions and 1 deletions

View File

@ -1,10 +1,11 @@
from .mobilenet_v2 import MobileNetV2
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
from .seresnet import SEResNet
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
__all__ = [
'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'ShuffleNetV1',
'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'SEResNet', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2'
]

View File

@ -0,0 +1,115 @@
import torch.utils.checkpoint as cp
from ..builder import BACKBONES
from ..utils.se_layer import SELayer
from .resnet import Bottleneck, ResLayer, ResNet
class SEBottleneck(Bottleneck):
"""SEBottleneck block for SEResNet.
Args:
inplanes (int): The input channels of the SEBottleneck block.
planes (int): The output channel base of the SEBottleneck block.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
"""
expansion = 4
def __init__(self, inplanes, planes, se_ratio=16, **kwargs):
super(SEBottleneck, self).__init__(inplanes, planes, **kwargs)
self.se_layer = SELayer(planes * self.expansion, ratio=se_ratio)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
out = self.se_layer(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@BACKBONES.register_module()
class SEResNet(ResNet):
"""SEResNet backbone.
Args:
depth (int): Depth of seresnet, from {50, 101, 152}.
in_channels (int): Number of input image channels. Normally 3.
base_channels (int): Number of base channels of hidden layer.
num_stages (int): Resnet stages, normally 4.
strides (Sequence[int]): Strides of the first block of each stage.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters.
norm_cfg (dict): Dictionary to construct and config norm layer.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity.
Example:
>>> from mmcls.models import SEResNet
>>> import torch
>>> self = SEResNet(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 56, 56)
(1, 128, 28, 28)
(1, 256, 14, 14)
(1, 512, 7, 7)
"""
arch_settings = {
50: (SEBottleneck, (3, 4, 6, 3)),
101: (SEBottleneck, (3, 4, 23, 3)),
152: (SEBottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, se_ratio=16, **kwargs):
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.se_ratio = se_ratio
super(SEResNet, self).__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(se_ratio=self.se_ratio, **kwargs)

View File

@ -0,0 +1,30 @@
import torch.nn as nn
class SELayer(nn.Module):
"""Squeeze-and-Excitation Module.
Args:
inplanes (int): The input channels of the SEBottleneck block.
ratio (int): Squeeze ratio in SELayer. Default: 16
"""
def __init__(self, inplanes, ratio=16):
super(SELayer, self).__init__()
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
inplanes, int(inplanes / ratio), kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(
int(inplanes / ratio), inplanes, kernel_size=1, stride=1)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.global_avgpool(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.sigmoid(out)
return x * out

View File

@ -0,0 +1,259 @@
import pytest
import torch
from torch.nn.modules import AvgPool2d
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import SEResNet
from mmcls.models.backbones.resnet import ResLayer
from mmcls.models.backbones.seresnet import SEBottleneck, SELayer
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (SEBottleneck, )):
return True
return False
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (_BatchNorm, )):
return True
return False
def all_zeros(modules):
"""Check if the weight(and bias) is all zero."""
weight_zero = torch.equal(modules.weight.data,
torch.zeros_like(modules.weight.data))
if hasattr(modules, 'bias'):
bias_zero = torch.equal(modules.bias.data,
torch.zeros_like(modules.bias.data))
else:
bias_zero = True
return weight_zero and bias_zero
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_serenet_selayer():
# Test selayer forward
layer = SELayer(64)
x = torch.randn(1, 64, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# Test selayer forward with different ratio
layer = SELayer(64, ratio=8)
x = torch.randn(1, 64, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_seresnet_bottleneckse():
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
SEBottleneck(64, 64, style='tensorflow')
# Test SEBottleneck with checkpoint forward
block = SEBottleneck(64, 16, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# Test Bottleneck style
block = SEBottleneck(64, 64, stride=2, style='pytorch')
assert block.conv1.stride == (1, 1)
assert block.conv2.stride == (2, 2)
block = SEBottleneck(64, 64, stride=2, style='caffe')
assert block.conv1.stride == (2, 2)
assert block.conv2.stride == (1, 1)
# Test Bottleneck forward
block = SEBottleneck(64, 16)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_seresnet_res_layer():
# Test ResLayer of 3 Bottleneck w\o downsample
layer = ResLayer(SEBottleneck, 64, 16, 3, se_ratio=16)
assert len(layer) == 3
assert layer[0].conv1.in_channels == 64
assert layer[0].conv1.out_channels == 16
for i in range(1, len(layer)):
assert layer[i].conv1.in_channels == 64
assert layer[i].conv1.out_channels == 16
for i in range(len(layer)):
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# Test ResLayer of 3 SEBottleneck with downsample
layer = ResLayer(SEBottleneck, 64, 64, 3, se_ratio=16)
assert layer[0].downsample[0].out_channels == 256
for i in range(1, len(layer)):
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 256, 56, 56])
# Test ResLayer of 3 SEBottleneck with stride=2
layer = ResLayer(SEBottleneck, 64, 64, 3, stride=2, se_ratio=8)
assert layer[0].downsample[0].out_channels == 256
assert layer[0].downsample[0].stride == (2, 2)
for i in range(1, len(layer)):
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 256, 28, 28])
# Test ResLayer of 3 SEBottleneck with stride=2 and average downsample
layer = ResLayer(
SEBottleneck, 64, 64, 3, stride=2, avg_down=True, se_ratio=8)
assert isinstance(layer[0].downsample[0], AvgPool2d)
assert layer[0].downsample[1].out_channels == 256
assert layer[0].downsample[1].stride == (1, 1)
for i in range(1, len(layer)):
assert layer[i].downsample is None
x = torch.randn(1, 64, 56, 56)
x_out = layer(x)
assert x_out.shape == torch.Size([1, 256, 28, 28])
def test_seresnet_backbone():
"""Test resnet backbone"""
with pytest.raises(KeyError):
# SEResNet depth should be in [50, 101, 152]
SEResNet(20)
with pytest.raises(AssertionError):
# In SEResNet: 1 <= num_stages <= 4
SEResNet(50, num_stages=0)
with pytest.raises(AssertionError):
# In SEResNet: 1 <= num_stages <= 4
SEResNet(50, num_stages=5)
with pytest.raises(AssertionError):
# len(strides) == len(dilations) == num_stages
SEResNet(50, strides=(1, ), dilations=(1, 1), num_stages=3)
with pytest.raises(TypeError):
# pretrained must be a string path
model = SEResNet(50)
model.init_weights(pretrained=0)
with pytest.raises(AssertionError):
# Style must be in ['pytorch', 'caffe']
SEResNet(50, style='tensorflow')
# Test SEResNet50 norm_eval=True
model = SEResNet(50, norm_eval=True)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test SEResNet50 with torchvision pretrained weight
model = SEResNet(depth=50, norm_eval=True)
model.init_weights('torchvision://resnet50')
model.train()
assert check_norm_state(model.modules(), False)
# Test SEResNet50 with first stage frozen
frozen_stages = 1
model = SEResNet(50, frozen_stages=frozen_stages)
model.init_weights()
model.train()
assert model.norm1.training is False
for layer in [model.conv1, model.norm1]:
for param in layer.parameters():
assert param.requires_grad is False
for i in range(1, frozen_stages + 1):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test SEResNet50 with BatchNorm forward
model = SEResNet(50, out_indices=(0, 1, 2, 3))
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test SEResNet50 with layers 1, 2, 3 out forward
model = SEResNet(50, out_indices=(0, 1, 2))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
# Test SEResNet50 with layers 3 (top feature maps) out forward
model = SEResNet(50, out_indices=(3, ))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat.shape == torch.Size([1, 2048, 7, 7])
# Test SEResNet50 with checkpoint forward
model = SEResNet(50, out_indices=(0, 1, 2, 3), with_cp=True)
for m in model.modules():
if is_block(m):
assert m.with_cp
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])
# Test SEResNet50 zero initialization of residual
model = SEResNet(50, out_indices=(0, 1, 2, 3), zero_init_residual=True)
model.init_weights()
for m in model.modules():
if isinstance(m, SEBottleneck):
assert all_zeros(m.norm3)
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 56, 56])
assert feat[1].shape == torch.Size([1, 512, 28, 28])
assert feat[2].shape == torch.Size([1, 1024, 14, 14])
assert feat[3].shape == torch.Size([1, 2048, 7, 7])