diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 9e06e1a5..37b5fdc0 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -345,7 +345,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None: OptimInfo( name='sgd', opt_class=optim.SGD, - description='Stochastic Gradient Descent with Nesterov momentum (default)', + description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum', has_eps=False, has_momentum=True, defaults={'nesterov': True} @@ -353,7 +353,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None: OptimInfo( name='momentum', opt_class=optim.SGD, - description='Stochastic Gradient Descent with classical momentum', + description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum', has_eps=False, has_momentum=True, defaults={'nesterov': False} @@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo: def get_optimizer_class( name: str, - bind_defaults: bool = False, + bind_defaults: bool = True, ) -> Union[Type[optim.Optimizer], OptimizerCallable]: """Get optimizer class by name with option to bind default arguments. @@ -821,17 +821,14 @@ def get_optimizer_class( ValueError: If optimizer name is not found in registry Examples: - >>> # Get raw optimizer class - >>> Adam = get_optimizer_class('adam') - >>> opt = Adam(model.parameters(), lr=1e-3) - - >>> # Get optimizer with defaults bound - >>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True) - >>> opt = AdamWithDefaults(model.parameters(), lr=1e-3) - >>> # Get SGD with nesterov momentum default - >>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound + >>> SGD = get_optimizer_class('sgd') # nesterov=True bound >>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9) + + >>> # Get raw optimizer class + >>> SGD = get_optimizer_class('sgd') + >>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9) + """ return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)