mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove redundant types, kwargs back in own section (lesser of many evils?)
This commit is contained in:
parent
14b84e8895
commit
320bf9c469
@ -12,7 +12,7 @@ from ._registry import is_model, model_entrypoint
|
|||||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||||
|
|
||||||
|
|
||||||
def parse_model_name(model_name):
|
def parse_model_name(model_name: str):
|
||||||
if model_name.startswith('hf_hub'):
|
if model_name.startswith('hf_hub'):
|
||||||
# NOTE for backwards compat, deprecate hf_hub use
|
# NOTE for backwards compat, deprecate hf_hub use
|
||||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||||
@ -26,7 +26,7 @@ def parse_model_name(model_name):
|
|||||||
return 'timm', model_name
|
return 'timm', model_name
|
||||||
|
|
||||||
|
|
||||||
def safe_model_name(model_name, remove_source=True):
|
def safe_model_name(model_name: str, remove_source: bool = True):
|
||||||
# return a filename / path safe model name
|
# return a filename / path safe model name
|
||||||
def make_safe(name):
|
def make_safe(name):
|
||||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||||
@ -56,16 +56,19 @@ def create_model(
|
|||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (`str`): Name of model to instantiate.
|
model_name: Name of model to instantiate.
|
||||||
pretrained (`bool`): If set to `True`, load pretrained ImageNet-1k weights.
|
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
|
||||||
pretrained_cfg (`Union[str, dict, PretrainedCfg]`): Pass in an external pretrained_cfg for model.
|
pretrained_cfg: Pass in an external pretrained_cfg for model.
|
||||||
pretrained_cfg_overlay (`dict`): Replace key-values in base pretrained_cfg with these.
|
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
|
||||||
checkpoint_path (`str`): Path of checkpoint to load _after_ the model is initialized.
|
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
|
||||||
scriptable (`bool`): Set layer config so that model is jit scriptable (not working for all models yet).
|
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
|
||||||
exportable (`bool`): Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
|
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
|
||||||
no_jit (`bool`): Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
|
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
|
||||||
**drop_rate (`float`): Dropout rate for training. Defaults to `0.0`.
|
|
||||||
**global_pool (`str`): Global pooling type. Defaults to `'avg'`.
|
Keyword Args:
|
||||||
|
drop_rate (float): Classifier dropout rate for training.
|
||||||
|
drop_path_rate (float): Stochastic depth drop rate for training.
|
||||||
|
global_pool (str): Classifier global pooling type.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user