mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Xception model working
This commit is contained in:
parent
1e23727f2f
commit
183d8e4aef
@ -10,7 +10,7 @@ from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
|
|||||||
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
|
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
|
||||||
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
||||||
from .resnext import resnext50, resnext101, resnext152
|
from .resnext import resnext50, resnext101, resnext152
|
||||||
|
from .xception import xception
|
||||||
|
|
||||||
model_config_dict = {
|
model_config_dict = {
|
||||||
'resnet18': {
|
'resnet18': {
|
||||||
@ -45,6 +45,8 @@ model_config_dict = {
|
|||||||
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||||
'inception_resnet_v2': {
|
'inception_resnet_v2': {
|
||||||
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||||
|
'xception': {
|
||||||
|
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -121,6 +123,8 @@ def create_model(
|
|||||||
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
model = resnext101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||||
elif model_name == 'resnext152':
|
elif model_name == 'resnext152':
|
||||||
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
model = resnext152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||||
|
elif model_name == 'xception':
|
||||||
|
model = xception(num_classes=num_classes, pretrained=pretrained)
|
||||||
else:
|
else:
|
||||||
assert False and "Invalid model"
|
assert False and "Invalid model"
|
||||||
|
|
||||||
|
@ -162,14 +162,12 @@ class Xception(nn.Module):
|
|||||||
self.fc = nn.Linear(2048, num_classes)
|
self.fc = nn.Linear(2048, num_classes)
|
||||||
|
|
||||||
# #------- init weights --------
|
# #------- init weights --------
|
||||||
# for m in self.modules():
|
for m in self.modules():
|
||||||
# if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
# m.weight.data.normal_(0, math.sqrt(2. / n))
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
# elif isinstance(m, nn.BatchNorm2d):
|
m.weight.data.fill_(1)
|
||||||
# m.weight.data.fill_(1)
|
m.bias.data.zero_()
|
||||||
# m.bias.data.zero_()
|
|
||||||
# #-----------------------------
|
|
||||||
|
|
||||||
def forward_features(self, input):
|
def forward_features(self, input):
|
||||||
x = self.conv1(input)
|
x = self.conv1(input)
|
||||||
@ -215,10 +213,10 @@ class Xception(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def xception(num_classes=1000, pretrained='imagenet'):
|
def xception(num_classes=1000, pretrained=False):
|
||||||
model = Xception(num_classes=num_classes)
|
model = Xception(num_classes=num_classes)
|
||||||
if pretrained:
|
if pretrained:
|
||||||
config = pretrained_config['xception'][pretrained]
|
config = pretrained_config['xception']['imagenet']
|
||||||
assert num_classes == config['num_classes'], \
|
assert num_classes == config['num_classes'], \
|
||||||
"num_classes should be {}, but is {}".format(config['num_classes'], num_classes)
|
"num_classes should be {}, but is {}".format(config['num_classes'], num_classes)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user