merge resnet and resnext; add shufflenetv2

pull/201/head
kaiyangzhou 2019-05-24 15:35:09 +01:00
parent ac89f94129
commit 97ee4b423b
5 changed files with 356 additions and 60 deletions

View File

@ -4,7 +4,6 @@ import torch
from .resnet import *
from .resnetmid import *
from .resnext import *
from .senet import *
from .densenet import *
from .inceptionresnetv2 import *
@ -15,6 +14,7 @@ from .nasnet import *
from .mobilenetv2 import *
from .shufflenet import *
from .squeezenet import *
from .shufflenetv2 import *
from .mudeep import *
from .hacnn import *
@ -29,9 +29,9 @@ __model_factory = {
'resnet50': resnet50,
'resnet101': resnet101,
'resnet152': resnet152,
'resnet50_fc512': resnet50_fc512,
'resnext50_32x4d': resnext50_32x4d,
'resnext50_32x4d_fc512': resnext50_32x4d_fc512,
'resnext101_32x8d': resnext101_32x8d,
'resnet50_fc512': resnet50_fc512,
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,
'se_resnet101': se_resnet101,
@ -53,6 +53,10 @@ __model_factory = {
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,
'squeezenet1_1': squeezenet1_1,
'shufflenet_v2_x0_5': shufflenet_v2_x0_5,
'shufflenet_v2_x1_0': shufflenet_v2_x1_0,
'shufflenet_v2_x1_5': shufflenet_v2_x1_5,
'shufflenet_v2_x2_0': shufflenet_v2_x2_0,
# reid-specific models
'mudeep': MuDeep,
'resnet50mid': resnet50mid,

View File

@ -1,3 +1,6 @@
"""
Code source: https://github.com/pytorch/vision
"""
from __future__ import absolute_import
from __future__ import division

View File

@ -1,7 +1,11 @@
"""
Code source: https://github.com/pytorch/vision
"""
from __future__ import absolute_import
from __future__ import division
__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet50_fc512']
__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d',
'resnext101_32x8d', 'resnet50_fc512']
import torch
from torch import nn
@ -16,30 +20,45 @@ model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
identity = x
out = self.conv1(x)
out = self.bn1(out)
@ -49,9 +68,9 @@ class BasicBlock(nn.Module):
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
identity = self.downsample(x)
out += residual
out += identity
out = self.relu(out)
return out
@ -60,21 +79,25 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
identity = x
out = self.conv1(x)
out = self.bn1(out)
@ -88,9 +111,9 @@ class Bottleneck(nn.Module):
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
identity = self.downsample(x)
out += residual
out += identity
out = self.relu(out)
return out
@ -100,7 +123,8 @@ class ResNet(nn.Module):
"""Residual network.
Reference:
He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
- He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
- Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
Public keys:
- ``resnet18``: ResNet18.
@ -108,49 +132,80 @@ class ResNet(nn.Module):
- ``resnet50``: ResNet50.
- ``resnet101``: ResNet101.
- ``resnet152``: ResNet152.
- ``resnext50_32x4d``: ResNeXt50.
- ``resnext101_32x8d``: ResNeXt101.
- ``resnet50_fc512``: ResNet50 + FC.
"""
def __init__(self, num_classes, loss, block, layers,
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs):
self.inplanes = 64
def __init__(self, num_classes, loss, block, layers, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, last_stride=2, fc_dims=None, dropout_p=None, **kwargs):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.loss = loss
self.feature_dim = 512 * block.expansion
# backbone network
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride,
dilate=replace_stride_with_dilation[2])
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p)
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, block, planes, blocks, stride=1):
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
@ -242,17 +297,7 @@ def init_pretrained_weights(model, model_url):
model.load_state_dict(model_dict)
"""
Residual network configurations:
--
resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
"""
"""ResNet"""
def resnet18(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
@ -333,8 +378,45 @@ def resnet152(num_classes, loss='softmax', pretrained=True, **kwargs):
return model
"""ResNeXt"""
def resnext50_32x4d(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 6, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
groups=32,
width_per_group=4,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
return model
def resnext101_32x8d(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
num_classes=num_classes,
loss=loss,
block=Bottleneck,
layers=[3, 4, 23, 3],
last_stride=2,
fc_dims=None,
dropout_p=None,
groups=32,
width_per_group=8,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['resnext101_32x8d'])
return model
"""
resnet + fc
ResNet + FC
"""
def resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ResNet(
@ -349,4 +431,4 @@ def resnet50_fc512(num_classes, loss='softmax', pretrained=True, **kwargs):
)
if pretrained:
init_pretrained_weights(model, model_urls['resnet50'])
return model
return model

View File

@ -0,0 +1,204 @@
"""
Code source: https://github.com/pytorch/vision
"""
from __future__ import absolute_import
from __future__ import division
__all__ = ['shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torch.utils.model_zoo as model_zoo
model_urls = {
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
'shufflenetv2_x1.5': None,
'shufflenetv2_x2.0': None,
}
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
"""ShuffleNetV2.
Reference:
Ma et al. ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design. ECCV 2018.
"""
def __init__(self, num_classes, loss, stages_repeats, stages_out_channels, **kwargs):
super(ShuffleNetV2, self).__init__()
self.loss = loss
if len(stages_repeats) != 3:
raise ValueError('expected stages_repeats as list of 3 positive ints')
if len(stages_out_channels) != 5:
raise ValueError('expected stages_out_channels as list of 5 positive ints')
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]):
seq = [InvertedResidual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(InvertedResidual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(output_channels, num_classes)
def featuremaps(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if not self.training:
return v
y = self.classifier(v)
if self.loss == 'softmax':
return y
elif self.loss == 'triplet':
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""Initializes model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
if model_url is None:
import warnings
warnings.warn('ImageNet pretrained weights are unavailable for this model')
return
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
def shufflenet_v2_x0_5(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls['shufflenetv2_x0.5'])
return model
def shufflenet_v2_x1_0(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls['shufflenetv2_x1.0'])
return model
def shufflenet_v2_x1_5(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls['shufflenetv2_x1.5'])
return model
def shufflenet_v2_x2_0(num_classes, loss='softmax', pretrained=True, **kwargs):
model = ShuffleNetV2(num_classes, loss, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
if pretrained:
init_pretrained_weights(model, model_urls['shufflenetv2_x2.0'])
return model

View File

@ -1,3 +1,6 @@
"""
Code source: https://github.com/pytorch/vision
"""
from __future__ import absolute_import
from __future__ import division