Remove redundant types, kwargs back in own section (lesser of many evils?)

This commit is contained in:
Ross Wightman 2023-05-01 14:21:48 -07:00
parent 14b84e8895
commit 320bf9c469

View File

@ -12,7 +12,7 @@ from ._registry import is_model, model_entrypoint
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
def parse_model_name(model_name):
def parse_model_name(model_name: str):
if model_name.startswith('hf_hub'):
# NOTE for backwards compat, deprecate hf_hub use
model_name = model_name.replace('hf_hub', 'hf-hub')
@ -26,7 +26,7 @@ def parse_model_name(model_name):
return 'timm', model_name
def safe_model_name(model_name, remove_source=True):
def safe_model_name(model_name: str, remove_source: bool = True):
# return a filename / path safe model name
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
@ -56,16 +56,19 @@ def create_model(
</Tip>
Args:
model_name (`str`): Name of model to instantiate.
pretrained (`bool`): If set to `True`, load pretrained ImageNet-1k weights.
pretrained_cfg (`Union[str, dict, PretrainedCfg]`): Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay (`dict`): Replace key-values in base pretrained_cfg with these.
checkpoint_path (`str`): Path of checkpoint to load _after_ the model is initialized.
scriptable (`bool`): Set layer config so that model is jit scriptable (not working for all models yet).
exportable (`bool`): Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit (`bool`): Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
**drop_rate (`float`): Dropout rate for training. Defaults to `0.0`.
**global_pool (`str`): Global pooling type. Defaults to `'avg'`.
model_name: Name of model to instantiate.
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
pretrained_cfg: Pass in an external pretrained_cfg for model.
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
Keyword Args:
drop_rate (float): Classifier dropout rate for training.
drop_path_rate (float): Stochastic depth drop rate for training.
global_pool (str): Classifier global pooling type.
Example: