Tweak a few docstrings

This commit is contained in:
Ross Wightman 2024-11-13 10:12:31 -08:00
parent 015ac30a91
commit 3bef09f831
3 changed files with 32 additions and 26 deletions

View File

@ -584,22 +584,22 @@ class AutoAugment:
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
"""
Create a AutoAugment transform
""" Create a AutoAugment transform
Args:
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-').
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
The remaining sections:
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
While the remaining sections define other arguments
* 'mstd' - float std deviation of magnitude noise applied
hparams: Other hparams (kwargs) for the AutoAugmentation scheme
Returns:
A PyTorch compatible Transform
Examples::
'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
"""
config = config_str.split('-')
policy_name = config[0]
@ -764,27 +764,30 @@ def rand_augment_transform(
hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None,
):
"""
Create a RandAugment transform
""" Create a RandAugment transform
Args:
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine
'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image)
'p' - float probability of applying each layer (default 0.5)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
't' - str name of transform set to use
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
The remaining sections, not order specific determine
* 'm' - integer magnitude of rand augment
* 'n' - integer num layers (number of transform ops selected per image)
* 'p' - float probability of applying each layer (default 0.5)
* 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
* 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
* 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
* 't' - str name of transform set to use
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
Returns:
A PyTorch compatible Transform
Examples::
'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
"""
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image

View File

@ -114,10 +114,10 @@ def transforms_imagenet_train(
Returns:
If separate==True, the transforms are returned as a tuple of 3 separate transforms
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
train_crop_mode = train_crop_mode or 'rrc'
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}

View File

@ -96,7 +96,7 @@ def extract_spp_stats(
hook_fns,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP).
forward pass to plot Signal Propogation Plots (SPP).
Paper: https://arxiv.org/abs/2101.08692
@ -111,7 +111,8 @@ def extract_spp_stats(
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
done in place.
done in place.
Args:
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
@ -180,6 +181,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
def freeze(root_module, submodules=[], include_bn_running_stats=True):
"""
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
@ -214,6 +216,7 @@ def freeze(root_module, submodules=[], include_bn_running_stats=True):
def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
"""
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided