mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some pretrianed URL changes
* host some of Cadene's weights on github instead of .fr for speed * add my old port of ensemble adversarial inception resnet v2 * switch to my TF port of normal inception res v2 and change FC layer back to 'classif' for compat with ens_adv
This commit is contained in:
parent
827a3d6010
commit
87b92c528e
@ -6,16 +6,25 @@ from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import *
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
_models = ['inception_resnet_v2']
|
||||
_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2']
|
||||
__all__ = ['InceptionResnetV2'] + _models
|
||||
|
||||
default_cfgs = {
|
||||
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
|
||||
'inception_resnet_v2': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
},
|
||||
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
|
||||
'ens_adv_inception_resnet_v2': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
|
||||
}
|
||||
}
|
||||
|
||||
@ -274,19 +283,20 @@ class InceptionResnetV2(nn.Module):
|
||||
)
|
||||
self.block8 = Block8(noReLU=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
|
||||
self.classif = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.last_linear
|
||||
return self.classif
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
del self.last_linear
|
||||
del self.classif
|
||||
if num_classes:
|
||||
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
|
||||
self.classif = torch.nn.Linear(self.num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
self.classif = None
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.conv2d_1a(x)
|
||||
@ -314,13 +324,13 @@ class InceptionResnetV2(nn.Module):
|
||||
x = self.forward_features(x, pool=True)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.last_linear(x)
|
||||
x = self.classif(x)
|
||||
return x
|
||||
|
||||
|
||||
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
|
||||
"""
|
||||
default_cfg = default_cfgs['inception_resnet_v2']
|
||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
|
||||
As per https://arxiv.org/abs/1705.07204 and
|
||||
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
|
||||
"""
|
||||
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
|
||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
return model
|
||||
|
@ -11,7 +11,7 @@ __all__ = ['InceptionV4'] + _models
|
||||
|
||||
default_cfgs = {
|
||||
'inception_v4': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
|
@ -20,7 +20,7 @@ __all__ = ['PNASNet5Large'] + _models
|
||||
|
||||
default_cfgs = {
|
||||
'pnasnet5large': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'crop_pct': 0.875,
|
||||
|
@ -37,20 +37,20 @@ def _cfg(url='', **kwargs):
|
||||
default_cfgs = {
|
||||
'senet154':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
|
||||
'seresnet18':
|
||||
_cfg(url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'seresnet34':
|
||||
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
|
||||
'seresnet50':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
|
||||
'seresnet101':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
|
||||
'seresnet152':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
|
||||
'seresnext26_32x4d':
|
||||
_cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'seresnet18': _cfg(
|
||||
url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'seresnet34': _cfg(
|
||||
url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
|
||||
'seresnet50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
|
||||
'seresnet101': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
|
||||
'seresnet152': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
|
||||
'seresnext26_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
|
||||
interpolation='bicubic'),
|
||||
'seresnext50_32x4d':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
|
||||
'seresnext101_32x4d':
|
||||
|
@ -35,7 +35,7 @@ __all__ = ['Xception'] + _models
|
||||
|
||||
default_cfgs = {
|
||||
'xception': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.8975,
|
||||
'interpolation': 'bicubic',
|
||||
|
Loading…
x
Reference in New Issue
Block a user