add new model
parent
91c33c7719
commit
6972dbafd0
|
@ -232,9 +232,38 @@ class Block8(nn.Module):
|
|||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
def inceptionresnetv2(num_classes=1000, pretrained='imagenet'):
|
||||
r"""InceptionResNetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
||||
"""
|
||||
if pretrained:
|
||||
settings = pretrained_settings['inceptionresnetv2'][pretrained]
|
||||
assert num_classes == settings['num_classes'], \
|
||||
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
|
||||
|
||||
# both 'imagenet'&'imagenet+background' are loaded from same parameters
|
||||
model = InceptionResNetV2(num_classes=1001)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
|
||||
if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(1536, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
else:
|
||||
model = InceptionResNetV2(num_classes=num_classes)
|
||||
return model
|
||||
|
||||
##################### Model Definition #########################
|
||||
|
||||
class InceptionResNetV2(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(InceptionResNetV2, self).__init__()
|
||||
self.loss = loss
|
||||
|
@ -347,33 +376,4 @@ class InceptionResNetV2(nn.Module):
|
|||
elif self.loss == {'ring'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
def inceptionresnetv2(num_classes=1000, pretrained='imagenet'):
|
||||
r"""InceptionResNetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
||||
"""
|
||||
if pretrained:
|
||||
settings = pretrained_settings['inceptionresnetv2'][pretrained]
|
||||
assert num_classes == settings['num_classes'], \
|
||||
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
|
||||
|
||||
# both 'imagenet'&'imagenet+background' are loaded from same parameters
|
||||
model = InceptionResNetV2(num_classes=1001)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
|
||||
if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(1536, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
else:
|
||||
model = InceptionResNetV2(num_classes=num_classes)
|
||||
return model
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
File diff suppressed because it is too large
Load Diff
|
@ -5,7 +5,7 @@ from torch import nn
|
|||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
__all__ = ['ResNet50', 'ResNet50M']
|
||||
__all__ = ['ResNet50', 'ResNet101', 'ResNet50M']
|
||||
|
||||
class ResNet50(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
|
@ -35,6 +35,34 @@ class ResNet50(nn.Module):
|
|||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
class ResNet101(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(ResNet101, self).__init__()
|
||||
self.loss = loss
|
||||
resnet101 = torchvision.models.resnet101(pretrained=True)
|
||||
self.base = nn.Sequential(*list(resnet101.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
elif self.loss == {'ring'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
class ResNet50M(nn.Module):
|
||||
"""ResNet50 + mid-level features.
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import torchvision
|
|||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
|
||||
__all__ = ['SEResNet50']
|
||||
__all__ = ['SEResNet50', 'SEResNet101', 'SEResNeXt50', 'SEResNeXt101']
|
||||
|
||||
pretrained_settings = {
|
||||
'senet154': {
|
||||
|
@ -190,7 +190,7 @@ class SEResNeXtBottleneck(Bottleneck):
|
|||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None, base_width=4):
|
||||
super(SEResNeXtBottleneck, self).__init__()
|
||||
width = math.floor(planes * (base_width / 64)) * groups
|
||||
width = int(math.floor(planes * (base_width / 64.)) * groups)
|
||||
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
|
||||
stride=1)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
|
@ -442,6 +442,8 @@ def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
|
|||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
##################### Model Definition #########################
|
||||
|
||||
class SEResNet50(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNet50, self).__init__()
|
||||
|
@ -451,6 +453,90 @@ class SEResNet50(nn.Module):
|
|||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
elif self.loss == {'ring'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
class SEResNet101(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNet101, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnet101()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
elif self.loss == {'ring'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
class SEResNeXt50(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNeXt50, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnext50_32x4d()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
elif self.loss == {'ring'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
class SEResNeXt101(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNeXt101, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnext101_32x4d()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from .ResNet import *
|
||||
from .ResNeXt import *
|
||||
from .SEResNet import *
|
||||
from .DenseNet import *
|
||||
from .MuDeep import *
|
||||
from .HACNN import *
|
||||
|
@ -9,26 +11,30 @@ from .MobileNet import *
|
|||
from .ShuffleNet import *
|
||||
from .Xception import *
|
||||
from .InceptionV4 import *
|
||||
from .SEResNet import *
|
||||
from .NASNet import *
|
||||
from .DPN import *
|
||||
from .InceptionResNetV2 import *
|
||||
|
||||
__factory = {
|
||||
'resnet50': ResNet50,
|
||||
'resnet101': ResNet101,
|
||||
'seresnet50': SEResNet50,
|
||||
'seresnet101': SEResNet101,
|
||||
'seresnext50': SEResNeXt50,
|
||||
'seresnext101': SEResNeXt101,
|
||||
'resnext101': ResNeXt101_32x4d,
|
||||
'resnet50m': ResNet50M,
|
||||
'densenet121': DenseNet121,
|
||||
'mudeep': MuDeep,
|
||||
'hacnn': HACNN,
|
||||
'squeezenet': SqueezeNet,
|
||||
'mobilenet': MobileNetV2,
|
||||
'shufflenet': ShuffleNet,
|
||||
'xception': Xception,
|
||||
'inceptionv4': InceptionV4ReID,
|
||||
'seresnet50': SEResNet50,
|
||||
'nasnet': NASNetAMobile,
|
||||
'dpn92': DPN,
|
||||
'inceptionresnetv2': InceptionResNetV2,
|
||||
'mudeep': MuDeep,
|
||||
'hacnn': HACNN,
|
||||
}
|
||||
|
||||
def get_names():
|
||||
|
|
Loading…
Reference in New Issue