Small optim factory tweak. default bind_defaults=True for get_optimizer_class

This commit is contained in:
Ross Wightman 2024-11-13 10:44:53 -08:00
parent ef062eefe3
commit d0161f303a

View File

@ -345,7 +345,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
OptimInfo(
name='sgd',
opt_class=optim.SGD,
description='Stochastic Gradient Descent with Nesterov momentum (default)',
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
@ -353,7 +353,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
OptimInfo(
name='momentum',
opt_class=optim.SGD,
description='Stochastic Gradient Descent with classical momentum',
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': False}
@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
def get_optimizer_class(
name: str,
bind_defaults: bool = False,
bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
"""Get optimizer class by name with option to bind default arguments.
@ -821,17 +821,14 @@ def get_optimizer_class(
ValueError: If optimizer name is not found in registry
Examples:
>>> # Get raw optimizer class
>>> Adam = get_optimizer_class('adam')
>>> opt = Adam(model.parameters(), lr=1e-3)
>>> # Get optimizer with defaults bound
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
>>> # Get SGD with nesterov momentum default
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
>>> SGD = get_optimizer_class('sgd') # nesterov=True bound
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> # Get raw optimizer class
>>> SGD = get_optimizer_class('sgd')
>>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
"""
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)