mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Punch cache_dir through model factory / builder / pretrain helpers. Improve some annotations in related code.
This commit is contained in:
parent
553ded5c6b
commit
dc1bb05e8e
@ -2,7 +2,8 @@ import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
@ -90,6 +91,7 @@ def load_custom_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
load_fn: Optional[Callable] = None,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
@ -102,9 +104,9 @@ def load_custom_pretrained(
|
||||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
pretrained_cfg (dict): Default pretrained model cfg
|
||||
pretrained_cfg: Default pretrained model cfg
|
||||
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
'load_pretrained' on the model will be called if it exists
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
@ -122,6 +124,7 @@ def load_custom_pretrained(
|
||||
pretrained_loc,
|
||||
check_hash=_CHECK_HASH,
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
if load_fn is not None:
|
||||
@ -139,17 +142,18 @@ def load_pretrained(
|
||||
in_chans: int = 3,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
strict: bool = True,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
""" Load pretrained checkpoint
|
||||
|
||||
Args:
|
||||
model (nn.Module) : PyTorch model module
|
||||
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
|
||||
num_classes (int): num_classes for target model
|
||||
in_chans (int): in_chans for target model
|
||||
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
||||
strict (bool): strict load of checkpoint
|
||||
|
||||
model: PyTorch module
|
||||
pretrained_cfg: configuration for pretrained weights / target dataset
|
||||
num_classes: number of classes for target model
|
||||
in_chans: number of input chans for target model
|
||||
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
|
||||
strict: strict load of checkpoint
|
||||
cache_dir: override path to cache dir for this load
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
@ -173,6 +177,7 @@ def load_pretrained(
|
||||
pretrained_loc,
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
model.load_pretrained(pretrained_loc)
|
||||
return
|
||||
@ -184,6 +189,7 @@ def load_pretrained(
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
weights_only=True,
|
||||
model_dir=cache_dir,
|
||||
)
|
||||
except TypeError:
|
||||
state_dict = load_state_dict_from_url(
|
||||
@ -191,18 +197,19 @@ def load_pretrained(
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
model_dir=cache_dir,
|
||||
)
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
custom_load = pretrained_cfg.get('custom_load', False)
|
||||
if isinstance(custom_load, str) and custom_load == 'hf':
|
||||
load_custom_from_hf(*pretrained_loc, model)
|
||||
load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir)
|
||||
return
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
|
||||
else:
|
||||
model_name = pretrained_cfg.get('architecture', 'this model')
|
||||
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
||||
@ -362,6 +369,7 @@ def build_model_with_cfg(
|
||||
feature_cfg: Optional[Dict] = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Optional[Callable] = None,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
kwargs_filter: Optional[Tuple[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -382,6 +390,7 @@ def build_model_with_cfg(
|
||||
feature_cfg: feature extraction adapter config
|
||||
pretrained_strict: load pretrained weights strictly
|
||||
pretrained_filter_fn: filter callable for pretrained weights
|
||||
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
|
||||
kwargs_filter: kwargs to filter before passing to model
|
||||
**kwargs: model args passed through to model __init__
|
||||
"""
|
||||
@ -431,6 +440,7 @@ def build_model_with_cfg(
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn,
|
||||
strict=pretrained_strict,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
@ -40,7 +41,8 @@ def create_model(
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = '',
|
||||
checkpoint_path: Optional[Union[str, Path]] = None,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
@ -50,10 +52,9 @@ def create_model(
|
||||
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
|
||||
<Tip>
|
||||
Tip:
|
||||
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
|
||||
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model_name: Name of model to instantiate.
|
||||
@ -61,6 +62,7 @@ def create_model(
|
||||
pretrained_cfg: Pass in an external pretrained_cfg for model.
|
||||
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
|
||||
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
|
||||
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
|
||||
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
|
||||
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
|
||||
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
|
||||
@ -99,7 +101,10 @@ def create_model(
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
|
||||
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
|
||||
model_name,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if model_args:
|
||||
for k, v in model_args.items():
|
||||
kwargs.setdefault(k, v)
|
||||
@ -118,6 +123,7 @@ def create_model(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
cache_dir=cache_dir,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,6 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
@ -53,7 +53,7 @@ HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
|
||||
HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
def get_cache_dir(child_dir: str = ''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
"""
|
||||
@ -68,13 +68,22 @@ def get_cache_dir(child_dir=''):
|
||||
return model_dir
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
def download_cached_file(
|
||||
url: Union[str, List[str], Tuple[str, str]],
|
||||
check_hash: bool = True,
|
||||
progress: bool = False,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if cache_dir:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
else:
|
||||
cache_dir = get_cache_dir()
|
||||
cached_file = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
@ -85,13 +94,19 @@ def download_cached_file(url, check_hash=True, progress=False):
|
||||
return cached_file
|
||||
|
||||
|
||||
def check_cached_file(url, check_hash=True):
|
||||
def check_cached_file(
|
||||
url: Union[str, List[str], Tuple[str, str]],
|
||||
check_hash: bool = True,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if not cache_dir:
|
||||
cache_dir = get_cache_dir()
|
||||
cached_file = os.path.join(cache_dir, filename)
|
||||
if os.path.exists(cached_file):
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
@ -105,7 +120,7 @@ def check_cached_file(url, check_hash=True):
|
||||
return False
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
def has_hf_hub(necessary: bool = False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
@ -122,20 +137,32 @@ def hf_split(hf_id: str):
|
||||
return hf_model_id, hf_revision
|
||||
|
||||
|
||||
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
||||
def load_cfg_from_json(json_file: Union[str, Path]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def download_from_hf(model_id: str, filename: str):
|
||||
def download_from_hf(
|
||||
model_id: str,
|
||||
filename: str,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
||||
return hf_hub_download(
|
||||
hf_model_id,
|
||||
filename,
|
||||
revision=hf_revision,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
def load_model_config_from_hf(
|
||||
model_id: str,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = download_from_hf(model_id, 'config.json')
|
||||
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)
|
||||
|
||||
hf_config = load_cfg_from_json(cached_file)
|
||||
if 'pretrained_cfg' not in hf_config:
|
||||
@ -172,6 +199,7 @@ def load_state_dict_from_hf(
|
||||
model_id: str,
|
||||
filename: str = HF_WEIGHTS_NAME,
|
||||
weights_only: bool = False,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
@ -180,7 +208,12 @@ def load_state_dict_from_hf(
|
||||
if _has_safetensors:
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
cached_safe_file = hf_hub_download(
|
||||
repo_id=hf_model_id,
|
||||
filename=safe_filename,
|
||||
revision=hf_revision,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
_logger.info(
|
||||
f"[{model_id}] Safe alternative available for '{filename}' "
|
||||
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
||||
@ -189,7 +222,12 @@ def load_state_dict_from_hf(
|
||||
pass
|
||||
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
cached_file = hf_hub_download(
|
||||
hf_model_id,
|
||||
filename=filename,
|
||||
revision=hf_revision,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
try:
|
||||
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
|
||||
@ -198,15 +236,25 @@ def load_state_dict_from_hf(
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
||||
def load_custom_from_hf(
|
||||
model_id: str,
|
||||
filename: str,
|
||||
model: torch.nn.Module,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
cached_file = hf_hub_download(
|
||||
hf_model_id,
|
||||
filename=filename,
|
||||
revision=hf_revision,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
return model.load_pretrained(cached_file)
|
||||
|
||||
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
model: torch.nn.Module,
|
||||
config_path: str,
|
||||
model_config: Optional[dict] = None,
|
||||
model_args: Optional[dict] = None
|
||||
@ -255,7 +303,7 @@ def save_config_for_hf(
|
||||
|
||||
|
||||
def save_for_hf(
|
||||
model,
|
||||
model: torch.nn.Module,
|
||||
save_directory: str,
|
||||
model_config: Optional[dict] = None,
|
||||
model_args: Optional[dict] = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user