mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Try to fix documentation build, add better docstrings to public optimizer api
This commit is contained in:
parent
ee5f6e76bb
commit
53657a31b7
@ -6,22 +6,28 @@ This page contains the API reference documentation for learning rate optimizers
|
|||||||
|
|
||||||
### Factory functions
|
### Factory functions
|
||||||
|
|
||||||
[[autodoc]] timm.optim.optim_factory.create_optimizer
|
[[autodoc]] timm.optim.create_optimizer_v2
|
||||||
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
|
[[autodoc]] timm.optim.list_optimizers
|
||||||
|
[[autodoc]] timm.optim.get_optimizer_class
|
||||||
|
|
||||||
### Optimizer Classes
|
### Optimizer Classes
|
||||||
|
|
||||||
[[autodoc]] timm.optim.adabelief.AdaBelief
|
[[autodoc]] timm.optim.adabelief.AdaBelief
|
||||||
[[autodoc]] timm.optim.adafactor.Adafactor
|
[[autodoc]] timm.optim.adafactor.Adafactor
|
||||||
|
[[autodoc]] timm.optim.adafactor_bv.AdafactorBigVision
|
||||||
[[autodoc]] timm.optim.adahessian.Adahessian
|
[[autodoc]] timm.optim.adahessian.Adahessian
|
||||||
[[autodoc]] timm.optim.adamp.AdamP
|
[[autodoc]] timm.optim.adamp.AdamP
|
||||||
[[autodoc]] timm.optim.adamw.AdamW
|
[[autodoc]] timm.optim.adamw.AdamW
|
||||||
|
[[autodoc]] timm.optim.adopt.Adopt
|
||||||
[[autodoc]] timm.optim.lamb.Lamb
|
[[autodoc]] timm.optim.lamb.Lamb
|
||||||
[[autodoc]] timm.optim.lars.Lars
|
[[autodoc]] timm.optim.lars.Lars
|
||||||
|
[[autodoc]] timm.optim.lion,Lion
|
||||||
[[autodoc]] timm.optim.lookahead.Lookahead
|
[[autodoc]] timm.optim.lookahead.Lookahead
|
||||||
[[autodoc]] timm.optim.madgrad.MADGRAD
|
[[autodoc]] timm.optim.madgrad.MADGRAD
|
||||||
[[autodoc]] timm.optim.nadam.Nadam
|
[[autodoc]] timm.optim.nadam.Nadam
|
||||||
|
[[autodoc]] timm.optim.nadamw.NadamW
|
||||||
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
|
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
|
||||||
[[autodoc]] timm.optim.radam.RAdam
|
[[autodoc]] timm.optim.radam.RAdam
|
||||||
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
|
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
|
||||||
[[autodoc]] timm.optim.sgdp.SGDP
|
[[autodoc]] timm.optim.sgdp.SGDP
|
||||||
|
[[autodoc]] timm.optim.sgdw.SGDW
|
@ -124,7 +124,7 @@ class OptimizerRegistry:
|
|||||||
|
|
||||||
def list_optimizers(
|
def list_optimizers(
|
||||||
self,
|
self,
|
||||||
filter: str = '',
|
filter: Union[str, List[str]] = '',
|
||||||
exclude_filters: Optional[List[str]] = None,
|
exclude_filters: Optional[List[str]] = None,
|
||||||
with_description: bool = False
|
with_description: bool = False
|
||||||
) -> List[Union[str, Tuple[str, str]]]:
|
) -> List[Union[str, Tuple[str, str]]]:
|
||||||
@ -141,7 +141,14 @@ class OptimizerRegistry:
|
|||||||
names = sorted(self._optimizers.keys())
|
names = sorted(self._optimizers.keys())
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
names = [n for n in names if fnmatch(n, filter)]
|
if isinstance(filter, str):
|
||||||
|
filters = [filter]
|
||||||
|
else:
|
||||||
|
filters = filter
|
||||||
|
filtered_names = set()
|
||||||
|
for f in filters:
|
||||||
|
filtered_names.update(n for n in names if fnmatch(n, f))
|
||||||
|
names = sorted(filtered_names)
|
||||||
|
|
||||||
if exclude_filters:
|
if exclude_filters:
|
||||||
for exclude_filter in exclude_filters:
|
for exclude_filter in exclude_filters:
|
||||||
@ -149,6 +156,7 @@ class OptimizerRegistry:
|
|||||||
|
|
||||||
if with_description:
|
if with_description:
|
||||||
return [(name, self._optimizers[name].description) for name in names]
|
return [(name, self._optimizers[name].description) for name in names]
|
||||||
|
|
||||||
return names
|
return names
|
||||||
|
|
||||||
def get_optimizer_info(self, name: str) -> OptimInfo:
|
def get_optimizer_info(self, name: str) -> OptimInfo:
|
||||||
@ -718,11 +726,46 @@ _register_default_optimizers()
|
|||||||
# Public API
|
# Public API
|
||||||
|
|
||||||
def list_optimizers(
|
def list_optimizers(
|
||||||
filter: str = '',
|
filter: Union[str, List[str]] = '',
|
||||||
exclude_filters: Optional[List[str]] = None,
|
exclude_filters: Optional[List[str]] = None,
|
||||||
with_description: bool = False,
|
with_description: bool = False,
|
||||||
) -> List[Union[str, Tuple[str, str]]]:
|
) -> List[Union[str, Tuple[str, str]]]:
|
||||||
"""List available optimizer names, optionally filtered.
|
"""List available optimizer names, optionally filtered.
|
||||||
|
|
||||||
|
List all registered optimizers, with optional filtering using wildcard patterns.
|
||||||
|
Optimizers can be filtered using include and exclude patterns, and can optionally
|
||||||
|
return descriptions with each optimizer name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter: Wildcard style filter string or list of filter strings
|
||||||
|
(e.g., 'adam*' for all Adam variants, or ['adam*', '*8bit'] for
|
||||||
|
Adam variants and 8-bit optimizers). Empty string means no filtering.
|
||||||
|
exclude_filters: Optional list of wildcard patterns to exclude. For example,
|
||||||
|
['*8bit', 'fused*'] would exclude 8-bit and fused implementations.
|
||||||
|
with_description: If True, returns tuples of (name, description) instead of
|
||||||
|
just names. Descriptions provide brief explanations of optimizer characteristics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If with_description is False:
|
||||||
|
List of optimizer names as strings (e.g., ['adam', 'adamw', ...])
|
||||||
|
If with_description is True:
|
||||||
|
List of tuples of (name, description) (e.g., [('adam', 'Adaptive Moment...'), ...])
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> list_optimizers()
|
||||||
|
['adam', 'adamw', 'sgd', ...]
|
||||||
|
|
||||||
|
>>> list_optimizers(['la*', 'nla*']) # List lamb & lars
|
||||||
|
['lamb', 'lambc', 'larc', 'lars', 'nlarc', 'nlars']
|
||||||
|
|
||||||
|
>>> list_optimizers('*adam*', exclude_filters=['bnb*', 'fused*']) # Exclude bnb & apex adam optimizers
|
||||||
|
['adam', 'adamax', 'adamp', 'adamw', 'nadam', 'nadamw', 'radam']
|
||||||
|
|
||||||
|
>>> list_optimizers(with_description=True) # Get descriptions
|
||||||
|
[('adabelief', 'Adapts learning rate based on gradient prediction error'),
|
||||||
|
('adadelta', 'torch.optim Adadelta, Adapts learning rates based on running windows of gradients'),
|
||||||
|
('adafactor', 'Memory-efficient implementation of Adam with factored gradients'),
|
||||||
|
...]
|
||||||
"""
|
"""
|
||||||
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||||||
|
|
||||||
@ -731,7 +774,38 @@ def get_optimizer_class(
|
|||||||
name: str,
|
name: str,
|
||||||
bind_defaults: bool = False,
|
bind_defaults: bool = False,
|
||||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||||
"""Get optimizer class by name with any defaults applied.
|
"""Get optimizer class by name with option to bind default arguments.
|
||||||
|
|
||||||
|
Retrieves the optimizer class or a partial function with default arguments bound.
|
||||||
|
This allows direct instantiation of optimizers with their default configurations
|
||||||
|
without going through the full factory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the optimizer to retrieve (e.g., 'adam', 'sgd')
|
||||||
|
bind_defaults: If True, returns a partial function with default arguments from OptimInfo bound.
|
||||||
|
If False, returns the raw optimizer class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If bind_defaults is False:
|
||||||
|
The optimizer class (e.g., torch.optim.Adam)
|
||||||
|
If bind_defaults is True:
|
||||||
|
A partial function with default arguments bound
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If optimizer name is not found in registry
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Get raw optimizer class
|
||||||
|
>>> Adam = get_optimizer_class('adam')
|
||||||
|
>>> opt = Adam(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
>>> # Get optimizer with defaults bound
|
||||||
|
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
|
||||||
|
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
|
||||||
|
|
||||||
|
>>> # Get SGD with nesterov momentum default
|
||||||
|
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
|
||||||
|
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||||
"""
|
"""
|
||||||
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
||||||
|
|
||||||
@ -748,7 +822,69 @@ def create_optimizer_v2(
|
|||||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> optim.Optimizer:
|
) -> optim.Optimizer:
|
||||||
"""Create an optimizer instance using the default registry."""
|
"""Create an optimizer instance via timm registry.
|
||||||
|
|
||||||
|
Creates and configures an optimizer with appropriate parameter groups and settings.
|
||||||
|
Supports automatic parameter group creation for weight decay and layer-wise learning
|
||||||
|
rates, as well as custom parameter grouping.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_params: A PyTorch model or an iterable of parameters/parameter groups.
|
||||||
|
If a model is provided, parameters will be automatically extracted and grouped
|
||||||
|
based on the other arguments.
|
||||||
|
opt: Name of the optimizer to create (e.g., 'adam', 'adamw', 'sgd').
|
||||||
|
Use list_optimizers() to see available options.
|
||||||
|
lr: Learning rate. If None, will use the optimizer's default.
|
||||||
|
weight_decay: Weight decay factor. Will be used to create param groups if model_or_params is a model.
|
||||||
|
momentum: Momentum factor for optimizers that support it. Only used if the
|
||||||
|
chosen optimizer accepts a momentum parameter.
|
||||||
|
foreach: Enable/disable foreach (multi-tensor) implementation if available.
|
||||||
|
If None, will use optimizer-specific defaults.
|
||||||
|
filter_bias_and_bn: If True, bias, norm layer parameters (all 1d params) will not have
|
||||||
|
weight decay applied. Only used when model_or_params is a model and
|
||||||
|
weight_decay > 0.
|
||||||
|
layer_decay: Optional layer-wise learning rate decay factor. If provided,
|
||||||
|
learning rates will be scaled by layer_decay^(max_depth - layer_depth).
|
||||||
|
Only used when model_or_params is a model.
|
||||||
|
param_group_fn: Optional function to create custom parameter groups.
|
||||||
|
If provided, other parameter grouping options will be ignored.
|
||||||
|
**kwargs: Additional optimizer-specific arguments (e.g., betas for Adam).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured optimizer instance.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Basic usage with a model
|
||||||
|
>>> optimizer = create_optimizer_v2(model, 'adamw', lr=1e-3)
|
||||||
|
|
||||||
|
>>> # SGD with momentum and weight decay
|
||||||
|
>>> optimizer = create_optimizer_v2(
|
||||||
|
... model, 'sgd', lr=0.1, momentum=0.9, weight_decay=1e-4
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Adam with layer-wise learning rate decay
|
||||||
|
>>> optimizer = create_optimizer_v2(
|
||||||
|
... model, 'adam', lr=1e-3, layer_decay=0.7
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Custom parameter groups
|
||||||
|
>>> def group_fn(model):
|
||||||
|
... return [
|
||||||
|
... {'params': model.backbone.parameters(), 'lr': 1e-4},
|
||||||
|
... {'params': model.head.parameters(), 'lr': 1e-3}
|
||||||
|
... ]
|
||||||
|
>>> optimizer = create_optimizer_v2(
|
||||||
|
... model, 'sgd', param_group_fn=group_fn
|
||||||
|
... )
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Parameter group handling precedence:
|
||||||
|
1. If param_group_fn is provided, it will be used exclusively
|
||||||
|
2. If layer_decay is provided, layer-wise groups will be created
|
||||||
|
3. If weight_decay > 0 and filter_bias_and_bn is True, weight decay groups will be created
|
||||||
|
4. Otherwise, all parameters will be in a single group
|
||||||
|
"""
|
||||||
|
|
||||||
return default_registry.create_optimizer(
|
return default_registry.create_optimizer(
|
||||||
model_or_params,
|
model_or_params,
|
||||||
opt=opt,
|
opt=opt,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user