Add local-dir: schema support for model loading (config + weights) from folder

This commit is contained in:
Ross Wightman 2025-04-17 10:32:48 -07:00
parent ceca5efdec
commit fe353419af
3 changed files with 126 additions and 44 deletions

View File

@ -11,8 +11,8 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
from timm.models._features_fx import FeatureGraphNet from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\ from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \
load_custom_from_hf load_state_dict_from_path, load_custom_from_hf
from timm.models._manipulate import adapt_input_conv from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file from timm.models._prune import adapt_model_from_file
@ -45,6 +45,9 @@ def _resolve_pretrained_source(pretrained_cfg):
load_from = 'hf-hub' load_from = 'hf-hub'
assert hf_hub_id assert hf_hub_id
pretrained_loc = hf_hub_id pretrained_loc = hf_hub_id
elif cfg_source == 'local-dir':
load_from = 'local-dir'
pretrained_loc = pretrained_file
else: else:
# default source == timm or unspecified # default source == timm or unspecified
if pretrained_sd: if pretrained_sd:
@ -211,6 +214,13 @@ def load_pretrained(
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir) state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
else: else:
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir) state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
elif load_from == 'local-dir':
_logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})')
pretrained_path = Path(pretrained_loc)
if pretrained_path.is_dir():
state_dict = load_state_dict_from_path(pretrained_path)
else:
RuntimeError(f"Specified path is not a directory: {pretrained_loc}")
else: else:
model_name = pretrained_cfg.get('architecture', 'this model') model_name = pretrained_cfg.get('architecture', 'this model')
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.") raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

View File

