Add local-dir: schema support for model loading (config + weights) from folder
parent
ceca5efdec
commit
fe353419af
|
@ -11,8 +11,8 @@ from torch.hub import load_state_dict_from_url
|
|||
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
|
||||
load_custom_from_hf
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \
|
||||
load_state_dict_from_path, load_custom_from_hf
|
||||
from timm.models._manipulate import adapt_input_conv
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._prune import adapt_model_from_file
|
||||
|
@ -45,6 +45,9 @@ def _resolve_pretrained_source(pretrained_cfg):
|
|||
load_from = 'hf-hub'
|
||||
assert hf_hub_id
|
||||
pretrained_loc = hf_hub_id
|
||||
elif cfg_source == 'local-dir':
|
||||
load_from = 'local-dir'
|
||||
pretrained_loc = pretrained_file
|
||||
else:
|
||||
# default source == timm or unspecified
|
||||
if pretrained_sd:
|
||||
|
@ -211,6 +214,13 @@ def load_pretrained(
|
|||
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
|
||||
elif load_from == 'local-dir':
|
||||
_logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})')
|
||||
pretrained_path = Path(pretrained_loc)
|
||||
if pretrained_path.is_dir():
|
||||
state_dict = load_state_dict_from_path(pretrained_path)
|
||||
else:
|
||||
RuntimeError(f"Specified path is not a directory: {pretrained_loc}")
|
||||
else:
|
||||
model_name = pretrained_cfg.get('architecture', 'this model')
|
||||
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")
|
||||
|
|
|
@ -5,7 +5,7 @@ from urllib.parse import urlsplit
|
|||
|
||||
from timm.layers import set_layer_config
|
||||
from ._helpers import load_checkpoint
|
||||
from ._hub import load_model_config_from_hf
|
||||
from ._hub import load_model_config_from_hf, load_model_config_from_path
|
||||
from ._pretrained import PretrainedCfg
|
||||
from ._registry import is_model, model_entrypoint, split_model_name_tag
|
||||
|
||||
|
@ -18,13 +18,15 @@ def parse_model_name(model_name: str):
|
|||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||
parsed = urlsplit(model_name)
|
||||
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
||||
assert parsed.scheme in ('', 'hf-hub', 'local-dir')
|
||||
if parsed.scheme == 'hf-hub':
|
||||
# FIXME may use fragment as revision, currently `@` in URI path
|
||||
return parsed.scheme, parsed.path
|
||||
elif parsed.scheme == 'local-dir':
|
||||
return parsed.scheme, parsed.path
|
||||
else:
|
||||
model_name = os.path.split(parsed.path)[-1]
|
||||
return 'timm', model_name
|
||||
return None, model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name: str, remove_source: bool = True):
|
||||
|
@ -100,20 +102,27 @@ def create_model(
|
|||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == 'hf-hub':
|
||||
model_source, model_id = parse_model_name(model_name)
|
||||
if model_source:
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
|
||||
model_name,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if model_source == 'hf-hub':
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
|
||||
model_id,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
elif model_source == 'local-dir':
|
||||
pretrained_cfg, model_name, model_args = load_model_config_from_path(
|
||||
model_id,
|
||||
)
|
||||
else:
|
||||
assert False, f'Unknown model_source {model_source}'
|
||||
if model_args:
|
||||
for k, v in model_args.items():
|
||||
kwargs.setdefault(k, v)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
model_name, pretrained_tag = split_model_name_tag(model_id)
|
||||
if pretrained_tag and not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
|
@ -157,42 +157,60 @@ def download_from_hf(
|
|||
)
|
||||
|
||||
|
||||
def _parse_model_cfg(
|
||||
cfg: Dict[str, Any],
|
||||
extra_fields: Dict[str, Any],
|
||||
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
|
||||
""""""
|
||||
# legacy "single‑dict" → split
|
||||
if "pretrained_cfg" not in cfg:
|
||||
pretrained_cfg = cfg
|
||||
cfg = {
|
||||
"architecture": pretrained_cfg.pop("architecture"),
|
||||
"num_features": pretrained_cfg.pop("num_features", None),
|
||||
"pretrained_cfg": pretrained_cfg,
|
||||
}
|
||||
if "labels" in pretrained_cfg: # rename ‑‑> label_names
|
||||
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
|
||||
|
||||
pretrained_cfg = cfg["pretrained_cfg"]
|
||||
pretrained_cfg.update(extra_fields)
|
||||
|
||||
# top‑level overrides
|
||||
if "num_classes" in cfg:
|
||||
pretrained_cfg["num_classes"] = cfg["num_classes"]
|
||||
if "label_names" in cfg:
|
||||
pretrained_cfg["label_names"] = cfg.pop("label_names")
|
||||
if "label_descriptions" in cfg:
|
||||
pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")
|
||||
|
||||
model_args = cfg.get("model_args", {})
|
||||
model_name = cfg["architecture"]
|
||||
return pretrained_cfg, model_name, model_args
|
||||
|
||||
|
||||
def load_model_config_from_hf(
|
||||
model_id: str,
|
||||
cache_dir: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
|
||||
assert has_hf_hub(True)
|
||||
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)
|
||||
cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
|
||||
cfg = load_cfg_from_json(cfg_path)
|
||||
return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})
|
||||
|
||||
hf_config = load_cfg_from_json(cached_file)
|
||||
if 'pretrained_cfg' not in hf_config:
|
||||
# old form, pull pretrain_cfg out of the base dict
|
||||
pretrained_cfg = hf_config
|
||||
hf_config = {}
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
||||
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
|
||||
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
|
||||
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
||||
pretrained_cfg = hf_config['pretrained_cfg']
|
||||
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
pretrained_cfg['source'] = 'hf-hub'
|
||||
|
||||
# model should be created with base config num_classes if its exist
|
||||
if 'num_classes' in hf_config:
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
|
||||
# label meta-data in base config overrides saved pretrained_cfg on load
|
||||
if 'label_names' in hf_config:
|
||||
pretrained_cfg['label_names'] = hf_config.pop('label_names')
|
||||
if 'label_descriptions' in hf_config:
|
||||
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
|
||||
|
||||
model_args = hf_config.get('model_args', {})
|
||||
model_name = hf_config['architecture']
|
||||
return pretrained_cfg, model_name, model_args
|
||||
def load_model_config_from_path(
|
||||
model_path: Union[str, Path],
|
||||
):
|
||||
"""Load from ``<model_path>/config.json`` on the local filesystem."""
|
||||
model_path = Path(model_path)
|
||||
cfg_file = model_path / "config.json"
|
||||
if not cfg_file.is_file():
|
||||
raise FileNotFoundError(f"Config file not found: {cfg_file}")
|
||||
cfg = load_cfg_from_json(cfg_file)
|
||||
extra_fields = {"file": str(model_path), "source": "local-dir"}
|
||||
return _parse_model_cfg(cfg, extra_fields=extra_fields)
|
||||
|
||||
|
||||
def load_state_dict_from_hf(
|
||||
|
@ -236,6 +254,51 @@ def load_state_dict_from_hf(
|
|||
return state_dict
|
||||
|
||||
|
||||
_PREFERRED_FILES = (
|
||||
"model.safetensors",
|
||||
"pytorch_model.bin",
|
||||
"pytorch_model.pth",
|
||||
"model.pth",
|
||||
"open_clip_model.safetensors",
|
||||
"open_clip_pytorch_model.safetensors",
|
||||
"open_clip_pytorch_model.bin",
|
||||
"open_clip_pytorch_model.pth",
|
||||
)
|
||||
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
|
||||
|
||||
def load_state_dict_from_path(
|
||||
path: str,
|
||||
weights_only: bool = False,
|
||||
):
|
||||
found_file = None
|
||||
for fname in _PREFERRED_FILES:
|
||||
p = path / fname
|
||||
if p.exists():
|
||||
logging.info(f"Found preferred checkpoint: {p.name}")
|
||||
found_file = p
|
||||
break
|
||||
|
||||
# fallback: first match per‑extension class
|
||||
for ext in _EXT_PRIORITY:
|
||||
files = sorted(path.glob(f"*{ext}"))
|
||||
if files:
|
||||
if len(files) > 1:
|
||||
logging.warning(
|
||||
f"Multiple {ext} checkpoints in {path}: {names}. "
|
||||
f"Using '{files[0].name}'."
|
||||
)
|
||||
found_file = files[0]
|
||||
|
||||
if not found_file:
|
||||
raise RuntimeError(f"No suitable checkpoints found in {path}.")
|
||||
|
||||
try:
|
||||
state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only)
|
||||
except TypeError:
|
||||
state_dict = torch.load(found_file, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_custom_from_hf(
|
||||
model_id: str,
|
||||
filename: str,
|
||||
|
|
Loading…
Reference in New Issue