Small optim factory tweak. default bind_defaults=True for get_optimizer_class

This commit is contained in:
Ross Wightman 2024-11-13 10:44:53 -08:00
parent ef062eefe3
commit d0161f303a

View File

@ -345,7 +345,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
OptimInfo( OptimInfo(
name='sgd', name='sgd',
opt_class=optim.SGD, opt_class=optim.SGD,
description='Stochastic Gradient Descent with Nesterov momentum (default)', description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
has_eps=False, has_eps=False,
has_momentum=True, has_momentum=True,
defaults={'nesterov': True} defaults={'nesterov': True}
@ -353,7 +353,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
OptimInfo( OptimInfo(
name='momentum', name='momentum',
opt_class=optim.SGD, opt_class=optim.SGD,
description='Stochastic Gradient Descent with classical momentum', description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
has_eps=False, has_eps=False,
has_momentum=True, has_momentum=True,
defaults={'nesterov': False} defaults={'nesterov': False}
@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
def get_optimizer_class( def get_optimizer_class(
name: str, name: str,
bind_defaults: bool = False, bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]: ) -> Union[Type[optim.Optimizer], OptimizerCallable]:
"""Get optimizer class by name with option to bind default arguments. """Get optimizer class by name with option to bind default arguments.
@ -821,17 +821,14 @@ def get_optimizer_class(
ValueError: If optimizer name is not found in registry ValueError: If optimizer name is not found in registry
Examples: Examples:
>>> # Get raw optimizer class
>>> Adam = get_optimizer_class('adam')
>>> opt = Adam(model.parameters(), lr=1e-3)
>>> # Get optimizer with defaults bound
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
>>> # Get SGD with nesterov momentum default >>> # Get SGD with nesterov momentum default
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound >>> SGD = get_optimizer_class('sgd') # nesterov=True bound
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9) >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> # Get raw optimizer class
>>> SGD = get_optimizer_class('sgd')
>>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
""" """
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults) return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)