add new model

pull/17/head
KaiyangZhou 2018-05-10 11:46:59 +01:00
parent 91c33c7719
commit 6972dbafd0
5 changed files with 1647 additions and 38 deletions

View File

@ -232,9 +232,38 @@ class Block8(nn.Module):
out = self.relu(out)
return out
def inceptionresnetv2(num_classes=1000, pretrained='imagenet'):
r"""InceptionResNetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
"""
if pretrained:
settings = pretrained_settings['inceptionresnetv2'][pretrained]
assert num_classes == settings['num_classes'], \
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
# both 'imagenet'&'imagenet+background' are loaded from same parameters
model = InceptionResNetV2(num_classes=1001)
model.load_state_dict(model_zoo.load_url(settings['url']))
if pretrained == 'imagenet':
new_last_linear = nn.Linear(1536, 1000)
new_last_linear.weight.data = model.last_linear.weight.data[1:]
new_last_linear.bias.data = model.last_linear.bias.data[1:]
model.last_linear = new_last_linear
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
else:
model = InceptionResNetV2(num_classes=num_classes)
return model
##################### Model Definition #########################
class InceptionResNetV2(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(InceptionResNetV2, self).__init__()
self.loss = loss
@ -348,32 +377,3 @@ class InceptionResNetV2(nn.Module):
return y, x
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def inceptionresnetv2(num_classes=1000, pretrained='imagenet'):
r"""InceptionResNetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
"""
if pretrained:
settings = pretrained_settings['inceptionresnetv2'][pretrained]
assert num_classes == settings['num_classes'], \
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
# both 'imagenet'&'imagenet+background' are loaded from same parameters
model = InceptionResNetV2(num_classes=1001)
model.load_state_dict(model_zoo.load_url(settings['url']))
if pretrained == 'imagenet':
new_last_linear = nn.Linear(1536, 1000)
new_last_linear.weight.data = model.last_linear.weight.data[1:]
new_last_linear.bias.data = model.last_linear.bias.data[1:]
model.last_linear = new_last_linear
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
else:
model = InceptionResNetV2(num_classes=num_classes)
return model

1489
models/ResNeXt.py 100644

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@ from torch import nn
from torch.nn import functional as F
import torchvision
__all__ = ['ResNet50', 'ResNet50M']
__all__ = ['ResNet50', 'ResNet101', 'ResNet50M']
class ResNet50(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
@ -35,6 +35,34 @@ class ResNet50(nn.Module):
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class ResNet101(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(ResNet101, self).__init__()
self.loss = loss
resnet101 = torchvision.models.resnet101(pretrained=True)
self.base = nn.Sequential(*list(resnet101.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048 # feature dimension
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
elif self.loss == {'ring'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class ResNet50M(nn.Module):
"""ResNet50 + mid-level features.

View File

@ -13,7 +13,7 @@ import torchvision
Code imported from https://github.com/Cadene/pretrained-models.pytorch
"""
__all__ = ['SEResNet50']
__all__ = ['SEResNet50', 'SEResNet101', 'SEResNeXt50', 'SEResNeXt101']
pretrained_settings = {
'senet154': {
@ -190,7 +190,7 @@ class SEResNeXtBottleneck(Bottleneck):
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None, base_width=4):
super(SEResNeXtBottleneck, self).__init__()
width = math.floor(planes * (base_width / 64)) * groups
width = int(math.floor(planes * (base_width / 64.)) * groups)
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
stride=1)
self.bn1 = nn.BatchNorm2d(width)
@ -442,6 +442,8 @@ def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
initialize_pretrained_model(model, num_classes, settings)
return model
##################### Model Definition #########################
class SEResNet50(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(SEResNet50, self).__init__()
@ -469,3 +471,87 @@ class SEResNet50(nn.Module):
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class SEResNet101(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(SEResNet101, self).__init__()
self.loss = loss
base = se_resnet101()
self.base = nn.Sequential(*list(base.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048 # feature dimension
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
elif self.loss == {'ring'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class SEResNeXt50(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(SEResNeXt50, self).__init__()
self.loss = loss
base = se_resnext50_32x4d()
self.base = nn.Sequential(*list(base.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048 # feature dimension
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
elif self.loss == {'ring'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class SEResNeXt101(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(SEResNeXt101, self).__init__()
self.loss = loss
base = se_resnext101_32x4d()
self.base = nn.Sequential(*list(base.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048 # feature dimension
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
elif self.loss == {'ring'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))

View File

@ -1,6 +1,8 @@
from __future__ import absolute_import
from .ResNet import *
from .ResNeXt import *
from .SEResNet import *
from .DenseNet import *
from .MuDeep import *
from .HACNN import *
@ -9,26 +11,30 @@ from .MobileNet import *
from .ShuffleNet import *
from .Xception import *
from .InceptionV4 import *
from .SEResNet import *
from .NASNet import *
from .DPN import *
from .InceptionResNetV2 import *
__factory = {
'resnet50': ResNet50,
'resnet101': ResNet101,
'seresnet50': SEResNet50,
'seresnet101': SEResNet101,
'seresnext50': SEResNeXt50,
'seresnext101': SEResNeXt101,
'resnext101': ResNeXt101_32x4d,
'resnet50m': ResNet50M,
'densenet121': DenseNet121,
'mudeep': MuDeep,
'hacnn': HACNN,
'squeezenet': SqueezeNet,
'mobilenet': MobileNetV2,
'shufflenet': ShuffleNet,
'xception': Xception,
'inceptionv4': InceptionV4ReID,
'seresnet50': SEResNet50,
'nasnet': NASNetAMobile,
'dpn92': DPN,
'inceptionresnetv2': InceptionResNetV2,
'mudeep': MuDeep,
'hacnn': HACNN,
}
def get_names():