@ -5,7 +5,7 @@ from urllib.parse import urlsplit
from timm.layers import set_layer_config from timm.layers import set_layer_config
from ._helpers import load_checkpoint from ._helpers import load_checkpoint
from ._hub import load_model_config_from_hf from ._hub import load_model_config_from_hf, load_model_config_from_path
from ._pretrained import PretrainedCfg from ._pretrained import PretrainedCfg
from ._registry import is_model, model_entrypoint, split_model_name_tag from ._registry import is_model, model_entrypoint, split_model_name_tag
@ -18,13 +18,15 @@ def parse_model_name(model_name: str):
# NOTE for backwards compat, deprecate hf_hub use # NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub') model_name = model_name.replace('hf_hub', 'hf-hub')
parsed = urlsplit(model_name) parsed = urlsplit(model_name)
assert parsed.scheme in ('', 'timm', 'hf-hub') assert parsed.scheme in ('', 'hf-hub', 'local-dir')
if parsed.scheme == 'hf-hub': if parsed.scheme == 'hf-hub':
# FIXME may use fragment as revision, currently `@` in URI path # FIXME may use fragment as revision, currently `@` in URI path
return parsed.scheme, parsed.path return parsed.scheme, parsed.path
elif parsed.scheme == 'local-dir':
return parsed.scheme, parsed.path
else: else:
model_name = os.path.split(parsed.path)[-1] model_name = os.path.split(parsed.path)[-1]
return 'timm', model_name return None, model_name
def safe_model_name(model_name: str, remove_source: bool = True): def safe_model_name(model_name: str, remove_source: bool = True):
@ -100,20 +102,27 @@ def create_model(
# non-supporting models don't break and default args remain in effect. # non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
model_source, model_name = parse_model_name(model_name) model_source, model_id = parse_model_name(model_name)
if model_source == 'hf-hub': if model_source:
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
# For model names specified in the form `hf-hub:path/architecture_name@revision`, if model_source == 'hf-hub':
# load model weights + pretrained_cfg from Hugging Face hub. # For model names specified in the form `hf-hub:path/architecture_name@revision`,
pretrained_cfg, model_name, model_args = load_model_config_from_hf( # load model weights + pretrained_cfg from Hugging Face hub.
model_name, pretrained_cfg, model_name, model_args = load_model_config_from_hf(
cache_dir=cache_dir, model_id,
) cache_dir=cache_dir,
)
elif model_source == 'local-dir':
pretrained_cfg, model_name, model_args = load_model_config_from_path(
model_id,
)
else:
assert False, f'Unknown model_source {model_source}'
if model_args: if model_args:
for k, v in model_args.items(): for k, v in model_args.items():
kwargs.setdefault(k, v) kwargs.setdefault(k, v)
else: else:
model_name, pretrained_tag = split_model_name_tag(model_name) model_name, pretrained_tag = split_model_name_tag(model_id)
if pretrained_tag and not pretrained_cfg: if pretrained_tag and not pretrained_cfg:
# a valid pretrained_cfg argument takes priority over tag in model name # a valid pretrained_cfg argument takes priority over tag in model name
pretrained_cfg = pretrained_tag pretrained_cfg = pretrained_tag

View File

@ -5,7 +5,7 @@ import os
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
@ -157,42 +157,60 @@ def download_from_hf(
) )
def _parse_model_cfg(
cfg: Dict[str, Any],
extra_fields: Dict[str, Any],
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
""""""
# legacy "singledict" → split
if "pretrained_cfg" not in cfg:
pretrained_cfg = cfg
cfg = {
"architecture": pretrained_cfg.pop("architecture"),
"num_features": pretrained_cfg.pop("num_features", None),
"pretrained_cfg": pretrained_cfg,
}
if "labels" in pretrained_cfg: # rename > label_names
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
pretrained_cfg = cfg["pretrained_cfg"]
pretrained_cfg.update(extra_fields)
# toplevel overrides
if "num_classes" in cfg:
pretrained_cfg["num_classes"] = cfg["num_classes"]
if "label_names" in cfg:
pretrained_cfg["label_names"] = cfg.pop("label_names")
if "label_descriptions" in cfg:
pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")
model_args = cfg.get("model_args", {})
model_name = cfg["architecture"]
return pretrained_cfg, model_name, model_args
def load_model_config_from_hf( def load_model_config_from_hf(
model_id: str, model_id: str,
cache_dir: Optional[Union[str, Path]] = None, cache_dir: Optional[Union[str, Path]] = None,
): ):
"""Original HFHub loader (unchanged download, shared parsing)."""
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir) cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
cfg = load_cfg_from_json(cfg_path)
return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
# old form, pull pretrain_cfg out of the base dict
pretrained_cfg = hf_config
hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now def load_model_config_from_path(
pretrained_cfg = hf_config['pretrained_cfg'] model_path: Union[str, Path],
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation ):
pretrained_cfg['source'] = 'hf-hub' """Load from ``<model_path>/config.json`` on the local filesystem."""
model_path = Path(model_path)
# model should be created with base config num_classes if its exist cfg_file = model_path / "config.json"
if 'num_classes' in hf_config: if not cfg_file.is_file():
pretrained_cfg['num_classes'] = hf_config['num_classes'] raise FileNotFoundError(f"Config file not found: {cfg_file}")
cfg = load_cfg_from_json(cfg_file)
# label meta-data in base config overrides saved pretrained_cfg on load extra_fields = {"file": str(model_path), "source": "local-dir"}
if 'label_names' in hf_config: return _parse_model_cfg(cfg, extra_fields=extra_fields)
pretrained_cfg['label_names'] = hf_config.pop('label_names')
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
model_args = hf_config.get('model_args', {})
model_name = hf_config['architecture']
return pretrained_cfg, model_name, model_args
def load_state_dict_from_hf( def load_state_dict_from_hf(
@ -236,6 +254,51 @@ def load_state_dict_from_hf(
return state_dict return state_dict
_PREFERRED_FILES = (
"model.safetensors",
"pytorch_model.bin",
"pytorch_model.pth",
"model.pth",
"open_clip_model.safetensors",
"open_clip_pytorch_model.safetensors",
"open_clip_pytorch_model.bin",
"open_clip_pytorch_model.pth",
)
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
def load_state_dict_from_path(
path: str,
weights_only: bool = False,
):
found_file = None
for fname in _PREFERRED_FILES:
p = path / fname
if p.exists():
logging.info(f"Found preferred checkpoint: {p.name}")
found_file = p
break
# fallback: first match perextension class
for ext in _EXT_PRIORITY:
files = sorted(path.glob(f"*{ext}"))
if files:
if len(files) > 1:
logging.warning(
f"Multiple {ext} checkpoints in {path}: {names}. "
f"Using '{files[0].name}'."
)
found_file = files[0]
if not found_file:
raise RuntimeError(f"No suitable checkpoints found in {path}.")
try:
state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only)
except TypeError:
state_dict = torch.load(found_file, map_location='cpu')
return state_dict
def load_custom_from_hf( def load_custom_from_hf(
model_id: str, model_id: str,
filename: str, filename: str,