mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update efficientnet.py and convnext.py to multi-weight, add ImageNet-12k pretrained EfficientNet-B5 and ConvNeXt-Nano.
This commit is contained in:
parent
e7da205345
commit
6a01101905
@ -1,5 +1,6 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Callable, Any, Tuple
|
||||
|
||||
@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url
|
||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
||||
from timm.models._manipulate import adapt_input_conv
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._prune import adapt_model_from_file
|
||||
@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
|
||||
pretrained_url = pretrained_cfg.get('url', None)
|
||||
pretrained_file = pretrained_cfg.get('file', None)
|
||||
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
||||
|
||||
# resolve where to load pretrained weights from
|
||||
load_from = ''
|
||||
pretrained_loc = ''
|
||||
@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
|
||||
else:
|
||||
# default source == timm or unspecified
|
||||
if pretrained_file:
|
||||
# file load override is the highest priority if set
|
||||
load_from = 'file'
|
||||
pretrained_loc = pretrained_file
|
||||
elif pretrained_url:
|
||||
load_from = 'url'
|
||||
pretrained_loc = pretrained_url
|
||||
elif hf_hub_id and has_hf_hub(necessary=True):
|
||||
else:
|
||||
# next, HF hub is prioritized unless a valid cached version of weights exists already
|
||||
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
|
||||
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
|
||||
# hf-hub available as alternate weight source in default_cfg
|
||||
load_from = 'hf-hub'
|
||||
pretrained_loc = hf_hub_id
|
||||
elif pretrained_url:
|
||||
load_from = 'url'
|
||||
pretrained_loc = pretrained_url
|
||||
|
||||
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
|
||||
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
||||
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
||||
@ -105,7 +112,7 @@ def load_custom_pretrained(
|
||||
pretrained_loc = download_cached_file(
|
||||
pretrained_loc,
|
||||
check_hash=_CHECK_HASH,
|
||||
progress=_DOWNLOAD_PROGRESS
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
)
|
||||
|
||||
if load_fn is not None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
|
||||
return cached_file
|
||||
|
||||
|
||||
def check_cached_file(url, check_hash=True):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if os.path.exists(cached_file):
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
if hash_prefix:
|
||||
with open(cached_file, 'rb') as f:
|
||||
hd = hashlib.sha256(f.read()).hexdigest()
|
||||
if hd[:len(hash_prefix)] != hash_prefix:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
@ -145,7 +166,9 @@ def save_for_hf(model, save_directory, model_config=None):
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
if isinstance(global_pool_type, str) and global_pool_type:
|
||||
hf_config['global_pool'] = global_pool_type
|
||||
|
||||
if 'label' in model_config:
|
||||
_logger.warning(
|
||||
|
@ -19,6 +19,7 @@ class PretrainedCfg:
|
||||
|
||||
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
|
||||
architecture: Optional[str] = None # architecture variant can be set when not implicit
|
||||
tag: Optional[str] = None # pretrained tag of source
|
||||
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
|
||||
|
||||
# input / data config
|
||||
|
@ -7,6 +7,7 @@ import re
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
|
||||
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
|
||||
|
||||
@ -48,24 +49,31 @@ def register_model(fn):
|
||||
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||
# entrypoints or non-matching combos
|
||||
cfg = mod.default_cfgs[model_name]
|
||||
if not isinstance(cfg, DefaultCfg):
|
||||
default_cfg = mod.default_cfgs[model_name]
|
||||
if not isinstance(default_cfg, DefaultCfg):
|
||||
# new style default cfg dataclass w/ multiple entries per model-arch
|
||||
assert isinstance(cfg, dict)
|
||||
assert isinstance(default_cfg, dict)
|
||||
# old style cfg dict per model-arch
|
||||
cfg = PretrainedCfg(**cfg)
|
||||
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
|
||||
pretrained_cfg = PretrainedCfg(**default_cfg)
|
||||
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
|
||||
|
||||
for tag_idx, tag in enumerate(cfg.tags):
|
||||
for tag_idx, tag in enumerate(default_cfg.tags):
|
||||
is_default = tag_idx == 0
|
||||
pretrained_cfg = cfg.cfgs[tag]
|
||||
pretrained_cfg = default_cfg.cfgs[tag]
|
||||
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
|
||||
replace_items = dict(architecture=model_name, tag=tag if tag else None)
|
||||
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
|
||||
# auto-complete hub name w/ architecture.tag
|
||||
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
|
||||
pretrained_cfg = replace(pretrained_cfg, **replace_items)
|
||||
|
||||
if is_default:
|
||||
_model_pretrained_cfgs[model_name] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add tagless entry if it's default and has weights
|
||||
_model_has_pretrained.add(model_name)
|
||||
|
||||
if tag:
|
||||
model_name_tag = '.'.join([model_name, tag])
|
||||
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add model w/ tag if tag is valid
|
||||
@ -74,7 +82,7 @@ def register_model(fn):
|
||||
else:
|
||||
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
|
||||
|
||||
_model_default_cfgs[model_name] = cfg
|
||||
_model_default_cfgs[model_name] = default_cfg
|
||||
|
||||
return fn
|
||||
|
||||
|
@ -361,7 +361,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
@ -375,90 +374,130 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# timm specific variants
|
||||
'convnext_atto.timm_in1k': _cfg(
|
||||
'convnext_atto.d2_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'convnext_atto_ols.timm_in1k': _cfg(
|
||||
'convnext_atto_ols.a2_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'convnext_femto.timm_in1k': _cfg(
|
||||
'convnext_femto.d1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'convnext_femto_ols.timm_in1k': _cfg(
|
||||
'convnext_femto_ols.d1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'convnext_pico.timm_in1k': _cfg(
|
||||
'convnext_pico.d1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'convnext_pico_ols.timm_in1k': _cfg(
|
||||
'convnext_pico_ols.d1_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_nano.timm_in1k': _cfg(
|
||||
'convnext_nano.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_nano.d1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_nano_ols.timm_in1k': _cfg(
|
||||
'convnext_nano_ols.d1h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_tiny_hnf.timm_in1k': _cfg(
|
||||
'convnext_tiny_hnf.a2h_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_nano.in12k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, num_classes=11821),
|
||||
|
||||
'convnext_tiny.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_base.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_large.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_xlarge.untrained': _cfg(),
|
||||
|
||||
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_base.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_large.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
|
||||
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'convnext_tiny_in22k.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
|
||||
'convnext_small_in22k.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
|
||||
'convnext_base_in22k.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
|
||||
'convnext_large_in22k.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
|
||||
'convnext_xlarge_in22k.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
|
||||
'convnext_tiny.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21841),
|
||||
'convnext_small.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21841),
|
||||
'convnext_base.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21841),
|
||||
'convnext_large.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21841),
|
||||
'convnext_xlarge.fb_in22k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21841),
|
||||
})
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -711,10 +711,10 @@ default_cfgs = generate_default_cfgs({
|
||||
|
||||
|
||||
# patch models, imagenet21k (weights from official Google JAX impl)
|
||||
'vit_large_patch32_224.v1_in21k': _cfg(
|
||||
'vit_large_patch32_224.orig_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||
num_classes=21843),
|
||||
'vit_huge_patch14_224.v1_in21k': _cfg(
|
||||
'vit_huge_patch14_224.orig_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
||||
hf_hub_id='timm/vit_huge_patch14_224_in21k',
|
||||
custom_load=True, num_classes=21843),
|
||||
|
Loading…
x
Reference in New Issue
Block a user