mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add inception_v3 models via torchvision, 4 different pretrained weight choices
This commit is contained in:
parent
af1a68d2e1
commit
2da0b4dbc1
@ -110,7 +110,7 @@ Several (less common) features that I often utilize in my projects are included.
|
||||
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
|
||||
* An inference script that dumps output to CSV is provided as an example
|
||||
|
||||
### Custom Weights
|
||||
### Self-trained Weights
|
||||
I've leveraged the training scripts in this repository to train a few of the models with missing weights to good levels of performance. These numbers are all for 224x224 training and validation image sizing with the usual 87.5% validation crop.
|
||||
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|
||||
@ -125,8 +125,12 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
||||
### Ported Weights
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|
||||
|---|---|---|---|---|---|
|
||||
| MNASNet 1.00 (B1) | 72.398 (27.602) | 90.930 (9.070) | 4.36M | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) |
|
||||
| Gluon Inception-V3 | 78.804 (21.196) | 94.380 (5.620) | 27.16M | bicubic | [MxNet Gluon](https://gluon-cv.mxnet.io/model_zoo/classification.html) |
|
||||
| Tensorflow Inception-V3 | 77.856 (22.144) | 93.644 (6.356) | 27.16M | bicubic | [Tensorflow Slim](https://github.com/tensorflow/models/tree/master/research/slim) |
|
||||
| Adversarially trained Inception-V3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | [Tensorflow Adv models](https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models) |
|
||||
| SE-MNASNet 1.00 (A1) | 73.086 (26.914) | 91.336 (8.664) | 3.87M | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) |
|
||||
| MNASNet 1.00 (B1) | 72.398 (27.602) | 90.930 (9.070) | 4.36M | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) |
|
||||
|
||||
|
||||
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
|
||||
|
||||
|
113
models/inception_v3.py
Normal file
113
models/inception_v3.py
Normal file
@ -0,0 +1,113 @@
|
||||
from torchvision.models import Inception3
|
||||
from models.helpers import load_pretrained
|
||||
from data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
default_cfgs = {
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
'inception_v3': {
|
||||
'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, # also works well enough with resnet defaults
|
||||
'std': IMAGENET_INCEPTION_STD, # also works well enough with resnet defaults
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
},
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
'tf_inception_v3': {
|
||||
'url': 'https://www.dropbox.com/s/xdh32bpdgqzpx8t/tf_inception_v3-e0069de4.pth?dl=1',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'std': IMAGENET_INCEPTION_STD,
|
||||
'num_classes': 1001,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
},
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
'adv_inception_v3': {
|
||||
'url': 'https://www.dropbox.com/s/b5pudqh84gtl7i8/adv_inception_v3-9e27bd63.pth?dl=1',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'std': IMAGENET_INCEPTION_STD,
|
||||
'num_classes': 1001,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
},
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
'gluon_inception_v3': {
|
||||
'url': 'https://www.dropbox.com/s/8uv6wrl6it6394u/gluon_inception_v3-9f746940.pth?dl=1',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
||||
'std': IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _assert_default_kwargs(kwargs):
|
||||
# for imported models (ie torchvision) without capability to change these params,
|
||||
# make sure they aren't being set to non-defaults
|
||||
assert kwargs.pop('global_pool', 'avg') == 'avg'
|
||||
assert kwargs.pop('drop_rate', 0.) == 0.
|
||||
|
||||
|
||||
def inception_v3(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
default_cfg = default_cfgs['inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=True, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
||||
|
||||
|
||||
def tf_inception_v3(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
default_cfg = default_cfgs['tf_inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
||||
|
||||
|
||||
def adv_inception_v3(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
default_cfg = default_cfgs['adv_inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
||||
|
||||
|
||||
def gluon_inception_v3(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
default_cfg = default_cfgs['gluon_inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
return model
|
@ -13,6 +13,7 @@ from models.genmobilenet import \
|
||||
semnasnet_050, semnasnet_075, semnasnet_100, semnasnet_140, tflite_semnasnet_100, mnasnet_small,\
|
||||
mobilenetv1_100, mobilenetv2_100, mobilenetv3_050, mobilenetv3_075, mobilenetv3_100,\
|
||||
fbnetc_100, chamnetv1_100, chamnetv2_100, spnasnet_100
|
||||
from models.inception_v3 import inception_v3, gluon_inception_v3, tf_inception_v3, adv_inception_v3
|
||||
|
||||
from models.helpers import load_checkpoint
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user