Punch cache_dir through model factory / builder / pretrain helpers. Improve some annotations in related code.

This commit is contained in:
Ross Wightman 2024-12-04 22:02:40 -08:00 committed by Ross Wightman
parent 553ded5c6b
commit dc1bb05e8e
4 changed files with 99 additions and 36 deletions

View File

@ -2,7 +2,8 @@ import dataclasses
import logging
import os
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch import nn as nn
from torch.hub import load_state_dict_from_url
@ -90,6 +91,7 @@ def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
):
r"""Loads a custom (read non .pth) weight file
@ -102,9 +104,9 @@ def load_custom_pretrained(
Args:
model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg
pretrained_cfg: Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
'load_pretrained' on the model will be called if it exists
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
@ -122,6 +124,7 @@ def load_custom_pretrained(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS,
cache_dir=cache_dir,
)
if load_fn is not None:
@ -139,17 +142,18 @@ def load_pretrained(
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
""" Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
model: PyTorch module
pretrained_cfg: configuration for pretrained weights / target dataset
num_classes: number of classes for target model
in_chans: number of input chans for target model
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
strict: strict load of checkpoint
cache_dir: override path to cache dir for this load
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
@ -173,6 +177,7 @@ def load_pretrained(
pretrained_loc,
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
cache_dir=cache_dir,
)
model.load_pretrained(pretrained_loc)
return
@ -184,6 +189,7 @@ def load_pretrained(
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
weights_only=True,
model_dir=cache_dir,
)
except TypeError:
state_dict = load_state_dict_from_url(
@ -191,18 +197,19 @@ def load_pretrained(
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
model_dir=cache_dir,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
custom_load = pretrained_cfg.get('custom_load', False)
if isinstance(custom_load, str) and custom_load == 'hf':
load_custom_from_hf(*pretrained_loc, model)
load_custom_from_hf(*pretrained_loc, model, cache_dir=cache_dir)
return
else:
state_dict = load_state_dict_from_hf(*pretrained_loc)
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
else:
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True)
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
@ -362,6 +369,7 @@ def build_model_with_cfg(
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
cache_dir: Optional[Union[str, Path]] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
@ -382,6 +390,7 @@ def build_model_with_cfg(
feature_cfg: feature extraction adapter config
pretrained_strict: load pretrained weights strictly
pretrained_filter_fn: filter callable for pretrained weights
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
kwargs_filter: kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
@ -431,6 +440,7 @@ def build_model_with_cfg(
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
cache_dir=cache_dir,
)
# Wrap the model in a feature extraction module if enabled

View File

@ -1,4 +1,5 @@
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union
from urllib.parse import urlsplit
@ -40,7 +41,8 @@ def create_model(
pretrained: bool = False,
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
checkpoint_path: str = '',
checkpoint_path: Optional[Union[str, Path]] = None,
cache_dir: Optional[Union[str, Path]] = None,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
@ -50,10 +52,9 @@ def create_model(
Lookup model's entrypoint function and pass relevant args to create a new model.
<Tip>
Tip:
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
and then the model class __init__(). kwargs values set to None are pruned before passing.
</Tip>
Args:
model_name: Name of model to instantiate.
@ -61,6 +62,7 @@ def create_model(
pretrained_cfg: Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
@ -99,7 +101,10 @@ def create_model(
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_name,
cache_dir=cache_dir,
)
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v)
@ -118,6 +123,7 @@ def create_model(
pretrained=pretrained,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay,
cache_dir=cache_dir,
**kwargs,
)

View File

@ -4,7 +4,6 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import os
from collections import OrderedDict
from typing import Any, Callable, Dict, Optional, Union
import torch

View File

@ -5,7 +5,7 @@ import os
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
@ -53,7 +53,7 @@ HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
def get_cache_dir(child_dir=''):
def get_cache_dir(child_dir: str = ''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
@ -68,13 +68,22 @@ def get_cache_dir(child_dir=''):
return model_dir
def download_cached_file(url, check_hash=True, progress=False):
def download_cached_file(
url: Union[str, List[str], Tuple[str, str]],
check_hash: bool = True,
progress: bool = False,
cache_dir: Optional[Union[str, Path]] = None,
):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if cache_dir:
os.makedirs(cache_dir, exist_ok=True)
else:
cache_dir = get_cache_dir()
cached_file = os.path.join(cache_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
@ -85,13 +94,19 @@ def download_cached_file(url, check_hash=True, progress=False):
return cached_file
def check_cached_file(url, check_hash=True):
def check_cached_file(
url: Union[str, List[str], Tuple[str, str]],
check_hash: bool = True,
cache_dir: Optional[Union[str, Path]] = None,
):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if not cache_dir:
cache_dir = get_cache_dir()
cached_file = os.path.join(cache_dir, filename)
if os.path.exists(cached_file):
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
@ -105,7 +120,7 @@ def check_cached_file(url, check_hash=True):
return False
def has_hf_hub(necessary=False):
def has_hf_hub(necessary: bool = False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise RuntimeError(
@ -122,20 +137,32 @@ def hf_split(hf_id: str):
return hf_model_id, hf_revision
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
def load_cfg_from_json(json_file: Union[str, Path]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def download_from_hf(model_id: str, filename: str):
def download_from_hf(
model_id: str,
filename: str,
cache_dir: Optional[Union[str, Path]] = None,
):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
return hf_hub_download(
hf_model_id,
filename,
revision=hf_revision,
cache_dir=cache_dir,
)
def load_model_config_from_hf(model_id: str):
def load_model_config_from_hf(
model_id: str,
cache_dir: Optional[Union[str, Path]] = None,
):
assert has_hf_hub(True)
cached_file = download_from_hf(model_id, 'config.json')
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
@ -172,6 +199,7 @@ def load_state_dict_from_hf(
model_id: str,
filename: str = HF_WEIGHTS_NAME,
weights_only: bool = False,
cache_dir: Optional[Union[str, Path]] = None,
):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
@ -180,7 +208,12 @@ def load_state_dict_from_hf(
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
cached_safe_file = hf_hub_download(
repo_id=hf_model_id,
filename=safe_filename,
revision=hf_revision,
cache_dir=cache_dir,
)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
@ -189,7 +222,12 @@ def load_state_dict_from_hf(
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
cached_file = hf_hub_download(
hf_model_id,
filename=filename,
revision=hf_revision,
cache_dir=cache_dir,
)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
try:
state_dict = torch.load(cached_file, map_location='cpu', weights_only=weights_only)
@ -198,15 +236,25 @@ def load_state_dict_from_hf(
return state_dict
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
def load_custom_from_hf(
model_id: str,
filename: str,
model: torch.nn.Module,
cache_dir: Optional[Union[str, Path]] = None,
):
assert has_hf_hub(True)
hf_model_id, hf_revision = hf_split(model_id)
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
cached_file = hf_hub_download(
hf_model_id,
filename=filename,
revision=hf_revision,
cache_dir=cache_dir,
)
return model.load_pretrained(cached_file)
def save_config_for_hf(
model,
model: torch.nn.Module,
config_path: str,
model_config: Optional[dict] = None,
model_args: Optional[dict] = None
@ -255,7 +303,7 @@ def save_config_for_hf(
def save_for_hf(
model,
model: torch.nn.Module,
save_directory: str,
model_config: Optional[dict] = None,
model_args: Optional[dict] = None,