Try to fix documentation build, add better docstrings to public optimizer api

small_384_weights
Ross Wightman 2024-11-12 16:45:01 -08:00 committed by Ross Wightman
parent ee5f6e76bb
commit 53657a31b7
2 changed files with 149 additions and 7 deletions

View File

@ -6,22 +6,28 @@ This page contains the API reference documentation for learning rate optimizers
### Factory functions
[[autodoc]] timm.optim.optim_factory.create_optimizer
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
[[autodoc]] timm.optim.create_optimizer_v2
[[autodoc]] timm.optim.list_optimizers
[[autodoc]] timm.optim.get_optimizer_class
### Optimizer Classes
[[autodoc]] timm.optim.adabelief.AdaBelief
[[autodoc]] timm.optim.adafactor.Adafactor
[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
[[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.adopt.Adopt
[[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lion,Lion
[[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam
[[autodoc]] timm.optim.nadamw.NadamW
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
[[autodoc]] timm.optim.radam.RAdam
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
[[autodoc]] timm.optim.sgdp.SGDP
[[autodoc]] timm.optim.sgdw.SGDW

View File

@ -124,7 +124,7 @@ class OptimizerRegistry:
def list_optimizers(
self,
filter: str = '',
filter: Union[str, List[str]] = '',
exclude_filters: Optional[List[str]] = None,
with_description: bool = False
) -> List[Union[str, Tuple[str, str]]]:
@ -141,7 +141,14 @@ class OptimizerRegistry:
names = sorted(self._optimizers.keys())
if filter:
names = [n for n in names if fnmatch(n, filter)]
if isinstance(filter, str):
filters = [filter]
else:
filters = filter
filtered_names = set()
for f in filters:
filtered_names.update(n for n in names if fnmatch(n, f))
names = sorted(filtered_names)
if exclude_filters:
for exclude_filter in exclude_filters:
@ -149,6 +156,7 @@ class OptimizerRegistry:
if with_description:
return [(name, self._optimizers[name].description) for name in names]
return names
def get_optimizer_info(self, name: str) -> OptimInfo:
@ -718,11 +726,46 @@ _register_default_optimizers()
# Public API
def list_optimizers(
filter: str = '',
filter: Union[str, List[str]] = '',
exclude_filters: Optional[List[str]] = None,
with_description: bool = False,
) -> List[Union[str, Tuple[str, str]]]:
"""List available optimizer names, optionally filtered.
List all registered optimizers, with optional filtering using wildcard patterns.
Optimizers can be filtered using include and exclude patterns, and can optionally
return descriptions with each optimizer name.
Args:
filter: Wildcard style filter string or list of filter strings
(e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
Adam variants and 8-bit optimizers). Empty string means no filtering.
exclude_filters: Optional list of wildcard patterns to exclude. For example,
['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
with_description: If True, returns tuples of (name, description) instead of
just names. Descriptions provide brief explanations of optimizer characteristics.
Returns:
If with_description is False:
List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
If with_description is True:
List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
Examples:
>>> list_optimizers()
['adam', 'adamw', 'sgd', ...]
>>> list_optimizers(['la*', 'nla*']) # List lamb & lars
['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
>>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
>>> list_optimizers(with_description=True) # Get descriptions
[('adabelief', 'Adapts learning rate based on gradient prediction error'),
('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
...]
"""
return default_registry.list_optimizers(filter, exclude_filters, with_description)
@ -731,7 +774,38 @@ def get_optimizer_class(
name: str,
bind_defaults: bool = False,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
"""Get optimizer class by name with any defaults applied.
"""Get optimizer class by name with option to bind default arguments.
Retrieves the optimizer class or a partial function with default arguments bound.
This allows direct instantiation of optimizers with their default configurations
without going through the full factory.
Args:
name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
If False, returns the raw optimizer class.
Returns:
If bind_defaults is False:
The optimizer class (e.g., torch.optim.Adam)
If bind_defaults is True:
A partial function with default arguments bound
Raises:
ValueError: If optimizer name is not found in registry
Examples:
>>> # Get raw optimizer class
>>> Adam = get_optimizer_class('adam')
>>> opt = Adam(model.parameters(), lr=1e-3)
>>> # Get optimizer with defaults bound
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
>>> # Get SGD with nesterov momentum default
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
"""
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
@ -748,7 +822,69 @@ def create_optimizer_v2(
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
**kwargs: Any,
) -> optim.Optimizer:
"""Create an optimizer instance using the default registry."""
"""Create an optimizer instance via timm registry.
Creates and configures an optimizer with appropriate parameter groups and settings.
Supports automatic parameter group creation for weight decay and layer-wise learning
rates, as well as custom parameter grouping.
Args:
model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
If a model is provided, parameters will be automatically extracted and grouped
based on the other arguments.
opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
Use list_optimizers() to see available options.
lr: Learning rate. If None, will use the optimizer's default.
weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
momentum: Momentum factor for optimizers that support it. Only used if the
chosen optimizer accepts a momentum parameter.
foreach: Enable/disable foreach (multi-tensor) implementation if available.
If None, will use optimizer-specific defaults.
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
weight decay applied. Only used when model_or_params is a model and
weight_decay > 0.
layer_decay: Optional layer-wise learning rate decay factor. If provided,
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
Only used when model_or_params is a model.
param_group_fn: Optional function to create custom parameter groups.
If provided, other parameter grouping options will be ignored.
**kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
Returns:
Configured optimizer instance.
Examples:
>>> # Basic usage with a model
>>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
>>> # SGD with momentum and weight decay
>>> optimizer = create_optimizer_v2(
... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
... )
>>> # Adam with layer-wise learning rate decay
>>> optimizer = create_optimizer_v2(
... model, 'adam', lr=1e-3, layer_decay=0.7
... )
>>> # Custom parameter groups
>>> def group_fn(model):
... return [
... {'params': model.backbone.parameters(), 'lr': 1e-4},
... {'params': model.head.parameters(), 'lr': 1e-3}
... ]
>>> optimizer = create_optimizer_v2(
... model, 'sgd', param_group_fn=group_fn
... )
Note:
Parameter group handling precedence:
1. If param_group_fn is provided, it will be used exclusively
2. If layer_decay is provided, layer-wise groups will be created
3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
4. Otherwise, all parameters will be in a single group
"""
return default_registry.create_optimizer(
model_or_params,
opt=opt,