A bit of an optimizer overhaul, added an improved factory, list_optimizers, class helper and add info classes with descriptions, arg configs

small_384_weights
Ross Wightman 2024-11-12 16:13:17 -08:00 committed by Ross Wightman
parent c1cf8c52b9
commit ee5f6e76bb
19 changed files with 1127 additions and 770 deletions

View File

@ -15,7 +15,7 @@ from torch.nn import Parameter
from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler from timm.scheduler import PlateauLRScheduler
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
import importlib import importlib
import os import os
@ -293,10 +293,11 @@ def _build_params_dict_single(weight, bias, **kwargs):
return [dict(params=bias, **kwargs)] return [dict(params=bias, **kwargs)]
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum']) @pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts def test_optim_factory(optimizer):
@pytest.mark.parametrize('optimizer', ['sgd']) get_optimizer_class(optimizer)
def test_sgd(optimizer):
# test basic cases that don't need specific tuning via factory test
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
) )
@ -316,6 +317,12 @@ def test_sgd(optimizer):
lambda weight, bias: create_optimizer_v2( lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer) _build_params_dict_single(weight, bias, lr=1e-2), optimizer)
) )
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
@pytest.mark.parametrize('optimizer', ['sgd'])
def test_sgd(optimizer):
# _test_basic_cases( # _test_basic_cases(
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3), # lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)] # [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
@ -358,21 +365,6 @@ def test_sgd(optimizer):
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw']) @pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw'])
def test_adam(optimizer): def test_adam(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
) )
@ -381,21 +373,6 @@ def test_adam(optimizer):
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw']) @pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
def test_adopt(optimizer): def test_adopt(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
# FIXME rosenbrock is not passing for ADOPT # FIXME rosenbrock is not passing for ADOPT
# _test_rosenbrock( # _test_rosenbrock(
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) # lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
@ -405,25 +382,6 @@ def test_adopt(optimizer):
@pytest.mark.parametrize('optimizer', ['adabelief']) @pytest.mark.parametrize('optimizer', ['adabelief'])
def test_adabelief(optimizer): def test_adabelief(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
) )
@ -435,21 +393,6 @@ def test_adabelief(optimizer):
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief']) @pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
def test_rectified(optimizer): def test_rectified(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
) )
@ -458,25 +401,6 @@ def test_rectified(optimizer):
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad']) @pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
def test_adaother(optimizer): def test_adaother(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
) )
@ -488,24 +412,6 @@ def test_adaother(optimizer):
@pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv']) @pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv'])
def test_adafactor(optimizer): def test_adafactor(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
)
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
) )
@ -517,25 +423,6 @@ def test_adafactor(optimizer):
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc']) @pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
def test_lamb(optimizer): def test_lamb(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
) )
@ -544,25 +431,6 @@ def test_lamb(optimizer):
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc']) @pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
def test_lars(optimizer): def test_lars(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
) )
@ -571,25 +439,6 @@ def test_lars(optimizer):
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw']) @pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
def test_madgrad(optimizer): def test_madgrad(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
) )
@ -598,25 +447,6 @@ def test_madgrad(optimizer):
@pytest.mark.parametrize('optimizer', ['novograd']) @pytest.mark.parametrize('optimizer', ['novograd'])
def test_novograd(optimizer): def test_novograd(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
) )
@ -625,25 +455,6 @@ def test_novograd(optimizer):
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf']) @pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
def test_rmsprop(optimizer): def test_rmsprop(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2) lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
) )
@ -652,25 +463,6 @@ def test_rmsprop(optimizer):
@pytest.mark.parametrize('optimizer', ['adamp']) @pytest.mark.parametrize('optimizer', ['adamp'])
def test_adamp(optimizer): def test_adamp(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
) )
@ -679,25 +471,6 @@ def test_adamp(optimizer):
@pytest.mark.parametrize('optimizer', ['sgdp']) @pytest.mark.parametrize('optimizer', ['sgdp'])
def test_sgdp(optimizer): def test_sgdp(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
) )
@ -706,25 +479,6 @@ def test_sgdp(optimizer):
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum']) @pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
def test_lookahead_sgd(optimizer): def test_lookahead_sgd(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
) )
@ -732,25 +486,6 @@ def test_lookahead_sgd(optimizer):
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam']) @pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
def test_lookahead_adam(optimizer): def test_lookahead_adam(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2) lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
) )
@ -758,25 +493,6 @@ def test_lookahead_adam(optimizer):
@pytest.mark.parametrize('optimizer', ['lookahead_radam']) @pytest.mark.parametrize('optimizer', ['lookahead_radam'])
def test_lookahead_radam(optimizer): def test_lookahead_radam(optimizer):
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
)
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
) )

View File

@ -17,4 +17,5 @@ from .radam import RAdam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry

View File

