Add local-dir: schema support for model loading (config + weights) from folder

pull/2476/head
Ross Wightman 2025-04-17 10:32:48 -07:00
parent ceca5efdec
commit fe353419af
3 changed files with 126 additions and 44 deletions

View File

@ -11,8 +11,8 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
load_custom_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \
load_state_dict_from_path, load_custom_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
@ -45,6 +45,9 @@ def _resolve_pretrained_source(pretrained_cfg):
load_from = 'hf-hub'
assert hf_hub_id
pretrained_loc = hf_hub_id
elif cfg_source == 'local-dir':
load_from = 'local-dir'
pretrained_loc = pretrained_file
else:
# default source == timm or unspecified
if pretrained_sd:
@ -211,6 +214,13 @@ def load_pretrained(
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
else:
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
elif load_from == 'local-dir':
_logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})')
pretrained_path = Path(pretrained_loc)
if pretrained_path.is_dir():
state_dict = load_state_dict_from_path(pretrained_path)
else:
RuntimeError(f"Specified path is not a directory: {pretrained_loc}")
else:
model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

View File

@ -5,7 +5,7 @@ from urllib.parse import urlsplit
from timm.layers import set_layer_config
from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf
from ._hub import load_model_config_from_hf, load_model_config_from_path
from ._pretrained import PretrainedCfg
from ._registry import is_model, model_entrypoint, split_model_name_tag
@ -18,13 +18,15 @@ def parse_model_name(model_name: str):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub')
assert parsed.scheme in ('', 'hf-hub', 'local-dir')
if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path
elif parsed.scheme == 'local-dir':
return parsed.scheme, parsed.path
else:
model_name = os.path.split(parsed.path)[-1]
return 'timm', model_name
return None, model_name
def safe_model_name(model_name: str, remove_source: bool = True):
@ -100,20 +102,27 @@ def create_model(
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_name = parse_model_name(model_name)
if model_source == 'hf-hub':
model_source, model_id = parse_model_name(model_name)
if model_source:
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_name,
cache_dir=cache_dir,
)
if model_source == 'hf-hub':
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
# load model weights + pretrained_cfg from Hugging Face hub.
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
model_id,
cache_dir=cache_dir,
)
elif model_source == 'local-dir':
pretrained_cfg, model_name, model_args = load_model_config_from_path(
model_id,
)
else:
assert False, f'Unknown model_source {model_source}'
if model_args:
for k, v in model_args.items():
kwargs.setdefault(k, v)
else:
model_name, pretrained_tag = split_model_name_tag(model_name)
model_name, pretrained_tag = split_model_name_tag(model_id)
if pretrained_tag and not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag

View File

@ -5,7 +5,7 @@ import os
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
@ -157,42 +157,60 @@ def download_from_hf(
)
def _parse_model_cfg(
cfg: Dict[str, Any],
extra_fields: Dict[str, Any],
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
""""""
# legacy "singledict" → split
if "pretrained_cfg" not in cfg:
pretrained_cfg = cfg
cfg = {
"architecture": pretrained_cfg.pop("architecture"),
"num_features": pretrained_cfg.pop("num_features", None),
"pretrained_cfg": pretrained_cfg,
}
if "labels" in pretrained_cfg: # rename > label_names
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
pretrained_cfg = cfg["pretrained_cfg"]
pretrained_cfg.update(extra_fields)
# toplevel overrides
if "num_classes" in cfg:
pretrained_cfg["num_classes"] = cfg["num_classes"]
if "label_names" in cfg:
pretrained_cfg["label_names"] = cfg.pop("label_names")
if "label_descriptions" in cfg:
pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")
model_args = cfg.get("model_args", {})
model_name = cfg["architecture"]
return pretrained_cfg, model_name, model_args
def load_model_config_from_hf(
model_id: str,
cache_dir: Optional[Union[str, Path]] = None,
):
"""Original HFHub loader (unchanged download, shared parsing)."""
assert has_hf_hub(True)
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)
cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
cfg = load_cfg_from_json(cfg_path)
return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub'
# model should be created with base config num_classes if its exist
if 'num_classes' in hf_config:
pretrained_cfg['num_classes'] = hf_config['num_classes']
# label meta-data in base config overrides saved pretrained_cfg on load
if 'label_names' in hf_config:
pretrained_cfg['label_names'] = hf_config.pop('label_names')
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
model_args = hf_config.get('model_args', {})
model_name = hf_config['architecture']
return pretrained_cfg, model_name, model_args
def load_model_config_from_path(
model_path: Union[str, Path],
):
"""Load from ``<model_path>/config.json`` on the local filesystem."""
model_path = Path(model_path)
cfg_file = model_path / "config.json"
if not cfg_file.is_file():
raise FileNotFoundError(f"Config file not found: {cfg_file}")
cfg = load_cfg_from_json(cfg_file)
extra_fields = {"file": str(model_path), "source": "local-dir"}
return _parse_model_cfg(cfg, extra_fields=extra_fields)
def load_state_dict_from_hf(
@ -236,6 +254,51 @@ def load_state_dict_from_hf(
return state_dict
_PREFERRED_FILES = (
"model.safetensors",
"pytorch_model.bin",
"pytorch_model.pth",
"model.pth",
"open_clip_model.safetensors",
"open_clip_pytorch_model.safetensors",
"open_clip_pytorch_model.bin",
"open_clip_pytorch_model.pth",
)
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
def load_state_dict_from_path(
path: str,
weights_only: bool = False,
):
found_file = None
for fname in _PREFERRED_FILES:
p = path / fname
if p.exists():
logging.info(f"Found preferred checkpoint: {p.name}")
found_file = p
break
# fallback: first match perextension class
for ext in _EXT_PRIORITY:
files = sorted(path.glob(f"*{ext}"))
if files:
if len(files) > 1:
logging.warning(
f"Multiple {ext} checkpoints in {path}: {names}. "
f"Using '{files[0].name}'."
)
found_file = files[0]
if not found_file:
raise RuntimeError(f"No suitable checkpoints found in {path}.")
try:
state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only)
except TypeError:
state_dict = torch.load(found_file, map_location='cpu')
return state_dict
def load_custom_from_hf(
model_id: str,
filename: str,