mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Small optim factory tweak. default bind_defaults=True for get_optimizer_class
This commit is contained in:
parent
ef062eefe3
commit
d0161f303a
@ -345,7 +345,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
|||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='sgd',
|
name='sgd',
|
||||||
opt_class=optim.SGD,
|
opt_class=optim.SGD,
|
||||||
description='Stochastic Gradient Descent with Nesterov momentum (default)',
|
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
|
||||||
has_eps=False,
|
has_eps=False,
|
||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
defaults={'nesterov': True}
|
defaults={'nesterov': True}
|
||||||
@ -353,7 +353,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
|||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='momentum',
|
name='momentum',
|
||||||
opt_class=optim.SGD,
|
opt_class=optim.SGD,
|
||||||
description='Stochastic Gradient Descent with classical momentum',
|
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
|
||||||
has_eps=False,
|
has_eps=False,
|
||||||
has_momentum=True,
|
has_momentum=True,
|
||||||
defaults={'nesterov': False}
|
defaults={'nesterov': False}
|
||||||
@ -798,7 +798,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
|
|||||||
|
|
||||||
def get_optimizer_class(
|
def get_optimizer_class(
|
||||||
name: str,
|
name: str,
|
||||||
bind_defaults: bool = False,
|
bind_defaults: bool = True,
|
||||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||||
"""Get optimizer class by name with option to bind default arguments.
|
"""Get optimizer class by name with option to bind default arguments.
|
||||||
|
|
||||||
@ -821,17 +821,14 @@ def get_optimizer_class(
|
|||||||
ValueError: If optimizer name is not found in registry
|
ValueError: If optimizer name is not found in registry
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # Get raw optimizer class
|
|
||||||
>>> Adam = get_optimizer_class('adam')
|
|
||||||
>>> opt = Adam(model.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
>>> # Get optimizer with defaults bound
|
|
||||||
>>> AdamWithDefaults = get_optimizer_class('adam', bind_defaults=True)
|
|
||||||
>>> opt = AdamWithDefaults(model.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
>>> # Get SGD with nesterov momentum default
|
>>> # Get SGD with nesterov momentum default
|
||||||
>>> SGD = get_optimizer_class('sgd', bind_defaults=True) # nesterov=True bound
|
>>> SGD = get_optimizer_class('sgd') # nesterov=True bound
|
||||||
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
|
>>> opt = SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||||
|
|
||||||
|
>>> # Get raw optimizer class
|
||||||
|
>>> SGD = get_optimizer_class('sgd')
|
||||||
|
>>> opt = SGD(model.parameters(), lr=1e-3, momentum=0.9)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user