@ -0,0 +1,798 @@
""" Optimizer Factory w/ custom Weight Decay & Layer Decay support
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
from fnmatch import fnmatch
import importlib
import torch
import torch.nn as nn
import torch.optim as optim
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
from .adamp import AdamP
from .adan import Adan
from .adopt import Adopt
from .lamb import Lamb
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW
_logger = logging.getLogger(__name__)
# Type variables
T = TypeVar('T')
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
def _import_class(class_string: str) -> Type:
"""Dynamically import a class from a string."""
try:
module_name, class_name = class_string.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(f"Could not import {class_string}: {e}")
class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
@dataclass(frozen=True)
class OptimInfo:
"""Immutable configuration for an optimizer.
Attributes:
name: Unique identifier for the optimizer
opt_class: The optimizer class
description: Brief description of the optimizer's characteristics and behavior
has_eps: Whether the optimizer accepts epsilon parameter
has_momentum: Whether the optimizer accepts momentum parameter
has_betas: Whether the optimizer accepts a tuple of beta parameters
num_betas: number of betas in tuple (valid IFF has_betas = True)
defaults: Optional default parameters for the optimizer
"""
name: str
opt_class: Union[str, Type[optim.Optimizer]]
description: str = ''
has_eps: bool = True
has_momentum: bool = False
has_betas: bool = False
num_betas: int = 2
defaults: Optional[Dict[str, Any]] = None
class OptimizerRegistry:
"""Registry managing optimizer configurations and instantiation.
This class provides a central registry for optimizer configurations and handles
their instantiation with appropriate parameter groups and settings.
"""
def __init__(self) -> None:
self._optimizers: Dict[str, OptimInfo] = {}
self._foreach_defaults: Set[str] = {'lion'}
def register(self, info: OptimInfo) -> None:
"""Register an optimizer configuration.
Args:
info: The OptimInfo configuration containing name, type and description
"""
name = info.name.lower()
if name in self._optimizers:
_logger.warning(f'Optimizer {name} already registered, overwriting')
self._optimizers[name] = info
def register_alias(self, alias: str, target: str) -> None:
"""Register an alias for an existing optimizer.
Args:
alias: The alias name
target: The target optimizer name
Raises:
KeyError: If target optimizer doesn't exist
"""
target = target.lower()
if target not in self._optimizers:
raise KeyError(f'Cannot create alias for non-existent optimizer {target}')
self._optimizers[alias.lower()] = self._optimizers[target]
def register_foreach_default(self, name: str) -> None:
"""Register an optimizer as defaulting to foreach=True."""
self._foreach_defaults.add(name.lower())
def list_optimizers(
self,
filter: str = '',
exclude_filters: Optional[List[str]] = None,
with_description: bool = False
) -> List[Union[str, Tuple[str, str]]]:
"""List available optimizer names, optionally filtered.
Args:
filter: Wildcard style filter string (e.g., 'adam*')
exclude_filters: Optional list of wildcard patterns to exclude
with_description: If True, return tuples of (name, description)
Returns:
List of either optimizer names or (name, description) tuples
"""
names = sorted(self._optimizers.keys())
if filter:
names = [n for n in names if fnmatch(n, filter)]
if exclude_filters:
for exclude_filter in exclude_filters:
names = [n for n in names if not fnmatch(n, exclude_filter)]
if with_description:
return [(name, self._optimizers[name].description) for name in names]
return names
def get_optimizer_info(self, name: str) -> OptimInfo:
"""Get the OptimInfo for an optimizer.
Args:
name: Name of the optimizer
Returns:
OptimInfo configuration
Raises:
ValueError: If optimizer is not found
"""
name = name.lower()
if name not in self._optimizers:
raise ValueError(f'Optimizer {name} not found in registry')
return self._optimizers[name]
def get_optimizer_class(
self,
name_or_info: Union[str, OptimInfo],
bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
"""Get the optimizer class with any default arguments applied.
This allows direct instantiation of optimizers with their default configs
without going through the full factory.
Args:
name_or_info: Name of the optimizer
bind_defaults: Bind default arguments to optimizer class via `partial` before returning
Returns:
Optimizer class or partial with defaults applied
Raises:
ValueError: If optimizer not found
"""
if isinstance(name_or_info, str):
opt_info = self.get_optimizer_info(name_or_info)
else:
assert isinstance(name_or_info, OptimInfo)
opt_info = name_or_info
if isinstance(opt_info.opt_class, str):
# Special handling for APEX and BNB optimizers
if opt_info.opt_class.startswith('apex.'):
assert torch.cuda.is_available(), 'CUDA required for APEX optimizers'
try:
opt_class = _import_class(opt_info.opt_class)
except ImportError as e:
raise ImportError('APEX optimizers require apex to be installed') from e
elif opt_info.opt_class.startswith('bitsandbytes.'):
assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers'
try:
opt_class = _import_class(opt_info.opt_class)
except ImportError as e:
raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e
else:
opt_class = _import_class(opt_info.opt_class)
else:
opt_class = opt_info.opt_class
# Return class or partial with defaults
if bind_defaults and opt_info.defaults:
opt_class = partial(opt_class, **opt_info.defaults)
return opt_class
def create_optimizer(
self,
model_or_params: Union[nn.Module, Params],
opt: str,
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
**kwargs: Any,
) -> optim.Optimizer:
"""Create an optimizer instance.
Args:
model_or_params: Model or parameters to optimize
opt: Name of optimizer to create
lr: Learning rate
weight_decay: Weight decay factor
momentum: Momentum factor for applicable optimizers
foreach: Enable/disable foreach operation
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
layer_decay: Layer-wise learning rate decay
param_group_fn: Optional custom parameter grouping function
**kwargs: Additional optimizer-specific arguments
Returns:
Configured optimizer instance
Raises:
ValueError: If optimizer not found or configuration invalid
"""
# Get parameters to optimize
if isinstance(model_or_params, nn.Module):
# Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied
no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())()
if param_group_fn:
# run custom fn to generate param groups from nn.Module
parameters = param_group_fn(model_or_params)
elif layer_decay is not None:
parameters = param_groups_layer_decay(
model_or_params,
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
weight_decay_exclude_1d=weight_decay_exclude_1d,
)
weight_decay = 0.
elif weight_decay and weight_decay_exclude_1d:
parameters = param_groups_weight_decay(
model_or_params,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay,
)
weight_decay = 0.
else:
parameters = model_or_params.parameters()
else:
# pass parameters / parameter groups through to optimizer
parameters = model_or_params
# Parse optimizer name
opt_split = opt.lower().split('_')
opt_name = opt_split[-1]
use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False
opt_info = self.get_optimizer_info(opt_name)
# Build optimizer arguments
opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs}
# Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None.
if lr is not None:
opt_args['lr'] = lr
# Apply optimizer-specific settings
if opt_info.defaults:
for k, v in opt_info.defaults.items():
opt_args.setdefault(k, v)
# timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat.
if opt_info.has_momentum:
opt_args.setdefault('momentum', momentum)
# Remove commonly used kwargs that aren't always supported
if not opt_info.has_eps:
opt_args.pop('eps', None)
if not opt_info.has_betas:
opt_args.pop('betas', None)
if foreach is not None:
# Explicitly activate or deactivate multi-tensor foreach impl.
# Not all optimizers support this, and those that do usually default to using
# multi-tensor impl if foreach is left as default 'None' and can be enabled.
opt_args.setdefault('foreach', foreach)
# Create optimizer
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
optimizer = opt_class(parameters, **opt_args)
# Apply Lookahead if requested
if use_lookahead:
optimizer = Lookahead(optimizer)
return optimizer
def _register_sgd_variants(registry: OptimizerRegistry) -> None:
"""Register SGD-based optimizers"""
sgd_optimizers = [
OptimInfo(
name='sgd',
opt_class=optim.SGD,
description='Stochastic Gradient Descent with Nesterov momentum (default)',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='momentum',
opt_class=optim.SGD,
description='Stochastic Gradient Descent with classical momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': False}
),
OptimInfo(
name='sgdp',
opt_class=SGDP,
description='SGD with built-in projection to unit norm sphere',
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='sgdw',
opt_class=SGDW,
description='SGD with decoupled weight decay and Nesterov momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
),
]
for opt in sgd_optimizers:
registry.register(opt)
def _register_adam_variants(registry: OptimizerRegistry) -> None:
"""Register Adam-based optimizers"""
adam_optimizers = [
OptimInfo(
name='adam',
opt_class=optim.Adam,
description='torch.optim Adam (Adaptive Moment Estimation)',
has_betas=True
),
OptimInfo(
name='adamw',
opt_class=optim.AdamW,
description='torch.optim Adam with decoupled weight decay regularization',
has_betas=True
),
OptimInfo(
name='adamp',
opt_class=AdamP,
description='Adam with built-in projection to unit norm sphere',
has_betas=True,
defaults={'wd_ratio': 0.01, 'nesterov': True}
),
OptimInfo(
name='nadam',
opt_class=Nadam,
description='Adam with Nesterov momentum',
has_betas=True
),
OptimInfo(
name='nadamw',
opt_class=NAdamW,
description='Adam with Nesterov momentum and decoupled weight decay',
has_betas=True
),
OptimInfo(
name='radam',
opt_class=RAdam,
description='Rectified Adam with variance adaptation',
has_betas=True
),
OptimInfo(
name='adamax',
opt_class=optim.Adamax,
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
has_betas=True
),
OptimInfo(
name='adafactor',
opt_class=Adafactor,
description='Memory-efficient implementation of Adam with factored gradients',
),
OptimInfo(
name='adafactorbv',
opt_class=AdafactorBigVision,
description='Big Vision variant of Adafactor with factored gradients, half precision momentum.',
),
OptimInfo(
name='adopt',
opt_class=Adopt,
description='Memory-efficient implementation of Adam with factored gradients',
),
OptimInfo(
name='adoptw',
opt_class=Adopt,
description='Memory-efficient implementation of Adam with factored gradients',
defaults={'decoupled': True}
),
]
for opt in adam_optimizers:
registry.register(opt)
def _register_lamb_lars(registry: OptimizerRegistry) -> None:
"""Register LAMB and LARS variants"""
lamb_lars_optimizers = [
OptimInfo(
name='lamb',
opt_class=Lamb,
description='Layer-wise Adaptive Moments for batch optimization',
has_betas=True
),
OptimInfo(
name='lambc',
opt_class=Lamb,
description='LAMB with trust ratio clipping for stability',
has_betas=True,
defaults={'trust_clip': True}
),
OptimInfo(
name='lars',
opt_class=Lars,
description='Layer-wise Adaptive Rate Scaling',
has_momentum=True
),
OptimInfo(
name='larc',
opt_class=Lars,
description='LARS with trust ratio clipping for stability',
has_momentum=True,
defaults={'trust_clip': True}
),
OptimInfo(
name='nlars',
opt_class=Lars,
description='LARS with Nesterov momentum',
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='nlarc',
opt_class=Lars,
description='LARS with Nesterov momentum & trust ratio clipping',
has_momentum=True,
defaults={'nesterov': True, 'trust_clip': True}
),
]
for opt in lamb_lars_optimizers:
registry.register(opt)
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
"""Register miscellaneous optimizers"""
other_optimizers = [
OptimInfo(
name='adabelief',
opt_class=AdaBelief,
description='Adapts learning rate based on gradient prediction error',
has_betas=True,
defaults={'rectify': False}
),
OptimInfo(
name='radabelief',
opt_class=AdaBelief,
description='Rectified AdaBelief with variance adaptation',
has_betas=True,
defaults={'rectify': True}
),
OptimInfo(
name='adadelta',
opt_class=optim.Adadelta,
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
),
OptimInfo(
name='adagrad',
opt_class=optim.Adagrad,
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8}
),
OptimInfo(
name='adan',
opt_class=Adan,
description='Adaptive Nesterov Momentum Algorithm',
defaults={'no_prox': False},
has_betas=True,
num_betas=3
),
OptimInfo(
name='adanw',
opt_class=Adan,
description='Adaptive Nesterov Momentum with decoupled weight decay',
defaults={'no_prox': True},
has_betas=True,
num_betas=3
),
OptimInfo(
name='lion',
opt_class=Lion,
description='Evolved Sign Momentum optimizer for improved convergence',
has_eps=False,
has_betas=True
),
OptimInfo(
name='madgrad',
opt_class=MADGRAD,
description='Momentum-based Adaptive gradient method',
has_momentum=True
),
OptimInfo(
name='madgradw',
opt_class=MADGRAD,
description='MADGRAD with decoupled weight decay',
has_momentum=True,
defaults={'decoupled_decay': True}
),
OptimInfo(
name='novograd',
opt_class=NvNovoGrad,
description='Normalized Adam with L2 norm gradient normalization',
has_betas=True
),
OptimInfo(
name='rmsprop',
opt_class=optim.RMSprop,
description='torch.optim RMSprop, Root Mean Square Propagation',
has_momentum=True,
defaults={'alpha': 0.9}
),
OptimInfo(
name='rmsproptf',
opt_class=RMSpropTF,
description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation',
has_momentum=True,
defaults={'alpha': 0.9}
),
]
for opt in other_optimizers:
registry.register(opt)
registry.register_foreach_default('lion')
def _register_apex_optimizers(registry: OptimizerRegistry) -> None:
"""Register APEX optimizers (lazy import)"""
apex_optimizers = [
OptimInfo(
name='fusedsgd',
opt_class='apex.optimizers.FusedSGD',
description='NVIDIA APEX fused SGD implementation for faster training',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='fusedadam',
opt_class='apex.optimizers.FusedAdam',
description='NVIDIA APEX fused Adam implementation',
has_betas=True,
defaults={'adam_w_mode': False}
),
OptimInfo(
name='fusedadamw',
opt_class='apex.optimizers.FusedAdam',
description='NVIDIA APEX fused AdamW implementation',
has_betas=True,
defaults={'adam_w_mode': True}
),
OptimInfo(
name='fusedlamb',
opt_class='apex.optimizers.FusedLAMB',
description='NVIDIA APEX fused LAMB implementation',
has_betas=True
),
OptimInfo(
name='fusednovograd',
opt_class='apex.optimizers.FusedNovoGrad',
description='NVIDIA APEX fused NovoGrad implementation',
has_betas=True,
defaults={'betas': (0.95, 0.98)}
),
]
for opt in apex_optimizers:
registry.register(opt)
def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
"""Register bitsandbytes optimizers (lazy import)"""
bnb_optimizers = [
OptimInfo(
name='bnbsgd',
opt_class='bitsandbytes.optim.SGD',
description='bitsandbytes SGD',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='bnbsgd8bit',
opt_class='bitsandbytes.optim.SGD8bit',
description='bitsandbytes 8-bit SGD with dynamic quantization',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='bnbadam',
opt_class='bitsandbytes.optim.Adam',
description='bitsandbytes Adam',
has_betas=True
),
OptimInfo(
name='bnbadam8bit',
opt_class='bitsandbytes.optim.Adam',
description='bitsandbytes 8-bit Adam with dynamic quantization',
has_betas=True
),
OptimInfo(
name='bnbadamw',
opt_class='bitsandbytes.optim.AdamW',
description='bitsandbytes AdamW',
has_betas=True
),
OptimInfo(
name='bnbadamw8bit',
opt_class='bitsandbytes.optim.AdamW',
description='bitsandbytes 8-bit AdamW with dynamic quantization',
has_betas=True
),
OptimInfo(
'bnblion',
'bitsandbytes.optim.Lion',
description='bitsandbytes Lion',
has_betas=True
),
OptimInfo(
'bnblion8bit',
'bitsandbytes.optim.Lion8bit',
description='bitsandbytes 8-bit Lion with dynamic quantization',
has_betas=True
),
OptimInfo(
'bnbademamix',
'bitsandbytes.optim.AdEMAMix',
description='bitsandbytes AdEMAMix',
has_betas=True,
num_betas=3,
),
OptimInfo(
'bnbademamix8bit',
'bitsandbytes.optim.AdEMAMix8bit',
description='bitsandbytes 8-bit AdEMAMix with dynamic quantization',
has_betas=True,
num_betas=3,
),
]
for opt in bnb_optimizers:
registry.register(opt)
default_registry = OptimizerRegistry()
def _register_default_optimizers() -> None:
"""Register all default optimizers to the global registry."""
# Register all optimizer groups
_register_sgd_variants(default_registry)
_register_adam_variants(default_registry)
_register_lamb_lars(default_registry)
_register_other_optimizers(default_registry)
_register_apex_optimizers(default_registry)
_register_bnb_optimizers(default_registry)
# Register aliases
default_registry.register_alias('nesterov', 'sgd')
default_registry.register_alias('nesterovw', 'sgdw')
# Initialize default registry
_register_default_optimizers()
# Public API
def list_optimizers(
filter: str = '',
exclude_filters: Optional[List[str]] = None,
with_description: bool = False,
) -> List[Union[str, Tuple[str, str]]]:
"""List available optimizer names, optionally filtered.
"""
return default_registry.list_optimizers(filter, exclude_filters, with_description)
def get_optimizer_class(
name: str,
bind_defaults: bool = False,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
"""Get optimizer class by name with any defaults applied.
"""
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
def create_optimizer_v2(
model_or_params: Union[nn.Module, Params],
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
**kwargs: Any,
) -> optim.Optimizer:
"""Create an optimizer instance using the default registry."""
return default_registry.create_optimizer(
model_or_params,
opt=opt,
lr=lr,
weight_decay=weight_decay,
momentum=momentum,
foreach=foreach,
weight_decay_exclude_1d=filter_bias_and_bn,
layer_decay=layer_decay,
param_group_fn=param_group_fn,
**kwargs
)
def optimizer_kwargs(cfg):
""" cfg/argparse to kwargs helper
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
"""
kwargs = dict(
opt=cfg.opt,
lr=cfg.lr,
weight_decay=cfg.weight_decay,
momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'layer_decay', None) is not None:
kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
return kwargs
def create_optimizer(args, model, filter_bias_and_bn=True):
""" Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code.
"""
return create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
filter_bias_and_bn=filter_bias_and_bn,
)

View File

@ -0,0 +1,129 @@
import logging
from itertools import islice
from typing import Collection, Optional, Tuple
from torch import nn as nn
from timm.models import group_parameters
_logger = logging.getLogger(__name__)
def param_groups_weight_decay(
model: nn.Module,
weight_decay: float = 1e-5,
no_weight_decay_list: Collection[str] = (),
):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def _group(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def _layer_map(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
elif isinstance(hp, (tuple, list)):
return any([n.startswith(hpi) for hpi in hp])
else:
return n.startswith(hp)
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
names_trunk = []
names_head = []
for n, _ in model.named_parameters():
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
# group non-head layers
num_trunk_layers = len(names_trunk)
if num_groups is not None:
layers_per_group = -(num_trunk_layers // -num_groups)
names_trunk = list(_group(names_trunk, layers_per_group))
num_trunk_groups = len(names_trunk)
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map
def param_groups_layer_decay(
model: nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: Collection[str] = (),
weight_decay_exclude_1d: bool = True,
layer_decay: float = .75,
end_layer_decay: Optional[float] = None,
verbose: bool = False,
):
"""
Parameter groups for layer-wise lr decay & weight decay
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
no_weight_decay_list = set(no_weight_decay_list)
param_group_names = {} # NOTE for debugging
param_groups = {}
if hasattr(model, 'group_matcher'):
# FIXME interface needs more work
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else:
# fallback
layer_map = _layer_map(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"param_names": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)
if verbose:
import json
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())

View File

@ -40,9 +40,18 @@ class AdaBelief(Optimizer):
""" """
def __init__( def __init__(
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, self,
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-16,
weight_decay=0,
amsgrad=False,
decoupled_decay=True,
fixed_decay=False,
rectify=True,
degenerated_to_sgd=True,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -58,9 +67,17 @@ class AdaBelief(Optimizer):
param['buffer'] = [[None, None, None] for _ in range(10)] param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict( defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, lr=lr,
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, betas=betas,
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
degenerated_to_sgd=degenerated_to_sgd,
decoupled_decay=decoupled_decay,
rectify=rectify,
fixed_decay=fixed_decay,
buffer=[[None, None, None] for _ in range(10)]
)
super(AdaBelief, self).__init__(params, defaults) super(AdaBelief, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -16,6 +16,7 @@ import math
class Adafactor(torch.optim.Optimizer): class Adafactor(torch.optim.Optimizer):
"""Implements Adafactor algorithm. """Implements Adafactor algorithm.
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
(see https://arxiv.org/abs/1804.04235) (see https://arxiv.org/abs/1804.04235)

View File

@ -23,8 +23,18 @@ class Adahessian(torch.optim.Optimizer):
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
""" """
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, def __init__(
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): self,
params,
lr=0.1,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
hessian_power=1.0,
update_each=1,
n_samples=1,
avg_conv_kernel=False,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
@ -44,7 +54,13 @@ class Adahessian(torch.optim.Optimizer):
self.seed = 2147483647 self.seed = 2147483647
self.generator = torch.Generator().manual_seed(self.seed) self.generator = torch.Generator().manual_seed(self.seed)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
hessian_power=hessian_power,
)
super(Adahessian, self).__init__(params, defaults) super(Adahessian, self).__init__(params, defaults)
for p in self.get_params(): for p in self.get_params():

View File

@ -41,11 +41,26 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
class AdamP(Optimizer): class AdamP(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
delta=0.1,
wd_ratio=0.1,
nesterov=False,
):
defaults = dict( defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, lr=lr,
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) betas=betas,
eps=eps,
weight_decay=weight_decay,
delta=delta,
wd_ratio=wd_ratio,
nesterov=nesterov,
)
super(AdamP, self).__init__(params, defaults) super(AdamP, self).__init__(params, defaults)
@torch.no_grad() @torch.no_grad()

View File

@ -36,8 +36,16 @@ class AdamW(Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ https://openreview.net/forum?id=ryQu7f-RZ
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=1e-2, amsgrad=False): self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
):
# NOTE: deprecated in favour of builtin torch.optim.AdamW
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -46,8 +54,13 @@ class AdamW(Optimizer):
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(
weight_decay=weight_decay, amsgrad=amsgrad) lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
)
super(AdamW, self).__init__(params, defaults) super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -137,7 +137,7 @@ def lion(
""" """
if foreach is None: if foreach is None:
# Placeholder for more complex foreach logic to be added when value is not set # Placeholder for more complex foreach logic to be added when value is not set
foreach = False foreach = True
if foreach and torch.jit.is_scripting(): if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers') raise RuntimeError('torch.jit.script not supported with foreach optimizers')

View File

@ -71,7 +71,12 @@ class MADGRAD(torch.optim.Optimizer):
raise ValueError(f"Eps must be non-negative") raise ValueError(f"Eps must be non-negative")
defaults = dict( defaults = dict(
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) lr=lr,
eps=eps,
momentum=momentum,
weight_decay=weight_decay,
decoupled_decay=decoupled_decay,
)
super().__init__(params, defaults) super().__init__(params, defaults)
@property @property

View File

@ -27,8 +27,15 @@ class Nadam(Optimizer):
NOTE: Has potential issues but does work well on some problems. NOTE: Has potential issues but does work well on some problems.
""" """
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, schedule_decay=4e-3): self,
params,
lr=2e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
schedule_decay=4e-3,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
defaults = dict( defaults = dict(

View File

@ -29,8 +29,16 @@ class NvNovoGrad(Optimizer):
(default: False) (default: False)
""" """
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, def __init__(
weight_decay=0, grad_averaging=False, amsgrad=False): self,
params,
lr=1e-3,
betas=(0.95, 0.98),
eps=1e-8,
weight_decay=0,
grad_averaging=False,
amsgrad=False,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -39,10 +47,14 @@ class NvNovoGrad(Optimizer):
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(
weight_decay=weight_decay, lr=lr,
grad_averaging=grad_averaging, betas=betas,
amsgrad=amsgrad) eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad,
)
super(NvNovoGrad, self).__init__(params, defaults) super(NvNovoGrad, self).__init__(params, defaults)

View File

@ -1,431 +0,0 @@
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
from itertools import islice
from typing import Optional, Callable, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from timm.models import group_parameters
from . import AdafactorBigVision
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adahessian import Adahessian
from .adamp import AdamP
from .adan import Adan
from .adopt import Adopt
from .lamb import Lamb
from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead
from .madgrad import MADGRAD
from .nadam import Nadam
from .nadamw import NAdamW
from .nvnovograd import NvNovoGrad
from .radam import RAdam
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from .sgdw import SGDW
_logger = logging.getLogger(__name__)
# optimizers to default to multi-tensor
_DEFAULT_FOREACH = {
'lion',
}
def param_groups_weight_decay(
model: nn.Module,
weight_decay=1e-5,
no_weight_decay_list=()
):
no_weight_decay_list = set(no_weight_decay_list)
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def _group(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def _layer_map(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
elif isinstance(hp, (tuple, list)):
return any([n.startswith(hpi) for hpi in hp])
else:
return n.startswith(hp)
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
names_trunk = []
names_head = []
for n, _ in model.named_parameters():
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
# group non-head layers
num_trunk_layers = len(names_trunk)
if num_groups is not None:
layers_per_group = -(num_trunk_layers // -num_groups)
names_trunk = list(_group(names_trunk, layers_per_group))
num_trunk_groups = len(names_trunk)
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map
def param_groups_layer_decay(
model: nn.Module,
weight_decay: float = 0.05,
no_weight_decay_list: Tuple[str] = (),
layer_decay: float = .75,
end_layer_decay: Optional[float] = None,
verbose: bool = False,
):
"""
Parameter groups for layer-wise lr decay & weight decay
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
no_weight_decay_list = set(no_weight_decay_list)
param_group_names = {} # NOTE for debugging
param_groups = {}
if hasattr(model, 'group_matcher'):
# FIXME interface needs more work
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else:
# fallback
layer_map = _layer_map(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if param.ndim == 1 or name in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = layer_map.get(name, layer_max)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_groups:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"param_names": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["param_names"].append(name)
param_groups[group_name]["params"].append(param)
if verbose:
import json
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def optimizer_kwargs(cfg):
""" cfg/argparse to kwargs helper
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
"""
kwargs = dict(
opt=cfg.opt,
lr=cfg.lr,
weight_decay=cfg.weight_decay,
momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'layer_decay', None) is not None:
kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
return kwargs
def create_optimizer(args, model, filter_bias_and_bn=True):
""" Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code.
"""
return create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
filter_bias_and_bn=filter_bias_and_bn,
)
def create_optimizer_v2(
model_or_params,
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None,
**kwargs,
):
""" Create an optimizer.
TODO currently the model is passed in and all parameters are selected for optimization.
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
* expose the parameters interface and leave it up to caller
Args:
model_or_params (nn.Module): model containing parameters to optimize
opt: name of optimizer to create
lr: initial learning rate
weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through
Returns:
Optimizer
"""
if isinstance(model_or_params, nn.Module):
# a model was passed in, extract parameters and add weight decays to appropriate layers
no_weight_decay = {}
if hasattr(model_or_params, 'no_weight_decay'):
no_weight_decay = model_or_params.no_weight_decay()
if param_group_fn:
parameters = param_group_fn(model_or_params)
elif layer_decay is not None:
parameters = param_groups_layer_decay(
model_or_params,
weight_decay=weight_decay,
layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay,
)
weight_decay = 0.
elif weight_decay and filter_bias_and_bn:
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
weight_decay = 0.
else:
parameters = model_or_params.parameters()
else:
# iterable of parameters or param groups passed in
parameters = model_or_params
opt_lower = opt.lower()
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower.startswith('fused'):
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
if opt_lower.startswith('bnb'):
try:
import bitsandbytes as bnb
has_bnb = True
except ImportError:
has_bnb = False
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None:
opt_args.setdefault('lr', lr)
if foreach is None:
if opt in _DEFAULT_FOREACH:
opt_args.setdefault('foreach', True)
else:
opt_args['foreach'] = foreach
# basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'sgdp':
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'sgdw' or opt_lower == 'nesterovw':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
opt_args.pop('eps', None)
optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentumw':
opt_args.pop('eps', None)
optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args)
# adaptive
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'adamp':
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
elif opt_lower == 'nadam':
try:
# NOTE PyTorch >= 1.10 should have native NAdam
optimizer = optim.Nadam(parameters, **opt_args)
except AttributeError:
optimizer = Nadam(parameters, **opt_args)
elif opt_lower == 'nadamw':
optimizer = NAdamW(parameters, **opt_args)
elif opt_lower == 'radam':
optimizer = RAdam(parameters, **opt_args)
elif opt_lower == 'adamax':
optimizer = optim.Adamax(parameters, **opt_args)
elif opt_lower == 'adabelief':
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
elif opt_lower == 'radabelief':
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adagrad':
opt_args.setdefault('eps', 1e-8)
optimizer = optim.Adagrad(parameters, **opt_args)
elif opt_lower == 'adafactor':
optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adanp':
optimizer = Adan(parameters, no_prox=False, **opt_args)
elif opt_lower == 'adanw':
optimizer = Adan(parameters, no_prox=True, **opt_args)
elif opt_lower == 'lamb':
optimizer = Lamb(parameters, **opt_args)
elif opt_lower == 'lambc':
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
elif opt_lower == 'larc':
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
elif opt_lower == 'lars':
optimizer = Lars(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'nlarc':
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
elif opt_lower == 'nlars':
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'madgrad':
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'madgradw':
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'rmsprop':
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'lion':
opt_args.pop('eps', None)
optimizer = Lion(parameters, **opt_args)
elif opt_lower == 'adafactorbv':
optimizer = AdafactorBigVision(parameters, **opt_args)
elif opt_lower == 'adopt':
optimizer = Adopt(parameters, **opt_args)
elif opt_lower == 'adoptw':
optimizer = Adopt(parameters, decoupled=True, **opt_args)
# second order
elif opt_lower == 'adahessian':
optimizer = Adahessian(parameters, **opt_args)
# NVIDIA fused optimizers, require APEX to be installed
elif opt_lower == 'fusedsgd':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'fusedmomentum':
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
elif opt_lower == 'fusedadamw':
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
elif opt_lower == 'fusedlamb':
optimizer = FusedLAMB(parameters, **opt_args)
elif opt_lower == 'fusednovograd':
opt_args.setdefault('betas', (0.95, 0.98))
optimizer = FusedNovoGrad(parameters, **opt_args)
# bitsandbytes optimizers, require bitsandbytes to be installed
elif opt_lower == 'bnbsgd':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbsgd8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbmomentum':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbmomentum8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbadam':
optimizer = bnb.optim.Adam(parameters, **opt_args)
elif opt_lower == 'bnbadam8bit':
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
elif opt_lower == 'bnbadamw':
optimizer = bnb.optim.AdamW(parameters, **opt_args)
elif opt_lower == 'bnbadamw8bit':
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
elif opt_lower == 'bnblamb':
optimizer = bnb.optim.LAMB(parameters, **opt_args)
elif opt_lower == 'bnblamb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblars':
optimizer = bnb.optim.LARS(parameters, **opt_args)
elif opt_lower == 'bnblarsb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblion':
optimizer = bnb.optim.Lion(parameters, **opt_args)
elif opt_lower == 'bnblion8bit':
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
else:
assert False and "Invalid optimizer"
raise ValueError
if len(opt_split) > 1:
if opt_split[0] == 'lookahead':
optimizer = Lookahead(optimizer)
return optimizer

View File

@ -9,10 +9,21 @@ from torch.optim.optimizer import Optimizer
class RAdam(Optimizer): class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
):
defaults = dict( defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, lr=lr,
buffer=[[None, None, None] for _ in range(10)]) betas=betas,
eps=eps,
weight_decay=weight_decay,
buffer=[[None, None, None] for _ in range(10)]
)
super(RAdam, self).__init__(params, defaults) super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -45,8 +45,18 @@ class RMSpropTF(Optimizer):
""" """
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, def __init__(
decoupled_decay=False, lr_in_momentum=True): self,
params,
lr=1e-2,
alpha=0.9,
eps=1e-10,
weight_decay=0,
momentum=0.,
centered=False,
decoupled_decay=False,
lr_in_momentum=True,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -59,8 +69,15 @@ class RMSpropTF(Optimizer):
raise ValueError("Invalid alpha value: {}".format(alpha)) raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict( defaults = dict(
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, lr=lr,
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) momentum=momentum,
alpha=alpha,
eps=eps,
centered=centered,
weight_decay=weight_decay,
decoupled_decay=decoupled_decay,
lr_in_momentum=lr_in_momentum,
)
super(RMSpropTF, self).__init__(params, defaults) super(RMSpropTF, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -17,11 +17,28 @@ from .adamp import projection
class SGDP(Optimizer): class SGDP(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): self,
params,
lr=required,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
eps=1e-8,
delta=0.1,
wd_ratio=0.1
):
defaults = dict( defaults = dict(
lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, lr=lr,
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
eps=eps,
delta=delta,
wd_ratio=wd_ratio,
)
super(SGDP, self).__init__(params, defaults) super(SGDP, self).__init__(params, defaults)
@torch.no_grad() @torch.no_grad()

View File

@ -35,10 +35,15 @@ class SGDW(Optimizer):
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = dict(
lr=lr, momentum=momentum, dampening=dampening, lr=lr,
weight_decay=weight_decay, nesterov=nesterov, momentum=momentum,
maximize=maximize, foreach=foreach, dampening=dampening,
differentiable=differentiable) weight_decay=weight_decay,
nesterov=nesterov,
maximize=maximize,
foreach=foreach,
differentiable=differentiable,
)
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults) super().__init__(params, defaults)

View File

@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
""" """
import argparse import argparse
import copy
import importlib import importlib
import json import json
import logging import logging
@ -554,6 +555,13 @@ def main():
**optimizer_kwargs(cfg=args), **optimizer_kwargs(cfg=args),
**args.opt_kwargs, **args.opt_kwargs,
) )
if utils.is_primary(args):
defaults = copy.deepcopy(optimizer.defaults)
defaults['weight_decay'] = args.weight_decay # this isn't stored in optimizer.defaults
defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
logging.info(
f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
)
# setup automatic mixed-precision (AMP) loss scaling and op casting # setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing