From 3bef09f831434dfea770135ff5eec5911e0e03d0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 13 Nov 2024 10:12:31 -0800 Subject: [PATCH] Tweak a few docstrings --- timm/data/auto_augment.py | 43 ++++++++++++++++++--------------- timm/data/transforms_factory.py | 8 +++--- timm/utils/model.py | 7 ++++-- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 58cd1bbd..94438a0e 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -584,22 +584,22 @@ class AutoAugment: def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None): - """ - Create a AutoAugment transform + """ Create a AutoAugment transform Args: config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). - - The remaining sections: - 'mstd' - float std deviation of magnitude noise applied - Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 - + While the remaining sections define other arguments + * 'mstd' - float std deviation of magnitude noise applied hparams: Other hparams (kwargs) for the AutoAugmentation scheme Returns: A PyTorch compatible Transform + + Examples:: + + 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 """ config = config_str.split('-') policy_name = config[0] @@ -764,27 +764,30 @@ def rand_augment_transform( hparams: Optional[Dict] = None, transforms: Optional[Union[str, Dict, List]] = None, ): - """ - Create a RandAugment transform + """ Create a RandAugment transform Args: config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). - The remaining sections, not order sepecific determine - 'm' - integer magnitude of rand augment - 'n' - integer num layers (number of transform ops selected per image) - 'p' - float probability of applying each layer (default 0.5) - 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) - 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) - 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) - 't' - str name of transform set to use - Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 - 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2 - + The remaining sections, not order specific determine + * 'm' - integer magnitude of rand augment + * 'n' - integer num layers (number of transform ops selected per image) + * 'p' - float probability of applying each layer (default 0.5) + * 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) + * 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) + * 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + * 't' - str name of transform set to use hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme Returns: A PyTorch compatible Transform + + Examples:: + + 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + + 'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2 + """ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) num_layers = 2 # default to 2 ops per image diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index cef8291d..9be0e3bf 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -114,10 +114,10 @@ def transforms_imagenet_train( Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms - for use in a mixing dataset that passes - * all data through the first (primary) transform, called the 'clean' data - * a portion of the data through the secondary transform - * normalizes and converts the branches above with the third, final transform + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform """ train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} diff --git a/timm/utils/model.py b/timm/utils/model.py index 492313cb..8c9c7734 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -96,7 +96,7 @@ def extract_spp_stats( hook_fns, input_shape=[8, 3, 224, 224]): """Extract average square channel mean and variance of activations during - forward pass to plot Signal Propogation Plots (SPP). + forward pass to plot Signal Propogation Plots (SPP). Paper: https://arxiv.org/abs/2101.08692 @@ -111,7 +111,8 @@ def extract_spp_stats( def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): """ Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is - done in place. + done in place. + Args: root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as @@ -180,6 +181,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, def freeze(root_module, submodules=[], include_bn_running_stats=True): """ Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: root_module (nn.Module): Root module relative to which `submodules` are referenced. submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as @@ -214,6 +216,7 @@ def freeze(root_module, submodules=[], include_bn_running_stats=True): def unfreeze(root_module, submodules=[], include_bn_running_stats=True): """ Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: root_module (nn.Module): Root module relative to which `submodules` are referenced. submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided