From 320bf9c4693c808c15569fe7a080dc93dd07aaa9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 1 May 2023 14:21:48 -0700 Subject: [PATCH] Remove redundant types, kwargs back in own section (lesser of many evils?) --- timm/models/_factory.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/timm/models/_factory.py b/timm/models/_factory.py index 0b12f9dd..16c02cb5 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -12,7 +12,7 @@ from ._registry import is_model, model_entrypoint __all__ = ['parse_model_name', 'safe_model_name', 'create_model'] -def parse_model_name(model_name): +def parse_model_name(model_name: str): if model_name.startswith('hf_hub'): # NOTE for backwards compat, deprecate hf_hub use model_name = model_name.replace('hf_hub', 'hf-hub') @@ -26,7 +26,7 @@ def parse_model_name(model_name): return 'timm', model_name -def safe_model_name(model_name, remove_source=True): +def safe_model_name(model_name: str, remove_source: bool = True): # return a filename / path safe model name def make_safe(name): return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') @@ -56,16 +56,19 @@ def create_model( Args: - model_name (`str`): Name of model to instantiate. - pretrained (`bool`): If set to `True`, load pretrained ImageNet-1k weights. - pretrained_cfg (`Union[str, dict, PretrainedCfg]`): Pass in an external pretrained_cfg for model. - pretrained_cfg_overlay (`dict`): Replace key-values in base pretrained_cfg with these. - checkpoint_path (`str`): Path of checkpoint to load _after_ the model is initialized. - scriptable (`bool`): Set layer config so that model is jit scriptable (not working for all models yet). - exportable (`bool`): Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). - no_jit (`bool`): Set layer config so that model doesn't utilize jit scripted layers (so far activations only). - **drop_rate (`float`): Dropout rate for training. Defaults to `0.0`. - **global_pool (`str`): Global pooling type. Defaults to `'avg'`. + model_name: Name of model to instantiate. + pretrained: If set to `True`, load pretrained ImageNet-1k weights. + pretrained_cfg: Pass in an external pretrained_cfg for model. + pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these. + checkpoint_path: Path of checkpoint to load _after_ the model is initialized. + scriptable: Set layer config so that model is jit scriptable (not working for all models yet). + exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). + no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only). + + Keyword Args: + drop_rate (float): Classifier dropout rate for training. + drop_path_rate (float): Stochastic depth drop rate for training. + global_pool (str): Classifier global pooling type. Example: