Swin and FocalNet weights on HF hub. Add model deprecation functionality w/ some registry tweaks.
parent
2fc5ac3d18
commit
572f05096a
|
@ -74,12 +74,12 @@ from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet
|
|||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
||||
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
||||
register_notrace_function, is_notrace_function, get_notrace_functions
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
||||
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
|
||||
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
|
||||
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, \
|
||||
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
|
||||
from ._prune import adapt_model_from_string
|
||||
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
|
||||
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
|
||||
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
|
||||
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
|
|
@ -3,10 +3,10 @@ from typing import Any, Dict, Optional, Union
|
|||
from urllib.parse import urlsplit
|
||||
|
||||
from timm.layers import set_layer_config
|
||||
from ._pretrained import PretrainedCfg, split_model_name_tag
|
||||
from ._helpers import load_checkpoint
|
||||
from ._hub import load_model_config_from_hf
|
||||
from ._registry import is_model, model_entrypoint
|
||||
from ._pretrained import PretrainedCfg
|
||||
from ._registry import is_model, model_entrypoint, split_model_name_tag
|
||||
|
||||
|
||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||
|
|
|
@ -5,6 +5,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
try:
|
||||
|
@ -13,30 +14,32 @@ try:
|
|||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
cleaned_state_dict = OrderedDict()
|
||||
cleaned_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k
|
||||
cleaned_state_dict[name] = v
|
||||
return cleaned_state_dict
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=True):
|
||||
def load_state_dict(
|
||||
checkpoint_path: str,
|
||||
use_ema: bool = True,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
) -> Dict[str, Any]:
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
|
@ -56,22 +59,37 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
def load_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
checkpoint_path: str,
|
||||
use_ema: bool = True,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
strict: bool = True,
|
||||
remap: bool = False,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
timm.models._model_builder.load_pretrained(checkpoint_path)
|
||||
model.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema, device=device)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
state_dict = remap_state_dict(state_dict, model)
|
||||
elif filter_fn:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
def remap_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
model: torch.nn.Module,
|
||||
allow_reshape: bool = True
|
||||
):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
|
@ -87,7 +105,13 @@ def remap_checkpoint(model, state_dict, allow_reshape=True):
|
|||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
def resume_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
checkpoint_path: str,
|
||||
optimizer: torch.optim.Optimizer = None,
|
||||
loss_scaler: Any = None,
|
||||
log_info: bool = True,
|
||||
):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Callable, Union, Dict
|
||||
from typing import Any, Callable, Dict, Iterator, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
@ -13,7 +13,7 @@ __all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_wi
|
|||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
def model_parameters(model: nn.Module, exclude_head: bool = False):
|
||||
if exclude_head:
|
||||
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
||||
return [p for p in model.parameters()][:-2]
|
||||
|
@ -21,7 +21,12 @@ def model_parameters(model, exclude_head=False):
|
|||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
def named_apply(
|
||||
fn: Callable,
|
||||
module: nn.Module, name='',
|
||||
depth_first: bool = True,
|
||||
include_root: bool = False,
|
||||
) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
|
@ -32,7 +37,12 @@ def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, incl
|
|||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
def named_modules(
|
||||
module: nn.Module,
|
||||
name: str = '',
|
||||
depth_first: bool = True,
|
||||
include_root: bool = False,
|
||||
):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
|
@ -43,7 +53,12 @@ def named_modules(module: nn.Module, name='', depth_first=True, include_root=Fal
|
|||
yield name, module
|
||||
|
||||
|
||||
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
def named_modules_with_params(
|
||||
module: nn.Module,
|
||||
name: str = '',
|
||||
depth_first: bool = True,
|
||||
include_root: bool = False,
|
||||
):
|
||||
if module._parameters and not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
|
@ -58,9 +73,9 @@ MATCH_PREV_GROUP = (99999,)
|
|||
|
||||
|
||||
def group_with_matcher(
|
||||
named_objects,
|
||||
named_objects: Iterator[Tuple[str, Any]],
|
||||
group_matcher: Union[Dict, Callable],
|
||||
output_values: bool = False,
|
||||
return_values: bool = False,
|
||||
reverse: bool = False
|
||||
):
|
||||
if isinstance(group_matcher, dict):
|
||||
|
@ -96,7 +111,7 @@ def group_with_matcher(
|
|||
# map layers into groups via ordinals (ints or tuples of ints) from matcher
|
||||
grouping = defaultdict(list)
|
||||
for k, v in named_objects:
|
||||
grouping[_get_grouping(k)].append(v if output_values else k)
|
||||
grouping[_get_grouping(k)].append(v if return_values else k)
|
||||
|
||||
# remap to integers
|
||||
layer_id_to_param = defaultdict(list)
|
||||
|
@ -107,7 +122,7 @@ def group_with_matcher(
|
|||
layer_id_to_param[lid].extend(grouping[k])
|
||||
|
||||
if reverse:
|
||||
assert not output_values, "reverse mapping only sensible for name output"
|
||||
assert not return_values, "reverse mapping only sensible for name output"
|
||||
# output reverse mapping
|
||||
param_to_layer_id = {}
|
||||
for lid, lm in layer_id_to_param.items():
|
||||
|
@ -121,24 +136,29 @@ def group_with_matcher(
|
|||
def group_parameters(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
return_values: bool = False,
|
||||
reverse: bool = False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
|
||||
module.named_parameters(), group_matcher, return_values=return_values, reverse=reverse)
|
||||
|
||||
|
||||
def group_modules(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
return_values: bool = False,
|
||||
reverse: bool = False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
|
||||
named_modules_with_params(module), group_matcher, return_values=return_values, reverse=reverse)
|
||||
|
||||
|
||||
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
|
||||
def flatten_modules(
|
||||
named_modules: Iterator[Tuple[str, nn.Module]],
|
||||
depth: int = 1,
|
||||
prefix: Union[str, Tuple[str, ...]] = '',
|
||||
module_types: Union[str, Tuple[Type[nn.Module]]] = 'sequential',
|
||||
):
|
||||
prefix_is_tuple = isinstance(prefix, tuple)
|
||||
if isinstance(module_types, str):
|
||||
if module_types == 'container':
|
||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, field, replace, asdict
|
|||
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
||||
|
||||
|
||||
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
|
||||
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -91,41 +91,3 @@ class DefaultCfg:
|
|||
def default_with_tag(self):
|
||||
tag = self.tags[0]
|
||||
return tag, self.cfgs[tag]
|
||||
|
||||
|
||||
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||
model_name, *tag_list = model_name.split('.', 1)
|
||||
tag = tag_list[0] if tag_list else no_tag
|
||||
return model_name, tag
|
||||
|
||||
|
||||
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
out = defaultdict(DefaultCfg)
|
||||
default_set = set() # no tag and tags ending with * are prioritized as default
|
||||
|
||||
for k, v in cfgs.items():
|
||||
if isinstance(v, dict):
|
||||
v = PretrainedCfg(**v)
|
||||
has_weights = v.has_weights
|
||||
|
||||
model, tag = split_model_name_tag(k)
|
||||
is_default_set = model in default_set
|
||||
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
|
||||
tag = tag.strip('*')
|
||||
|
||||
default_cfg = out[model]
|
||||
|
||||
if priority:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
default_set.add(model)
|
||||
elif has_weights and not default_cfg.is_pretrained:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
else:
|
||||
default_cfg.tags.append(tag)
|
||||
|
||||
if has_weights:
|
||||
default_cfg.is_pretrained = True
|
||||
|
||||
default_cfg.cfgs[tag] = v
|
||||
|
||||
return out
|
||||
|
|
|
@ -5,16 +5,19 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||
import fnmatch
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg
|
||||
|
||||
__all__ = [
|
||||
'split_model_name_tag', 'get_arch_name', 'register_model', 'generate_default_cfgs',
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained'
|
||||
]
|
||||
|
||||
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||
|
@ -23,12 +26,52 @@ _model_has_pretrained: Set[str] = set() # set of model names that have pretrain
|
|||
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
_module_to_deprecated_models: Dict[str, Dict[str, Optional[str]]] = defaultdict(dict)
|
||||
_deprecated_models: Dict[str, Optional[str]] = {}
|
||||
|
||||
|
||||
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||
model_name, *tag_list = model_name.split('.', 1)
|
||||
tag = tag_list[0] if tag_list else no_tag
|
||||
return model_name, tag
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> str:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
||||
out = defaultdict(DefaultCfg)
|
||||
default_set = set() # no tag and tags ending with * are prioritized as default
|
||||
|
||||
for k, v in cfgs.items():
|
||||
if isinstance(v, dict):
|
||||
v = PretrainedCfg(**v)
|
||||
has_weights = v.has_weights
|
||||
|
||||
model, tag = split_model_name_tag(k)
|
||||
is_default_set = model in default_set
|
||||
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
|
||||
tag = tag.strip('*')
|
||||
|
||||
default_cfg = out[model]
|
||||
|
||||
if priority:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
default_set.add(model)
|
||||
elif has_weights and not default_cfg.is_pretrained:
|
||||
default_cfg.tags.appendleft(tag)
|
||||
else:
|
||||
default_cfg.tags.append(tag)
|
||||
|
||||
if has_weights:
|
||||
default_cfg.is_pretrained = True
|
||||
|
||||
default_cfg.cfgs[tag] = v
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
|
@ -87,6 +130,37 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|||
return fn
|
||||
|
||||
|
||||
def _deprecated_model_shim(deprecated_name: str, current_fn: Callable = None, current_tag: str = ''):
|
||||
def _fn(pretrained=False, **kwargs):
|
||||
assert current_fn is not None, f'Model {deprecated_name} has been removed with no replacement.'
|
||||
warnings.warn(f'Mapping deprecated model {deprecated_name} to current {current_fn.__name__}', stacklevel=2)
|
||||
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
|
||||
return current_fn(pretrained=pretrained, pretrained_cfg=pretrained_cfg or current_tag, **kwargs)
|
||||
return _fn
|
||||
|
||||
|
||||
def register_model_deprecations(module_name: str, deprecation_map: Dict[str, Optional[str]]):
|
||||
mod = sys.modules[module_name]
|
||||
module_name_split = module_name.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
|
||||
for deprecated, current in deprecation_map.items():
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(deprecated)
|
||||
current_fn = None
|
||||
current_tag = ''
|
||||
if current:
|
||||
current_name, current_tag = split_model_name_tag(current)
|
||||
current_fn = getattr(mod, current_name)
|
||||
deprecated_entrypoint_fn = _deprecated_model_shim(deprecated, current_fn, current_tag)
|
||||
setattr(mod, deprecated, deprecated_entrypoint_fn)
|
||||
_model_entrypoints[deprecated] = deprecated_entrypoint_fn
|
||||
_model_to_module[deprecated] = module_name
|
||||
_module_to_models[module_name].add(deprecated)
|
||||
_deprecated_models[deprecated] = current
|
||||
_module_to_deprecated_models[module_name][deprecated] = current
|
||||
|
||||
|
||||
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
@ -122,16 +196,14 @@ def list_models(
|
|||
# FIXME should this be default behaviour? or default to include_tags=True?
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models: Iterable[str] = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
|
||||
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
|
||||
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
models_with_tags: Set[str] = set()
|
||||
for m in all_models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
models_with_tags.update(_model_with_tags[m])
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
|
@ -142,7 +214,7 @@ def list_models(
|
|||
if len(include_models):
|
||||
models = models.union(include_models)
|
||||
else:
|
||||
models = set(all_models)
|
||||
models = all_models
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
|
@ -173,6 +245,11 @@ def list_pretrained(
|
|||
)
|
||||
|
||||
|
||||
def get_deprecated_models(module: str = '') -> Dict[str, str]:
|
||||
all_deprecated = _module_to_deprecated_models[module] if module else _deprecated_models
|
||||
return deepcopy(all_deprecated)
|
||||
|
||||
|
||||
def is_model(model_name: str) -> bool:
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
|
|
|
@ -63,8 +63,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
||||
__all__ = ['Beit']
|
||||
|
|
|
@ -50,8 +50,7 @@ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalRespo
|
|||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -519,6 +518,13 @@ def _cfgv2(url='', **kwargs):
|
|||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
# timm specific variants
|
||||
'convnext_tiny.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_atto.d2_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -558,12 +564,6 @@ default_cfgs = generate_default_cfgs({
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_tiny.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.in12k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.in12k_ft_in1k_384': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
|
@ -582,25 +582,6 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='timm/',
|
||||
crop_pct=0.95, num_classes=11821),
|
||||
|
||||
'convnext_tiny.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_base.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_large.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_xlarge.untrained': _cfg(),
|
||||
'convnext_xxlarge.untrained': _cfg(),
|
||||
|
||||
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -622,6 +603,23 @@ default_cfgs = generate_default_cfgs({
|
|||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_small.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_base.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'convnext_large.fb_in1k': _cfg(
|
||||
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
|
||||
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -1038,3 +1036,22 @@ def convnextv2_huge(pretrained=False, **kwargs):
|
|||
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
|
||||
'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
|
||||
'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
|
||||
'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
|
||||
'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
|
||||
'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
|
||||
'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
|
||||
'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
|
||||
'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
|
||||
'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
|
||||
'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
|
||||
'convnext_small_in22k': 'convnext_small.fb_in22k',
|
||||
'convnext_base_in22k': 'convnext_base.fb_in22k',
|
||||
'convnext_large_in22k': 'convnext_large.fb_in22k',
|
||||
'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
|
||||
})
|
||||
|
|
|
@ -11,8 +11,6 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|||
# Copyright (c) 2022 Mingyu Ding
|
||||
# All rights reserved.
|
||||
# This source code is licensed under the MIT license
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -22,13 +20,12 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['DaViT']
|
||||
|
||||
|
|
|
@ -21,8 +21,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
|
|
@ -26,8 +26,7 @@ from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_nor
|
|||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
|
||||
EfficientFormer_width = {
|
||||
|
|
|
@ -51,8 +51,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
|
|||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['EfficientNet', 'EfficientNetFeatures']
|
||||
|
||||
|
@ -1064,42 +1063,46 @@ default_cfgs = generate_default_cfgs({
|
|||
'efficientnetv2_xl.untrained': _cfg(
|
||||
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'tf_efficientnet_b0.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224)),
|
||||
'tf_efficientnet_b1.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
|
||||
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'tf_efficientnet_b2.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
|
||||
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'tf_efficientnet_b3.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
|
||||
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'tf_efficientnet_b4.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
||||
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
||||
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b6.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
||||
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'tf_efficientnet_b7.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
||||
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_b8.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
||||
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||
|
||||
'tf_efficientnet_b0.ap_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
|
||||
|
@ -1146,46 +1149,42 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_b0.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
|
||||
'tf_efficientnet_b0.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 224, 224)),
|
||||
'tf_efficientnet_b1.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
|
||||
'tf_efficientnet_b1.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
|
||||
'tf_efficientnet_b2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
|
||||
'tf_efficientnet_b2.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
|
||||
'tf_efficientnet_b3.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
|
||||
'tf_efficientnet_b3.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
'tf_efficientnet_b4.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
|
||||
'tf_efficientnet_b4.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
|
||||
'tf_efficientnet_b5.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
|
||||
'tf_efficientnet_b5.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
|
||||
'tf_efficientnet_b6.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
|
||||
'tf_efficientnet_b6.aa_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
|
||||
'tf_efficientnet_b7.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
|
||||
'tf_efficientnet_b7.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
|
||||
'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
|
||||
'tf_efficientnet_b8.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
|
||||
'tf_efficientnet_l2.ns_jft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
|
||||
input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
|
||||
|
||||
'tf_efficientnet_es.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
|
||||
|
@ -1248,22 +1247,6 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
|
||||
|
||||
'tf_efficientnetv2_s.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'tf_efficientnetv2_l.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'tf_efficientnetv2_s.in21k_ft_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -1285,6 +1268,22 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'tf_efficientnetv2_s.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
|
||||
'tf_efficientnetv2_m.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
'tf_efficientnetv2_l.in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
|
||||
|
||||
'tf_efficientnetv2_s.in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
|
||||
hf_hub_id='timm/',
|
||||
|
@ -2289,3 +2288,34 @@ def tinynet_d(pretrained=False, **kwargs):
|
|||
def tinynet_e(pretrained=False, **kwargs):
|
||||
model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
|
||||
'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
|
||||
'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k',
|
||||
'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k',
|
||||
'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k',
|
||||
'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k',
|
||||
'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k',
|
||||
'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k',
|
||||
'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k',
|
||||
'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k',
|
||||
'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k',
|
||||
'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k',
|
||||
'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k',
|
||||
'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k',
|
||||
'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k',
|
||||
'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k',
|
||||
'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k',
|
||||
'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475',
|
||||
'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k',
|
||||
'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k',
|
||||
'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k',
|
||||
'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k',
|
||||
'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k',
|
||||
'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k',
|
||||
})
|
||||
|
|
|
@ -28,7 +28,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['FocalNet']
|
||||
|
||||
|
@ -485,51 +485,51 @@ def _cfg(url='', **kwargs):
|
|||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
"focalnet_tiny_srf": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_srf.pth'),
|
||||
"focalnet_small_srf": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth'),
|
||||
"focalnet_base_srf": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth'),
|
||||
"focalnet_tiny_lrf": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_lrf.pth'),
|
||||
"focalnet_small_lrf": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth'),
|
||||
"focalnet_base_lrf": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'),
|
||||
"focalnet_large_fl3": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth',
|
||||
default_cfgs = generate_default_cfgs({
|
||||
"focalnet_tiny_srf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_small_srf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_base_srf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_tiny_lrf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_small_lrf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
"focalnet_base_lrf.ms_in1k": _cfg(
|
||||
hf_hub_id='timm/'),
|
||||
|
||||
"focalnet_large_fl3.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_large_fl4": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth',
|
||||
"focalnet_large_fl4.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_xlarge_fl3": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth',
|
||||
"focalnet_xlarge_fl3.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_xlarge_fl4": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth',
|
||||
"focalnet_xlarge_fl4.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21842),
|
||||
"focalnet_huge_fl3": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth',
|
||||
"focalnet_huge_fl3.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=21842),
|
||||
"focalnet_huge_fl4": _cfg(
|
||||
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth',
|
||||
"focalnet_huge_fl4.ms_in22k": _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model: FocalNet):
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
if 'stem.proj.weight' in state_dict:
|
||||
return
|
||||
return state_dict
|
||||
import re
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
dest_dict = model.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||
|
|
|
@ -24,7 +24,6 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
|||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
|
@ -35,9 +34,7 @@ from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
|||
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['Levit']
|
||||
|
||||
|
|
|
@ -52,8 +52,7 @@ from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
|
|||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
|
||||
|
||||
|
|
|
@ -22,8 +22,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
|
|||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||
|
||||
|
@ -796,3 +795,9 @@ def lcnet_150(pretrained=False, **kwargs):
|
|||
""" PP-LCNet 1.5"""
|
||||
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
|
||||
'mobilenetv3_large_100_miil_in21k': 'mobilenetv3_large_100.miil_in21k',
|
||||
})
|
||||
|
|
|
@ -23,11 +23,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert, ClassifierHead
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from .vision_transformer import get_init_weights_vit
|
||||
|
||||
__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
|
||||
|
@ -302,7 +302,12 @@ class PatchMerging(nn.Module):
|
|||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: Callable = nn.LayerNorm):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
out_dim: Optional[int] = None,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
|
@ -345,13 +350,13 @@ class SwinTransformerStage(nn.Module):
|
|||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
output_nchw: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
input_resolution: Input resolution.
|
||||
depth: Number of blocks.
|
||||
downsample: Downsample layer at the end of the layer.
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Channels per head (dim // num_heads if not set)
|
||||
window_size: Local window size.
|
||||
|
@ -361,14 +366,12 @@ class SwinTransformerStage(nn.Module):
|
|||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
downsample: Downsample layer at the end of the layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||
self.depth = depth
|
||||
self.use_nchw = output_nchw
|
||||
self.grad_checkpointing = False
|
||||
|
||||
# patch merging layer
|
||||
|
@ -401,18 +404,12 @@ class SwinTransformerStage(nn.Module):
|
|||
for i in range(depth)])
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_nchw:
|
||||
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
|
||||
if self.use_nchw:
|
||||
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
return x
|
||||
|
||||
|
||||
|
@ -442,7 +439,6 @@ class SwinTransformer(nn.Module):
|
|||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
||||
weight_init: str = '',
|
||||
output_fmt: str = 'NHWC',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -465,15 +461,13 @@ class SwinTransformer(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg')
|
||||
assert output_fmt in ('NCHW', 'NHWC')
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.output_fmt = output_fmt
|
||||
self.output_fmt = 'NHWC'
|
||||
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.output_nchw = self.output_fmt == 'NCHW' # bool flag for fwd
|
||||
self.feature_info = []
|
||||
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
|
@ -518,7 +512,6 @@ class SwinTransformer(nn.Module):
|
|||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
output_nchw=self.output_nchw,
|
||||
)]
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
|
@ -577,14 +570,8 @@ class SwinTransformer(nn.Module):
|
|||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.output_nchw:
|
||||
# patch embed always outputs NHWC, stage layers expect NCHW input
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
if self.output_nchw:
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
|
@ -596,14 +583,10 @@ class SwinTransformer(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict,
|
||||
model,
|
||||
adapt_layer_scale=False,
|
||||
interpolation='bicubic',
|
||||
antialias=True,
|
||||
):
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
import re
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
|
@ -635,106 +618,83 @@ def _cfg(url='', **kwargs):
|
|||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swin_base_patch4_window12_384': _cfg(
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
|
||||
'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
|
||||
'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'swin_base_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
|
||||
'swin_large_patch4_window12_384': _cfg(
|
||||
'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
|
||||
'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'swin_large_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
|
||||
),
|
||||
'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
|
||||
'swin_small_patch4_window7_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
|
||||
'swin_base_patch4_window7_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
|
||||
'swin_base_patch4_window12_384.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
'swin_small_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
|
||||
),
|
||||
# tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
|
||||
'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
|
||||
|
||||
'swin_tiny_patch4_window7_224': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
|
||||
),
|
||||
|
||||
'swin_base_patch4_window12_384_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_base_patch4_window7_224_in22k': _cfg(
|
||||
'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_small_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_base_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window12_384_in22k': _cfg(
|
||||
'swin_base_patch4_window12_384.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
'swin_large_patch4_window7_224.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
'swin_large_patch4_window12_384.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
|
||||
|
||||
'swin_large_patch4_window7_224_in22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
||||
num_classes=21841),
|
||||
|
||||
'swin_s3_tiny_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'
|
||||
),
|
||||
'swin_s3_small_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'
|
||||
),
|
||||
'swin_s3_base_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-S @ 224x224, trained ImageNet-1k
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
'swin_s3_tiny_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
|
||||
'swin_s3_small_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
|
||||
'swin_s3_base_224.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -747,44 +707,53 @@ def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384, trained ImageNet-22k
|
||||
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-S @ 224x224
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
|
||||
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 224x224, trained ImageNet-22k
|
||||
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 224x224
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 384x384, trained ImageNet-22k
|
||||
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-B @ 384x384
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
|
||||
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 224x224, trained ImageNet-22k
|
||||
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 224x224
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
|
||||
""" Swin-L @ 384x384
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swin_s3_tiny_224(pretrained=False, **kwargs):
|
||||
""" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725
|
||||
""" Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
|
||||
|
@ -794,7 +763,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def swin_s3_small_224(pretrained=False, **kwargs):
|
||||
""" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
|
||||
""" Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
|
||||
|
@ -804,10 +773,17 @@ def swin_s3_small_224(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def swin_s3_base_224(pretrained=False, **kwargs):
|
||||
""" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
|
||||
""" Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),
|
||||
num_heads=(3, 6, 12, 24), **kwargs)
|
||||
return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
|
||||
'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
|
||||
'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
|
||||
'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
|
||||
})
|
||||
|
|
|
@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -425,9 +425,6 @@ class SwinTransformerV2Stage(nn.Module):
|
|||
for i in range(depth)])
|
||||
|
||||
def forward(self, x):
|
||||
if self.output_nchw:
|
||||
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
|
@ -435,9 +432,6 @@ class SwinTransformerV2Stage(nn.Module):
|
|||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
if self.output_nchw:
|
||||
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
|
||||
return x
|
||||
|
||||
def _init_respostnorm(self):
|
||||
|
@ -473,7 +467,6 @@ class SwinTransformerV2(nn.Module):
|
|||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0, 0),
|
||||
output_fmt: str = 'NHWC',
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -500,13 +493,11 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
self.num_classes = num_classes
|
||||
assert global_pool in ('', 'avg')
|
||||
assert output_fmt in ('NCHW', 'NHWC')
|
||||
self.global_pool = global_pool
|
||||
self.output_fmt = output_fmt
|
||||
self.output_fmt = 'NHWC'
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
||||
self.output_nchw = self.output_fmt == 'NCHW'
|
||||
self.feature_info = []
|
||||
|
||||
if not isinstance(embed_dim, (tuple, list)):
|
||||
|
@ -544,7 +535,6 @@ class SwinTransformerV2(nn.Module):
|
|||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_sizes[i],
|
||||
output_nchw=self.output_nchw,
|
||||
)]
|
||||
in_dim = out_dim
|
||||
if i > 0:
|
||||
|
@ -605,14 +595,8 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.output_nchw:
|
||||
# patch embed always outputs NHWC, stage layers expect NCHW input if output_nchw = True
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
if self.output_nchw:
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
|
@ -625,10 +609,12 @@ class SwinTransformerV2(nn.Module):
|
|||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
import re
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
out_dict = {}
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
|
||||
continue # skip buffers that should not be persistent
|
||||
|
@ -657,53 +643,66 @@ def _cfg(url='', **kwargs):
|
|||
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
'license': 'mit', **kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_tiny_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_tiny_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_small_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_small_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_base_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_base_window16_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
|
||||
),
|
||||
|
||||
'swinv2_base_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
),
|
||||
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
|
||||
'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
'swinv2_large_window12_192_22k': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
|
||||
'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
|
||||
),
|
||||
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
|
||||
'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
|
||||
),
|
||||
}
|
||||
|
||||
'swinv2_tiny_window8_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_tiny_window16_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_small_window8_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_small_window16_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
|
||||
),
|
||||
'swinv2_base_window8_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
|
||||
),
|
||||
'swinv2_base_window16_256.ms_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
|
||||
),
|
||||
|
||||
'swinv2_base_window12_192.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
'swinv2_large_window12_192.ms_in22k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
|
||||
num_classes=21841, input_size=(3, 192, 192), pool_size=(6, 6)
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -761,62 +760,72 @@ def swinv2_base_window8_256(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def swinv2_base_window12_192_22k(pretrained=False, **kwargs):
|
||||
def swinv2_base_window12_192(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_base_window12_192', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_base_window12to16_192to256(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_base_window12to16_192to256', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_base_window12to24_192to384(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_base_window12to24_192to384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_large_window12_192_22k(pretrained=False, **kwargs):
|
||||
def swinv2_large_window12_192(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer_v2('swinv2_large_window12_192', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_large_window12to16_192to256(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_large_window12to16_192to256', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
|
||||
def swinv2_large_window12to24_192to384(pretrained=False, **kwargs):
|
||||
"""
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
|
||||
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
|
||||
return _create_swin_transformer_v2(
|
||||
'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
|
||||
'swinv2_large_window12to24_192to384', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'swinv2_base_window12_192_22k': 'swinv2_base_window12_192.ms_in22k',
|
||||
'swinv2_base_window12to16_192to256_22kft1k': 'swinv2_base_window12to16_192to256.ms_in22k_ft_in1k',
|
||||
'swinv2_base_window12to24_192to384_22kft1k': 'swinv2_base_window12to24_192to384.ms_in22k_ft_in1k',
|
||||
'swinv2_large_window12_192_22k': 'swinv2_large_window12_192.ms_in22k',
|
||||
'swinv2_large_window12to16_192to256_22kft1k': 'swinv2_large_window12to16_192to256.ms_in22k_ft_in1k',
|
||||
'swinv2_large_window12to24_192to384_22kft1k': 'swinv2_large_window12to24_192to384.ms_in22k_ft_in1k',
|
||||
})
|
||||
|
|
|
@ -41,7 +41,7 @@ from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
|
|||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -748,9 +748,11 @@ def init_weights(module: nn.Module, name: str = ''):
|
|||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'tau' in k:
|
||||
# convert old tau based checkpoints -> logit_scale (inverse)
|
||||
|
@ -791,43 +793,46 @@ def _cfg(url='', **kwargs):
|
|||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_cr_tiny_384': _cfg(
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'swinv2_cr_tiny_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_tiny_224': _cfg(
|
||||
'swinv2_cr_tiny_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_tiny_ns_224': _cfg(
|
||||
'swinv2_cr_tiny_ns_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_384': _cfg(
|
||||
'swinv2_cr_small_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_small_224': _cfg(
|
||||
'swinv2_cr_small_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_224': _cfg(
|
||||
'swinv2_cr_small_ns_224.sw_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_ns_256': _cfg(
|
||||
'swinv2_cr_small_ns_256.untrained': _cfg(
|
||||
url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
|
||||
'swinv2_cr_base_384': _cfg(
|
||||
'swinv2_cr_base_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_base_224': _cfg(
|
||||
'swinv2_cr_base_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_ns_224': _cfg(
|
||||
'swinv2_cr_base_ns_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_large_384': _cfg(
|
||||
'swinv2_cr_large_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_large_224': _cfg(
|
||||
'swinv2_cr_large_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_huge_384': _cfg(
|
||||
'swinv2_cr_huge_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_huge_224': _cfg(
|
||||
'swinv2_cr_huge_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_giant_384': _cfg(
|
||||
'swinv2_cr_giant_384.untrained': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_giant_224': _cfg(
|
||||
'swinv2_cr_giant_224.untrained': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
|
|
@ -41,9 +41,7 @@ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_,
|
|||
resample_abs_pos_embed, RmsNorm
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
|
|
@ -20,8 +20,7 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||
from .vision_transformer import _create_vision_transformer
|
||||
|
|
|
@ -17,8 +17,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
|
Loading…
Reference in New Issue