mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix a few small things.
This commit is contained in:
parent
dc85e5a237
commit
b4e216e377
@ -185,7 +185,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||||||
state_dict = filter_fn(state_dict)
|
state_dict = filter_fn(state_dict)
|
||||||
|
|
||||||
input_convs = cfg.get('first_conv', None)
|
input_convs = cfg.get('first_conv', None)
|
||||||
if input_convs is not None:
|
if input_convs is not None and in_chans != 3:
|
||||||
if isinstance(input_convs, str):
|
if isinstance(input_convs, str):
|
||||||
input_convs = (input_convs,)
|
input_convs = (input_convs,)
|
||||||
for input_conv_name in input_convs:
|
for input_conv_name in input_convs:
|
||||||
|
@ -32,12 +32,12 @@ default_cfgs = {
|
|||||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||||
'tf_inception_v3': _cfg(
|
'tf_inception_v3': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
||||||
num_classes=1001, has_aux=False),
|
num_classes=1000, has_aux=False, label_offset=1),
|
||||||
# my port of Tensorflow adversarially trained Inception V3 from
|
# my port of Tensorflow adversarially trained Inception V3 from
|
||||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||||
'adv_inception_v3': _cfg(
|
'adv_inception_v3': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
||||||
num_classes=1001, has_aux=False),
|
num_classes=1000, has_aux=False, label_offset=1),
|
||||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||||
'gluon_inception_v3': _cfg(
|
'gluon_inception_v3': _cfg(
|
||||||
|
@ -284,7 +284,7 @@ def main():
|
|||||||
if args.model == 'all':
|
if args.model == 'all':
|
||||||
# validate all models in a list of names with pretrained checkpoints
|
# validate all models in a list of names with pretrained checkpoints
|
||||||
args.pretrained = True
|
args.pretrained = True
|
||||||
model_names = list_models(pretrained=True)
|
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
|
||||||
model_cfgs = [(n, '') for n in model_names]
|
model_cfgs = [(n, '') for n in model_names]
|
||||||
elif not is_model(args.model):
|
elif not is_model(args.model):
|
||||||
# model name doesn't exist, try as wildcard filter
|
# model name doesn't exist, try as wildcard filter
|
||||||
|
Loading…
x
Reference in New Issue
Block a user