Try to fix documentation build, add better docstrings to public optimizer api
parent
ee5f6e76bb
commit
53657a31b7
|
@ -6,22 +6,28 @@ This page contains the API reference documentation for learning rate optimizers
|
|||
|
||||
### Factory functions
|
||||
|
||||
[[autodoc]] timm.optim.optim_factory.create_optimizer
|
||||
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
|
||||
[[autodoc]] timm.optim.create_optimizer_v2
|
||||
[[autodoc]] timm.optim.list_optimizers
|
||||
[[autodoc]] timm.optim.get_optimizer_class
|
||||
|
||||
### Optimizer Classes
|
||||
|
||||
[[autodoc]] timm.optim.adabelief.AdaBelief
|
||||
[[autodoc]] timm.optim.adafactor.Adafactor
|
||||
[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
|
||||
[[autodoc]] timm.optim.adahessian.Adahessian
|
||||
[[autodoc]] timm.optim.adamp.AdamP
|
||||
[[autodoc]] timm.optim.adamw.AdamW
|
||||
[[autodoc]] timm.optim.adopt.Adopt
|
||||
[[autodoc]] timm.optim.lamb.Lamb
|
||||
[[autodoc]] timm.optim.lars.Lars
|
||||
[[autodoc]] timm.optim.lion,Lion
|
||||
[[autodoc]] timm.optim.lookahead.Lookahead
|
||||
[[autodoc]] timm.optim.madgrad.MADGRAD
|
||||
[[autodoc]] timm.optim.nadam.Nadam
|
||||
[[autodoc]] timm.optim.nadamw.NadamW
|
||||
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
|
||||
[[autodoc]] timm.optim.radam.RAdam
|
||||
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
|
||||
[[autodoc]] timm.optim.sgdp.SGDP
|
||||
[[autodoc]] timm.optim.sgdw.SGDW
|
|
@ -124,7 +124,7 @@ class OptimizerRegistry:
|
|||
|
||||
def list_optimizers(
|
||||
self,
|
||||
filter: str = '',
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: Optional[List[str]] = None,
|
||||
with_description: bool = False
|
||||
) -> List[Union[str, Tuple[str, str]]]:
|
||||
|
@ -141,7 +141,14 @@ class OptimizerRegistry:
|
|||
names = sorted(self._optimizers.keys())
|
||||
|
||||
if filter:
|
||||
names = [n for n in names if fnmatch(n, filter)]
|
||||
if isinstance(filter, str):
|
||||
filters = [filter]
|
||||
else:
|
||||
filters = filter
|
||||
filtered_names = set()
|
||||
for f in filters:
|
||||
filtered_names.update(n for n in names if fnmatch(n, f))
|
||||
names = sorted(filtered_names)
|
||||
|
||||
if exclude_filters:
|
||||
for exclude_filter in exclude_filters:
|
||||
|
@ -149,6 +156,7 @@ class OptimizerRegistry:
|
|||
|
||||
if with_description:
|
||||
return [(name, self._optimizers[name].description) for name in names]
|
||||
|
||||
return names
|
||||
|
||||
def get_optimizer_info(self, name: str) -> OptimInfo:
|
||||
|
@ -718,11 +726,46 @@ _register_default_optimizers()
|
|||
# Public API
|
||||
|
||||
def list_optimizers(
|
||||
filter: str = '',
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: Optional[List[str]] = None,
|
||||
with_description: bool = False,
|
||||
) -> List[Union[str, Tuple[str, str]]]:
|
||||
"""List available optimizer names, optionally filtered.
|
||||
|
||||
List all registered optimizers, with optional filtering using wildcard patterns.
|
||||
Optimizers can be filtered using include and exclude patterns, and can optionally
|
||||
return descriptions with each optimizer name.
|
||||
|
||||
Args:
|
||||
filter: Wildcard style filter string or list of filter strings
|
||||
(e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
|
||||
Adam variants and 8-bit optimizers). Empty string means no filtering.
|
||||
exclude_filters: Optional list of wildcard patterns to exclude. For example,
|
||||
['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
|
||||
with_description: If True, returns tuples of (name, description) instead of
|
||||
just names. Descriptions provide brief explanations of optimizer characteristics.
|
||||
|
||||
Returns:
|
||||
If with_description is False:
|
||||
List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
|
||||
If with_description is True:
|
||||
List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
|
||||
|
||||
Examples:
|
||||
>>> list_optimizers()
|
||||
['adam', 'adamw', 'sgd', ...]
|
||||
|
||||
>>> list_optimizers(['la*', 'nla*']) # List lamb & lars
|
||||
['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
|
||||
|
||||
>>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
|
||||
['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
|
||||
|
||||
>>> list_optimizers(with_description=True) # Get descriptions
|
||||
[('adabelief', 'Adapts learning rate based on gradient prediction error'),
|
||||
('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
|
||||
('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
|
||||
...]
|
||||
"""
|
||||
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||||
|
||||
|
@ -731,7 +774,38 @@ def get_optimizer_class(
|
|||
name: str,
|
||||
bind_defaults: bool = False,
|
||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||
"""Get optimizer class by name with any defaults applied.
|
||||
"""Get optimizer class by name with option to bind default arguments.
|
||||
|
||||
Retrieves the optimizer class or a partial function with default arguments bound.
|
||||
This allows direct instantiation of optimizers with their default configurations
|
||||
without going through the full factory.
|
||||
|
||||
Args:
|
||||
name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
|
||||
bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
|
||||
If False, returns the raw optimizer class.
|
||||
|
||||
Returns:
|
||||
If bind_defaults is False:
|
||||
The optimizer class (e.g., torch.optim.Adam)
|
||||
If bind_defaults is True:
|
||||
A partial function with default arguments bound
|
||||
|
||||
Raises:
|
||||
ValueError: If optimizer name is not found in registry
|
||||
|
||||
Examples:
|
||||
>>> # Get raw optimizer class
|
||||
>>> Adam = get_optimizer_class('adam')
|
||||
>>> opt = Adam(model.parameters(), lr=1e-3)
|
||||
|
||||
>>> # Get optimizer with defaults bound
|
||||
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
|
||||
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
|
||||
|
||||
>>> # Get SGD with nesterov momentum default
|
||||
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
|
||||
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
"""
|
||||
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
||||
|
||||
|
@ -748,7 +822,69 @@ def create_optimizer_v2(
|
|||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||
**kwargs: Any,
|
||||
) -> optim.Optimizer:
|
||||
"""Create an optimizer instance using the default registry."""
|
||||
"""Create an optimizer instance via timm registry.
|
||||
|
||||
Creates and configures an optimizer with appropriate parameter groups and settings.
|
||||
Supports automatic parameter group creation for weight decay and layer-wise learning
|
||||
rates, as well as custom parameter grouping.
|
||||
|
||||
Args:
|
||||
model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
|
||||
If a model is provided, parameters will be automatically extracted and grouped
|
||||
based on the other arguments.
|
||||
opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
|
||||
Use list_optimizers() to see available options.
|
||||
lr: Learning rate. If None, will use the optimizer's default.
|
||||
weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
|
||||
momentum: Momentum factor for optimizers that support it. Only used if the
|
||||
chosen optimizer accepts a momentum parameter.
|
||||
foreach: Enable/disable foreach (multi-tensor) implementation if available.
|
||||
If None, will use optimizer-specific defaults.
|
||||
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
|
||||
weight decay applied. Only used when model_or_params is a model and
|
||||
weight_decay > 0.
|
||||
layer_decay: Optional layer-wise learning rate decay factor. If provided,
|
||||
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
|
||||
Only used when model_or_params is a model.
|
||||
param_group_fn: Optional function to create custom parameter groups.
|
||||
If provided, other parameter grouping options will be ignored.
|
||||
**kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
|
||||
|
||||
Returns:
|
||||
Configured optimizer instance.
|
||||
|
||||
Examples:
|
||||
>>> # Basic usage with a model
|
||||
>>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
|
||||
|
||||
>>> # SGD with momentum and weight decay
|
||||
>>> optimizer = create_optimizer_v2(
|
||||
... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
|
||||
... )
|
||||
|
||||
>>> # Adam with layer-wise learning rate decay
|
||||
>>> optimizer = create_optimizer_v2(
|
||||
... model, 'adam', lr=1e-3, layer_decay=0.7
|
||||
... )
|
||||
|
||||
>>> # Custom parameter groups
|
||||
>>> def group_fn(model):
|
||||
... return [
|
||||
... {'params': model.backbone.parameters(), 'lr': 1e-4},
|
||||
... {'params': model.head.parameters(), 'lr': 1e-3}
|
||||
... ]
|
||||
>>> optimizer = create_optimizer_v2(
|
||||
... model, 'sgd', param_group_fn=group_fn
|
||||
... )
|
||||
|
||||
Note:
|
||||
Parameter group handling precedence:
|
||||
1. If param_group_fn is provided, it will be used exclusively
|
||||
2. If layer_decay is provided, layer-wise groups will be created
|
||||
3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
|
||||
4. Otherwise, all parameters will be in a single group
|
||||
"""
|
||||
|
||||
return default_registry.create_optimizer(
|
||||
model_or_params,
|
||||
opt=opt,
|
||||
|
|
Loading…
Reference in New Issue