mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add type annotations to _registry.py
Description Add type annotations to _registry.py so that they will pass mypy --strict. Comment I was reading the code and felt that this module would be easier to understand with type annotations. Therefore, I went ahead and added the annotations. The idea with this PR is to start small to see if we can align on _how_ to annotate types. I've seen people in the past disagree on how strictly to annotate the code base, so before spending too much time on this, I wanted to check if you agree, Ross. Most of the added types should be straightforward. Some notes on the non-trivial changes: - I made no assumption about the fn passed to register_model, but maybe the type could be stricter. Are all models nn.Modules? - If I'm not mistaken, the type hint for get_arch_name was incorrect - I had to add a # type: ignore to model.__all__ = ... - I made some minor code changes to list_models to facilitate the typing. I think the changes should not affect the logic of the function. - I removed list from list(sorted(...)) because sorted returns always a list.
This commit is contained in:
parent
c9406ce608
commit
a5b01ec04e
@ -93,7 +93,7 @@ class DefaultCfg:
|
|||||||
return tag, self.cfgs[tag]
|
return tag, self.cfgs[tag]
|
||||||
|
|
||||||
|
|
||||||
def split_model_name_tag(model_name: str, no_tag=''):
|
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||||
model_name, *tag_list = model_name.split('.', 1)
|
model_name, *tag_list = model_name.split('.', 1)
|
||||||
tag = tag_list[0] if tag_list else no_tag
|
tag = tag_list[0] if tag_list else no_tag
|
||||||
return model_name, tag
|
return model_name, tag
|
||||||
|
@ -8,7 +8,7 @@ import sys
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import List, Optional, Union, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||||
|
|
||||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||||
|
|
||||||
@ -16,20 +16,20 @@ __all__ = [
|
|||||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||||
|
|
||||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||||
_model_to_module = {} # mapping of model names to module names
|
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
|
||||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
|
||||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||||
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
|
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||||
|
|
||||||
|
|
||||||
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
|
def get_arch_name(model_name: str) -> str:
|
||||||
return split_model_name_tag(model_name)[0]
|
return split_model_name_tag(model_name)[0]
|
||||||
|
|
||||||
|
|
||||||
def register_model(fn):
|
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
# lookup containing module
|
# lookup containing module
|
||||||
mod = sys.modules[fn.__module__]
|
mod = sys.modules[fn.__module__]
|
||||||
module_name_split = fn.__module__.split('.')
|
module_name_split = fn.__module__.split('.')
|
||||||
@ -40,7 +40,7 @@ def register_model(fn):
|
|||||||
if hasattr(mod, '__all__'):
|
if hasattr(mod, '__all__'):
|
||||||
mod.__all__.append(model_name)
|
mod.__all__.append(model_name)
|
||||||
else:
|
else:
|
||||||
mod.__all__ = [model_name]
|
mod.__all__ = [model_name] # type: ignore
|
||||||
|
|
||||||
# add entries to registry dict/sets
|
# add entries to registry dict/sets
|
||||||
_model_entrypoints[model_name] = fn
|
_model_entrypoints[model_name] = fn
|
||||||
@ -87,28 +87,33 @@ def register_model(fn):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def _natural_key(string_):
|
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||||
|
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||||
|
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
filter: Union[str, List[str]] = '',
|
filter: Union[str, List[str]] = '',
|
||||||
module: str = '',
|
module: str = '',
|
||||||
pretrained=False,
|
pretrained: bool = False,
|
||||||
exclude_filters: str = '',
|
exclude_filters: Union[str, List[str]] = '',
|
||||||
name_matches_cfg: bool = False,
|
name_matches_cfg: bool = False,
|
||||||
include_tags: Optional[bool] = None,
|
include_tags: Optional[bool] = None,
|
||||||
):
|
) -> List[str]:
|
||||||
""" Return list of available model names, sorted alphabetically
|
""" Return list of available model names, sorted alphabetically
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filter (str) - Wildcard filter string that works with fnmatch
|
filter - Wildcard filter string that works with fnmatch
|
||||||
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
|
module - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||||
pretrained (bool) - Include only models with valid pretrained weights if True
|
pretrained - Include only models with valid pretrained weights if True
|
||||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
exclude_filters - Wildcard filters to exclude models after including them with filter
|
||||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||||
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
|
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
|
||||||
set to True when pretrained=True else False (default: None)
|
set to True when pretrained=True else False (default: None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
models - The sorted list of models
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||||
@ -118,7 +123,7 @@ def list_models(
|
|||||||
include_tags = pretrained
|
include_tags = pretrained
|
||||||
|
|
||||||
if module:
|
if module:
|
||||||
all_models = list(_module_to_models[module])
|
all_models: Iterable[str] = list(_module_to_models[module])
|
||||||
else:
|
else:
|
||||||
all_models = _model_entrypoints.keys()
|
all_models = _model_entrypoints.keys()
|
||||||
|
|
||||||
@ -130,14 +135,14 @@ def list_models(
|
|||||||
all_models = models_with_tags
|
all_models = models_with_tags
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
models = []
|
models: Set[str] = set()
|
||||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||||
for f in include_filters:
|
for f in include_filters:
|
||||||
include_models = fnmatch.filter(all_models, f) # include these models
|
include_models = fnmatch.filter(all_models, f) # include these models
|
||||||
if len(include_models):
|
if len(include_models):
|
||||||
models = set(models).union(include_models)
|
models = models.union(include_models)
|
||||||
else:
|
else:
|
||||||
models = all_models
|
models = set(all_models)
|
||||||
|
|
||||||
if exclude_filters:
|
if exclude_filters:
|
||||||
if not isinstance(exclude_filters, (tuple, list)):
|
if not isinstance(exclude_filters, (tuple, list)):
|
||||||
@ -145,7 +150,7 @@ def list_models(
|
|||||||
for xf in exclude_filters:
|
for xf in exclude_filters:
|
||||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||||
if len(exclude_models):
|
if len(exclude_models):
|
||||||
models = set(models).difference(exclude_models)
|
models = models.difference(exclude_models)
|
||||||
|
|
||||||
if pretrained:
|
if pretrained:
|
||||||
models = _model_has_pretrained.intersection(models)
|
models = _model_has_pretrained.intersection(models)
|
||||||
@ -153,13 +158,13 @@ def list_models(
|
|||||||
if name_matches_cfg:
|
if name_matches_cfg:
|
||||||
models = set(_model_pretrained_cfgs).intersection(models)
|
models = set(_model_pretrained_cfgs).intersection(models)
|
||||||
|
|
||||||
return list(sorted(models, key=_natural_key))
|
return sorted(models, key=_natural_key)
|
||||||
|
|
||||||
|
|
||||||
def list_pretrained(
|
def list_pretrained(
|
||||||
filter: Union[str, List[str]] = '',
|
filter: Union[str, List[str]] = '',
|
||||||
exclude_filters: str = '',
|
exclude_filters: str = '',
|
||||||
):
|
) -> List[str]:
|
||||||
return list_models(
|
return list_models(
|
||||||
filter=filter,
|
filter=filter,
|
||||||
pretrained=True,
|
pretrained=True,
|
||||||
@ -168,14 +173,14 @@ def list_pretrained(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_model(model_name):
|
def is_model(model_name: str) -> bool:
|
||||||
""" Check if a model name exists
|
""" Check if a model name exists
|
||||||
"""
|
"""
|
||||||
arch_name = get_arch_name(model_name)
|
arch_name = get_arch_name(model_name)
|
||||||
return arch_name in _model_entrypoints
|
return arch_name in _model_entrypoints
|
||||||
|
|
||||||
|
|
||||||
def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
|
||||||
"""Fetch a model entrypoint for specified model name
|
"""Fetch a model entrypoint for specified model name
|
||||||
"""
|
"""
|
||||||
arch_name = get_arch_name(model_name)
|
arch_name = get_arch_name(model_name)
|
||||||
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
|||||||
return _model_entrypoints[arch_name]
|
return _model_entrypoints[arch_name]
|
||||||
|
|
||||||
|
|
||||||
def list_modules():
|
def list_modules() -> List[str]:
|
||||||
""" Return list of module names that contain models / model entrypoints
|
""" Return list of module names that contain models / model entrypoints
|
||||||
"""
|
"""
|
||||||
modules = _module_to_models.keys()
|
modules = _module_to_models.keys()
|
||||||
return list(sorted(modules))
|
return sorted(modules)
|
||||||
|
|
||||||
|
|
||||||
def is_model_in_modules(model_name, module_names):
|
def is_model_in_modules(
|
||||||
|
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
|
||||||
|
) -> bool:
|
||||||
"""Check if a model exists within a subset of modules
|
"""Check if a model exists within a subset of modules
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str) - name of model to check
|
model_name - name of model to check
|
||||||
module_names (tuple, list, set) - names of modules to search in
|
module_names - names of modules to search in
|
||||||
"""
|
"""
|
||||||
arch_name = get_arch_name(model_name)
|
arch_name = get_arch_name(model_name)
|
||||||
assert isinstance(module_names, (tuple, list, set))
|
assert isinstance(module_names, (tuple, list, set))
|
||||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||||
|
|
||||||
|
|
||||||
def is_model_pretrained(model_name):
|
def is_model_pretrained(model_name: str) -> bool:
|
||||||
return model_name in _model_has_pretrained
|
return model_name in _model_has_pretrained
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg(model_name, allow_unregistered=True):
|
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
|
||||||
if model_name in _model_pretrained_cfgs:
|
if model_name in _model_pretrained_cfgs:
|
||||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||||
arch_name, tag = split_model_name_tag(model_name)
|
arch_name, tag = split_model_name_tag(model_name)
|
||||||
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
|
|||||||
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
|
||||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||||
"""
|
"""
|
||||||
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user