Swin and FocalNet weights on HF hub. Add model deprecation functionality w/ some registry tweaks.

This commit is contained in:
Ross Wightman 2023-03-18 14:55:09 -07:00
parent 2fc5ac3d18
commit 572f05096a
22 changed files with 560 additions and 449 deletions

View File

@ -74,12 +74,12 @@ from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \ from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
register_notrace_module, is_notrace_module, get_notrace_modules, \ register_notrace_module, is_notrace_module, get_notrace_modules, \
register_notrace_function, is_notrace_function, get_notrace_functions register_notrace_function, is_notrace_function, get_notrace_functions
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \ from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
group_modules, group_parameters, checkpoint_seq, adapt_input_conv group_modules, group_parameters, checkpoint_seq, adapt_input_conv
from ._pretrained import PretrainedCfg, DefaultCfg, \ from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
from ._prune import adapt_model_from_string from ._prune import adapt_model_from_string
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \ from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

View File

@ -3,10 +3,10 @@ from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit from urllib.parse import urlsplit
from timm.layers import set_layer_config from timm.layers import set_layer_config
from ._pretrained import PretrainedCfg, split_model_name_tag
from ._helpers import load_checkpoint from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf from ._hub import load_model_config_from_hf
from ._registry import is_model, model_entrypoint from ._pretrained import PretrainedCfg
from ._registry import is_model, model_entrypoint, split_model_name_tag
__all__ = ['parse_model_name', 'safe_model_name', 'create_model'] __all__ = ['parse_model_name', 'safe_model_name', 'create_model']

View File

