Swin and FocalNet weights on HF hub. Add model deprecation functionality w/ some registry tweaks.

pull/1628/head
Ross Wightman 2023-03-18 14:55:09 -07:00
parent 2fc5ac3d18
commit 572f05096a
22 changed files with 560 additions and 449 deletions

View File

@ -74,12 +74,12 @@ from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
register_notrace_module, is_notrace_module, get_notrace_modules, \
register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, \
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
from ._prune import adapt_model_from_string
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

View File

@ -3,10 +3,10 @@ from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
from timm.layers import set_layer_config
from ._pretrained import PretrainedCfg, split_model_name_tag
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint
from ._pretrained import PretrainedCfg
from ._registry import is_model, model_entrypoint, split_model_name_tag
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']

View File

@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import logging
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union
import torch
try:
@ -13,30 +14,32 @@ try:
except ImportError:
_has_safetensors = False
import timm.models._builder
_logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
def clean_state_dict(state_dict):
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = OrderedDict()
cleaned_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v
return cleaned_state_dict
def load_state_dict(checkpoint_path, use_ema=True):
def load_state_dict(
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
) -> Dict[str, Any]:
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict_key = ''
if isinstance(checkpoint, dict):
@ -56,22 +59,37 @@ def load_state_dict(checkpoint_path, use_ema=True):
raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
def load_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
strict: bool = True,
remap: bool = False,
filter_fn: Optional[Callable] = None,
):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
timm.models._model_builder.load_pretrained(checkpoint_path)
model.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema)
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
if remap:
state_dict = remap_checkpoint(model, state_dict)
state_dict = remap_state_dict(state_dict, model)
elif filter_fn:
state_dict = filter_fn(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True):
def remap_state_dict(
state_dict: Dict[str, Any],
model: torch.nn.Module,
allow_reshape: bool = True
):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order.
"""
@ -87,7 +105,13 @@ def remap_checkpoint(model, state_dict, allow_reshape=True):
return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
def resume_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
optimizer: torch.optim.Optimizer = None,
loss_scaler: Any = None,
log_info: bool = True,
):
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')

View File

@ -3,7 +3,7 @@ import math
import re
from collections import defaultdict
from itertools import chain
from typing import Callable, Union, Dict
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
import torch
from torch import nn as nn
@ -13,7 +13,7 @@ __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_wi
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
def model_parameters(model, exclude_head=False):
def model_parameters(model: nn.Module, exclude_head: bool = False):
if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2]
@ -21,7 +21,12 @@ def model_parameters(model, exclude_head=False):
return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
def named_apply(
fn: Callable,
module: nn.Module, name='',
depth_first: bool = True,
include_root: bool = False,
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
def named_modules(
module: nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False,
):
if not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
yield name, module
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
def named_modules_with_params(
module: nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False,
):
if module._parameters and not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
@ -58,9 +73,9 @@ MATCH_PREV_GROUP = (99999,)
def group_with_matcher(
named_objects,
named_objects: Iterator[Tuple[str, Any]],
group_matcher: Union[Dict, Callable],
output_values: bool = False,
return_values: bool = False,
reverse: bool = False
):
if isinstance(group_matcher, dict):
@ -96,7 +111,7 @@ def group_with_matcher(
# map layers into groups via ordinals (ints or tuples of ints) from matcher
grouping = defaultdict(list)
for k, v in named_objects:
grouping[_get_grouping(k)].append(v if output_values else k)
grouping[_get_grouping(k)].append(v if return_values else k)
# remap to integers
layer_id_to_param = defaultdict(list)
@ -107,7 +122,7 @@ def group_with_matcher(
layer_id_to_param[lid].extend(grouping[k])
if reverse:
assert not output_values, "reverse mapping only sensible for name output"
assert not return_values, "reverse mapping only sensible for name output"
# output reverse mapping
param_to_layer_id = {}
for lid, lm in layer_id_to_param.items():
@ -121,24 +136,29 @@ def group_with_matcher(
def group_parameters(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
return_values: bool = False,
reverse: bool = False,
):
return group_with_matcher(
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
def group_modules(
module: nn.Module,
group_matcher,
output_values=False,
reverse=False,
return_values: bool = False,
reverse: bool = False,
):
return group_with_matcher(
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
def flatten_modules(
named_modules: Iterator[Tuple[str, nn.Module]],
depth: int = 1,
prefix: Union[str, Tuple[str, ...]] = '',
module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
):
prefix_is_tuple = isinstance(prefix, tuple)
if isinstance(module_types, str):
if module_types == 'container':

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
@dataclass
@ -91,41 +91,3 @@ class DefaultCfg:
def default_with_tag(self):
tag = self.tags[0]
return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out

View File

@ -5,16 +5,19 @@ Hacked together by / Copyright 2020 Ross Wightman
import fnmatch
import re
import sys
import warnings
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
from ._pretrained import PretrainedCfg, DefaultCfg
__all__ = [
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
'get_pretrained_cfg_value', 'is_model_pretrained'
]
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
@ -23,12 +26,52 @@ _model_has_pretrained: Set[str] = set() # set of model names that have pretrain
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
_deprecated_models: Dict[str, Optional[str]] = {}
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0]
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module
mod = sys.modules[fn.__module__]
@ -87,6 +130,37 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
return fn
def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
def _fn(pretrained=False, **kwargs):
assert current_fn is not None, f'Model {deprecated_name} has been removed with no replacement.'
warnings.warn(f'Mapping deprecated model {deprecated_name} to current {current_fn.__name__}', stacklevel=2)
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
return _fn
def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
mod = sys.modules[module_name]
module_name_split = module_name.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
for deprecated, current in deprecation_map.items():
if hasattr(mod, '__all__'):
mod.__all__.append(deprecated)
current_fn = None
current_tag = ''
if current:
current_name, current_tag = split_model_name_tag(current)
current_fn = getattr(mod, current_name)
deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
setattr(mod, deprecated, deprecated_entrypoint_fn)
_model_entrypoints[deprecated] = deprecated_entrypoint_fn
_model_to_module[deprecated] = module_name
_module_to_models[module_name].add(deprecated)
_deprecated_models[deprecated] = current
_module_to_deprecated_models[module_name][deprecated] = current
def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
@ -122,16 +196,14 @@ def list_models(
# FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained
if module:
all_models: Iterable[str] = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
if include_tags:
# expand model names to include names w/ pretrained tags
models_with_tags = []
models_with_tags: Set[str] = set()
for m in all_models:
models_with_tags.extend(_model_with_tags[m])
models_with_tags.update(_model_with_tags[m])
all_models = models_with_tags
if filter:
@ -142,7 +214,7 @@ def list_models(
if len(include_models):
models = models.union(include_models)
else:
models = set(all_models)
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
@ -173,6 +245,11 @@ def list_pretrained(
)
def get_deprecated_models(module: str = '') -> Dict[str, str]:
all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
return deepcopy(all_deprecated)
def is_model(model_name: str) -> bool:
""" Check if a model name exists
"""

View File

@ -63,8 +63,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit']

View File

@ -50,8 +50,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
@ -519,6 +518,13 @@ def _cfgv2(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
@ -558,12 +564,6 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
@ -582,25 +582,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(),
'convnext_xxlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/',
@ -622,6 +603,23 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/',
@ -1038,3 +1036,22 @@ def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
'convnext_small_in22k': 'convnext_small.fb_in22k',
'convnext_base_in22k': 'convnext_base.fb_in22k',
'convnext_large_in22k': 'convnext_large.fb_in22k',
'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
})

View File

@ -11,8 +11,6 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
@ -22,13 +20,12 @@ import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['DaViT']

View File

@ -21,8 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this

View File

@ -26,8 +26,7 @@ from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_nor
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
EfficientFormer_width = {

View File

@ -51,8 +51,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['EfficientNet', 'EfficientNetFeatures']
@ -1064,42 +1063,46 @@ default_cfgs = generate_default_cfgs({
'efficientnetv2_xl.untrained': _cfg(
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnet_b0.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
hf_hub_id='timm/',
input_size=(3, 224, 224)),
'tf_efficientnet_b1.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
hf_hub_id='timm/',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
hf_hub_id='timm/',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
hf_hub_id='timm/',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
hf_hub_id='timm/',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
hf_hub_id='timm/',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
hf_hub_id='timm/',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
hf_hub_id='timm/',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_b8.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
hf_hub_id='timm/',
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
hf_hub_id='timm/',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
'tf_efficientnet_b0.ap_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
@ -1146,46 +1149,42 @@ default_cfgs = generate_default_cfgs({
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
'tf_efficientnet_b0.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
hf_hub_id='timm/',
input_size=(3, 224, 224)),
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
'tf_efficientnet_b1.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
hf_hub_id='timm/',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
'tf_efficientnet_b2.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
hf_hub_id='timm/',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
'tf_efficientnet_b3.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
hf_hub_id='timm/',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
'tf_efficientnet_b4.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
hf_hub_id='timm/',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
'tf_efficientnet_b5.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
hf_hub_id='timm/',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
'tf_efficientnet_b6.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
hf_hub_id='timm/',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
'tf_efficientnet_b7.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
hf_hub_id='timm/',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
'tf_efficientnet_b8.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
hf_hub_id='timm/',
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
hf_hub_id='timm/',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_es.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
@ -1248,22 +1247,6 @@ default_cfgs = generate_default_cfgs({
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
'tf_efficientnetv2_s.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_l.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
hf_hub_id='timm/',
@ -1285,6 +1268,22 @@ default_cfgs = generate_default_cfgs({
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_l.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
hf_hub_id='timm/',
@ -2289,3 +2288,34 @@ def tinynet_d(pretrained=False, **kwargs):
def tinynet_e(pretrained=False, **kwargs):
model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
return model
register_model_deprecations(__name__, {
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k',
'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k',
'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k',
'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k',
'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k',
'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k',
'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k',
'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k',
'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k',
'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k',
'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k',
'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k',
'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k',
'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k',
'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k',
'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475',
'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k',
'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k',
'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k',
'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k',
'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k',
'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k',
'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k',
'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k',
'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k',
})

View File

@ -28,7 +28,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['FocalNet']
@ -485,51 +485,51 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.proj', 'classifier': 'head.fc',
**kwargs
'license': 'mit', **kwargs
}
default_cfgs = {
"focalnet_tiny_srf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_srf.pth'),
"focalnet_small_srf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth'),
"focalnet_base_srf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth'),
"focalnet_tiny_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_lrf.pth'),
"focalnet_small_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth'),
"focalnet_base_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'),
"focalnet_large_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth',
default_cfgs = generate_default_cfgs({
"focalnet_tiny_srf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_small_srf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_base_srf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_tiny_lrf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_small_lrf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_base_lrf.ms_in1k": _cfg(
hf_hub_id='timm/'),
"focalnet_large_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_large_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth',
"focalnet_large_fl4.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth',
"focalnet_xlarge_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth',
"focalnet_xlarge_fl4.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_huge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth',
"focalnet_huge_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
num_classes=21842),
"focalnet_huge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth',
"focalnet_huge_fl4.ms_in22k": _cfg(
hf_hub_id='timm/',
num_classes=0),
}
})
def checkpoint_filter_fn(state_dict, model: FocalNet):
state_dict = state_dict.get('model', state_dict)
if 'stem.proj.weight' in state_dict:
return
return state_dict
import re
out_dict = {}
if 'model' in state_dict:
state_dict = state_dict['model']
dest_dict = model.state_dict()
for k, v in state_dict.items():
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)

View File

@ -24,7 +24,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import Dict
@ -35,9 +34,7 @@ from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['Levit']

View File

@ -52,8 +52,7 @@ from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']

View File

@ -22,8 +22,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -796,3 +795,9 @@ def lcnet_150(pretrained=False, **kwargs):
""" PP-LCNet 1.5"""
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
return model
register_model_deprecations(__name__, {
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
})

View File

@ -23,11 +23,11 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert, ClassifierHead
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .vision_transformer import get_init_weights_vit
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
@ -302,7 +302,12 @@ class PatchMerging(nn.Module):
""" Patch Merging Layer.
"""
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm):
def __init__(
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: Callable = nn.LayerNorm,
):
"""
Args:
dim: Number of input channels.
@ -345,13 +350,13 @@ class SwinTransformerStage(nn.Module):
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.,
norm_layer: Callable = nn.LayerNorm,
output_nchw: bool = False,
):
"""
Args:
dim: Number of input channels.
input_resolution: Input resolution.
depth: Number of blocks.
downsample: Downsample layer at the end of the layer.
num_heads: Number of attention heads.
head_dim: Channels per head (dim // num_heads if not set)
window_size: Local window size.
@ -361,14 +366,12 @@ class SwinTransformerStage(nn.Module):
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
downsample: Downsample layer at the end of the layer.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth
self.use_nchw = output_nchw
self.grad_checkpointing = False
# patch merging layer
@ -401,18 +404,12 @@ class SwinTransformerStage(nn.Module):
for i in range(depth)])
def forward(self, x):
if self.use_nchw:
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
if self.use_nchw:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
return x
@ -442,7 +439,6 @@ class SwinTransformer(nn.Module):
drop_path_rate: float = 0.1,
norm_layer: Union[str, Callable] = nn.LayerNorm,
weight_init: str = '',
output_fmt: str = 'NHWC',
**kwargs,
):
"""
@ -465,15 +461,13 @@ class SwinTransformer(nn.Module):
"""
super().__init__()
assert global_pool in ('', 'avg')
assert output_fmt in ('NCHW', 'NHWC')
self.num_classes = num_classes
self.global_pool = global_pool
self.output_fmt = output_fmt
self.output_fmt = 'NHWC'
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.output_nchw = self.output_fmt == 'NCHW' # bool flag for fwd
self.feature_info = []
if not isinstance(embed_dim, (tuple, list)):
@ -518,7 +512,6 @@ class SwinTransformer(nn.Module):
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
output_nchw=self.output_nchw,
)]
in_dim = out_dim
if i > 0:
@ -577,14 +570,8 @@ class SwinTransformer(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
if self.output_nchw:
# patch embed always outputs NHWC, stage layers expect NCHW input
x = x.permute(0, 3, 1, 2)
x = self.layers(x)
if self.output_nchw:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.norm(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
@ -596,14 +583,10 @@ class SwinTransformer(nn.Module):
return x
def checkpoint_filter_fn(
state_dict,
model,
adapt_layer_scale=False,
interpolation='bicubic',
antialias=True,
):
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'head.fc.weight' in state_dict:
return state_dict
import re
out_dict = {}
state_dict = state_dict.get('model', state_dict)
@ -635,106 +618,83 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
'license': 'mit', **kwargs
}
default_cfgs = {
'swin_base_patch4_window12_384': _cfg(
default_cfgs = generate_default_cfgs({
'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_base_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
),
'swin_large_patch4_window12_384': _cfg(
'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_large_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
),
'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
'swin_small_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
'swin_base_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
'swin_base_patch4_window12_384.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_small_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
),
# tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
'swin_tiny_patch4_window7_224': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
),
'swin_base_patch4_window12_384_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_base_patch4_window7_224_in22k': _cfg(
'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_small_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_base_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384_in22k': _cfg(
'swin_base_patch4_window12_384.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_s3_tiny_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
),
'swin_s3_small_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
),
'swin_s3_base_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
)
}
@register_model
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-S @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
'swin_s3_tiny_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
'swin_s3_small_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
'swin_s3_base_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
})
@register_model
@ -747,44 +707,53 @@ def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
@register_model
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
""" Swin-B @ 384x384, trained ImageNet-22k
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-S @ 224x224
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
""" Swin-B @ 224x224, trained ImageNet-22k
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 224x224
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
""" Swin-L @ 384x384, trained ImageNet-22k
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
""" Swin-L @ 224x224, trained ImageNet-22k
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L @ 224x224
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_s3_tiny_224(pretrained=False, **kwargs):
""" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725
""" Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
@ -794,7 +763,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
@register_model
def swin_s3_small_224(pretrained=False, **kwargs):
""" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
""" Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
@ -804,10 +773,17 @@ def swin_s3_small_224(pretrained=False, **kwargs):
@register_model
def swin_s3_base_224(pretrained=False, **kwargs):
""" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
""" Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),
num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)
register_model_deprecations(__name__, {
'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
})

View File

@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
@ -425,9 +425,6 @@ class SwinTransformerV2Stage(nn.Module):
for i in range(depth)])
def forward(self, x):
if self.output_nchw:
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
x = self.downsample(x)
for blk in self.blocks:
@ -435,9 +432,6 @@ class SwinTransformerV2Stage(nn.Module):
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.output_nchw:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
return x
def _init_respostnorm(self):
@ -473,7 +467,6 @@ class SwinTransformerV2(nn.Module):
drop_path_rate: float = 0.1,
norm_layer: Callable = nn.LayerNorm,
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
output_fmt: str = 'NHWC',
**kwargs,
):
"""
@ -500,13 +493,11 @@ class SwinTransformerV2(nn.Module):
self.num_classes = num_classes
assert global_pool in ('', 'avg')
assert output_fmt in ('NCHW', 'NHWC')
self.global_pool = global_pool
self.output_fmt = output_fmt
self.output_fmt = 'NHWC'
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.output_nchw = self.output_fmt == 'NCHW'
self.feature_info = []
if not isinstance(embed_dim, (tuple, list)):
@ -544,7 +535,6 @@ class SwinTransformerV2(nn.Module):
drop_path=dpr[i],
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_sizes[i],
output_nchw=self.output_nchw,
)]
in_dim = out_dim
if i > 0:
@ -605,14 +595,8 @@ class SwinTransformerV2(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
if self.output_nchw:
# patch embed always outputs NHWC, stage layers expect NCHW input if output_nchw = True
x = x.permute(0, 3, 1, 2)
x = self.layers(x)
if self.output_nchw:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.norm(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
@ -625,10 +609,12 @@ class SwinTransformerV2(nn.Module):
def checkpoint_filter_fn(state_dict, model):
import re
out_dict = {}
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
if 'head.fc.weight' in state_dict:
return state_dict
out_dict = {}
import re
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
continue # skip buffers that should not be persistent
@ -657,53 +643,66 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
'license': 'mit', **kwargs
}
default_cfgs = {
'swinv2_tiny_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
),
'swinv2_tiny_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
),
'swinv2_small_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
),
'swinv2_small_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
),
'swinv2_base_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
),
'swinv2_base_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
),
'swinv2_base_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
default_cfgs = generate_default_cfgs({
'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
),
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'swinv2_large_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
),
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
}
'swinv2_tiny_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
),
'swinv2_tiny_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
),
'swinv2_small_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
),
'swinv2_small_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
),
'swinv2_base_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
),
'swinv2_base_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
),
'swinv2_base_window12_192.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_large_window12_192.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
})
@register_model
@ -761,62 +760,72 @@ def swinv2_base_window8_256(pretrained=False, **kwargs):
@register_model
def swinv2_base_window12_192_22k(pretrained=False, **kwargs):
def swinv2_base_window12_192(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer_v2('swinv2_base_window12_192', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
def swinv2_base_window12to16_192to256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_base_window12to16_192to256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
def swinv2_base_window12to24_192to384(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_base_window12to24_192to384', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12_192_22k(pretrained=False, **kwargs):
def swinv2_large_window12_192(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs)
return _create_swin_transformer_v2('swinv2_large_window12_192', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
def swinv2_large_window12to16_192to256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_large_window12to16_192to256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
def swinv2_large_window12to24_192to384(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
'swinv2_large_window12to24_192to384', pretrained=pretrained, **model_kwargs)
register_model_deprecations(__name__, {
'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k',
'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k',
'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',
'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k',
'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k',
'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
})

View File

@ -41,7 +41,7 @@ from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
@ -748,9 +748,11 @@ def init_weights(module: nn.Module, name: str = ''):
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
if 'head.fc.weight' in state_dict:
return state_dict
out_dict = {}
for k, v in state_dict.items():
if 'tau' in k:
# convert old tau based checkpoints -> logit_scale (inverse)
@ -791,43 +793,46 @@ def _cfg(url='', **kwargs):
}
default_cfgs = {
'swinv2_cr_tiny_384': _cfg(
default_cfgs = generate_default_cfgs({
'swinv2_cr_tiny_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224': _cfg(
'swinv2_cr_tiny_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224': _cfg(
'swinv2_cr_tiny_ns_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384': _cfg(
'swinv2_cr_small_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224': _cfg(
'swinv2_cr_small_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_224': _cfg(
'swinv2_cr_small_ns_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_256': _cfg(
'swinv2_cr_small_ns_256.untrained': _cfg(
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
'swinv2_cr_base_384': _cfg(
'swinv2_cr_base_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg(
'swinv2_cr_base_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224': _cfg(
'swinv2_cr_base_ns_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384': _cfg(
'swinv2_cr_large_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224': _cfg(
'swinv2_cr_large_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384': _cfg(
'swinv2_cr_huge_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224': _cfg(
'swinv2_cr_huge_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg(
'swinv2_cr_giant_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg(
'swinv2_cr_giant_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
}
})
@register_model

View File

@ -41,9 +41,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_,
resample_abs_pos_embed, RmsNorm
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this

View File

@ -20,8 +20,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer

View File

@ -17,8 +17,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from ._registry import generate_default_cfgs, register_model
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this