mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add common model interface to pnasnet and xception, update factory
This commit is contained in:
parent
f2029dfb65
commit
c0e6e5f3db
@ -20,7 +20,7 @@ def get_model_meanstd(model_name):
|
|||||||
model_name = model_name.lower()
|
model_name = model_name.lower()
|
||||||
if 'dpn' in model_name:
|
if 'dpn' in model_name:
|
||||||
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||||
elif 'ception' in model_name:
|
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||||
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
else:
|
else:
|
||||||
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
@ -30,7 +30,7 @@ def get_model_mean(model_name):
|
|||||||
model_name = model_name.lower()
|
model_name = model_name.lower()
|
||||||
if 'dpn' in model_name:
|
if 'dpn' in model_name:
|
||||||
return IMAGENET_DPN_STD
|
return IMAGENET_DPN_STD
|
||||||
elif 'ception' in model_name:
|
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||||
return IMAGENET_INCEPTION_MEAN
|
return IMAGENET_INCEPTION_MEAN
|
||||||
else:
|
else:
|
||||||
return IMAGENET_DEFAULT_MEAN
|
return IMAGENET_DEFAULT_MEAN
|
||||||
@ -40,7 +40,7 @@ def get_model_std(model_name):
|
|||||||
model_name = model_name.lower()
|
model_name = model_name.lower()
|
||||||
if 'dpn' in model_name:
|
if 'dpn' in model_name:
|
||||||
return IMAGENET_DEFAULT_STD
|
return IMAGENET_DEFAULT_STD
|
||||||
elif 'ception' in model_name:
|
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||||
return IMAGENET_INCEPTION_STD
|
return IMAGENET_INCEPTION_STD
|
||||||
else:
|
else:
|
||||||
return IMAGENET_DEFAULT_STD
|
return IMAGENET_DEFAULT_STD
|
||||||
|
@ -11,6 +11,7 @@ from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152,
|
|||||||
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
||||||
from .resnext import resnext50, resnext101, resnext152
|
from .resnext import resnext50, resnext101, resnext152
|
||||||
from .xception import xception
|
from .xception import xception
|
||||||
|
from .pnasnet import pnasnet5large
|
||||||
|
|
||||||
model_config_dict = {
|
model_config_dict = {
|
||||||
'resnet18': {
|
'resnet18': {
|
||||||
@ -47,6 +48,8 @@ model_config_dict = {
|
|||||||
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||||
'xception': {
|
'xception': {
|
||||||
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||||
|
'pnasnet5large': {
|
||||||
|
'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -125,6 +128,8 @@ def create_model(
|
|||||||
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||||
elif model_name == 'xception':
|
elif model_name == 'xception':
|
||||||
model = xception(num_classes=num_classes, pretrained=pretrained)
|
model = xception(num_classes=num_classes, pretrained=pretrained)
|
||||||
|
elif model_name == 'pnasnet5large':
|
||||||
|
model = pnasnet5large(num_classes=num_classes, pretrained=pretrained)
|
||||||
else:
|
else:
|
||||||
assert False and "Invalid model"
|
assert False and "Invalid model"
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.model_zoo as model_zoo
|
import torch.utils.model_zoo as model_zoo
|
||||||
|
|
||||||
|
|
||||||
pretrained_settings = {
|
pretrained_settings = {
|
||||||
'pnasnet5large': {
|
'pnasnet5large': {
|
||||||
'imagenet': {
|
'imagenet': {
|
||||||
@ -292,6 +291,8 @@ class PNASNet5Large(nn.Module):
|
|||||||
def __init__(self, num_classes=1001):
|
def __init__(self, num_classes=1001):
|
||||||
super(PNASNet5Large, self).__init__()
|
super(PNASNet5Large, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
self.num_features = 4320
|
||||||
|
|
||||||
self.conv_0 = nn.Sequential(OrderedDict([
|
self.conv_0 = nn.Sequential(OrderedDict([
|
||||||
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)),
|
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)),
|
||||||
('bn', nn.BatchNorm2d(96, eps=0.001))
|
('bn', nn.BatchNorm2d(96, eps=0.001))
|
||||||
@ -335,9 +336,20 @@ class PNASNet5Large(nn.Module):
|
|||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
|
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
|
||||||
self.dropout = nn.Dropout(0.5)
|
self.dropout = nn.Dropout(0.5)
|
||||||
self.last_linear = nn.Linear(4320, num_classes)
|
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||||
|
|
||||||
def features(self, x):
|
def get_classifier(self):
|
||||||
|
return self.last_linear
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
del self.last_linear
|
||||||
|
if num_classes:
|
||||||
|
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||||
|
else:
|
||||||
|
self.last_linear = None
|
||||||
|
|
||||||
|
def forward_features(self, x, pool=True):
|
||||||
x_conv_0 = self.conv_0(x)
|
x_conv_0 = self.conv_0(x)
|
||||||
x_stem_0 = self.cell_stem_0(x_conv_0)
|
x_stem_0 = self.cell_stem_0(x_conv_0)
|
||||||
x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
|
x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
|
||||||
@ -353,19 +365,16 @@ class PNASNet5Large(nn.Module):
|
|||||||
x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
|
x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
|
||||||
x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
|
x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
|
||||||
x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
|
x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
|
||||||
return x_cell_11
|
x = self.relu(x_cell_11)
|
||||||
|
if pool:
|
||||||
def logits(self, features):
|
|
||||||
x = self.relu(features)
|
|
||||||
x = self.avg_pool(x)
|
x = self.avg_pool(x)
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
x = self.dropout(x)
|
|
||||||
x = self.last_linear(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
x = self.features(input)
|
x = self.forward_features(input)
|
||||||
x = self.logits(x)
|
x = self.dropout(x)
|
||||||
|
x = self.last_linear(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -375,7 +384,7 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
|||||||
<https://arxiv.org/abs/1712.00559>`_ paper.
|
<https://arxiv.org/abs/1712.00559>`_ paper.
|
||||||
"""
|
"""
|
||||||
if pretrained:
|
if pretrained:
|
||||||
settings = pretrained_settings['pnasnet5large'][pretrained]
|
settings = pretrained_settings['pnasnet5large']['imagenet']
|
||||||
assert num_classes == settings[
|
assert num_classes == settings[
|
||||||
'num_classes'], 'num_classes should be {}, but is {}'.format(
|
'num_classes'], 'num_classes should be {}, but is {}'.format(
|
||||||
settings['num_classes'], num_classes)
|
settings['num_classes'], num_classes)
|
||||||
@ -384,18 +393,12 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
|||||||
model = PNASNet5Large(num_classes=1001)
|
model = PNASNet5Large(num_classes=1001)
|
||||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||||
|
|
||||||
if pretrained == 'imagenet':
|
#if pretrained == 'imagenet':
|
||||||
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
||||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||||
model.last_linear = new_last_linear
|
model.last_linear = new_last_linear
|
||||||
|
|
||||||
model.input_space = settings['input_space']
|
|
||||||
model.input_size = settings['input_size']
|
|
||||||
model.input_range = settings['input_range']
|
|
||||||
|
|
||||||
model.mean = settings['mean']
|
|
||||||
model.std = settings['std']
|
|
||||||
else:
|
else:
|
||||||
model = PNASNet5Large(num_classes=num_classes)
|
model = PNASNet5Large(num_classes=num_classes)
|
||||||
return model
|
return model
|
||||||
|
@ -142,7 +142,6 @@ def resnext50(cardinality=32, base_width=4, pretrained=False, **kwargs):
|
|||||||
Args:
|
Args:
|
||||||
cardinality (int): Cardinality of the aggregated transform
|
cardinality (int): Cardinality of the aggregated transform
|
||||||
base_width (int): Base width of the grouped convolution
|
base_width (int): Base width of the grouped convolution
|
||||||
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
|
|
||||||
"""
|
"""
|
||||||
model = ResNeXt(
|
model = ResNeXt(
|
||||||
ResNeXtBottleneckC, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
ResNeXtBottleneckC, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||||
@ -155,7 +154,6 @@ def resnext101(cardinality=32, base_width=4, pretrained=False, **kwargs):
|
|||||||
Args:
|
Args:
|
||||||
cardinality (int): Cardinality of the aggregated transform
|
cardinality (int): Cardinality of the aggregated transform
|
||||||
base_width (int): Base width of the grouped convolution
|
base_width (int): Base width of the grouped convolution
|
||||||
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
|
|
||||||
"""
|
"""
|
||||||
model = ResNeXt(
|
model = ResNeXt(
|
||||||
ResNeXtBottleneckC, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
ResNeXtBottleneckC, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||||
@ -168,7 +166,6 @@ def resnext152(cardinality=32, base_width=4, pretrained=False, **kwargs):
|
|||||||
Args:
|
Args:
|
||||||
cardinality (int): Cardinality of the aggregated transform
|
cardinality (int): Cardinality of the aggregated transform
|
||||||
base_width (int): Base width of the grouped convolution
|
base_width (int): Base width of the grouped convolution
|
||||||
shortcut ('A'|'B'|'C'): 'B' use 1x1 conv to downsample, 'C' use 1x1 conv on every residual connection
|
|
||||||
"""
|
"""
|
||||||
model = ResNeXt(
|
model = ResNeXt(
|
||||||
ResNeXtBottleneckC, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
ResNeXtBottleneckC, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||||
|
@ -127,6 +127,7 @@ class Xception(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super(Xception, self).__init__()
|
super(Xception, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
|
self.num_features = 2048
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
|
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
|
||||||
self.bn1 = nn.BatchNorm2d(32)
|
self.bn1 = nn.BatchNorm2d(32)
|
||||||
@ -156,10 +157,10 @@ class Xception(nn.Module):
|
|||||||
self.bn3 = nn.BatchNorm2d(1536)
|
self.bn3 = nn.BatchNorm2d(1536)
|
||||||
|
|
||||||
# do relu here
|
# do relu here
|
||||||
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
|
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
|
||||||
self.bn4 = nn.BatchNorm2d(2048)
|
self.bn4 = nn.BatchNorm2d(self.num_features)
|
||||||
|
|
||||||
self.fc = nn.Linear(2048, num_classes)
|
self.fc = nn.Linear(self.num_features, num_classes)
|
||||||
|
|
||||||
# #------- init weights --------
|
# #------- init weights --------
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
@ -169,7 +170,18 @@ class Xception(nn.Module):
|
|||||||
m.weight.data.fill_(1)
|
m.weight.data.fill_(1)
|
||||||
m.bias.data.zero_()
|
m.bias.data.zero_()
|
||||||
|
|
||||||
def forward_features(self, input):
|
def get_classifier(self):
|
||||||
|
return self.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
del self.fc
|
||||||
|
if num_classes:
|
||||||
|
self.fc = nn.Linear(self.num_features, num_classes)
|
||||||
|
else:
|
||||||
|
self.fc = None
|
||||||
|
|
||||||
|
def forward_features(self, input, pool=True):
|
||||||
x = self.conv1(input)
|
x = self.conv1(input)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
@ -197,19 +209,16 @@ class Xception(nn.Module):
|
|||||||
|
|
||||||
x = self.conv4(x)
|
x = self.conv4(x)
|
||||||
x = self.bn4(x)
|
x = self.bn4(x)
|
||||||
return x
|
x = self.relu(x)
|
||||||
|
|
||||||
def logits(self, features):
|
|
||||||
x = self.relu(features)
|
|
||||||
|
|
||||||
|
if pool:
|
||||||
x = F.adaptive_avg_pool2d(x, (1, 1))
|
x = F.adaptive_avg_pool2d(x, (1, 1))
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
x = self.last_linear(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
x = self.forward_features(input)
|
x = self.forward_features(input)
|
||||||
x = self.logits(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -223,13 +232,4 @@ def xception(num_classes=1000, pretrained=False):
|
|||||||
model = Xception(num_classes=num_classes)
|
model = Xception(num_classes=num_classes)
|
||||||
model.load_state_dict(model_zoo.load_url(config['url']))
|
model.load_state_dict(model_zoo.load_url(config['url']))
|
||||||
|
|
||||||
model.input_space = config['input_space']
|
|
||||||
model.input_size = config['input_size']
|
|
||||||
model.input_range = config['input_range']
|
|
||||||
model.mean = config['mean']
|
|
||||||
model.std = config['std']
|
|
||||||
|
|
||||||
# TODO: ugly
|
|
||||||
model.last_linear = model.fc
|
|
||||||
del model.fc
|
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user