Allow passing state_dict directly via pretrained cfg mechanism as an override

This commit is contained in:
Ross Wightman 2023-05-08 15:15:44 -07:00
parent e4e43190ce
commit 8c6fccb879
2 changed files with 19 additions and 9 deletions

View File

@ -32,6 +32,7 @@ def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
pretrained_sd = pretrained_cfg.get('state_dict', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
@ -44,8 +45,13 @@ def _resolve_pretrained_source(pretrained_cfg):
pretrained_loc = hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
# file load override is the highest priority if set
if pretrained_sd:
# direct state_dict pass through is the highest priority
load_from = 'state_dict'
pretrained_loc = pretrained_sd
assert isinstance(pretrained_loc, dict)
elif pretrained_file:
# file load override is the second-highest priority if set
load_from = 'file'
pretrained_loc = pretrained_file
else:
@ -108,7 +114,7 @@ def load_custom_pretrained(
if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if load_from == 'hf-hub': # FIXME
if load_from == 'hf-hub':
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(
@ -150,7 +156,10 @@ def load_pretrained(
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file':
if load_from == 'state_dict':
_logger.info(f'Loading pretrained weights from state dict')
state_dict = pretrained_loc # pretrained_loc is the actual state dict for this override
elif load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':

View File

@ -11,11 +11,12 @@ __all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
class PretrainedCfg:
"""
"""
# weight locations
url: Optional[Union[str, Tuple[str, str]]] = None
file: Optional[str] = None
hf_hub_id: Optional[str] = None
hf_hub_filename: Optional[str] = None
# weight source locations
url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
file: Optional[str] = None # local / shared filesystem path
state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit