mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Tweak a few docstrings
This commit is contained in:
parent
015ac30a91
commit
3bef09f831
@ -584,22 +584,22 @@ class AutoAugment:
|
||||
|
||||
|
||||
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
|
||||
"""
|
||||
Create a AutoAugment transform
|
||||
""" Create a AutoAugment transform
|
||||
|
||||
Args:
|
||||
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-').
|
||||
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
|
||||
The remaining sections:
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
|
||||
While the remaining sections define other arguments
|
||||
* 'mstd' - float std deviation of magnitude noise applied
|
||||
hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
|
||||
Examples::
|
||||
|
||||
'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
"""
|
||||
config = config_str.split('-')
|
||||
policy_name = config[0]
|
||||
@ -764,27 +764,30 @@ def rand_augment_transform(
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[str, Dict, List]] = None,
|
||||
):
|
||||
"""
|
||||
Create a RandAugment transform
|
||||
""" Create a RandAugment transform
|
||||
|
||||
Args:
|
||||
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
|
||||
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
|
||||
The remaining sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'p' - float probability of applying each layer (default 0.5)
|
||||
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
't' - str name of transform set to use
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
|
||||
|
||||
The remaining sections, not order specific determine
|
||||
* 'm' - integer magnitude of rand augment
|
||||
* 'n' - integer num layers (number of transform ops selected per image)
|
||||
* 'p' - float probability of applying each layer (default 0.5)
|
||||
* 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
* 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
* 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
* 't' - str name of transform set to use
|
||||
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
|
||||
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
|
||||
Examples::
|
||||
|
||||
'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
|
||||
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
|
||||
|
||||
"""
|
||||
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
|
@ -114,10 +114,10 @@ def transforms_imagenet_train(
|
||||
|
||||
Returns:
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
for use in a mixing dataset that passes
|
||||
* all data through the first (primary) transform, called the 'clean' data
|
||||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
for use in a mixing dataset that passes
|
||||
* all data through the first (primary) transform, called the 'clean' data
|
||||
* a portion of the data through the secondary transform
|
||||
* normalizes and converts the branches above with the third, final transform
|
||||
"""
|
||||
train_crop_mode = train_crop_mode or 'rrc'
|
||||
assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'}
|
||||
|
@ -96,7 +96,7 @@ def extract_spp_stats(
|
||||
hook_fns,
|
||||
input_shape=[8, 3, 224, 224]):
|
||||
"""Extract average square channel mean and variance of activations during
|
||||
forward pass to plot Signal Propogation Plots (SPP).
|
||||
forward pass to plot Signal Propogation Plots (SPP).
|
||||
|
||||
Paper: https://arxiv.org/abs/2101.08692
|
||||
|
||||
@ -111,7 +111,8 @@ def extract_spp_stats(
|
||||
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
|
||||
"""
|
||||
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
|
||||
done in place.
|
||||
done in place.
|
||||
|
||||
Args:
|
||||
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
|
||||
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
|
||||
@ -180,6 +181,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
|
||||
def freeze(root_module, submodules=[], include_bn_running_stats=True):
|
||||
"""
|
||||
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
|
||||
|
||||
Args:
|
||||
root_module (nn.Module): Root module relative to which `submodules` are referenced.
|
||||
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
|
||||
@ -214,6 +216,7 @@ def freeze(root_module, submodules=[], include_bn_running_stats=True):
|
||||
def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
|
||||
"""
|
||||
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
|
||||
|
||||
Args:
|
||||
root_module (nn.Module): Root module relative to which `submodules` are referenced.
|
||||
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
|
||||
|
Loading…
x
Reference in New Issue
Block a user