mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add type annotations to _registry.py
Description Add type annotations to _registry.py so that they will pass mypy --strict. Comment I was reading the code and felt that this module would be easier to understand with type annotations. Therefore, I went ahead and added the annotations. The idea with this PR is to start small to see if we can align on _how_ to annotate types. I've seen people in the past disagree on how strictly to annotate the code base, so before spending too much time on this, I wanted to check if you agree, Ross. Most of the added types should be straightforward. Some notes on the non-trivial changes: - I made no assumption about the fn passed to register_model, but maybe the type could be stricter. Are all models nn.Modules? - If I'm not mistaken, the type hint for get_arch_name was incorrect - I had to add a # type: ignore to model.__all__ = ... - I made some minor code changes to list_models to facilitate the typing. I think the changes should not affect the logic of the function. - I removed list from list(sorted(...)) because sorted returns always a list.
This commit is contained in:
parent
c9406ce608
commit
a5b01ec04e
@ -93,7 +93,7 @@ class DefaultCfg:
|
||||
return tag, self.cfgs[tag]
|
||||
|
||||
|
||||
def split_model_name_tag(model_name: str, no_tag=''):
|
||||
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||
model_name, *tag_list = model_name.split('.', 1)
|
||||
tag = tag_list[0] if tag_list else no_tag
|
||||
return model_name, tag
|
||||
|
@ -8,7 +8,7 @@ import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
@ -16,20 +16,20 @@ __all__ = [
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
|
||||
def get_arch_name(model_name: str) -> str:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
@ -40,7 +40,7 @@ def register_model(fn):
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(model_name)
|
||||
else:
|
||||
mod.__all__ = [model_name]
|
||||
mod.__all__ = [model_name] # type: ignore
|
||||
|
||||
# add entries to registry dict/sets
|
||||
_model_entrypoints[model_name] = fn
|
||||
@ -87,28 +87,33 @@ def register_model(fn):
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
pretrained=False,
|
||||
exclude_filters: str = '',
|
||||
pretrained: bool = False,
|
||||
exclude_filters: Union[str, List[str]] = '',
|
||||
name_matches_cfg: bool = False,
|
||||
include_tags: Optional[bool] = None,
|
||||
):
|
||||
) -> List[str]:
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained (bool) - Include only models with valid pretrained weights if True
|
||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
filter - Wildcard filter string that works with fnmatch
|
||||
module - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained - Include only models with valid pretrained weights if True
|
||||
exclude_filters - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
set to True when pretrained=True else False (default: None)
|
||||
|
||||
Returns:
|
||||
models - The sorted list of models
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
@ -118,7 +123,7 @@ def list_models(
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models = list(_module_to_models[module])
|
||||
all_models: Iterable[str] = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
|
||||
@ -130,14 +135,14 @@ def list_models(
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
models = []
|
||||
models: Set[str] = set()
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = set(models).union(include_models)
|
||||
models = models.union(include_models)
|
||||
else:
|
||||
models = all_models
|
||||
models = set(all_models)
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
@ -145,7 +150,7 @@ def list_models(
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
if len(exclude_models):
|
||||
models = set(models).difference(exclude_models)
|
||||
models = models.difference(exclude_models)
|
||||
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
@ -153,13 +158,13 @@ def list_models(
|
||||
if name_matches_cfg:
|
||||
models = set(_model_pretrained_cfgs).intersection(models)
|
||||
|
||||
return list(sorted(models, key=_natural_key))
|
||||
return sorted(models, key=_natural_key)
|
||||
|
||||
|
||||
def list_pretrained(
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: str = '',
|
||||
):
|
||||
) -> List[str]:
|
||||
return list_models(
|
||||
filter=filter,
|
||||
pretrained=True,
|
||||
@ -168,14 +173,14 @@ def list_pretrained(
|
||||
)
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
def is_model(model_name: str) -> bool:
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
return arch_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
||||
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
||||
return _model_entrypoints[arch_name]
|
||||
|
||||
|
||||
def list_modules():
|
||||
def list_modules() -> List[str]:
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return list(sorted(modules))
|
||||
return sorted(modules)
|
||||
|
||||
|
||||
def is_model_in_modules(model_name, module_names):
|
||||
def is_model_in_modules(
|
||||
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
|
||||
) -> bool:
|
||||
"""Check if a model exists within a subset of modules
|
||||
|
||||
Args:
|
||||
model_name (str) - name of model to check
|
||||
module_names (tuple, list, set) - names of modules to search in
|
||||
model_name - name of model to check
|
||||
module_names - names of modules to search in
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||
|
||||
|
||||
def is_model_pretrained(model_name):
|
||||
def is_model_pretrained(model_name: str) -> bool:
|
||||
return model_name in _model_has_pretrained
|
||||
|
||||
|
||||
def get_pretrained_cfg(model_name, allow_unregistered=True):
|
||||
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||
arch_name, tag = split_model_name_tag(model_name)
|
||||
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
|
||||
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
||||
|
||||
|
||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
||||
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
|
||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||
"""
|
||||
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user