A few more additions to Gluon Xception models to match interface of others.
parent
4d505e0785
commit
3b4868f6dc
|
@ -120,6 +120,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
||||||
| gluon_resnet152_v1c | 79.916 (20.084) | 94.842 (5.158) | 60.21 | bicubic | 224 | |
|
| gluon_resnet152_v1c | 79.916 (20.084) | 94.842 (5.158) | 60.21 | bicubic | 224 | |
|
||||||
| gluon_seresnext50_32x4d | 79.912 (20.088) | 94.818 (5.182) | 27.56 | bicubic | 224 | |
|
| gluon_seresnext50_32x4d | 79.912 (20.088) | 94.818 (5.182) | 27.56 | bicubic | 224 | |
|
||||||
| gluon_resnet152_v1b | 79.692 (20.308) | 94.738 (5.262) | 60.19 | bicubic | 224 | |
|
| gluon_resnet152_v1b | 79.692 (20.308) | 94.738 (5.262) | 60.19 | bicubic | 224 | |
|
||||||
|
| gluon_xception65 | 79.604 (20.396) | 94.748 (5.252) | 39.92 | bicubic | 299 | |
|
||||||
| gluon_resnet101_v1c | 79.544 (20.456) | 94.586 (5.414) | 44.57 | bicubic | 224 | |
|
| gluon_resnet101_v1c | 79.544 (20.456) | 94.586 (5.414) | 44.57 | bicubic | 224 | |
|
||||||
| gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | 224 | |
|
| gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | 224 | |
|
||||||
| gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | 224 | |
|
| gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | 224 | |
|
||||||
|
|
|
@ -23,6 +23,7 @@ default_cfgs = {
|
||||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
'crop_pct': 0.875,
|
'crop_pct': 0.875,
|
||||||
|
'pool_size': (10, 10),
|
||||||
'interpolation': 'bicubic',
|
'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN,
|
'mean': IMAGENET_DEFAULT_MEAN,
|
||||||
'std': IMAGENET_DEFAULT_STD,
|
'std': IMAGENET_DEFAULT_STD,
|
||||||
|
@ -35,6 +36,7 @@ default_cfgs = {
|
||||||
'url': '',
|
'url': '',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
'crop_pct': 0.875,
|
'crop_pct': 0.875,
|
||||||
|
'pool_size': (10, 10),
|
||||||
'interpolation': 'bicubic',
|
'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN,
|
'mean': IMAGENET_DEFAULT_MEAN,
|
||||||
'std': IMAGENET_DEFAULT_STD,
|
'std': IMAGENET_DEFAULT_STD,
|
||||||
|
@ -181,7 +183,9 @@ class Xception65(nn.Module):
|
||||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||||
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
||||||
super(Xception65, self).__init__()
|
super(Xception65, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
self.global_pool = global_pool
|
||||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||||
if output_stride == 32:
|
if output_stride == 32:
|
||||||
entry_block3_stride = 2
|
entry_block3_stride = 2
|
||||||
|
@ -240,14 +244,26 @@ class Xception65(nn.Module):
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
||||||
|
|
||||||
|
self.num_features = 2048
|
||||||
self.conv5 = SeparableConv2d(
|
self.conv5 = SeparableConv2d(
|
||||||
1536, 2048, 3, stride=1, dilation=exit_block_dilations[1],
|
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
self.bn5 = norm_layer(num_features=2048, **norm_kwargs)
|
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes)
|
||||||
self.fc = nn.Linear(in_features=2048, out_features=num_classes)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def get_classifier(self):
|
||||||
|
return self.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.global_pool = global_pool
|
||||||
|
del self.fc
|
||||||
|
if num_classes:
|
||||||
|
self.fc = nn.Linear(self.num_features, num_classes)
|
||||||
|
else:
|
||||||
|
self.fc = None
|
||||||
|
|
||||||
|
def forward_features(self, x, pool=True):
|
||||||
# Entry flow
|
# Entry flow
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
|
@ -284,10 +300,15 @@ class Xception65(nn.Module):
|
||||||
x = self.bn5(x)
|
x = self.bn5(x)
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
|
|
||||||
x = self.avgpool(x)
|
if pool:
|
||||||
|
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
if self.drop_rate > 0.:
|
return x
|
||||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
if self.drop_rate:
|
||||||
|
F.dropout(x, self.drop_rate, training=self.training)
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -299,7 +320,9 @@ class Xception71(nn.Module):
|
||||||
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
|
||||||
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
norm_kwargs=None, drop_rate=0., global_pool='avg'):
|
||||||
super(Xception71, self).__init__()
|
super(Xception71, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
|
self.global_pool = global_pool
|
||||||
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
|
||||||
if output_stride == 32:
|
if output_stride == 32:
|
||||||
entry_block3_stride = 2
|
entry_block3_stride = 2
|
||||||
|
@ -365,14 +388,26 @@ class Xception71(nn.Module):
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
|
||||||
|
|
||||||
|
self.num_features = 2048
|
||||||
self.conv5 = SeparableConv2d(
|
self.conv5 = SeparableConv2d(
|
||||||
1536, 2048, 3, stride=1, dilation=exit_block_dilations[1],
|
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
|
||||||
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
|
||||||
self.bn5 = norm_layer(num_features=2048, **norm_kwargs)
|
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
|
||||||
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes)
|
||||||
self.fc = nn.Linear(in_features=2048, out_features=num_classes)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def get_classifier(self):
|
||||||
|
return self.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.global_pool = global_pool
|
||||||
|
del self.fc
|
||||||
|
if num_classes:
|
||||||
|
self.fc = nn.Linear(self.num_features, num_classes)
|
||||||
|
else:
|
||||||
|
self.fc = None
|
||||||
|
|
||||||
|
def forward_features(self, x, pool=True):
|
||||||
# Entry flow
|
# Entry flow
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
|
@ -409,16 +444,23 @@ class Xception71(nn.Module):
|
||||||
x = self.bn5(x)
|
x = self.bn5(x)
|
||||||
x = self.relu(x)
|
x = self.relu(x)
|
||||||
|
|
||||||
x = self.avgpool(x)
|
if pool:
|
||||||
|
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
if self.drop_rate > 0.:
|
return x
|
||||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
if self.drop_rate:
|
||||||
|
F.dropout(x, self.drop_rate, training=self.training)
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
""" Modified Aligned Xception-65
|
||||||
|
"""
|
||||||
default_cfg = default_cfgs['gluon_xception65']
|
default_cfg = default_cfgs['gluon_xception65']
|
||||||
model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
|
@ -429,6 +471,8 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
""" Modified Aligned Xception-71
|
||||||
|
"""
|
||||||
default_cfg = default_cfgs['gluon_xception71']
|
default_cfg = default_cfgs['gluon_xception71']
|
||||||
model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
model.default_cfg = default_cfg
|
model.default_cfg = default_cfg
|
||||||
|
|
Loading…
Reference in New Issue