mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some pretrianed URL changes
* host some of Cadene's weights on github instead of .fr for speed * add my old port of ensemble adversarial inception resnet v2 * switch to my TF port of normal inception res v2 and change FC layer back to 'classif' for compat with ens_adv
This commit is contained in:
parent
827a3d6010
commit
87b92c528e
@ -6,16 +6,25 @@ from .helpers import load_pretrained
|
|||||||
from .adaptive_avgmax_pool import *
|
from .adaptive_avgmax_pool import *
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
_models = ['inception_resnet_v2']
|
_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2']
|
||||||
__all__ = ['InceptionResnetV2'] + _models
|
__all__ = ['InceptionResnetV2'] + _models
|
||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
|
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
||||||
'inception_resnet_v2': {
|
'inception_resnet_v2': {
|
||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
|
||||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
|
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||||
|
},
|
||||||
|
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
|
||||||
|
'ens_adv_inception_resnet_v2': {
|
||||||
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
|
||||||
|
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
|
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -274,19 +283,20 @@ class InceptionResnetV2(nn.Module):
|
|||||||
)
|
)
|
||||||
self.block8 = Block8(noReLU=True)
|
self.block8 = Block8(noReLU=True)
|
||||||
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
||||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
|
||||||
|
self.classif = nn.Linear(self.num_features, num_classes)
|
||||||
|
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.last_linear
|
return self.classif
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
del self.last_linear
|
del self.classif
|
||||||
if num_classes:
|
if num_classes:
|
||||||
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
|
self.classif = torch.nn.Linear(self.num_features, num_classes)
|
||||||
else:
|
else:
|
||||||
self.last_linear = None
|
self.classif = None
|
||||||
|
|
||||||
def forward_features(self, x, pool=True):
|
def forward_features(self, x, pool=True):
|
||||||
x = self.conv2d_1a(x)
|
x = self.conv2d_1a(x)
|
||||||
@ -314,13 +324,13 @@ class InceptionResnetV2(nn.Module):
|
|||||||
x = self.forward_features(x, pool=True)
|
x = self.forward_features(x, pool=True)
|
||||||
if self.drop_rate > 0:
|
if self.drop_rate > 0:
|
||||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||||
x = self.last_linear(x)
|
x = self.classif(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
r"""InceptionResnetV2 model architecture from the
|
r"""InceptionResnetV2 model architecture from the
|
||||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['inception_resnet_v2']
|
default_cfg = default_cfgs['inception_resnet_v2']
|
||||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
|
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
||||||
|
As per https://arxiv.org/abs/1705.07204 and
|
||||||
|
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
||||||
|
"""
|
||||||
|
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
|
||||||
|
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
|
||||||
|
return model
|
||||||
|
@ -11,7 +11,7 @@ __all__ = ['InceptionV4'] + _models
|
|||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
'inception_v4': {
|
'inception_v4': {
|
||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
|
||||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
@ -20,7 +20,7 @@ __all__ = ['PNASNet5Large'] + _models
|
|||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
'pnasnet5large': {
|
'pnasnet5large': {
|
||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
|
||||||
'input_size': (3, 331, 331),
|
'input_size': (3, 331, 331),
|
||||||
'pool_size': (11, 11),
|
'pool_size': (11, 11),
|
||||||
'crop_pct': 0.875,
|
'crop_pct': 0.875,
|
||||||
|
@ -37,19 +37,19 @@ def _cfg(url='', **kwargs):
|
|||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
'senet154':
|
'senet154':
|
||||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
|
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
|
||||||
'seresnet18':
|
'seresnet18': _cfg(
|
||||||
_cfg(url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
|
url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
'seresnet34':
|
'seresnet34': _cfg(
|
||||||
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
|
url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
|
||||||
'seresnet50':
|
'seresnet50': _cfg(
|
||||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
|
||||||
'seresnet101':
|
'seresnet101': _cfg(
|
||||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
|
||||||
'seresnet152':
|
'seresnet152': _cfg(
|
||||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
|
||||||
'seresnext26_32x4d':
|
'seresnext26_32x4d': _cfg(
|
||||||
_cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
|
||||||
interpolation='bicubic'),
|
interpolation='bicubic'),
|
||||||
'seresnext50_32x4d':
|
'seresnext50_32x4d':
|
||||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
|
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
|
||||||
|
@ -35,7 +35,7 @@ __all__ = ['Xception'] + _models
|
|||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
'xception': {
|
'xception': {
|
||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
'crop_pct': 0.8975,
|
'crop_pct': 0.8975,
|
||||||
'interpolation': 'bicubic',
|
'interpolation': 'bicubic',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user