add resnext

pull/119/head
KaiyangZhou 2018-10-27 16:52:05 +01:00
parent 2c2a9c0098
commit d187902ed4
2 changed files with 240 additions and 1 deletions

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import
from .resnet import *
from .resnetmid import *
from .resnext import *
from .senet import *
from .densenet import *
from .mudeep import *
@ -27,6 +28,8 @@ __model_factory = {
'resnet50': resnet50,
'resnet50_fc512': resnet50_fc512,
'resnet50mid': resnet50mid,
'resnext50_32x4d': resnext50_32x4d,
'resnext101_32x4d': resnext101_32x4d,
'se_resnet50': se_resnet50,
'se_resnet50_fc512': se_resnet50_fc512,
'se_resnet101': se_resnet101,
@ -34,7 +37,6 @@ __model_factory = {
'se_resnext101_32x4d': se_resnext101_32x4d,
'densenet121': densenet121,
'densenet121_fc512': densenet121_fc512,
#'resnext101': ResNeXt101_32x4d,
#'squeezenet': SqueezeNet, # https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py
'mobilenetv2': MobileNetV2,
'shufflenet': ShuffleNet,

View File

@ -0,0 +1,237 @@
from __future__ import absolute_import
from __future__ import division
import math
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torch.utils.model_zoo as model_zoo
__all__ = ['resnext50_32x4d', 'resnext101_32x4d']
model_urls = {
'resnext50_32x4d': None,
'resnext101_32x4d': None,
}
class ResNeXtBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, groups=32, base_width=4, stride=1, downsample=None):
super(ResNeXtBottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64.)) * groups)
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNeXt(nn.Module):
"""
ResNeXt
Reference:
Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
"""
def __init__(self, num_classes, loss, block, layers,
groups=32,
base_width=4,
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs):
self.inplanes = 64
super(ResNeXt, self).__init__()
self.loss = loss
self.feature_dim = 512 * block.expansion
# backbone network
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], groups, base_width)
self.layer2 = self._make_layer(block, 128, layers[1], groups, base_width, stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], groups, base_width, stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], groups, base_width, stride=last_stride)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p)
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, block, planes, blocks, groups, base_width, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, groups, base_width, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, groups, base_width))
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""
Construct fully connected layer
- fc_dims (list or tuple): dimensions of fc layers, if None,
no fc layers are constructed
- input_dim (int): input dimension
- dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(fc_dims, (list, tuple)), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims))
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def forward(self, x):
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return v
y = self.classifier(v)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, v
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""
Initialize model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
"""pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Initialized model with pretrained weights from {}".format(model_url))"""
print("Imagenet weights unavailable")
def resnext50_32x4d(num_classes, loss, pretrained='imagenet', **kwargs):
model = ResNeXt(
num_classes=num_classes,
loss=loss,
block=ResNeXtBottleneck,
layers=[3, 4, 6, 3],
groups=32,
base_width=4,
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained == 'imagenet':
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
return model
def resnext101_32x4d(num_classes, loss, pretrained='imagenet', **kwargs):
model = ResNeXt(
num_classes=num_classes,
loss=loss,
block=ResNeXtBottleneck,
layers=[3, 4, 23, 3],
groups=32,
base_width=4,
last_stride=2,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained == 'imagenet':
init_pretrained_weights(model, model_urls['resnext50_32x4d'])
return model