Tweak a few docstrings

This commit is contained in:
Ross Wightman 2024-11-13 10:12:31 -08:00
parent 015ac30a91
commit 3bef09f831
3 changed files with 32 additions and 26 deletions

View File

@ -584,22 +584,22 @@ class AutoAugment:
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None): def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
""" """ Create a AutoAugment transform
Create a AutoAugment transform
Args: Args:
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
dashes ('-'). dashes ('-').
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
While the remaining sections define other arguments
The remaining sections: * 'mstd' - float std deviation of magnitude noise applied
'mstd' - float std deviation of magnitude noise applied
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
hparams: Other hparams (kwargs) for the AutoAugmentation scheme hparams: Other hparams (kwargs) for the AutoAugmentation scheme
Returns: Returns:
A PyTorch compatible Transform A PyTorch compatible Transform
Examples::
'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
""" """
config = config_str.split('-') config = config_str.split('-')
policy_name = config[0] policy_name = config[0]
@ -764,27 +764,30 @@ def rand_augment_transform(
hparams: Optional[Dict] = None, hparams: Optional[Dict] = None,
transforms: Optional[Union[str, Dict, List]] = None, transforms: Optional[Union[str, Dict, List]] = None,
): ):
""" """ Create a RandAugment transform
Create a RandAugment transform
Args: Args:
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
The remaining sections, not order sepecific determine The remaining sections, not order specific determine
'm' - integer magnitude of rand augment * 'm' - integer magnitude of rand augment
'n' - integer num layers (number of transform ops selected per image) * 'n' - integer num layers (number of transform ops selected per image)
'p' - float probability of applying each layer (default 0.5) * 'p' - float probability of applying each layer (default 0.5)
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100) * 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10) * 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) * 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
't' - str name of transform set to use * 't' - str name of transform set to use
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
Returns: Returns:
A PyTorch compatible Transform A PyTorch compatible Transform
Examples::
'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
""" """
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10) magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
num_layers = 2 # default to 2 ops per image num_layers = 2 # default to 2 ops per image

View File

@ -112,6 +112,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
""" """
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
done in place. done in place.
Args: Args:
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
@ -180,6 +181,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
def freeze(root_module, submodules=[], include_bn_running_stats=True): def freeze(root_module, submodules=[], include_bn_running_stats=True):
""" """
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args: Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced. root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
@ -214,6 +216,7 @@ def freeze(root_module, submodules=[], include_bn_running_stats=True):
def unfreeze(root_module, submodules=[], include_bn_running_stats=True): def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
""" """
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args: Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced. root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided