A few minor fixes and bit more cleanup on the huggingface hub integration.
parent
ead80d33c5
commit
45c048ba13
|
@ -1,7 +1,7 @@
|
|||
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||
from .helpers import load_checkpoint
|
||||
from .layers import set_layer_config
|
||||
from .hub import load_config_from_hf
|
||||
from .hub import load_model_config_from_hf
|
||||
|
||||
|
||||
def split_model_name(model_name):
|
||||
|
@ -67,11 +67,9 @@ def create_model(
|
|||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if source_name == 'hf_hub':
|
||||
# Load model weights + default_cfg from Hugging Face hub.
|
||||
# For model names specified in the form `hf_hub:path/architecture_name#revision`
|
||||
hf_default_cfg = load_config_from_hf(model_name)
|
||||
hf_default_cfg['hf_hub'] = model_name # insert hf_hub id for pretrained weight load during creation
|
||||
model_name = hf_default_cfg.get('architecture')
|
||||
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
|
||||
# load model weights + default_cfg from Hugging Face hub.
|
||||
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
|
||||
|
||||
if is_model(model_name):
|
||||
|
|
|
@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
|
|||
return default_cfg
|
||||
|
||||
|
||||
def overlay_external_default_cfg(kwargs, default_cfg):
|
||||
""" Overlay 'default_cfg' in kwargs on top of default_cfg arg.
|
||||
def overlay_external_default_cfg(default_cfg, kwargs):
|
||||
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
|
||||
"""
|
||||
default_cfg = default_cfg or {}
|
||||
external_default_cfg = kwargs.pop('external_default_cfg', None)
|
||||
if external_default_cfg:
|
||||
default_cfg = deepcopy(default_cfg)
|
||||
default_cfg.pop('url', None) # url should come from external cfg
|
||||
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
|
||||
default_cfg.update(external_default_cfg)
|
||||
return default_cfg
|
||||
|
||||
|
||||
def set_default_kwargs(kwargs, names, default_cfg):
|
||||
|
@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
|
|||
input_size = default_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[:-2])
|
||||
kwargs.setdefault(n, input_size[-2:])
|
||||
elif n == 'in_chans':
|
||||
input_size = default_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
|
@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
|
|||
kwargs.pop(n, None)
|
||||
|
||||
|
||||
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
|
||||
""" Update the default_cfg and kwargs before passing to model
|
||||
|
||||
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
|
||||
could/should be replaced by an improved configuration mechanism
|
||||
|
||||
Args:
|
||||
default_cfg: input default_cfg (updated in-place)
|
||||
kwargs: keyword args passed to model build fn (updated in-place)
|
||||
kwargs_filter: keyword arg keys that must be removed before model __init__
|
||||
"""
|
||||
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||
filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
|
@ -399,29 +415,20 @@ def build_model_with_cfg(
|
|||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
feature_cfg = feature_cfg or {}
|
||||
default_cfg = deepcopy(default_cfg) if default_cfg else {}
|
||||
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
|
||||
default_cfg.setdefault('architecture', variant)
|
||||
|
||||
# Setup for featyre extraction wrapper done at end of this fn
|
||||
# Setup for feature extraction wrapper done at end of this fn
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||
if 'out_indices' in kwargs:
|
||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||
|
||||
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
|
||||
# could/should be replaced by an improved configuration mechanism
|
||||
|
||||
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
||||
default_cfg = overlay_external_default_cfg(kwargs, default_cfg)
|
||||
|
||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
||||
|
||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||
filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
# Build the model
|
||||
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
||||
model.default_cfg = deepcopy(default_cfg)
|
||||
model.default_cfg = default_cfg
|
||||
|
||||
if pruned:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
|
|
@ -23,7 +23,7 @@ except ImportError:
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cache_dir(child=''):
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
"""
|
||||
|
@ -32,8 +32,8 @@ def get_cache_dir(child=''):
|
|||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
children = () if not child else child,
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints', *children)
|
||||
child_dir = () if not child_dir else (child_dir,)
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
|
@ -80,10 +80,13 @@ def _download_from_hf(model_id: str, filename: str):
|
|||
return cached_download(url, cache_dir=get_cache_dir('hf'))
|
||||
|
||||
|
||||
def load_config_from_hf(model_id: str):
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'config.json')
|
||||
return load_cfg_from_json(cached_file)
|
||||
default_cfg = load_cfg_from_json(cached_file)
|
||||
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
model_name = default_cfg.get('architecture')
|
||||
return default_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str):
|
||||
|
|
|
@ -21,6 +21,7 @@ import math
|
|||
import logging
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
|
||||
|
||||
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
||||
default_cfg = overlay_external_default_cfg(kwargs, default_cfgs[variant])
|
||||
default_cfg = deepcopy(default_cfgs[variant])
|
||||
overlay_external_default_cfg(default_cfg, kwargs)
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-1]
|
||||
default_img_size = default_cfg['input_size'][-2:]
|
||||
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
|
|
Loading…
Reference in New Issue