diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 482d370a..2b34ead6 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -11,8 +11,8 @@ from torch.hub import load_state_dict_from_url from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet from timm.models._features_fx import FeatureGraphNet from timm.models._helpers import load_state_dict -from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\ - load_custom_from_hf +from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \ + load_state_dict_from_path, load_custom_from_hf from timm.models._manipulate import adapt_input_conv from timm.models._pretrained import PretrainedCfg from timm.models._prune import adapt_model_from_file @@ -45,6 +45,9 @@ def _resolve_pretrained_source(pretrained_cfg): load_from = 'hf-hub' assert hf_hub_id pretrained_loc = hf_hub_id + elif cfg_source == 'local-dir': + load_from = 'local-dir' + pretrained_loc = pretrained_file else: # default source == timm or unspecified if pretrained_sd: @@ -211,6 +214,13 @@ def load_pretrained( state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir) else: state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir) + elif load_from == 'local-dir': + _logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})') + pretrained_path = Path(pretrained_loc) + if pretrained_path.is_dir(): + state_dict = load_state_dict_from_path(pretrained_path) + else: + RuntimeError(f"Specified path is not a directory: {pretrained_loc}") else: model_name = pretrained_cfg.get('architecture', 'this model') raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.") diff --git a/timm/models/_factory.py b/timm/models/_factory.py index b347bc4d..63e897e5 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -5,7 +5,7 @@ from urllib.parse import urlsplit from timm.layers import set_layer_config from ._helpers import load_checkpoint -from ._hub import load_model_config_from_hf +from ._hub import load_model_config_from_hf, load_model_config_from_path from ._pretrained import PretrainedCfg from ._registry import is_model, model_entrypoint, split_model_name_tag @@ -18,13 +18,15 @@ def parse_model_name(model_name: str): # NOTE for backwards compat, deprecate hf_hub use model_name = model_name.replace('hf_hub', 'hf-hub') parsed = urlsplit(model_name) - assert parsed.scheme in ('', 'timm', 'hf-hub') + assert parsed.scheme in ('', 'hf-hub', 'local-dir') if parsed.scheme == 'hf-hub': # FIXME may use fragment as revision, currently `@` in URI path return parsed.scheme, parsed.path + elif parsed.scheme == 'local-dir': + return parsed.scheme, parsed.path else: model_name = os.path.split(parsed.path)[-1] - return 'timm', model_name + return None, model_name def safe_model_name(model_name: str, remove_source: bool = True): @@ -100,20 +102,27 @@ def create_model( # non-supporting models don't break and default args remain in effect. kwargs = {k: v for k, v in kwargs.items() if v is not None} - model_source, model_name = parse_model_name(model_name) - if model_source == 'hf-hub': + model_source, model_id = parse_model_name(model_name) + if model_source: assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.' - # For model names specified in the form `hf-hub:path/architecture_name@revision`, - # load model weights + pretrained_cfg from Hugging Face hub. - pretrained_cfg, model_name, model_args = load_model_config_from_hf( - model_name, - cache_dir=cache_dir, - ) + if model_source == 'hf-hub': + # For model names specified in the form `hf-hub:path/architecture_name@revision`, + # load model weights + pretrained_cfg from Hugging Face hub. + pretrained_cfg, model_name, model_args = load_model_config_from_hf( + model_id, + cache_dir=cache_dir, + ) + elif model_source == 'local-dir': + pretrained_cfg, model_name, model_args = load_model_config_from_path( + model_id, + ) + else: + assert False, f'Unknown model_source {model_source}' if model_args: for k, v in model_args.items(): kwargs.setdefault(k, v) else: - model_name, pretrained_tag = split_model_name_tag(model_name) + model_name, pretrained_tag = split_model_name_tag(model_id) if pretrained_tag and not pretrained_cfg: # a valid pretrained_cfg argument takes priority over tag in model name pretrained_cfg = pretrained_tag diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 1da0942b..408d2b8f 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -5,7 +5,7 @@ import os from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch.hub import HASH_REGEX, download_url_to_file, urlparse @@ -157,42 +157,60 @@ def download_from_hf( ) +def _parse_model_cfg( + cfg: Dict[str, Any], + extra_fields: Dict[str, Any], +) -> Tuple[Dict[str, Any], str, Dict[str, Any]]: + """""" + # legacy "single‑dict" → split + if "pretrained_cfg" not in cfg: + pretrained_cfg = cfg + cfg = { + "architecture": pretrained_cfg.pop("architecture"), + "num_features": pretrained_cfg.pop("num_features", None), + "pretrained_cfg": pretrained_cfg, + } + if "labels" in pretrained_cfg: # rename ‑‑> label_names + pretrained_cfg["label_names"] = pretrained_cfg.pop("labels") + + pretrained_cfg = cfg["pretrained_cfg"] + pretrained_cfg.update(extra_fields) + + # top‑level overrides + if "num_classes" in cfg: + pretrained_cfg["num_classes"] = cfg["num_classes"] + if "label_names" in cfg: + pretrained_cfg["label_names"] = cfg.pop("label_names") + if "label_descriptions" in cfg: + pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions") + + model_args = cfg.get("model_args", {}) + model_name = cfg["architecture"] + return pretrained_cfg, model_name, model_args + + def load_model_config_from_hf( model_id: str, cache_dir: Optional[Union[str, Path]] = None, ): + """Original HF‑Hub loader (unchanged download, shared parsing).""" assert has_hf_hub(True) - cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir) + cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir) + cfg = load_cfg_from_json(cfg_path) + return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"}) - hf_config = load_cfg_from_json(cached_file) - if 'pretrained_cfg' not in hf_config: - # old form, pull pretrain_cfg out of the base dict - pretrained_cfg = hf_config - hf_config = {} - hf_config['architecture'] = pretrained_cfg.pop('architecture') - hf_config['num_features'] = pretrained_cfg.pop('num_features', None) - if 'labels' in pretrained_cfg: # deprecated name for 'label_names' - pretrained_cfg['label_names'] = pretrained_cfg.pop('labels') - hf_config['pretrained_cfg'] = pretrained_cfg - # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now - pretrained_cfg = hf_config['pretrained_cfg'] - pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation - pretrained_cfg['source'] = 'hf-hub' - - # model should be created with base config num_classes if its exist - if 'num_classes' in hf_config: - pretrained_cfg['num_classes'] = hf_config['num_classes'] - - # label meta-data in base config overrides saved pretrained_cfg on load - if 'label_names' in hf_config: - pretrained_cfg['label_names'] = hf_config.pop('label_names') - if 'label_descriptions' in hf_config: - pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions') - - model_args = hf_config.get('model_args', {}) - model_name = hf_config['architecture'] - return pretrained_cfg, model_name, model_args +def load_model_config_from_path( + model_path: Union[str, Path], +): + """Load from ``/config.json`` on the local filesystem.""" + model_path = Path(model_path) + cfg_file = model_path / "config.json" + if not cfg_file.is_file(): + raise FileNotFoundError(f"Config file not found: {cfg_file}") + cfg = load_cfg_from_json(cfg_file) + extra_fields = {"file": str(model_path), "source": "local-dir"} + return _parse_model_cfg(cfg, extra_fields=extra_fields) def load_state_dict_from_hf( @@ -236,6 +254,51 @@ def load_state_dict_from_hf( return state_dict +_PREFERRED_FILES = ( + "model.safetensors", + "pytorch_model.bin", + "pytorch_model.pth", + "model.pth", + "open_clip_model.safetensors", + "open_clip_pytorch_model.safetensors", + "open_clip_pytorch_model.bin", + "open_clip_pytorch_model.pth", +) +_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin') + +def load_state_dict_from_path( + path: str, + weights_only: bool = False, +): + found_file = None + for fname in _PREFERRED_FILES: + p = path / fname + if p.exists(): + logging.info(f"Found preferred checkpoint: {p.name}") + found_file = p + break + + # fallback: first match per‑extension class + for ext in _EXT_PRIORITY: + files = sorted(path.glob(f"*{ext}")) + if files: + if len(files) > 1: + logging.warning( + f"Multiple {ext} checkpoints in {path}: {names}. " + f"Using '{files[0].name}'." + ) + found_file = files[0] + + if not found_file: + raise RuntimeError(f"No suitable checkpoints found in {path}.") + + try: + state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only) + except TypeError: + state_dict = torch.load(found_file, map_location='cpu') + return state_dict + + def load_custom_from_hf( model_id: str, filename: str,