@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import logging import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union
import torch import torch
try: try:
@ -13,30 +14,32 @@ try:
except ImportError: except ImportError:
_has_safetensors = False _has_safetensors = False
import timm.models._builder
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint'] __all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
def clean_state_dict(state_dict): def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
cleaned_state_dict = OrderedDict() cleaned_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k name = k[7:] if k.startswith('module.') else k
cleaned_state_dict[name] = v cleaned_state_dict[name] = v
return cleaned_state_dict return cleaned_state_dict
def load_state_dict(checkpoint_path, use_ema=True): def load_state_dict(
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
) -> Dict[str, Any]:
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly # Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"): if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors" assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu') checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
else: else:
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict_key = '' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
@ -56,22 +59,37 @@ def load_state_dict(checkpoint_path, use_ema=True):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): def load_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
use_ema: bool = True,
device: Union[str, torch.device] = 'cpu',
strict: bool = True,
remap: bool = False,
filter_fn: Optional[Callable] = None,
):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn # numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'): if hasattr(model, 'load_pretrained'):
timm.models._model_builder.load_pretrained(checkpoint_path) model.load_pretrained(checkpoint_path)
else: else:
raise NotImplementedError('Model cannot load numpy checkpoint') raise NotImplementedError('Model cannot load numpy checkpoint')
return return
state_dict = load_state_dict(checkpoint_path, use_ema)
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
if remap: if remap:
state_dict = remap_checkpoint(model, state_dict) state_dict = remap_state_dict(state_dict, model)
elif filter_fn:
state_dict = filter_fn(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict) incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True): def remap_state_dict(
state_dict: Dict[str, Any],
model: torch.nn.Module,
allow_reshape: bool = True
):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys). """ remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order. This assumes models (and originating state dict) were created with params registered in same order.
""" """
@ -87,7 +105,13 @@ def remap_checkpoint(model, state_dict, allow_reshape=True):
return out_dict return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): def resume_checkpoint(
model: torch.nn.Module,
checkpoint_path: str,
optimizer: torch.optim.Optimizer = None,
loss_scaler: Any = None,
log_info: bool = True,
):
resume_epoch = None resume_epoch = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')

View File

@ -3,7 +3,7 @@ import math
import re import re
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from typing import Callable, Union, Dict from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
import torch import torch
from torch import nn as nn from torch import nn as nn
@ -13,7 +13,7 @@ __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_wi
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq'] 'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
def model_parameters(model, exclude_head=False): def model_parameters(model: nn.Module, exclude_head: bool = False):
if exclude_head: if exclude_head:
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
return [p for p in model.parameters()][:-2] return [p for p in model.parameters()][:-2]
@ -21,7 +21,12 @@ def model_parameters(model, exclude_head=False):
return model.parameters() return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: def named_apply(
fn: Callable,
module: nn.Module, name='',
depth_first: bool = True,
include_root: bool = False,
) -> nn.Module:
if not depth_first and include_root: if not depth_first and include_root:
fn(module=module, name=name) fn(module=module, name=name)
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
return module return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): def named_modules(
module: nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False,
):
if not depth_first and include_root: if not depth_first and include_root:
yield name, module yield name, module
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
yield name, module yield name, module
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False): def named_modules_with_params(
module: nn.Module,
name: str = '',
depth_first: bool = True,
include_root: bool = False,
):
if module._parameters and not depth_first and include_root: if module._parameters and not depth_first and include_root:
yield name, module yield name, module
for child_name, child_module in module.named_children(): for child_name, child_module in module.named_children():
@ -58,9 +73,9 @@ MATCH_PREV_GROUP = (99999,)
def group_with_matcher( def group_with_matcher(
named_objects, named_objects: Iterator[Tuple[str, Any]],
group_matcher: Union[Dict, Callable], group_matcher: Union[Dict, Callable],
output_values: bool = False, return_values: bool = False,
reverse: bool = False reverse: bool = False
): ):
if isinstance(group_matcher, dict): if isinstance(group_matcher, dict):
@ -96,7 +111,7 @@ def group_with_matcher(
# map layers into groups via ordinals (ints or tuples of ints) from matcher # map layers into groups via ordinals (ints or tuples of ints) from matcher
grouping = defaultdict(list) grouping = defaultdict(list)
for k, v in named_objects: for k, v in named_objects:
grouping[_get_grouping(k)].append(v if output_values else k) grouping[_get_grouping(k)].append(v if return_values else k)
# remap to integers # remap to integers
layer_id_to_param = defaultdict(list) layer_id_to_param = defaultdict(list)
@ -107,7 +122,7 @@ def group_with_matcher(
layer_id_to_param[lid].extend(grouping[k]) layer_id_to_param[lid].extend(grouping[k])
if reverse: if reverse:
assert not output_values, "reverse mapping only sensible for name output" assert not return_values, "reverse mapping only sensible for name output"
# output reverse mapping # output reverse mapping
param_to_layer_id = {} param_to_layer_id = {}
for lid, lm in layer_id_to_param.items(): for lid, lm in layer_id_to_param.items():
@ -121,24 +136,29 @@ def group_with_matcher(
def group_parameters( def group_parameters(
module: nn.Module, module: nn.Module,
group_matcher, group_matcher,
output_values=False, return_values: bool = False,
reverse=False, reverse: bool = False,
): ):
return group_with_matcher( return group_with_matcher(
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse) module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
def group_modules( def group_modules(
module: nn.Module, module: nn.Module,
group_matcher, group_matcher,
output_values=False, return_values: bool = False,
reverse=False, reverse: bool = False,
): ):
return group_with_matcher( return group_with_matcher(
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse) named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'): def flatten_modules(
named_modules: Iterator[Tuple[str, nn.Module]],
depth: int = 1,
prefix: Union[str, Tuple[str, ...]] = '',
module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
):
prefix_is_tuple = isinstance(prefix, tuple) prefix_is_tuple = isinstance(prefix, tuple)
if isinstance(module_types, str): if isinstance(module_types, str):
if module_types == 'container': if module_types == 'container':

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs'] __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
@dataclass @dataclass
@ -91,41 +91,3 @@ class DefaultCfg:
def default_with_tag(self): def default_with_tag(self):
tag = self.tags[0] tag = self.tags[0]
return tag, self.cfgs[tag] return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out

View File

@ -5,16 +5,19 @@ Hacked together by / Copyright 2020 Ross Wightman
import fnmatch import fnmatch
import re import re
import sys import sys
import warnings
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg
__all__ = [ __all__ = [
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] 'get_pretrained_cfg_value', 'is_model_pretrained'
]
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module _module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {} # mapping of model names to module names _model_to_module: Dict[str, str] = {} # mapping of model names to module names
@ -23,12 +26,52 @@ _model_has_pretrained: Set[str] = set() # set of model names that have pretrain
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects _model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs _model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names _model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
_deprecated_models: Dict[str, Optional[str]] = {}
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag
def get_arch_name(model_name: str) -> str: def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0] return split_model_name_tag(model_name)[0]
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
out = defaultdict(DefaultCfg)
default_set = set() # no tag and tags ending with * are prioritized as default
for k, v in cfgs.items():
if isinstance(v, dict):
v = PretrainedCfg(**v)
has_weights = v.has_weights
model, tag = split_model_name_tag(k)
is_default_set = model in default_set
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
tag = tag.strip('*')
default_cfg = out[model]
if priority:
default_cfg.tags.appendleft(tag)
default_set.add(model)
elif has_weights and not default_cfg.is_pretrained:
default_cfg.tags.appendleft(tag)
else:
default_cfg.tags.append(tag)
if has_weights:
default_cfg.is_pretrained = True
default_cfg.cfgs[tag] = v
return out
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]: def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module # lookup containing module
mod = sys.modules[fn.__module__] mod = sys.modules[fn.__module__]
@ -87,6 +130,37 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
return fn return fn
def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
def _fn(pretrained=False, **kwargs):
assert current_fn is not None, f'Model {deprecated_name} has been removed with no replacement.'
warnings.warn(f'Mapping deprecated model {deprecated_name} to current {current_fn.__name__}', stacklevel=2)
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
return _fn
def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
mod = sys.modules[module_name]
module_name_split = module_name.split('.')
module_name = module_name_split[-1] if len(module_name_split) else ''
for deprecated, current in deprecation_map.items():
if hasattr(mod, '__all__'):
mod.__all__.append(deprecated)
current_fn = None
current_tag = ''
if current:
current_name, current_tag = split_model_name_tag(current)
current_fn = getattr(mod, current_name)
deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
setattr(mod, deprecated, deprecated_entrypoint_fn)
_model_entrypoints[deprecated] = deprecated_entrypoint_fn
_model_to_module[deprecated] = module_name
_module_to_models[module_name].add(deprecated)
_deprecated_models[deprecated] = current
_module_to_deprecated_models[module_name][deprecated] = current
def _natural_key(string_: str) -> List[Union[int, str]]: def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/""" """See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
@ -122,16 +196,14 @@ def list_models(
# FIXME should this be default behaviour? or default to include_tags=True? # FIXME should this be default behaviour? or default to include_tags=True?
include_tags = pretrained include_tags = pretrained
if module: all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
all_models: Iterable[str] = list(_module_to_models[module]) all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
else:
all_models = _model_entrypoints.keys()
if include_tags: if include_tags:
# expand model names to include names w/ pretrained tags # expand model names to include names w/ pretrained tags
models_with_tags = [] models_with_tags: Set[str] = set()
for m in all_models: for m in all_models:
models_with_tags.extend(_model_with_tags[m]) models_with_tags.update(_model_with_tags[m])
all_models = models_with_tags all_models = models_with_tags
if filter: if filter:
@ -142,7 +214,7 @@ def list_models(
if len(include_models): if len(include_models):
models = models.union(include_models) models = models.union(include_models)
else: else:
models = set(all_models) models = all_models
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
@ -173,6 +245,11 @@ def list_pretrained(
) )
def get_deprecated_models(module: str = '') -> Dict[str, str]:
all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
return deepcopy(all_deprecated)
def is_model(model_name: str) -> bool: def is_model(model_name: str) -> bool:
""" Check if a model name exists """ Check if a model name exists
""" """

View File

@ -63,8 +63,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
from .vision_transformer import checkpoint_filter_fn from .vision_transformer import checkpoint_filter_fn
__all__ = ['Beit'] __all__ = ['Beit']

View File

@ -50,8 +50,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
from timm.layers import NormMlpClassifierHead, ClassifierHead from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
@ -519,6 +518,13 @@ def _cfgv2(url='', **kwargs):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# timm specific variants # timm specific variants
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_atto.d2_in1k': _cfg( 'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -558,12 +564,6 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg( 'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
@ -582,25 +582,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821), crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(),
'convnext_xxlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg( 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -622,6 +603,23 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0), test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg( 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -1038,3 +1036,22 @@ def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None) model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs)) model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
register_model_deprecations(__name__, {
'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
'convnext_small_in22k': 'convnext_small.fb_in22k',
'convnext_base_in22k': 'convnext_base.fb_in22k',
'convnext_large_in22k': 'convnext_large.fb_in22k',
'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
})

View File

@ -11,8 +11,6 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# Copyright (c) 2022 Mingyu Ding # Copyright (c) 2022 Mingyu Ding
# All rights reserved. # All rights reserved.
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
import itertools
from collections import OrderedDict
from functools import partial from functools import partial
from typing import Tuple from typing import Tuple
@ -22,13 +20,12 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import NormMlpClassifierHead, ClassifierHead from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
__all__ = ['DaViT'] __all__ = ['DaViT']

View File

@ -21,8 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this __all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this

View File

@ -26,8 +26,7 @@ from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_nor
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
EfficientFormer_width = { EfficientFormer_width = {

View File

@ -51,8 +51,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._registry import register_model
__all__ = ['EfficientNet', 'EfficientNetFeatures'] __all__ = ['EfficientNet', 'EfficientNetFeatures']
@ -1064,42 +1063,46 @@ default_cfgs = generate_default_cfgs({
'efficientnetv2_xl.untrained': _cfg( 'efficientnetv2_xl.untrained': _cfg(
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnet_b0.aa_in1k': _cfg( 'tf_efficientnet_b0.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 224, 224)), input_size=(3, 224, 224)),
'tf_efficientnet_b1.aa_in1k': _cfg( 'tf_efficientnet_b1.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2.aa_in1k': _cfg( 'tf_efficientnet_b2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3.aa_in1k': _cfg( 'tf_efficientnet_b3.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4.aa_in1k': _cfg( 'tf_efficientnet_b4.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5.ra_in1k': _cfg( 'tf_efficientnet_b5.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6.aa_in1k': _cfg( 'tf_efficientnet_b6.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7.ra_in1k': _cfg( 'tf_efficientnet_b7.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_b8.ra_in1k': _cfg( 'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
hf_hub_id='timm/',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
'tf_efficientnet_b0.ap_in1k': _cfg( 'tf_efficientnet_b0.ap_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
@ -1146,46 +1149,42 @@ default_cfgs = generate_default_cfgs({
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_b0.ns_jft_in1k': _cfg( 'tf_efficientnet_b0.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 224, 224)), input_size=(3, 224, 224)),
'tf_efficientnet_b1.ns_jft_in1k': _cfg( 'tf_efficientnet_b1.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2.ns_jft_in1k': _cfg( 'tf_efficientnet_b2.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3.ns_jft_in1k': _cfg( 'tf_efficientnet_b3.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4.ns_jft_in1k': _cfg( 'tf_efficientnet_b4.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5.ns_jft_in1k': _cfg( 'tf_efficientnet_b5.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b6.ns_jft_in1k': _cfg( 'tf_efficientnet_b6.aa_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
'tf_efficientnet_b7.ns_jft_in1k': _cfg( 'tf_efficientnet_b7.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg( 'tf_efficientnet_b8.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
hf_hub_id='timm/',
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
'tf_efficientnet_es.in1k': _cfg( 'tf_efficientnet_es.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
@ -1248,22 +1247,6 @@ default_cfgs = generate_default_cfgs({
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
'tf_efficientnetv2_s.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_l.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in21k_ft_in1k': _cfg( 'tf_efficientnetv2_s.in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -1285,6 +1268,22 @@ default_cfgs = generate_default_cfgs({
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
'tf_efficientnetv2_m.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_l.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
hf_hub_id='timm/',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'tf_efficientnetv2_s.in21k': _cfg( 'tf_efficientnetv2_s.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
@ -2289,3 +2288,34 @@ def tinynet_d(pretrained=False, **kwargs):
def tinynet_e(pretrained=False, **kwargs): def tinynet_e(pretrained=False, **kwargs):
model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs) model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
return model return model
register_model_deprecations(__name__, {
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k',
'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k',
'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k',
'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k',
'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k',
'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k',
'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k',
'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k',
'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k',
'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k',
'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k',
'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k',
'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k',
'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k',
'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k',
'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475',
'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k',
'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k',
'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k',
'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k',
'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k',
'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k',
'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k',
'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k',
'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k',
})

View File

@ -28,7 +28,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply from ._manipulate import named_apply
from ._registry import register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['FocalNet'] __all__ = ['FocalNet']
@ -485,51 +485,51 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic', 'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.proj', 'classifier': 'head.fc', 'first_conv': 'stem.proj', 'classifier': 'head.fc',
**kwargs 'license': 'mit', **kwargs
} }
default_cfgs = { default_cfgs = generate_default_cfgs({
"focalnet_tiny_srf": _cfg( "focalnet_tiny_srf.ms_in1k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_srf.pth'), hf_hub_id='timm/'),
"focalnet_small_srf": _cfg( "focalnet_small_srf.ms_in1k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth'), hf_hub_id='timm/'),
"focalnet_base_srf": _cfg( "focalnet_base_srf.ms_in1k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth'), hf_hub_id='timm/'),
"focalnet_tiny_lrf": _cfg( "focalnet_tiny_lrf.ms_in1k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_lrf.pth'), hf_hub_id='timm/'),
"focalnet_small_lrf": _cfg( "focalnet_small_lrf.ms_in1k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth'), hf_hub_id='timm/'),
"focalnet_base_lrf": _cfg( "focalnet_base_lrf.ms_in1k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'), hf_hub_id='timm/'),
"focalnet_large_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth', "focalnet_large_fl3.ms_in22k": _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_large_fl4": _cfg( "focalnet_large_fl4.ms_in22k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth', hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl3": _cfg( "focalnet_xlarge_fl3.ms_in22k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth', hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl4": _cfg( "focalnet_xlarge_fl4.ms_in22k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth', hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
"focalnet_huge_fl3": _cfg( "focalnet_huge_fl3.ms_in22k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth', hf_hub_id='timm/',
num_classes=21842), num_classes=21842),
"focalnet_huge_fl4": _cfg( "focalnet_huge_fl4.ms_in22k": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth', hf_hub_id='timm/',
num_classes=0), num_classes=0),
} })
def checkpoint_filter_fn(state_dict, model: FocalNet): def checkpoint_filter_fn(state_dict, model: FocalNet):
state_dict = state_dict.get('model', state_dict)
if 'stem.proj.weight' in state_dict: if 'stem.proj.weight' in state_dict:
return return state_dict
import re import re
out_dict = {} out_dict = {}
if 'model' in state_dict:
state_dict = state_dict['model']
dest_dict = model.state_dict() dest_dict = model.state_dict()
for k, v in state_dict.items(): for k, v in state_dict.items():
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)

View File

@ -24,7 +24,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License # Copyright 2020 Ross Wightman, Apache-2.0 License
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Dict from typing import Dict
@ -35,9 +34,7 @@ from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_ from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
__all__ = ['Levit'] __all__ = ['Levit']

View File

@ -52,8 +52,7 @@ from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']

View File

@ -22,8 +22,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._registry import register_model
__all__ = ['MobileNetV3', 'MobileNetV3Features'] __all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -796,3 +795,9 @@ def lcnet_150(pretrained=False, **kwargs):
""" PP-LCNet 1.5""" """ PP-LCNet 1.5"""
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs) model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
return model return model
register_model_deprecations(__name__, {
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
})

View File

@ -23,11 +23,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert, ClassifierHead from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply from ._manipulate import checkpoint_seq, named_apply
from ._registry import register_model from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from .vision_transformer import get_init_weights_vit from .vision_transformer import get_init_weights_vit
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this __all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
@ -302,7 +302,12 @@ class PatchMerging(nn.Module):
""" Patch Merging Layer. """ Patch Merging Layer.
""" """
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm): def __init__(
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: Callable = nn.LayerNorm,
):
""" """
Args: Args:
dim: Number of input channels. dim: Number of input channels.
@ -345,13 +350,13 @@ class SwinTransformerStage(nn.Module):
attn_drop: float = 0., attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0., drop_path: Union[List[float], float] = 0.,
norm_layer: Callable = nn.LayerNorm, norm_layer: Callable = nn.LayerNorm,
output_nchw: bool = False,
): ):
""" """
Args: Args:
dim: Number of input channels. dim: Number of input channels.
input_resolution: Input resolution. input_resolution: Input resolution.
depth: Number of blocks. depth: Number of blocks.
downsample: Downsample layer at the end of the layer.
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_dim: Channels per head (dim // num_heads if not set) head_dim: Channels per head (dim // num_heads if not set)
window_size: Local window size. window_size: Local window size.
@ -361,14 +366,12 @@ class SwinTransformerStage(nn.Module):
attn_drop: Attention dropout rate. attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate. drop_path: Stochastic depth rate.
norm_layer: Normalization layer. norm_layer: Normalization layer.
downsample: Downsample layer at the end of the layer.
""" """
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth self.depth = depth
self.use_nchw = output_nchw
self.grad_checkpointing = False self.grad_checkpointing = False
# patch merging layer # patch merging layer
@ -401,18 +404,12 @@ class SwinTransformerStage(nn.Module):
for i in range(depth)]) for i in range(depth)])
def forward(self, x): def forward(self, x):
if self.use_nchw:
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
x = self.downsample(x) x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
x = self.blocks(x) x = self.blocks(x)
if self.use_nchw:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
return x return x
@ -442,7 +439,6 @@ class SwinTransformer(nn.Module):
drop_path_rate: float = 0.1, drop_path_rate: float = 0.1,
norm_layer: Union[str, Callable] = nn.LayerNorm, norm_layer: Union[str, Callable] = nn.LayerNorm,
weight_init: str = '', weight_init: str = '',
output_fmt: str = 'NHWC',
**kwargs, **kwargs,
): ):
""" """
@ -465,15 +461,13 @@ class SwinTransformer(nn.Module):
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg') assert global_pool in ('', 'avg')
assert output_fmt in ('NCHW', 'NHWC')
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.output_fmt = output_fmt self.output_fmt = 'NHWC'
self.num_layers = len(depths) self.num_layers = len(depths)
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.output_nchw = self.output_fmt == 'NCHW' # bool flag for fwd
self.feature_info = [] self.feature_info = []
if not isinstance(embed_dim, (tuple, list)): if not isinstance(embed_dim, (tuple, list)):
@ -518,7 +512,6 @@ class SwinTransformer(nn.Module):
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], drop_path=dpr[i],
norm_layer=norm_layer, norm_layer=norm_layer,
output_nchw=self.output_nchw,
)] )]
in_dim = out_dim in_dim = out_dim
if i > 0: if i > 0:
@ -577,13 +570,7 @@ class SwinTransformer(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.output_nchw:
# patch embed always outputs NHWC, stage layers expect NCHW input
x = x.permute(0, 3, 1, 2)
x = self.layers(x) x = self.layers(x)
if self.output_nchw:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.norm(x) x = self.norm(x)
return x return x
@ -596,14 +583,10 @@ class SwinTransformer(nn.Module):
return x return x
def checkpoint_filter_fn( def checkpoint_filter_fn(state_dict, model):
state_dict,
model,
adapt_layer_scale=False,
interpolation='bicubic',
antialias=True,
):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
if 'head.fc.weight' in state_dict:
return state_dict
import re import re
out_dict = {} out_dict = {}
state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('model', state_dict)
@ -635,106 +618,83 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs 'license': 'mit', **kwargs
} }
default_cfgs = { default_cfgs = generate_default_cfgs({
'swin_base_patch4_window12_384': _cfg( 'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
'swin_base_patch4_window7_224': _cfg( hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
), 'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
'swin_large_patch4_window12_384': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_large_patch4_window7_224': _cfg( 'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', hf_hub_id='timm/',
), url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
'swin_small_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
'swin_base_patch4_window7_224.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
'swin_base_patch4_window12_384.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'swin_small_patch4_window7_224': _cfg( # tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', 'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
), hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
'swin_tiny_patch4_window7_224': _cfg( 'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', hf_hub_id='timm/',
), url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_base_patch4_window12_384_in22k': _cfg( 'swin_small_patch4_window7_224.ms_in22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_base_patch4_window7_224_in22k': _cfg( 'swin_base_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
num_classes=21841), num_classes=21841),
'swin_base_patch4_window12_384.ms_in22k': _cfg(
'swin_large_patch4_window12_384_in22k': _cfg( hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
num_classes=21841),
'swin_large_patch4_window12_384.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
'swin_large_patch4_window7_224_in22k': _cfg( 'swin_s3_tiny_224.ms_in1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', hf_hub_id='timm/',
num_classes=21841), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
'swin_s3_small_224.ms_in1k': _cfg(
'swin_s3_tiny_224': _cfg( hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
), 'swin_s3_base_224.ms_in1k': _cfg(
'swin_s3_small_224': _cfg( hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
), })
'swin_s3_base_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
)
}
@register_model
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-S @ 224x224, trained ImageNet-1k
"""
model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
@ -747,44 +707,53 @@ def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
@register_model @register_model
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs): def swin_small_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 384x384, trained ImageNet-22k """ Swin-S @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs): def swin_base_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-B @ 224x224, trained ImageNet-22k """ Swin-B @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs): def swin_base_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384, trained ImageNet-22k """ Swin-B @ 384x384
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): def swin_large_patch4_window7_224(pretrained=False, **kwargs):
""" Swin-L @ 224x224, trained ImageNet-22k """ Swin-L @ 224x224
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
""" Swin-L @ 384x384
"""
model_kwargs = dict(
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_s3_tiny_224(pretrained=False, **kwargs): def swin_s3_tiny_224(pretrained=False, **kwargs):
""" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725 """ Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
@ -794,7 +763,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
@register_model @register_model
def swin_s3_small_224(pretrained=False, **kwargs): def swin_s3_small_224(pretrained=False, **kwargs):
""" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725 """ Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
@ -804,10 +773,17 @@ def swin_s3_small_224(pretrained=False, **kwargs):
@register_model @register_model
def swin_s3_base_224(pretrained=False, **kwargs): def swin_s3_base_224(pretrained=False, **kwargs):
""" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725 """ Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),
num_heads=(3, 6, 12, 24), **kwargs) num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)
register_model_deprecations(__name__, {
'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
})

View File

@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._registry import register_model from ._registry import generate_default_cfgs, register_model, register_model_deprecations
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
@ -425,9 +425,6 @@ class SwinTransformerV2Stage(nn.Module):
for i in range(depth)]) for i in range(depth)])
def forward(self, x): def forward(self, x):
if self.output_nchw:
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
x = self.downsample(x) x = self.downsample(x)
for blk in self.blocks: for blk in self.blocks:
@ -435,9 +432,6 @@ class SwinTransformerV2Stage(nn.Module):
x = checkpoint.checkpoint(blk, x) x = checkpoint.checkpoint(blk, x)
else: else:
x = blk(x) x = blk(x)
if self.output_nchw:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
return x return x
def _init_respostnorm(self): def _init_respostnorm(self):
@ -473,7 +467,6 @@ class SwinTransformerV2(nn.Module):
drop_path_rate: float = 0.1, drop_path_rate: float = 0.1,
norm_layer: Callable = nn.LayerNorm, norm_layer: Callable = nn.LayerNorm,
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0), pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
output_fmt: str = 'NHWC',
**kwargs, **kwargs,
): ):
""" """
@ -500,13 +493,11 @@ class SwinTransformerV2(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
assert global_pool in ('', 'avg') assert global_pool in ('', 'avg')
assert output_fmt in ('NCHW', 'NHWC')
self.global_pool = global_pool self.global_pool = global_pool
self.output_fmt = output_fmt self.output_fmt = 'NHWC'
self.num_layers = len(depths) self.num_layers = len(depths)
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.output_nchw = self.output_fmt == 'NCHW'
self.feature_info = [] self.feature_info = []
if not isinstance(embed_dim, (tuple, list)): if not isinstance(embed_dim, (tuple, list)):
@ -544,7 +535,6 @@ class SwinTransformerV2(nn.Module):
drop_path=dpr[i], drop_path=dpr[i],
norm_layer=norm_layer, norm_layer=norm_layer,
pretrained_window_size=pretrained_window_sizes[i], pretrained_window_size=pretrained_window_sizes[i],
output_nchw=self.output_nchw,
)] )]
in_dim = out_dim in_dim = out_dim
if i > 0: if i > 0:
@ -605,13 +595,7 @@ class SwinTransformerV2(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.output_nchw:
# patch embed always outputs NHWC, stage layers expect NCHW input if output_nchw = True
x = x.permute(0, 3, 1, 2)
x = self.layers(x) x = self.layers(x)
if self.output_nchw:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.norm(x) x = self.norm(x)
return x return x
@ -625,10 +609,12 @@ class SwinTransformerV2(nn.Module):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
import re
out_dict = {}
state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict) state_dict = state_dict.get('state_dict', state_dict)
if 'head.fc.weight' in state_dict:
return state_dict
out_dict = {}
import re
for k, v in state_dict.items(): for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
continue # skip buffers that should not be persistent continue # skip buffers that should not be persistent
@ -657,53 +643,66 @@ def _cfg(url='', **kwargs):
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs 'license': 'mit', **kwargs
} }
default_cfgs = { default_cfgs = generate_default_cfgs({
'swinv2_tiny_window8_256': _cfg( 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', hf_hub_id='timm/',
),
'swinv2_tiny_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
),
'swinv2_small_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
),
'swinv2_small_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
),
'swinv2_base_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
),
'swinv2_base_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
),
'swinv2_base_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
), ),
'swinv2_base_window12to24_192to384_22kft1k': _cfg( 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
), ),
'swinv2_large_window12_192_22k': _cfg( 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', hf_hub_id='timm/',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
), ),
'swinv2_large_window12to24_192to384_22kft1k': _cfg( 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
), ),
}
'swinv2_tiny_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
),
'swinv2_tiny_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
),
'swinv2_small_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
),
'swinv2_small_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
),
'swinv2_base_window8_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
),
'swinv2_base_window16_256.ms_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
),
'swinv2_base_window12_192.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
'swinv2_large_window12_192.ms_in22k': _cfg(
hf_hub_id='timm/',
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
),
})
@register_model @register_model
@ -761,62 +760,72 @@ def swinv2_base_window8_256(pretrained=False, **kwargs):
@register_model @register_model
def swinv2_base_window12_192_22k(pretrained=False, **kwargs): def swinv2_base_window12_192(pretrained=False, **kwargs):
""" """
""" """
model_kwargs = dict( model_kwargs = dict(
window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2('swinv2_base_window12_192', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs): def swinv2_base_window12to16_192to256(pretrained=False, **kwargs):
""" """
""" """
model_kwargs = dict( model_kwargs = dict(
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs) pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2( return _create_swin_transformer_v2(
'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs) 'swinv2_base_window12to16_192to256', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs): def swinv2_base_window12to24_192to384(pretrained=False, **kwargs):
""" """
""" """
model_kwargs = dict( model_kwargs = dict(
window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs) pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2( return _create_swin_transformer_v2(
'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs) 'swinv2_base_window12to24_192to384', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swinv2_large_window12_192_22k(pretrained=False, **kwargs): def swinv2_large_window12_192(pretrained=False, **kwargs):
""" """
""" """
model_kwargs = dict( model_kwargs = dict(
window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2('swinv2_large_window12_192', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs): def swinv2_large_window12to16_192to256(pretrained=False, **kwargs):
""" """
""" """
model_kwargs = dict( model_kwargs = dict(
window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs) pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2( return _create_swin_transformer_v2(
'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs) 'swinv2_large_window12to16_192to256', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs): def swinv2_large_window12to24_192to384(pretrained=False, **kwargs):
""" """
""" """
model_kwargs = dict( model_kwargs = dict(
window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs) pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2( return _create_swin_transformer_v2(
'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs) 'swinv2_large_window12to24_192to384', pretrained=pretrained, **model_kwargs)
register_model_deprecations(__name__, {
'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k',
'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k',
'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',
'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k',
'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k',
'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
})

View File

@ -41,7 +41,7 @@ from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply from ._manipulate import named_apply
from ._registry import register_model from ._registry import generate_default_cfgs, register_model
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this __all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
@ -748,9 +748,11 @@ def init_weights(module: nn.Module, name: str = ''):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict) state_dict = state_dict.get('state_dict', state_dict)
if 'head.fc.weight' in state_dict:
return state_dict
out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'tau' in k: if 'tau' in k:
# convert old tau based checkpoints -> logit_scale (inverse) # convert old tau based checkpoints -> logit_scale (inverse)
@ -791,43 +793,46 @@ def _cfg(url='', **kwargs):
} }
default_cfgs = { default_cfgs = generate_default_cfgs({
'swinv2_cr_tiny_384': _cfg( 'swinv2_cr_tiny_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224': _cfg( 'swinv2_cr_tiny_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224': _cfg( 'swinv2_cr_tiny_ns_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9), input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384': _cfg( 'swinv2_cr_small_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224': _cfg( 'swinv2_cr_small_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9), input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_224': _cfg( 'swinv2_cr_small_ns_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9), input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_ns_256': _cfg( 'swinv2_cr_small_ns_256.untrained': _cfg(
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)), url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
'swinv2_cr_base_384': _cfg( 'swinv2_cr_base_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg( 'swinv2_cr_base_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224': _cfg( 'swinv2_cr_base_ns_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384': _cfg( 'swinv2_cr_large_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224': _cfg( 'swinv2_cr_large_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384': _cfg( 'swinv2_cr_huge_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224': _cfg( 'swinv2_cr_huge_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg( 'swinv2_cr_giant_384.untrained': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg( 'swinv2_cr_giant_224.untrained': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
} })
@register_model @register_model

View File

@ -41,9 +41,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_,
resample_abs_pos_embed, RmsNorm resample_abs_pos_embed, RmsNorm
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this

View File

@ -20,8 +20,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, create_resnetv2_stem from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer from .vision_transformer import _create_vision_transformer

View File

@ -17,8 +17,7 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs from ._registry import generate_default_cfgs, register_model
from ._registry import register_model
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this