More tweaks to docstrings for hub/builder

This commit is contained in:
Ross Wightman 2024-12-06 08:58:02 -08:00 committed by Ross Wightman
parent dc1bb05e8e
commit 7ab2b938e5
2 changed files with 21 additions and 21 deletions

View File

@ -107,6 +107,7 @@ def load_custom_pretrained(
pretrained_cfg: Default pretrained model cfg pretrained_cfg: Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'load_pretrained' on the model will be called if it exists 'load_pretrained' on the model will be called if it exists
cache_dir: Override model checkpoint cache dir for this load
""" """
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg: if not pretrained_cfg:
@ -148,12 +149,12 @@ def load_pretrained(
Args: Args:
model: PyTorch module model: PyTorch module
pretrained_cfg: configuration for pretrained weights / target dataset pretrained_cfg: Configuration for pretrained weights / target dataset
num_classes: number of classes for target model num_classes: Number of classes for target model. Will adapt pretrained if different.
in_chans: number of input chans for target model in_chans: Number of input chans for target model. Will adapt pretrained if different.
filter_fn: state_dict filter fn for load (takes state_dict, model as args) filter_fn: state_dict filter fn for load (takes state_dict, model as args)
strict: strict load of checkpoint strict: Strict load of checkpoint
cache_dir: override path to cache dir for this load cache_dir: Override model checkpoint cache dir for this load
""" """
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None) pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg: if not pretrained_cfg:
@ -326,8 +327,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
def resolve_pretrained_cfg( def resolve_pretrained_cfg(
variant: str, variant: str,
pretrained_cfg=None, pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
pretrained_cfg_overlay=None, pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
) -> PretrainedCfg: ) -> PretrainedCfg:
model_with_tag = variant model_with_tag = variant
pretrained_tag = None pretrained_tag = None
@ -382,17 +383,18 @@ def build_model_with_cfg(
* pruning config / model adaptation * pruning config / model adaptation
Args: Args:
model_cls: model class model_cls: Model class
variant: model variant name variant: Model variant name
pretrained: load pretrained weights pretrained: Load the pretrained weights
pretrained_cfg: model's pretrained weight/task config pretrained_cfg: Model's pretrained weight/task config
model_cfg: model's architecture config pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
feature_cfg: feature extraction adapter config model_cfg: Model's architecture config
pretrained_strict: load pretrained weights strictly feature_cfg: Feature extraction adapter config
pretrained_filter_fn: filter callable for pretrained weights pretrained_strict: Load pretrained weights strictly
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations pretrained_filter_fn: Filter callable for pretrained weights
kwargs_filter: kwargs to filter before passing to model cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
**kwargs: model args passed through to model __init__ kwargs_filter: Kwargs keys to filter (remove) before passing to model
**kwargs: Model args passed through to model __init__
""" """
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
features = False features = False
@ -404,8 +406,6 @@ def build_model_with_cfg(
pretrained_cfg=pretrained_cfg, pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay pretrained_cfg_overlay=pretrained_cfg_overlay
) )
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = pretrained_cfg.to_dict() pretrained_cfg = pretrained_cfg.to_dict()
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)

View File

@ -62,7 +62,7 @@ def create_model(
pretrained_cfg: Pass in an external pretrained_cfg for model. pretrained_cfg: Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these. pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
checkpoint_path: Path of checkpoint to load _after_ the model is initialized. checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints.
scriptable: Set layer config so that model is jit scriptable (not working for all models yet). scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only). no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).