mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More tweaks to docstrings for hub/builder
This commit is contained in:
parent
dc1bb05e8e
commit
7ab2b938e5
@ -107,6 +107,7 @@ def load_custom_pretrained(
|
|||||||
pretrained_cfg: Default pretrained model cfg
|
pretrained_cfg: Default pretrained model cfg
|
||||||
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
|
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
|
||||||
'load_pretrained' on the model will be called if it exists
|
'load_pretrained' on the model will be called if it exists
|
||||||
|
cache_dir: Override model checkpoint cache dir for this load
|
||||||
"""
|
"""
|
||||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||||
if not pretrained_cfg:
|
if not pretrained_cfg:
|
||||||
@ -148,12 +149,12 @@ def load_pretrained(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: PyTorch module
|
model: PyTorch module
|
||||||
pretrained_cfg: configuration for pretrained weights / target dataset
|
pretrained_cfg: Configuration for pretrained weights / target dataset
|
||||||
num_classes: number of classes for target model
|
num_classes: Number of classes for target model. Will adapt pretrained if different.
|
||||||
in_chans: number of input chans for target model
|
in_chans: Number of input chans for target model. Will adapt pretrained if different.
|
||||||
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
|
filter_fn: state_dict filter fn for load (takes state_dict, model as args)
|
||||||
strict: strict load of checkpoint
|
strict: Strict load of checkpoint
|
||||||
cache_dir: override path to cache dir for this load
|
cache_dir: Override model checkpoint cache dir for this load
|
||||||
"""
|
"""
|
||||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||||
if not pretrained_cfg:
|
if not pretrained_cfg:
|
||||||
@ -326,8 +327,8 @@ def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
|||||||
|
|
||||||
def resolve_pretrained_cfg(
|
def resolve_pretrained_cfg(
|
||||||
variant: str,
|
variant: str,
|
||||||
pretrained_cfg=None,
|
pretrained_cfg: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
pretrained_cfg_overlay=None,
|
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||||
) -> PretrainedCfg:
|
) -> PretrainedCfg:
|
||||||
model_with_tag = variant
|
model_with_tag = variant
|
||||||
pretrained_tag = None
|
pretrained_tag = None
|
||||||
@ -382,17 +383,18 @@ def build_model_with_cfg(
|
|||||||
* pruning config / model adaptation
|
* pruning config / model adaptation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_cls: model class
|
model_cls: Model class
|
||||||
variant: model variant name
|
variant: Model variant name
|
||||||
pretrained: load pretrained weights
|
pretrained: Load the pretrained weights
|
||||||
pretrained_cfg: model's pretrained weight/task config
|
pretrained_cfg: Model's pretrained weight/task config
|
||||||
model_cfg: model's architecture config
|
pretrained_cfg_overlay: Entries that will override those in pretrained_cfg
|
||||||
feature_cfg: feature extraction adapter config
|
model_cfg: Model's architecture config
|
||||||
pretrained_strict: load pretrained weights strictly
|
feature_cfg: Feature extraction adapter config
|
||||||
pretrained_filter_fn: filter callable for pretrained weights
|
pretrained_strict: Load pretrained weights strictly
|
||||||
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
|
pretrained_filter_fn: Filter callable for pretrained weights
|
||||||
kwargs_filter: kwargs to filter before passing to model
|
cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints
|
||||||
**kwargs: model args passed through to model __init__
|
kwargs_filter: Kwargs keys to filter (remove) before passing to model
|
||||||
|
**kwargs: Model args passed through to model __init__
|
||||||
"""
|
"""
|
||||||
pruned = kwargs.pop('pruned', False)
|
pruned = kwargs.pop('pruned', False)
|
||||||
features = False
|
features = False
|
||||||
@ -404,8 +406,6 @@ def build_model_with_cfg(
|
|||||||
pretrained_cfg=pretrained_cfg,
|
pretrained_cfg=pretrained_cfg,
|
||||||
pretrained_cfg_overlay=pretrained_cfg_overlay
|
pretrained_cfg_overlay=pretrained_cfg_overlay
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
|
|
||||||
pretrained_cfg = pretrained_cfg.to_dict()
|
pretrained_cfg = pretrained_cfg.to_dict()
|
||||||
|
|
||||||
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||||
|
@ -62,7 +62,7 @@ def create_model(
|
|||||||
pretrained_cfg: Pass in an external pretrained_cfg for model.
|
pretrained_cfg: Pass in an external pretrained_cfg for model.
|
||||||
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
|
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
|
||||||
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
|
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
|
||||||
cache_dir: Override system cache dir for Hugging Face Hub and Torch checkpoint locations
|
cache_dir: Override model cache dir for Hugging Face Hub and Torch checkpoints.
|
||||||
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
|
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
|
||||||
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
|
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
|
||||||
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
|
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
|
||||||
|
Loading…
x
Reference in New Issue
Block a user