mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Allow passing state_dict directly via pretrained cfg mechanism as an override
This commit is contained in:
parent
e4e43190ce
commit
8c6fccb879
@ -32,6 +32,7 @@ def _resolve_pretrained_source(pretrained_cfg):
|
||||
cfg_source = pretrained_cfg.get('source', '')
|
||||
pretrained_url = pretrained_cfg.get('url', None)
|
||||
pretrained_file = pretrained_cfg.get('file', None)
|
||||
pretrained_sd = pretrained_cfg.get('state_dict', None)
|
||||
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
||||
|
||||
# resolve where to load pretrained weights from
|
||||
@ -44,8 +45,13 @@ def _resolve_pretrained_source(pretrained_cfg):
|
||||
pretrained_loc = hf_hub_id
|
||||
else:
|
||||
# default source == timm or unspecified
|
||||
if pretrained_file:
|
||||
# file load override is the highest priority if set
|
||||
if pretrained_sd:
|
||||
# direct state_dict pass through is the highest priority
|
||||
load_from = 'state_dict'
|
||||
pretrained_loc = pretrained_sd
|
||||
assert isinstance(pretrained_loc, dict)
|
||||
elif pretrained_file:
|
||||
# file load override is the second-highest priority if set
|
||||
load_from = 'file'
|
||||
pretrained_loc = pretrained_file
|
||||
else:
|
||||
@ -108,7 +114,7 @@ def load_custom_pretrained(
|
||||
if not load_from:
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
if load_from == 'hf-hub': # FIXME
|
||||
if load_from == 'hf-hub':
|
||||
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
|
||||
elif load_from == 'url':
|
||||
pretrained_loc = download_cached_file(
|
||||
@ -150,7 +156,10 @@ def load_pretrained(
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if load_from == 'file':
|
||||
if load_from == 'state_dict':
|
||||
_logger.info(f'Loading pretrained weights from state dict')
|
||||
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
|
||||
elif load_from == 'file':
|
||||
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
||||
state_dict = load_state_dict(pretrained_loc)
|
||||
elif load_from == 'url':
|
||||
|
@ -11,11 +11,12 @@ __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
|
||||
class PretrainedCfg:
|
||||
"""
|
||||
"""
|
||||
# weight locations
|
||||
url: Optional[Union[str, Tuple[str, str]]] = None
|
||||
file: Optional[str] = None
|
||||
hf_hub_id: Optional[str] = None
|
||||
hf_hub_filename: Optional[str] = None
|
||||
# weight source locations
|
||||
url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
|
||||
file: Optional[str] = None # local / shared filesystem path
|
||||
state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
|
||||
hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
|
||||
hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
|
||||
|
||||
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
|
||||
architecture: Optional[str] = None # architecture variant can be set when not implicit
|
||||
|
Loading…
x
Reference in New Issue
Block a user