A few minor fixes and bit more cleanup on the huggingface hub integration.
parent
ead80d33c5
commit
45c048ba13
|
@ -1,7 +1,7 @@
|
||||||
from .registry import is_model, is_model_in_modules, model_entrypoint
|
from .registry import is_model, is_model_in_modules, model_entrypoint
|
||||||
from .helpers import load_checkpoint
|
from .helpers import load_checkpoint
|
||||||
from .layers import set_layer_config
|
from .layers import set_layer_config
|
||||||
from .hub import load_config_from_hf
|
from .hub import load_model_config_from_hf
|
||||||
|
|
||||||
|
|
||||||
def split_model_name(model_name):
|
def split_model_name(model_name):
|
||||||
|
@ -67,11 +67,9 @@ def create_model(
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
if source_name == 'hf_hub':
|
if source_name == 'hf_hub':
|
||||||
# Load model weights + default_cfg from Hugging Face hub.
|
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
|
||||||
# For model names specified in the form `hf_hub:path/architecture_name#revision`
|
# load model weights + default_cfg from Hugging Face hub.
|
||||||
hf_default_cfg = load_config_from_hf(model_name)
|
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
|
||||||
hf_default_cfg['hf_hub'] = model_name # insert hf_hub id for pretrained weight load during creation
|
|
||||||
model_name = hf_default_cfg.get('architecture')
|
|
||||||
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
|
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
|
||||||
|
|
||||||
if is_model(model_name):
|
if is_model(model_name):
|
||||||
|
|
|
@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
|
||||||
return default_cfg
|
return default_cfg
|
||||||
|
|
||||||
|
|
||||||
def overlay_external_default_cfg(kwargs, default_cfg):
|
def overlay_external_default_cfg(default_cfg, kwargs):
|
||||||
""" Overlay 'default_cfg' in kwargs on top of default_cfg arg.
|
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfg or {}
|
|
||||||
external_default_cfg = kwargs.pop('external_default_cfg', None)
|
external_default_cfg = kwargs.pop('external_default_cfg', None)
|
||||||
if external_default_cfg:
|
if external_default_cfg:
|
||||||
default_cfg = deepcopy(default_cfg)
|
|
||||||
default_cfg.pop('url', None) # url should come from external cfg
|
default_cfg.pop('url', None) # url should come from external cfg
|
||||||
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
|
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
|
||||||
default_cfg.update(external_default_cfg)
|
default_cfg.update(external_default_cfg)
|
||||||
return default_cfg
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_kwargs(kwargs, names, default_cfg):
|
def set_default_kwargs(kwargs, names, default_cfg):
|
||||||
|
@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
|
||||||
input_size = default_cfg.get('input_size', None)
|
input_size = default_cfg.get('input_size', None)
|
||||||
if input_size is not None:
|
if input_size is not None:
|
||||||
assert len(input_size) == 3
|
assert len(input_size) == 3
|
||||||
kwargs.setdefault(n, input_size[:-2])
|
kwargs.setdefault(n, input_size[-2:])
|
||||||
elif n == 'in_chans':
|
elif n == 'in_chans':
|
||||||
input_size = default_cfg.get('input_size', None)
|
input_size = default_cfg.get('input_size', None)
|
||||||
if input_size is not None:
|
if input_size is not None:
|
||||||
|
@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
|
||||||
kwargs.pop(n, None)
|
kwargs.pop(n, None)
|
||||||
|
|
||||||
|
|
||||||
|
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
|
||||||
|
""" Update the default_cfg and kwargs before passing to model
|
||||||
|
|
||||||
|
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
|
||||||
|
could/should be replaced by an improved configuration mechanism
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_cfg: input default_cfg (updated in-place)
|
||||||
|
kwargs: keyword args passed to model build fn (updated in-place)
|
||||||
|
kwargs_filter: keyword arg keys that must be removed before model __init__
|
||||||
|
"""
|
||||||
|
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
||||||
|
overlay_external_default_cfg(default_cfg, kwargs)
|
||||||
|
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||||
|
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
||||||
|
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||||
|
filter_kwargs(kwargs, names=kwargs_filter)
|
||||||
|
|
||||||
|
|
||||||
def build_model_with_cfg(
|
def build_model_with_cfg(
|
||||||
model_cls: Callable,
|
model_cls: Callable,
|
||||||
variant: str,
|
variant: str,
|
||||||
|
@ -399,29 +415,20 @@ def build_model_with_cfg(
|
||||||
pruned = kwargs.pop('pruned', False)
|
pruned = kwargs.pop('pruned', False)
|
||||||
features = False
|
features = False
|
||||||
feature_cfg = feature_cfg or {}
|
feature_cfg = feature_cfg or {}
|
||||||
|
default_cfg = deepcopy(default_cfg) if default_cfg else {}
|
||||||
|
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
|
||||||
|
default_cfg.setdefault('architecture', variant)
|
||||||
|
|
||||||
# Setup for featyre extraction wrapper done at end of this fn
|
# Setup for feature extraction wrapper done at end of this fn
|
||||||
if kwargs.pop('features_only', False):
|
if kwargs.pop('features_only', False):
|
||||||
features = True
|
features = True
|
||||||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||||
if 'out_indices' in kwargs:
|
if 'out_indices' in kwargs:
|
||||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||||
|
|
||||||
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
|
|
||||||
# could/should be replaced by an improved configuration mechanism
|
|
||||||
|
|
||||||
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
|
||||||
default_cfg = overlay_external_default_cfg(kwargs, default_cfg)
|
|
||||||
|
|
||||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
|
||||||
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
|
||||||
|
|
||||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
|
||||||
filter_kwargs(kwargs, names=kwargs_filter)
|
|
||||||
|
|
||||||
# Build the model
|
# Build the model
|
||||||
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
||||||
model.default_cfg = deepcopy(default_cfg)
|
model.default_cfg = default_cfg
|
||||||
|
|
||||||
if pruned:
|
if pruned:
|
||||||
model = adapt_model_from_file(model, variant)
|
model = adapt_model_from_file(model, variant)
|
||||||
|
|
|
@ -23,7 +23,7 @@ except ImportError:
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_cache_dir(child=''):
|
def get_cache_dir(child_dir=''):
|
||||||
"""
|
"""
|
||||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||||
"""
|
"""
|
||||||
|
@ -32,8 +32,8 @@ def get_cache_dir(child=''):
|
||||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||||
|
|
||||||
hub_dir = get_dir()
|
hub_dir = get_dir()
|
||||||
children = () if not child else child,
|
child_dir = () if not child_dir else (child_dir,)
|
||||||
model_dir = os.path.join(hub_dir, 'checkpoints', *children)
|
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
@ -80,10 +80,13 @@ def _download_from_hf(model_id: str, filename: str):
|
||||||
return cached_download(url, cache_dir=get_cache_dir('hf'))
|
return cached_download(url, cache_dir=get_cache_dir('hf'))
|
||||||
|
|
||||||
|
|
||||||
def load_config_from_hf(model_id: str):
|
def load_model_config_from_hf(model_id: str):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
cached_file = _download_from_hf(model_id, 'config.json')
|
cached_file = _download_from_hf(model_id, 'config.json')
|
||||||
return load_cfg_from_json(cached_file)
|
default_cfg = load_cfg_from_json(cached_file)
|
||||||
|
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||||
|
model_name = default_cfg.get('architecture')
|
||||||
|
return default_cfg, model_name
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_hf(model_id: str):
|
def load_state_dict_from_hf(model_id: str):
|
||||||
|
|
|
@ -21,6 +21,7 @@ import math
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
||||||
default_cfg = overlay_external_default_cfg(kwargs, default_cfgs[variant])
|
default_cfg = deepcopy(default_cfgs[variant])
|
||||||
|
overlay_external_default_cfg(default_cfg, kwargs)
|
||||||
default_num_classes = default_cfg['num_classes']
|
default_num_classes = default_cfg['num_classes']
|
||||||
default_img_size = default_cfg['input_size'][-1]
|
default_img_size = default_cfg['input_size'][-2:]
|
||||||
|
|
||||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||||
img_size = kwargs.pop('img_size', default_img_size)
|
img_size = kwargs.pop('img_size', default_img_size)
|
||||||
|
|
Loading…
Reference in New Issue