mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add ported Tensorflow MaxVit weights. Add a few more CLIP ViT fine-tunes. Tweak some model tag names. Improve model tag name sorting. Update HF hub push config layout.
This commit is contained in:
parent
dbe7531aa3
commit
72cfa57761
@ -1,4 +1,4 @@
|
||||
from .version import __version__
|
||||
from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
|
||||
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable, \
|
||||
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
@ -70,5 +70,6 @@ from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model, convert_sync_batchnorm
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .layers import set_fast_norm
|
||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
|
||||
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
from ._pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
|
||||
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
|
||||
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
from collections import deque, defaultdict
|
||||
from dataclasses import dataclass, field, replace
|
||||
from dataclasses import dataclass, field, replace, asdict
|
||||
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
||||
|
||||
|
||||
@ -8,13 +9,13 @@ class PretrainedCfg:
|
||||
"""
|
||||
"""
|
||||
# weight locations
|
||||
url: str = ''
|
||||
file: str = ''
|
||||
hf_hub_id: str = ''
|
||||
hf_hub_filename: str = ''
|
||||
url: Optional[Union[str, Tuple[str, str]]] = None
|
||||
file: Optional[str] = None
|
||||
hf_hub_id: Optional[str] = None
|
||||
hf_hub_filename: Optional[str] = None
|
||||
|
||||
source: str = '' # source of cfg / weight location used (url, file, hf-hub)
|
||||
architecture: str = '' # architecture variant can be set when not implicit
|
||||
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
|
||||
architecture: Optional[str] = None # architecture variant can be set when not implicit
|
||||
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
|
||||
|
||||
# input / data config
|
||||
@ -31,22 +32,40 @@ class PretrainedCfg:
|
||||
|
||||
# head config
|
||||
num_classes: int = 1000
|
||||
label_offset: int = 0
|
||||
label_offset: Optional[int] = None
|
||||
|
||||
# model attributes that vary with above or required for pretrained adaptation
|
||||
pool_size: Optional[Tuple[int, ...]] = None
|
||||
test_pool_size: Optional[Tuple[int, ...]] = None
|
||||
first_conv: str = ''
|
||||
classifier: str = ''
|
||||
first_conv: Optional[str] = None
|
||||
classifier: Optional[str] = None
|
||||
|
||||
license: str = ''
|
||||
source_url: str = ''
|
||||
paper: str = ''
|
||||
notes: str = ''
|
||||
license: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
paper: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@property
|
||||
def has_weights(self):
|
||||
return self.url.startswith('http') or self.file or self.hf_hub_id
|
||||
return self.url or self.file or self.hf_hub_id
|
||||
|
||||
def to_dict(self, remove_source=False, remove_null=True):
|
||||
return filter_pretrained_cfg(
|
||||
asdict(self),
|
||||
remove_source=remove_source,
|
||||
remove_null=remove_null
|
||||
)
|
||||
|
||||
|
||||
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
|
||||
filtered_cfg = {}
|
||||
for k, v in cfg.items():
|
||||
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
|
||||
continue
|
||||
if remove_null and v is None:
|
||||
continue
|
||||
filtered_cfg[k] = v
|
||||
return filtered_cfg
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -71,7 +90,7 @@ def split_model_name_tag(model_name: str, no_tag=''):
|
||||
return model_name, tag
|
||||
|
||||
|
||||
def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
out = defaultdict(DefaultCfg)
|
||||
default_set = set() # no tag and tags ending with * are prioritized as default
|
||||
|
||||
@ -82,21 +101,22 @@ def generate_defaults(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
|
||||
model, tag = split_model_name_tag(k)
|
||||
is_default_set = model in default_set
|
||||
priority = not tag or (tag.endswith('*') and not is_default_set)
|
||||
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
|
||||
tag = tag.strip('*')
|
||||
|
||||
default_cfg = out[model]
|
||||
if has_weights:
|
||||
default_cfg.is_pretrained = True
|
||||
|
||||
if priority:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
default_set.add(model)
|
||||
elif has_weights and not default_set:
|
||||
elif has_weights and not default_cfg.is_pretrained:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
else:
|
||||
default_cfg.tags.append(tag)
|
||||
|
||||
if has_weights:
|
||||
default_cfg.is_pretrained = True
|
||||
|
||||
default_cfg.cfgs[tag] = v
|
||||
|
||||
return out
|
||||
|
@ -21,7 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
|
||||
create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from ._pretrained import generate_defaults
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -373,7 +373,7 @@ def _cfg(url='', **kwargs):
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_defaults({
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# timm specific variants
|
||||
'convnext_atto.timm_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
||||
|
@ -575,7 +575,7 @@ def build_model_with_cfg(
|
||||
)
|
||||
|
||||
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
|
||||
pretrained_cfg = dataclasses.asdict(pretrained_cfg)
|
||||
pretrained_cfg = pretrained_cfg.to_dict()
|
||||
|
||||
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||
|
||||
|
@ -15,11 +15,13 @@ except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from timm import __version__
|
||||
from timm.models._pretrained import filter_pretrained_cfg
|
||||
|
||||
try:
|
||||
from huggingface_hub import (create_repo, get_hf_file_metadata,
|
||||
hf_hub_download, hf_hub_url,
|
||||
repo_type_and_id_from_hf_id, upload_folder)
|
||||
from huggingface_hub import (
|
||||
create_repo, get_hf_file_metadata,
|
||||
hf_hub_download, hf_hub_url,
|
||||
repo_type_and_id_from_hf_id, upload_folder)
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||
_has_hf_hub = True
|
||||
@ -46,8 +48,11 @@ def get_cache_dir(child_dir=''):
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
@ -90,10 +95,27 @@ def _download_from_hf(model_id: str, filename: str):
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'config.json')
|
||||
pretrained_cfg = load_cfg_from_json(cached_file)
|
||||
|
||||
hf_config = load_cfg_from_json(cached_file)
|
||||
if 'pretrained_cfg' not in hf_config:
|
||||
# old form, pull pretrain_cfg out of the base dict
|
||||
pretrained_cfg = hf_config
|
||||
hf_config = {}
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
||||
if 'labels' in pretrained_cfg:
|
||||
hf_config['label_name'] = pretrained_cfg.pop('labels')
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
|
||||
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
||||
pretrained_cfg = hf_config['pretrained_cfg']
|
||||
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
pretrained_cfg['source'] = 'hf-hub'
|
||||
model_name = pretrained_cfg.get('architecture')
|
||||
if 'num_classes' in hf_config:
|
||||
# model should be created with parent num_classes if they exist
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
model_name = hf_config['architecture']
|
||||
|
||||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
@ -114,10 +136,34 @@ def save_for_hf(model, save_directory, model_config=None):
|
||||
torch.save(model.state_dict(), weights_path)
|
||||
|
||||
config_path = save_directory / 'config.json'
|
||||
hf_config = model.pretrained_cfg
|
||||
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
|
||||
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
# set some values at root config level
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
|
||||
if 'label' in model_config:
|
||||
_logger.warning(
|
||||
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
|
||||
"Using provided 'label' field as 'label_name'.")
|
||||
model_config['label_name'] = model_config.pop('label')
|
||||
|
||||
label_name = model_config.pop('label_name', None)
|
||||
if label_name:
|
||||
assert isinstance(label_name, (dict, list, tuple))
|
||||
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
|
||||
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
|
||||
hf_config['label_name'] = model_config['label_name']
|
||||
|
||||
display_name = model_config.pop('display_name', None)
|
||||
if display_name:
|
||||
assert isinstance(display_name, dict)
|
||||
# map label_name -> user interface display name
|
||||
hf_config['display_name'] = model_config['display_name']
|
||||
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
hf_config.update(model_config)
|
||||
|
||||
with config_path.open('w') as f:
|
||||
@ -127,14 +173,14 @@ def save_for_hf(model, save_directory, model_config=None):
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str ='Add model',
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
):
|
||||
# Create repo if doesn't exist yet
|
||||
# Create repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
# Infer complete repo_id from repo_url
|
||||
@ -154,10 +200,11 @@ def push_to_hf_hub(
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, tmpdir, model_config=model_config)
|
||||
|
||||
# Add readme if does not exist
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
model_name = repo_id.split('/')[-1]
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
|
||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
# Upload model and return
|
||||
|
@ -54,7 +54,7 @@ from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, La
|
||||
from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
|
||||
from .layers import SelectAdaptivePool2d, create_pool2d
|
||||
from .layers import to_2tuple, extend_tuple, make_divisible, _assert
|
||||
from ._pretrained import generate_defaults
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from .registry import register_model
|
||||
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
|
||||
|
||||
@ -1859,7 +1859,7 @@ def _cfg(url='', **kwargs):
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_defaults({
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# Fiddling with configs / defaults / still pretraining
|
||||
'coatnet_pico_rw_224': _cfg(url=''),
|
||||
'coatnet_nano_rw_224': _cfg(
|
||||
@ -1941,86 +1941,67 @@ default_cfgs = generate_defaults({
|
||||
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
|
||||
|
||||
|
||||
# Trying to be like the MaxViT paper configs
|
||||
# MaxViT models ported from official Tensorflow impl
|
||||
'maxvit_tiny_tf_224.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_tiny_tf_224_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_tiny_tf_384.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_tiny_tf_384_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_tiny_tf_512.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_tiny_tf_512_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_small_tf_224.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_small_tf_224_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_small_tf_224.in1k',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_small_tf_384.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_small_tf_384_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_small_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_small_tf_512.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_small_tf_512_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_small_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_base_tf_224.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_base_tf_224_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_base_tf_224.in1k',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_base_tf_384.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_base_tf_384_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_base_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_base_tf_512.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_base_tf_512_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_base_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_224.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_large_tf_224_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_large_tf_224.in1k',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_large_tf_384.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_large_tf_384_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_large_tf_384.in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_512.in1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_large_tf_512_in1k.pth',
|
||||
hf_hub_id='timm/maxvit_large_tf_512.in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'maxvit_base_tf_224.in21k': _cfg(
|
||||
url='',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'maxvit_base_tf_384.in21k_ft1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_base_tf_384_in21k_ft_in1k.pth',
|
||||
url=''),
|
||||
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_base_tf_512.in21k_ft1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_base_tf_512_in21k_ft_in1k.pth',
|
||||
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_224.in21k': _cfg(
|
||||
url=''),
|
||||
'maxvit_large_tf_384.in21k_ft1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_large_tf_384_in21k_ft_in1k.pth',
|
||||
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_large_tf_512.in21k_ft1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_large_tf_512_in21k_ft_in1k.pth',
|
||||
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_xlarge_tf_224.in21k': _cfg(
|
||||
url=''),
|
||||
'maxvit_xlarge_tf_384.in21k_ft1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_xlarge_tf_384_in21k_ft_in1k.pth',
|
||||
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
|
||||
'maxvit_xlarge_tf_512.in21k_ft1k': _cfg(
|
||||
url='',
|
||||
#file='maxvit_xlarge_tf_512_in21k_ft_in1k.pth',
|
||||
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
|
||||
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
|
||||
})
|
||||
|
||||
|
@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
@ -84,7 +84,7 @@ def _natural_key(string_):
|
||||
|
||||
|
||||
def list_models(
|
||||
filter: str = '',
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
pretrained=False,
|
||||
exclude_filters: str = '',
|
||||
@ -114,7 +114,12 @@ def list_models(
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
|
||||
# FIXME wildcard filter tag as well as model arch name
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
for m in all_models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
models = []
|
||||
@ -134,13 +139,6 @@ def list_models(
|
||||
if len(exclude_models):
|
||||
models = set(models).difference(exclude_models)
|
||||
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
for m in models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
models = models_with_tags
|
||||
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
|
||||
@ -150,6 +148,18 @@ def list_models(
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
def list_pretrained(
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: str = '',
|
||||
):
|
||||
return list_models(
|
||||
filter=filter,
|
||||
pretrained=True,
|
||||
exclude_filters=exclude_filters,
|
||||
include_tags=True,
|
||||
)
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
|
@ -34,7 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
|
||||
from ._pretrained import generate_defaults
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from .registry import register_model
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@ -492,7 +492,8 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
||||
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
||||
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||
if model.cls_token is not None:
|
||||
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||
if pos_embed_w.shape != model.pos_embed.shape:
|
||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||
@ -630,51 +631,74 @@ def _cfg(url='', **kwargs):
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_defaults({
|
||||
# patch models (weights from official Google JAX impl)
|
||||
'vit_tiny_patch16_224.augreg_in21k_ft_1k': _cfg(
|
||||
default_cfgs = generate_default_cfgs({
|
||||
|
||||
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
|
||||
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_tiny_patch16_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_small_patch32_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_small_patch32_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_small_patch16_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_small_patch16_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_base_patch32_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_base_patch32_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_base_patch16_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_base_patch16_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_base_patch8_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_large_patch32_384.v1_in21k_ft_1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_large_patch16_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_large_patch16_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
# re-finetuned augreg 21k FT on in1k weights
|
||||
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
|
||||
file='b16_augreg-a-8.pth'),
|
||||
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(
|
||||
url=''),
|
||||
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
|
||||
url=''),
|
||||
|
||||
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
|
||||
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth'),
|
||||
'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth'),
|
||||
'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
# How to train your ViT (augreg) weights trained on in1k
|
||||
'vit_base_patch16_224.augreg_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||
custom_load=True),
|
||||
'vit_base_patch16_384.augreg_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
|
||||
|
||||
'vit_large_patch14_224.untrained': _cfg(url=''),
|
||||
'vit_huge_patch14_224.untrained': _cfg(url=''),
|
||||
'vit_giant_patch14_224.untrained': _cfg(url=''),
|
||||
@ -682,6 +706,15 @@ default_cfgs = generate_defaults({
|
||||
|
||||
|
||||
# patch models, imagenet21k (weights from official Google JAX impl)
|
||||
'vit_large_patch32_224.v1_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||
num_classes=21843),
|
||||
'vit_huge_patch14_224.v1_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
||||
hf_hub_id='timm/vit_huge_patch14_224_in21k',
|
||||
custom_load=True, num_classes=21843),
|
||||
|
||||
# How to train your ViT (augreg) weights, pretrained on in21k
|
||||
'vit_tiny_patch16_224.augreg_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
|
||||
custom_load=True, num_classes=21843),
|
||||
@ -700,16 +733,9 @@ default_cfgs = generate_defaults({
|
||||
'vit_base_patch8_224.augreg_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
|
||||
custom_load=True, num_classes=21843),
|
||||
'vit_large_patch32_224.v1_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||
num_classes=21843),
|
||||
'vit_large_patch16_224.augreg_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
|
||||
custom_load=True, num_classes=21843),
|
||||
'vit_huge_patch14_224.v1_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
|
||||
hf_hub_id='timm/vit_huge_patch14_224_in21k',
|
||||
custom_load=True, num_classes=21843),
|
||||
|
||||
# SAM trained models (https://arxiv.org/abs/2106.01548)
|
||||
'vit_base_patch32_224.sam': _cfg(
|
||||
@ -736,7 +762,7 @@ default_cfgs = generate_defaults({
|
||||
'vit_base_patch16_224_miil.in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
|
||||
'vit_base_patch16_224_miil.in21k_ft_1k': _cfg(
|
||||
'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
|
||||
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
|
||||
|
||||
@ -744,14 +770,15 @@ default_cfgs = generate_defaults({
|
||||
'vit_base_patch16_rpn_224.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
|
||||
'vit_medium_patch16_gap_240.in12k': _cfg(
|
||||
url='',
|
||||
hf_hub_id='timm/vit_medium_patch16_gap_240.in12k',
|
||||
input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
|
||||
'vit_medium_patch16_gap_256.in12k_ft_1k': _cfg(
|
||||
url='',
|
||||
'vit_medium_patch16_gap_256.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_medium_patch16_gap_256.in12k_ft_in1k',
|
||||
input_size=(3, 256, 256), crop_pct=0.95),
|
||||
'vit_medium_patch16_gap_384.in12k_ft_1k': _cfg(
|
||||
url='',
|
||||
input_size=(3, 384, 384), crop_pct=0.95),
|
||||
'vit_medium_patch16_gap_384.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_medium_patch16_gap_384.in12k_ft_in1k',
|
||||
input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
|
||||
'vit_base_patch16_gap_224': _cfg(),
|
||||
|
||||
# CLIP pretrained image tower and related fine-tuned weights
|
||||
'vit_base_patch32_clip_224.laion2b': _cfg(
|
||||
@ -781,15 +808,16 @@ default_cfgs = generate_defaults({
|
||||
'vit_base_patch32_clip_384.laion2b_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch32_clip_384.laion2b_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
|
||||
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
'vit_base_patch32_clip_448.laion2b_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch32_clip_448.laion2b_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
||||
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in1k',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
|
||||
@ -816,10 +844,11 @@ default_cfgs = generate_defaults({
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
|
||||
'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
|
||||
'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
||||
'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k',
|
||||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
|
||||
@ -866,7 +895,8 @@ default_cfgs = generate_defaults({
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||
'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
|
||||
'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||
@ -876,10 +906,15 @@ default_cfgs = generate_defaults({
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
|
||||
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
|
||||
#hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
|
||||
hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
|
||||
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
|
||||
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
|
||||
hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
|
||||
@ -1118,37 +1153,48 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
|
||||
""" ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
||||
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
||||
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
||||
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
|
||||
""" ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
||||
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
||||
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
||||
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
|
||||
""" ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
||||
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
||||
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
||||
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_gap_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False,
|
||||
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_clip_224(pretrained=False, **kwargs):
|
||||
""" ViT-B/32 CLIP image tower @ 224x224
|
||||
|
@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._pretrained import generate_defaults
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from .layers import StdConv2dSame, StdConv2d, to_2tuple
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
@ -39,31 +39,31 @@ def _cfg(url='', **kwargs):
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_defaults({
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# hybrid in-1k models (weights from official JAX impl where they exist)
|
||||
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
custom_load=True,
|
||||
first_conv='patch_embed.backbone.conv'),
|
||||
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
|
||||
'vit_small_r26_s32_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
|
||||
custom_load=True,
|
||||
),
|
||||
'vit_small_r26_s32_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
|
||||
'vit_base_r26_s32_224.untrained': _cfg(),
|
||||
'vit_base_r50_s16_384.v1_in21k_ft_1k': _cfg(
|
||||
'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_large_r50_s32_224.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
|
||||
custom_load=True,
|
||||
),
|
||||
'vit_large_r50_s32_384.augreg_in21k_ft_1k': _cfg(
|
||||
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
|
||||
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
|
||||
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
|
||||
),
|
||||
|
Loading…
x
Reference in New Issue
Block a user