A bit of an optimizer overhaul, added an improved factory, list_optimizers, class helper and add info classes with descriptions, arg configs
parent
c1cf8c52b9
commit
ee5f6e76bb
|
@ -15,7 +15,7 @@ from torch.nn import Parameter
|
|||
from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
|
||||
from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
@ -293,10 +293,11 @@ def _build_params_dict_single(weight, bias, **kwargs):
|
|||
return [dict(params=bias, **kwargs)]
|
||||
|
||||
|
||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
|
||||
@pytest.mark.parametrize('optimizer', ['sgd'])
|
||||
def test_sgd(optimizer):
|
||||
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
|
||||
def test_optim_factory(optimizer):
|
||||
get_optimizer_class(optimizer)
|
||||
|
||||
# test basic cases that don't need specific tuning via factory test
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -316,6 +317,12 @@ def test_sgd(optimizer):
|
|||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||
)
|
||||
|
||||
|
||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
|
||||
@pytest.mark.parametrize('optimizer', ['sgd'])
|
||||
def test_sgd(optimizer):
|
||||
# _test_basic_cases(
|
||||
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
|
||||
|
@ -358,21 +365,6 @@ def test_sgd(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw'])
|
||||
def test_adam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
|
@ -381,21 +373,6 @@ def test_adam(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
|
||||
def test_adopt(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
# FIXME rosenbrock is not passing for ADOPT
|
||||
# _test_rosenbrock(
|
||||
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
|
@ -405,25 +382,6 @@ def test_adopt(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adabelief'])
|
||||
def test_adabelief(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
|
@ -435,21 +393,6 @@ def test_adabelief(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
|
||||
def test_rectified(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -458,25 +401,6 @@ def test_rectified(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
|
||||
def test_adaother(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
|
@ -488,24 +412,6 @@ def test_adaother(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv'])
|
||||
def test_adafactor(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||
)
|
||||
|
@ -517,25 +423,6 @@ def test_adafactor(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
|
||||
def test_lamb(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -544,25 +431,6 @@ def test_lamb(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
||||
def test_lars(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -571,25 +439,6 @@ def test_lars(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||
def test_madgrad(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
|
@ -598,25 +447,6 @@ def test_madgrad(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['novograd'])
|
||||
def test_novograd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -625,25 +455,6 @@ def test_novograd(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
|
||||
def test_rmsprop(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||
)
|
||||
|
@ -652,25 +463,6 @@ def test_rmsprop(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['adamp'])
|
||||
def test_adamp(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
|
@ -679,25 +471,6 @@ def test_adamp(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['sgdp'])
|
||||
def test_sgdp(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -706,25 +479,6 @@ def test_sgdp(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
|
||||
def test_lookahead_sgd(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
|
@ -732,25 +486,6 @@ def test_lookahead_sgd(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
|
||||
def test_lookahead_adam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||
)
|
||||
|
@ -758,25 +493,6 @@ def test_lookahead_adam(optimizer):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
|
||||
def test_lookahead_radam(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
||||
)
|
||||
|
|
|
@ -17,4 +17,5 @@ from .radam import RAdam
|
|||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
|
||||
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
||||
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
|
||||
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
|
|
@ -0,0 +1,798 @@
|
|||
""" Optimizer Factory w/ custom Weight Decay & Layer Decay support
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
|
||||
from fnmatch import fnmatch
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
|
||||
from .adabelief import AdaBelief
|
||||
from .adafactor import Adafactor
|
||||
from .adafactor_bv import AdafactorBigVision
|
||||
from .adamp import AdamP
|
||||
from .adan import Adan
|
||||
from .adopt import Adopt
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .lion import Lion
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .nadamw import NAdamW
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .sgdw import SGDW
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Type variables
|
||||
T = TypeVar('T')
|
||||
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
|
||||
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
|
||||
|
||||
|
||||
def _import_class(class_string: str) -> Type:
|
||||
"""Dynamically import a class from a string."""
|
||||
try:
|
||||
module_name, class_name = class_string.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ImportError(f"Could not import {class_string}: {e}")
|
||||
|
||||
|
||||
class OptimizerCallable(Protocol):
|
||||
"""Protocol for optimizer constructor signatures."""
|
||||
|
||||
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptimInfo:
|
||||
"""Immutable configuration for an optimizer.
|
||||
|
||||
Attributes:
|
||||
name: Unique identifier for the optimizer
|
||||
opt_class: The optimizer class
|
||||
description: Brief description of the optimizer's characteristics and behavior
|
||||
has_eps: Whether the optimizer accepts epsilon parameter
|
||||
has_momentum: Whether the optimizer accepts momentum parameter
|
||||
has_betas: Whether the optimizer accepts a tuple of beta parameters
|
||||
num_betas: number of betas in tuple (valid IFF has_betas = True)
|
||||
defaults: Optional default parameters for the optimizer
|
||||
"""
|
||||
name: str
|
||||
opt_class: Union[str, Type[optim.Optimizer]]
|
||||
description: str = ''
|
||||
has_eps: bool = True
|
||||
has_momentum: bool = False
|
||||
has_betas: bool = False
|
||||
num_betas: int = 2
|
||||
defaults: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class OptimizerRegistry:
|
||||
"""Registry managing optimizer configurations and instantiation.
|
||||
|
||||
This class provides a central registry for optimizer configurations and handles
|
||||
their instantiation with appropriate parameter groups and settings.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._optimizers: Dict[str, OptimInfo] = {}
|
||||
self._foreach_defaults: Set[str] = {'lion'}
|
||||
|
||||
def register(self, info: OptimInfo) -> None:
|
||||
"""Register an optimizer configuration.
|
||||
|
||||
Args:
|
||||
info: The OptimInfo configuration containing name, type and description
|
||||
"""
|
||||
name = info.name.lower()
|
||||
if name in self._optimizers:
|
||||
_logger.warning(f'Optimizer {name} already registered, overwriting')
|
||||
self._optimizers[name] = info
|
||||
|
||||
def register_alias(self, alias: str, target: str) -> None:
|
||||
"""Register an alias for an existing optimizer.
|
||||
|
||||
Args:
|
||||
alias: The alias name
|
||||
target: The target optimizer name
|
||||
|
||||
Raises:
|
||||
KeyError: If target optimizer doesn't exist
|
||||
"""
|
||||
target = target.lower()
|
||||
if target not in self._optimizers:
|
||||
raise KeyError(f'Cannot create alias for non-existent optimizer {target}')
|
||||
self._optimizers[alias.lower()] = self._optimizers[target]
|
||||
|
||||
def register_foreach_default(self, name: str) -> None:
|
||||
"""Register an optimizer as defaulting to foreach=True."""
|
||||
self._foreach_defaults.add(name.lower())
|
||||
|
||||
def list_optimizers(
|
||||
self,
|
||||
filter: str = '',
|
||||
exclude_filters: Optional[List[str]] = None,
|
||||
with_description: bool = False
|
||||
) -> List[Union[str, Tuple[str, str]]]:
|
||||
"""List available optimizer names, optionally filtered.
|
||||
|
||||
Args:
|
||||
filter: Wildcard style filter string (e.g., 'adam*')
|
||||
exclude_filters: Optional list of wildcard patterns to exclude
|
||||
with_description: If True, return tuples of (name, description)
|
||||
|
||||
Returns:
|
||||
List of either optimizer names or (name, description) tuples
|
||||
"""
|
||||
names = sorted(self._optimizers.keys())
|
||||
|
||||
if filter:
|
||||
names = [n for n in names if fnmatch(n, filter)]
|
||||
|
||||
if exclude_filters:
|
||||
for exclude_filter in exclude_filters:
|
||||
names = [n for n in names if not fnmatch(n, exclude_filter)]
|
||||
|
||||
if with_description:
|
||||
return [(name, self._optimizers[name].description) for name in names]
|
||||
return names
|
||||
|
||||
def get_optimizer_info(self, name: str) -> OptimInfo:
|
||||
"""Get the OptimInfo for an optimizer.
|
||||
|
||||
Args:
|
||||
name: Name of the optimizer
|
||||
|
||||
Returns:
|
||||
OptimInfo configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If optimizer is not found
|
||||
"""
|
||||
name = name.lower()
|
||||
if name not in self._optimizers:
|
||||
raise ValueError(f'Optimizer {name} not found in registry')
|
||||
return self._optimizers[name]
|
||||
|
||||
def get_optimizer_class(
|
||||
self,
|
||||
name_or_info: Union[str, OptimInfo],
|
||||
bind_defaults: bool = True,
|
||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||
"""Get the optimizer class with any default arguments applied.
|
||||
|
||||
This allows direct instantiation of optimizers with their default configs
|
||||
without going through the full factory.
|
||||
|
||||
Args:
|
||||
name_or_info: Name of the optimizer
|
||||
bind_defaults: Bind default arguments to optimizer class via `partial` before returning
|
||||
|
||||
Returns:
|
||||
Optimizer class or partial with defaults applied
|
||||
|
||||
Raises:
|
||||
ValueError: If optimizer not found
|
||||
"""
|
||||
if isinstance(name_or_info, str):
|
||||
opt_info = self.get_optimizer_info(name_or_info)
|
||||
else:
|
||||
assert isinstance(name_or_info, OptimInfo)
|
||||
opt_info = name_or_info
|
||||
|
||||
if isinstance(opt_info.opt_class, str):
|
||||
# Special handling for APEX and BNB optimizers
|
||||
if opt_info.opt_class.startswith('apex.'):
|
||||
assert torch.cuda.is_available(), 'CUDA required for APEX optimizers'
|
||||
try:
|
||||
opt_class = _import_class(opt_info.opt_class)
|
||||
except ImportError as e:
|
||||
raise ImportError('APEX optimizers require apex to be installed') from e
|
||||
elif opt_info.opt_class.startswith('bitsandbytes.'):
|
||||
assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers'
|
||||
try:
|
||||
opt_class = _import_class(opt_info.opt_class)
|
||||
except ImportError as e:
|
||||
raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e
|
||||
else:
|
||||
opt_class = _import_class(opt_info.opt_class)
|
||||
else:
|
||||
opt_class = opt_info.opt_class
|
||||
|
||||
# Return class or partial with defaults
|
||||
if bind_defaults and opt_info.defaults:
|
||||
opt_class = partial(opt_class, **opt_info.defaults)
|
||||
|
||||
return opt_class
|
||||
|
||||
def create_optimizer(
|
||||
self,
|
||||
model_or_params: Union[nn.Module, Params],
|
||||
opt: str,
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: float = 0.,
|
||||
momentum: float = 0.9,
|
||||
foreach: Optional[bool] = None,
|
||||
weight_decay_exclude_1d: bool = True,
|
||||
layer_decay: Optional[float] = None,
|
||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||
**kwargs: Any,
|
||||
) -> optim.Optimizer:
|
||||
"""Create an optimizer instance.
|
||||
|
||||
Args:
|
||||
model_or_params: Model or parameters to optimize
|
||||
opt: Name of optimizer to create
|
||||
lr: Learning rate
|
||||
weight_decay: Weight decay factor
|
||||
momentum: Momentum factor for applicable optimizers
|
||||
foreach: Enable/disable foreach operation
|
||||
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
|
||||
layer_decay: Layer-wise learning rate decay
|
||||
param_group_fn: Optional custom parameter grouping function
|
||||
**kwargs: Additional optimizer-specific arguments
|
||||
|
||||
Returns:
|
||||
Configured optimizer instance
|
||||
|
||||
Raises:
|
||||
ValueError: If optimizer not found or configuration invalid
|
||||
"""
|
||||
|
||||
# Get parameters to optimize
|
||||
if isinstance(model_or_params, nn.Module):
|
||||
# Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied
|
||||
no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())()
|
||||
|
||||
if param_group_fn:
|
||||
# run custom fn to generate param groups from nn.Module
|
||||
parameters = param_group_fn(model_or_params)
|
||||
elif layer_decay is not None:
|
||||
parameters = param_groups_layer_decay(
|
||||
model_or_params,
|
||||
weight_decay=weight_decay,
|
||||
layer_decay=layer_decay,
|
||||
no_weight_decay_list=no_weight_decay,
|
||||
weight_decay_exclude_1d=weight_decay_exclude_1d,
|
||||
)
|
||||
weight_decay = 0.
|
||||
elif weight_decay and weight_decay_exclude_1d:
|
||||
parameters = param_groups_weight_decay(
|
||||
model_or_params,
|
||||
weight_decay=weight_decay,
|
||||
no_weight_decay_list=no_weight_decay,
|
||||
)
|
||||
weight_decay = 0.
|
||||
else:
|
||||
parameters = model_or_params.parameters()
|
||||
else:
|
||||
# pass parameters / parameter groups through to optimizer
|
||||
parameters = model_or_params
|
||||
|
||||
# Parse optimizer name
|
||||
opt_split = opt.lower().split('_')
|
||||
opt_name = opt_split[-1]
|
||||
use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False
|
||||
|
||||
opt_info = self.get_optimizer_info(opt_name)
|
||||
|
||||
# Build optimizer arguments
|
||||
opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs}
|
||||
|
||||
# Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None.
|
||||
if lr is not None:
|
||||
opt_args['lr'] = lr
|
||||
|
||||
# Apply optimizer-specific settings
|
||||
if opt_info.defaults:
|
||||
for k, v in opt_info.defaults.items():
|
||||
opt_args.setdefault(k, v)
|
||||
|
||||
# timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat.
|
||||
if opt_info.has_momentum:
|
||||
opt_args.setdefault('momentum', momentum)
|
||||
|
||||
# Remove commonly used kwargs that aren't always supported
|
||||
if not opt_info.has_eps:
|
||||
opt_args.pop('eps', None)
|
||||
if not opt_info.has_betas:
|
||||
opt_args.pop('betas', None)
|
||||
|
||||
if foreach is not None:
|
||||
# Explicitly activate or deactivate multi-tensor foreach impl.
|
||||
# Not all optimizers support this, and those that do usually default to using
|
||||
# multi-tensor impl if foreach is left as default 'None' and can be enabled.
|
||||
opt_args.setdefault('foreach', foreach)
|
||||
|
||||
# Create optimizer
|
||||
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
|
||||
optimizer = opt_class(parameters, **opt_args)
|
||||
|
||||
# Apply Lookahead if requested
|
||||
if use_lookahead:
|
||||
optimizer = Lookahead(optimizer)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
||||
"""Register SGD-based optimizers"""
|
||||
sgd_optimizers = [
|
||||
OptimInfo(
|
||||
name='sgd',
|
||||
opt_class=optim.SGD,
|
||||
description='Stochastic Gradient Descent with Nesterov momentum (default)',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='momentum',
|
||||
opt_class=optim.SGD,
|
||||
description='Stochastic Gradient Descent with classical momentum',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': False}
|
||||
),
|
||||
OptimInfo(
|
||||
name='sgdp',
|
||||
opt_class=SGDP,
|
||||
description='SGD with built-in projection to unit norm sphere',
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='sgdw',
|
||||
opt_class=SGDW,
|
||||
description='SGD with decoupled weight decay and Nesterov momentum',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
]
|
||||
for opt in sgd_optimizers:
|
||||
registry.register(opt)
|
||||
|
||||
|
||||
def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
||||
"""Register Adam-based optimizers"""
|
||||
adam_optimizers = [
|
||||
OptimInfo(
|
||||
name='adam',
|
||||
opt_class=optim.Adam,
|
||||
description='torch.optim Adam (Adaptive Moment Estimation)',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamw',
|
||||
opt_class=optim.AdamW,
|
||||
description='torch.optim Adam with decoupled weight decay regularization',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamp',
|
||||
opt_class=AdamP,
|
||||
description='Adam with built-in projection to unit norm sphere',
|
||||
has_betas=True,
|
||||
defaults={'wd_ratio': 0.01, 'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='nadam',
|
||||
opt_class=Nadam,
|
||||
description='Adam with Nesterov momentum',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='nadamw',
|
||||
opt_class=NAdamW,
|
||||
description='Adam with Nesterov momentum and decoupled weight decay',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='radam',
|
||||
opt_class=RAdam,
|
||||
description='Rectified Adam with variance adaptation',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adamax',
|
||||
opt_class=optim.Adamax,
|
||||
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='adafactor',
|
||||
opt_class=Adafactor,
|
||||
description='Memory-efficient implementation of Adam with factored gradients',
|
||||
),
|
||||
OptimInfo(
|
||||
name='adafactorbv',
|
||||
opt_class=AdafactorBigVision,
|
||||
description='Big Vision variant of Adafactor with factored gradients, half precision momentum.',
|
||||
),
|
||||
OptimInfo(
|
||||
name='adopt',
|
||||
opt_class=Adopt,
|
||||
description='Memory-efficient implementation of Adam with factored gradients',
|
||||
),
|
||||
OptimInfo(
|
||||
name='adoptw',
|
||||
opt_class=Adopt,
|
||||
description='Memory-efficient implementation of Adam with factored gradients',
|
||||
defaults={'decoupled': True}
|
||||
),
|
||||
]
|
||||
for opt in adam_optimizers:
|
||||
registry.register(opt)
|
||||
|
||||
|
||||
def _register_lamb_lars(registry: OptimizerRegistry) -> None:
|
||||
"""Register LAMB and LARS variants"""
|
||||
lamb_lars_optimizers = [
|
||||
OptimInfo(
|
||||
name='lamb',
|
||||
opt_class=Lamb,
|
||||
description='Layer-wise Adaptive Moments for batch optimization',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='lambc',
|
||||
opt_class=Lamb,
|
||||
description='LAMB with trust ratio clipping for stability',
|
||||
has_betas=True,
|
||||
defaults={'trust_clip': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='lars',
|
||||
opt_class=Lars,
|
||||
description='Layer-wise Adaptive Rate Scaling',
|
||||
has_momentum=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='larc',
|
||||
opt_class=Lars,
|
||||
description='LARS with trust ratio clipping for stability',
|
||||
has_momentum=True,
|
||||
defaults={'trust_clip': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='nlars',
|
||||
opt_class=Lars,
|
||||
description='LARS with Nesterov momentum',
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='nlarc',
|
||||
opt_class=Lars,
|
||||
description='LARS with Nesterov momentum & trust ratio clipping',
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True, 'trust_clip': True}
|
||||
),
|
||||
]
|
||||
for opt in lamb_lars_optimizers:
|
||||
registry.register(opt)
|
||||
|
||||
|
||||
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
||||
"""Register miscellaneous optimizers"""
|
||||
other_optimizers = [
|
||||
OptimInfo(
|
||||
name='adabelief',
|
||||
opt_class=AdaBelief,
|
||||
description='Adapts learning rate based on gradient prediction error',
|
||||
has_betas=True,
|
||||
defaults={'rectify': False}
|
||||
),
|
||||
OptimInfo(
|
||||
name='radabelief',
|
||||
opt_class=AdaBelief,
|
||||
description='Rectified AdaBelief with variance adaptation',
|
||||
has_betas=True,
|
||||
defaults={'rectify': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='adadelta',
|
||||
opt_class=optim.Adadelta,
|
||||
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
|
||||
),
|
||||
OptimInfo(
|
||||
name='adagrad',
|
||||
opt_class=optim.Adagrad,
|
||||
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||
defaults={'eps': 1e-8}
|
||||
),
|
||||
OptimInfo(
|
||||
name='adan',
|
||||
opt_class=Adan,
|
||||
description='Adaptive Nesterov Momentum Algorithm',
|
||||
defaults={'no_prox': False},
|
||||
has_betas=True,
|
||||
num_betas=3
|
||||
),
|
||||
OptimInfo(
|
||||
name='adanw',
|
||||
opt_class=Adan,
|
||||
description='Adaptive Nesterov Momentum with decoupled weight decay',
|
||||
defaults={'no_prox': True},
|
||||
has_betas=True,
|
||||
num_betas=3
|
||||
),
|
||||
OptimInfo(
|
||||
name='lion',
|
||||
opt_class=Lion,
|
||||
description='Evolved Sign Momentum optimizer for improved convergence',
|
||||
has_eps=False,
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='madgrad',
|
||||
opt_class=MADGRAD,
|
||||
description='Momentum-based Adaptive gradient method',
|
||||
has_momentum=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='madgradw',
|
||||
opt_class=MADGRAD,
|
||||
description='MADGRAD with decoupled weight decay',
|
||||
has_momentum=True,
|
||||
defaults={'decoupled_decay': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='novograd',
|
||||
opt_class=NvNovoGrad,
|
||||
description='Normalized Adam with L2 norm gradient normalization',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='rmsprop',
|
||||
opt_class=optim.RMSprop,
|
||||
description='torch.optim RMSprop, Root Mean Square Propagation',
|
||||
has_momentum=True,
|
||||
defaults={'alpha': 0.9}
|
||||
),
|
||||
OptimInfo(
|
||||
name='rmsproptf',
|
||||
opt_class=RMSpropTF,
|
||||
description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation',
|
||||
has_momentum=True,
|
||||
defaults={'alpha': 0.9}
|
||||
),
|
||||
]
|
||||
for opt in other_optimizers:
|
||||
registry.register(opt)
|
||||
registry.register_foreach_default('lion')
|
||||
|
||||
|
||||
def _register_apex_optimizers(registry: OptimizerRegistry) -> None:
|
||||
"""Register APEX optimizers (lazy import)"""
|
||||
apex_optimizers = [
|
||||
OptimInfo(
|
||||
name='fusedsgd',
|
||||
opt_class='apex.optimizers.FusedSGD',
|
||||
description='NVIDIA APEX fused SGD implementation for faster training',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='fusedadam',
|
||||
opt_class='apex.optimizers.FusedAdam',
|
||||
description='NVIDIA APEX fused Adam implementation',
|
||||
has_betas=True,
|
||||
defaults={'adam_w_mode': False}
|
||||
),
|
||||
OptimInfo(
|
||||
name='fusedadamw',
|
||||
opt_class='apex.optimizers.FusedAdam',
|
||||
description='NVIDIA APEX fused AdamW implementation',
|
||||
has_betas=True,
|
||||
defaults={'adam_w_mode': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='fusedlamb',
|
||||
opt_class='apex.optimizers.FusedLAMB',
|
||||
description='NVIDIA APEX fused LAMB implementation',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='fusednovograd',
|
||||
opt_class='apex.optimizers.FusedNovoGrad',
|
||||
description='NVIDIA APEX fused NovoGrad implementation',
|
||||
has_betas=True,
|
||||
defaults={'betas': (0.95, 0.98)}
|
||||
),
|
||||
]
|
||||
for opt in apex_optimizers:
|
||||
registry.register(opt)
|
||||
|
||||
|
||||
def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
|
||||
"""Register bitsandbytes optimizers (lazy import)"""
|
||||
bnb_optimizers = [
|
||||
OptimInfo(
|
||||
name='bnbsgd',
|
||||
opt_class='bitsandbytes.optim.SGD',
|
||||
description='bitsandbytes SGD',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='bnbsgd8bit',
|
||||
opt_class='bitsandbytes.optim.SGD8bit',
|
||||
description='bitsandbytes 8-bit SGD with dynamic quantization',
|
||||
has_eps=False,
|
||||
has_momentum=True,
|
||||
defaults={'nesterov': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='bnbadam',
|
||||
opt_class='bitsandbytes.optim.Adam',
|
||||
description='bitsandbytes Adam',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='bnbadam8bit',
|
||||
opt_class='bitsandbytes.optim.Adam',
|
||||
description='bitsandbytes 8-bit Adam with dynamic quantization',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='bnbadamw',
|
||||
opt_class='bitsandbytes.optim.AdamW',
|
||||
description='bitsandbytes AdamW',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
name='bnbadamw8bit',
|
||||
opt_class='bitsandbytes.optim.AdamW',
|
||||
description='bitsandbytes 8-bit AdamW with dynamic quantization',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
'bnblion',
|
||||
'bitsandbytes.optim.Lion',
|
||||
description='bitsandbytes Lion',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
'bnblion8bit',
|
||||
'bitsandbytes.optim.Lion8bit',
|
||||
description='bitsandbytes 8-bit Lion with dynamic quantization',
|
||||
has_betas=True
|
||||
),
|
||||
OptimInfo(
|
||||
'bnbademamix',
|
||||
'bitsandbytes.optim.AdEMAMix',
|
||||
description='bitsandbytes AdEMAMix',
|
||||
has_betas=True,
|
||||
num_betas=3,
|
||||
),
|
||||
OptimInfo(
|
||||
'bnbademamix8bit',
|
||||
'bitsandbytes.optim.AdEMAMix8bit',
|
||||
description='bitsandbytes 8-bit AdEMAMix with dynamic quantization',
|
||||
has_betas=True,
|
||||
num_betas=3,
|
||||
),
|
||||
]
|
||||
for opt in bnb_optimizers:
|
||||
registry.register(opt)
|
||||
|
||||
|
||||
default_registry = OptimizerRegistry()
|
||||
|
||||
def _register_default_optimizers() -> None:
|
||||
"""Register all default optimizers to the global registry."""
|
||||
# Register all optimizer groups
|
||||
_register_sgd_variants(default_registry)
|
||||
_register_adam_variants(default_registry)
|
||||
_register_lamb_lars(default_registry)
|
||||
_register_other_optimizers(default_registry)
|
||||
_register_apex_optimizers(default_registry)
|
||||
_register_bnb_optimizers(default_registry)
|
||||
|
||||
# Register aliases
|
||||
default_registry.register_alias('nesterov', 'sgd')
|
||||
default_registry.register_alias('nesterovw', 'sgdw')
|
||||
|
||||
|
||||
# Initialize default registry
|
||||
_register_default_optimizers()
|
||||
|
||||
# Public API
|
||||
|
||||
def list_optimizers(
|
||||
filter: str = '',
|
||||
exclude_filters: Optional[List[str]] = None,
|
||||
with_description: bool = False,
|
||||
) -> List[Union[str, Tuple[str, str]]]:
|
||||
"""List available optimizer names, optionally filtered.
|
||||
"""
|
||||
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||||
|
||||
|
||||
def get_optimizer_class(
|
||||
name: str,
|
||||
bind_defaults: bool = False,
|
||||
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||
"""Get optimizer class by name with any defaults applied.
|
||||
"""
|
||||
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
||||
|
||||
|
||||
def create_optimizer_v2(
|
||||
model_or_params: Union[nn.Module, Params],
|
||||
opt: str = 'sgd',
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: float = 0.,
|
||||
momentum: float = 0.9,
|
||||
foreach: Optional[bool] = None,
|
||||
filter_bias_and_bn: bool = True,
|
||||
layer_decay: Optional[float] = None,
|
||||
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||
**kwargs: Any,
|
||||
) -> optim.Optimizer:
|
||||
"""Create an optimizer instance using the default registry."""
|
||||
return default_registry.create_optimizer(
|
||||
model_or_params,
|
||||
opt=opt,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
momentum=momentum,
|
||||
foreach=foreach,
|
||||
weight_decay_exclude_1d=filter_bias_and_bn,
|
||||
layer_decay=layer_decay,
|
||||
param_group_fn=param_group_fn,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def optimizer_kwargs(cfg):
|
||||
""" cfg/argparse to kwargs helper
|
||||
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
||||
"""
|
||||
kwargs = dict(
|
||||
opt=cfg.opt,
|
||||
lr=cfg.lr,
|
||||
weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum,
|
||||
)
|
||||
if getattr(cfg, 'opt_eps', None) is not None:
|
||||
kwargs['eps'] = cfg.opt_eps
|
||||
if getattr(cfg, 'opt_betas', None) is not None:
|
||||
kwargs['betas'] = cfg.opt_betas
|
||||
if getattr(cfg, 'layer_decay', None) is not None:
|
||||
kwargs['layer_decay'] = cfg.layer_decay
|
||||
if getattr(cfg, 'opt_args', None) is not None:
|
||||
kwargs.update(cfg.opt_args)
|
||||
if getattr(cfg, 'opt_foreach', None) is not None:
|
||||
kwargs['foreach'] = cfg.opt_foreach
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
""" Legacy optimizer factory for backwards compatibility.
|
||||
NOTE: Use create_optimizer_v2 for new code.
|
||||
"""
|
||||
return create_optimizer_v2(
|
||||
model,
|
||||
**optimizer_kwargs(cfg=args),
|
||||
filter_bias_and_bn=filter_bias_and_bn,
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
import logging
|
||||
from itertools import islice
|
||||
from typing import Collection, Optional, Tuple
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.models import group_parameters
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def param_groups_weight_decay(
|
||||
model: nn.Module,
|
||||
weight_decay: float = 1e-5,
|
||||
no_weight_decay_list: Collection[str] = (),
|
||||
):
|
||||
no_weight_decay_list = set(no_weight_decay_list)
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
|
||||
return [
|
||||
{'params': no_decay, 'weight_decay': 0.},
|
||||
{'params': decay, 'weight_decay': weight_decay}]
|
||||
|
||||
|
||||
def _group(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def _layer_map(model, layers_per_group=12, num_groups=None):
|
||||
def _in_head(n, hp):
|
||||
if not hp:
|
||||
return True
|
||||
elif isinstance(hp, (tuple, list)):
|
||||
return any([n.startswith(hpi) for hpi in hp])
|
||||
else:
|
||||
return n.startswith(hp)
|
||||
|
||||
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
|
||||
names_trunk = []
|
||||
names_head = []
|
||||
for n, _ in model.named_parameters():
|
||||
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
|
||||
|
||||
# group non-head layers
|
||||
num_trunk_layers = len(names_trunk)
|
||||
if num_groups is not None:
|
||||
layers_per_group = -(num_trunk_layers // -num_groups)
|
||||
names_trunk = list(_group(names_trunk, layers_per_group))
|
||||
|
||||
num_trunk_groups = len(names_trunk)
|
||||
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
|
||||
layer_map.update({n: num_trunk_groups for n in names_head})
|
||||
return layer_map
|
||||
|
||||
|
||||
def param_groups_layer_decay(
|
||||
model: nn.Module,
|
||||
weight_decay: float = 0.05,
|
||||
no_weight_decay_list: Collection[str] = (),
|
||||
weight_decay_exclude_1d: bool = True,
|
||||
layer_decay: float = .75,
|
||||
end_layer_decay: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameter groups for layer-wise lr decay & weight decay
|
||||
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
||||
"""
|
||||
no_weight_decay_list = set(no_weight_decay_list)
|
||||
param_group_names = {} # NOTE for debugging
|
||||
param_groups = {}
|
||||
|
||||
if hasattr(model, 'group_matcher'):
|
||||
# FIXME interface needs more work
|
||||
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
||||
else:
|
||||
# fallback
|
||||
layer_map = _layer_map(model)
|
||||
num_layers = max(layer_map.values()) + 1
|
||||
layer_max = num_layers - 1
|
||||
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
# no decay: all 1D parameters and model specific ones
|
||||
if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list:
|
||||
g_decay = "no_decay"
|
||||
this_decay = 0.
|
||||
else:
|
||||
g_decay = "decay"
|
||||
this_decay = weight_decay
|
||||
|
||||
layer_id = layer_map.get(name, layer_max)
|
||||
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
||||
|
||||
if group_name not in param_groups:
|
||||
this_scale = layer_scales[layer_id]
|
||||
param_group_names[group_name] = {
|
||||
"lr_scale": this_scale,
|
||||
"weight_decay": this_decay,
|
||||
"param_names": [],
|
||||
}
|
||||
param_groups[group_name] = {
|
||||
"lr_scale": this_scale,
|
||||
"weight_decay": this_decay,
|
||||
"params": [],
|
||||
}
|
||||
|
||||
param_group_names[group_name]["param_names"].append(name)
|
||||
param_groups[group_name]["params"].append(param)
|
||||
|
||||
if verbose:
|
||||
import json
|
||||
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||
|
||||
return list(param_groups.values())
|
|
@ -40,9 +40,18 @@ class AdaBelief(Optimizer):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
|
||||
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
|
||||
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-16,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
decoupled_decay=True,
|
||||
fixed_decay=False,
|
||||
rectify=True,
|
||||
degenerated_to_sgd=True,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -58,9 +67,17 @@ class AdaBelief(Optimizer):
|
|||
param['buffer'] = [[None, None, None] for _ in range(10)]
|
||||
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
|
||||
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
|
||||
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
degenerated_to_sgd=degenerated_to_sgd,
|
||||
decoupled_decay=decoupled_decay,
|
||||
rectify=rectify,
|
||||
fixed_decay=fixed_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)]
|
||||
)
|
||||
super(AdaBelief, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
|
@ -16,6 +16,7 @@ import math
|
|||
|
||||
class Adafactor(torch.optim.Optimizer):
|
||||
"""Implements Adafactor algorithm.
|
||||
|
||||
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
||||
(see https://arxiv.org/abs/1804.04235)
|
||||
|
||||
|
|
|
@ -23,8 +23,18 @@ class Adahessian(torch.optim.Optimizer):
|
|||
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
|
||||
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=0.1,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0.0,
|
||||
hessian_power=1.0,
|
||||
update_each=1,
|
||||
n_samples=1,
|
||||
avg_conv_kernel=False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
|
@ -44,7 +54,13 @@ class Adahessian(torch.optim.Optimizer):
|
|||
self.seed = 2147483647
|
||||
self.generator = torch.Generator().manual_seed(self.seed)
|
||||
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
hessian_power=hessian_power,
|
||||
)
|
||||
super(Adahessian, self).__init__(params, defaults)
|
||||
|
||||
for p in self.get_params():
|
||||
|
|
|
@ -41,11 +41,26 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
|
|||
|
||||
|
||||
class AdamP(Optimizer):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
delta=0.1,
|
||||
wd_ratio=0.1,
|
||||
nesterov=False,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
delta=delta,
|
||||
wd_ratio=wd_ratio,
|
||||
nesterov=nesterov,
|
||||
)
|
||||
super(AdamP, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -36,8 +36,16 @@ class AdamW(Optimizer):
|
|||
https://openreview.net/forum?id=ryQu7f-RZ
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=1e-2, amsgrad=False):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
):
|
||||
# NOTE: deprecated in favour of builtin torch.optim.AdamW
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -46,8 +54,13 @@ class AdamW(Optimizer):
|
|||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
|
@ -137,7 +137,7 @@ def lion(
|
|||
"""
|
||||
if foreach is None:
|
||||
# Placeholder for more complex foreach logic to be added when value is not set
|
||||
foreach = False
|
||||
foreach = True
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
||||
|
|
|
@ -71,7 +71,12 @@ class MADGRAD(torch.optim.Optimizer):
|
|||
raise ValueError(f"Eps must be non-negative")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
||||
lr=lr,
|
||||
eps=eps,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@property
|
||||
|
|
|
@ -27,8 +27,15 @@ class Nadam(Optimizer):
|
|||
NOTE: Has potential issues but does work well on some problems.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=0, schedule_decay=4e-3):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=2e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
schedule_decay=4e-3,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
defaults = dict(
|
||||
|
|
|
@ -29,8 +29,16 @@ class NvNovoGrad(Optimizer):
|
|||
(default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
|
||||
weight_decay=0, grad_averaging=False, amsgrad=False):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.95, 0.98),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
grad_averaging=False,
|
||||
amsgrad=False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -39,10 +47,14 @@ class NvNovoGrad(Optimizer):
|
|||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
amsgrad=amsgrad)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
amsgrad=amsgrad,
|
||||
)
|
||||
|
||||
super(NvNovoGrad, self).__init__(params, defaults)
|
||||
|
||||
|
|
|
@ -1,431 +0,0 @@
|
|||
""" Optimizer Factory w/ Custom Weight Decay
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
from itertools import islice
|
||||
from typing import Optional, Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from timm.models import group_parameters
|
||||
from . import AdafactorBigVision
|
||||
|
||||
from .adabelief import AdaBelief
|
||||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .adan import Adan
|
||||
from .adopt import Adopt
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .lion import Lion
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .nadamw import NAdamW
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .sgdw import SGDW
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# optimizers to default to multi-tensor
|
||||
_DEFAULT_FOREACH = {
|
||||
'lion',
|
||||
}
|
||||
|
||||
|
||||
def param_groups_weight_decay(
|
||||
model: nn.Module,
|
||||
weight_decay=1e-5,
|
||||
no_weight_decay_list=()
|
||||
):
|
||||
no_weight_decay_list = set(no_weight_decay_list)
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
|
||||
return [
|
||||
{'params': no_decay, 'weight_decay': 0.},
|
||||
{'params': decay, 'weight_decay': weight_decay}]
|
||||
|
||||
|
||||
def _group(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def _layer_map(model, layers_per_group=12, num_groups=None):
|
||||
def _in_head(n, hp):
|
||||
if not hp:
|
||||
return True
|
||||
elif isinstance(hp, (tuple, list)):
|
||||
return any([n.startswith(hpi) for hpi in hp])
|
||||
else:
|
||||
return n.startswith(hp)
|
||||
|
||||
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
|
||||
names_trunk = []
|
||||
names_head = []
|
||||
for n, _ in model.named_parameters():
|
||||
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
|
||||
|
||||
# group non-head layers
|
||||
num_trunk_layers = len(names_trunk)
|
||||
if num_groups is not None:
|
||||
layers_per_group = -(num_trunk_layers // -num_groups)
|
||||
names_trunk = list(_group(names_trunk, layers_per_group))
|
||||
|
||||
num_trunk_groups = len(names_trunk)
|
||||
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
|
||||
layer_map.update({n: num_trunk_groups for n in names_head})
|
||||
return layer_map
|
||||
|
||||
|
||||
def param_groups_layer_decay(
|
||||
model: nn.Module,
|
||||
weight_decay: float = 0.05,
|
||||
no_weight_decay_list: Tuple[str] = (),
|
||||
layer_decay: float = .75,
|
||||
end_layer_decay: Optional[float] = None,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameter groups for layer-wise lr decay & weight decay
|
||||
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
||||
"""
|
||||
no_weight_decay_list = set(no_weight_decay_list)
|
||||
param_group_names = {} # NOTE for debugging
|
||||
param_groups = {}
|
||||
|
||||
if hasattr(model, 'group_matcher'):
|
||||
# FIXME interface needs more work
|
||||
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
||||
else:
|
||||
# fallback
|
||||
layer_map = _layer_map(model)
|
||||
num_layers = max(layer_map.values()) + 1
|
||||
layer_max = num_layers - 1
|
||||
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
# no decay: all 1D parameters and model specific ones
|
||||
if param.ndim == 1 or name in no_weight_decay_list:
|
||||
g_decay = "no_decay"
|
||||
this_decay = 0.
|
||||
else:
|
||||
g_decay = "decay"
|
||||
this_decay = weight_decay
|
||||
|
||||
layer_id = layer_map.get(name, layer_max)
|
||||
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
||||
|
||||
if group_name not in param_groups:
|
||||
this_scale = layer_scales[layer_id]
|
||||
param_group_names[group_name] = {
|
||||
"lr_scale": this_scale,
|
||||
"weight_decay": this_decay,
|
||||
"param_names": [],
|
||||
}
|
||||
param_groups[group_name] = {
|
||||
"lr_scale": this_scale,
|
||||
"weight_decay": this_decay,
|
||||
"params": [],
|
||||
}
|
||||
|
||||
param_group_names[group_name]["param_names"].append(name)
|
||||
param_groups[group_name]["params"].append(param)
|
||||
|
||||
if verbose:
|
||||
import json
|
||||
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||
|
||||
return list(param_groups.values())
|
||||
|
||||
|
||||
def optimizer_kwargs(cfg):
|
||||
""" cfg/argparse to kwargs helper
|
||||
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
||||
"""
|
||||
kwargs = dict(
|
||||
opt=cfg.opt,
|
||||
lr=cfg.lr,
|
||||
weight_decay=cfg.weight_decay,
|
||||
momentum=cfg.momentum,
|
||||
)
|
||||
if getattr(cfg, 'opt_eps', None) is not None:
|
||||
kwargs['eps'] = cfg.opt_eps
|
||||
if getattr(cfg, 'opt_betas', None) is not None:
|
||||
kwargs['betas'] = cfg.opt_betas
|
||||
if getattr(cfg, 'layer_decay', None) is not None:
|
||||
kwargs['layer_decay'] = cfg.layer_decay
|
||||
if getattr(cfg, 'opt_args', None) is not None:
|
||||
kwargs.update(cfg.opt_args)
|
||||
if getattr(cfg, 'opt_foreach', None) is not None:
|
||||
kwargs['foreach'] = cfg.opt_foreach
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
""" Legacy optimizer factory for backwards compatibility.
|
||||
NOTE: Use create_optimizer_v2 for new code.
|
||||
"""
|
||||
return create_optimizer_v2(
|
||||
model,
|
||||
**optimizer_kwargs(cfg=args),
|
||||
filter_bias_and_bn=filter_bias_and_bn,
|
||||
)
|
||||
|
||||
|
||||
def create_optimizer_v2(
|
||||
model_or_params,
|
||||
opt: str = 'sgd',
|
||||
lr: Optional[float] = None,
|
||||
weight_decay: float = 0.,
|
||||
momentum: float = 0.9,
|
||||
foreach: Optional[bool] = None,
|
||||
filter_bias_and_bn: bool = True,
|
||||
layer_decay: Optional[float] = None,
|
||||
param_group_fn: Optional[Callable] = None,
|
||||
**kwargs,
|
||||
):
|
||||
""" Create an optimizer.
|
||||
|
||||
TODO currently the model is passed in and all parameters are selected for optimization.
|
||||
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
|
||||
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
|
||||
* expose the parameters interface and leave it up to caller
|
||||
|
||||
Args:
|
||||
model_or_params (nn.Module): model containing parameters to optimize
|
||||
opt: name of optimizer to create
|
||||
lr: initial learning rate
|
||||
weight_decay: weight decay to apply in optimizer
|
||||
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
|
||||
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
|
||||
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
|
||||
**kwargs: extra optimizer specific kwargs to pass through
|
||||
|
||||
Returns:
|
||||
Optimizer
|
||||
"""
|
||||
if isinstance(model_or_params, nn.Module):
|
||||
# a model was passed in, extract parameters and add weight decays to appropriate layers
|
||||
no_weight_decay = {}
|
||||
if hasattr(model_or_params, 'no_weight_decay'):
|
||||
no_weight_decay = model_or_params.no_weight_decay()
|
||||
|
||||
if param_group_fn:
|
||||
parameters = param_group_fn(model_or_params)
|
||||
elif layer_decay is not None:
|
||||
parameters = param_groups_layer_decay(
|
||||
model_or_params,
|
||||
weight_decay=weight_decay,
|
||||
layer_decay=layer_decay,
|
||||
no_weight_decay_list=no_weight_decay,
|
||||
)
|
||||
weight_decay = 0.
|
||||
elif weight_decay and filter_bias_and_bn:
|
||||
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
|
||||
weight_decay = 0.
|
||||
else:
|
||||
parameters = model_or_params.parameters()
|
||||
else:
|
||||
# iterable of parameters or param groups passed in
|
||||
parameters = model_or_params
|
||||
|
||||
opt_lower = opt.lower()
|
||||
opt_split = opt_lower.split('_')
|
||||
opt_lower = opt_split[-1]
|
||||
|
||||
if opt_lower.startswith('fused'):
|
||||
try:
|
||||
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||
|
||||
if opt_lower.startswith('bnb'):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
has_bnb = True
|
||||
except ImportError:
|
||||
has_bnb = False
|
||||
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
|
||||
|
||||
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
||||
|
||||
if lr is not None:
|
||||
opt_args.setdefault('lr', lr)
|
||||
|
||||
if foreach is None:
|
||||
if opt in _DEFAULT_FOREACH:
|
||||
opt_args.setdefault('foreach', True)
|
||||
else:
|
||||
opt_args['foreach'] = foreach
|
||||
|
||||
# basic SGD & related
|
||||
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
||||
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'momentum':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
||||
elif opt_lower == 'sgdp':
|
||||
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'sgdw' or opt_lower == 'nesterovw':
|
||||
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'momentumw':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args)
|
||||
|
||||
# adaptive
|
||||
elif opt_lower == 'adam':
|
||||
optimizer = optim.Adam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamw':
|
||||
optimizer = optim.AdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'adamp':
|
||||
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'nadam':
|
||||
try:
|
||||
# NOTE PyTorch >= 1.10 should have native NAdam
|
||||
optimizer = optim.Nadam(parameters, **opt_args)
|
||||
except AttributeError:
|
||||
optimizer = Nadam(parameters, **opt_args)
|
||||
elif opt_lower == 'nadamw':
|
||||
optimizer = NAdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'radam':
|
||||
optimizer = RAdam(parameters, **opt_args)
|
||||
elif opt_lower == 'adamax':
|
||||
optimizer = optim.Adamax(parameters, **opt_args)
|
||||
elif opt_lower == 'adabelief':
|
||||
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
|
||||
elif opt_lower == 'radabelief':
|
||||
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
|
||||
elif opt_lower == 'adadelta':
|
||||
optimizer = optim.Adadelta(parameters, **opt_args)
|
||||
elif opt_lower == 'adagrad':
|
||||
opt_args.setdefault('eps', 1e-8)
|
||||
optimizer = optim.Adagrad(parameters, **opt_args)
|
||||
elif opt_lower == 'adafactor':
|
||||
optimizer = Adafactor(parameters, **opt_args)
|
||||
elif opt_lower == 'adanp':
|
||||
optimizer = Adan(parameters, no_prox=False, **opt_args)
|
||||
elif opt_lower == 'adanw':
|
||||
optimizer = Adan(parameters, no_prox=True, **opt_args)
|
||||
elif opt_lower == 'lamb':
|
||||
optimizer = Lamb(parameters, **opt_args)
|
||||
elif opt_lower == 'lambc':
|
||||
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
|
||||
elif opt_lower == 'larc':
|
||||
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
|
||||
elif opt_lower == 'lars':
|
||||
optimizer = Lars(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'nlarc':
|
||||
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'nlars':
|
||||
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'madgrad':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'madgradw':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
|
||||
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'rmsprop':
|
||||
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'rmsproptf':
|
||||
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'lion':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = Lion(parameters, **opt_args)
|
||||
elif opt_lower == 'adafactorbv':
|
||||
optimizer = AdafactorBigVision(parameters, **opt_args)
|
||||
elif opt_lower == 'adopt':
|
||||
optimizer = Adopt(parameters, **opt_args)
|
||||
elif opt_lower == 'adoptw':
|
||||
optimizer = Adopt(parameters, decoupled=True, **opt_args)
|
||||
|
||||
# second order
|
||||
elif opt_lower == 'adahessian':
|
||||
optimizer = Adahessian(parameters, **opt_args)
|
||||
|
||||
# NVIDIA fused optimizers, require APEX to be installed
|
||||
elif opt_lower == 'fusedsgd':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'fusedmomentum':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
||||
elif opt_lower == 'fusedadam':
|
||||
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
||||
elif opt_lower == 'fusedadamw':
|
||||
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
||||
elif opt_lower == 'fusedlamb':
|
||||
optimizer = FusedLAMB(parameters, **opt_args)
|
||||
elif opt_lower == 'fusednovograd':
|
||||
opt_args.setdefault('betas', (0.95, 0.98))
|
||||
optimizer = FusedNovoGrad(parameters, **opt_args)
|
||||
|
||||
# bitsandbytes optimizers, require bitsandbytes to be installed
|
||||
elif opt_lower == 'bnbsgd':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'bnbsgd8bit':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'bnbmomentum':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'bnbmomentum8bit':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'bnbadam':
|
||||
optimizer = bnb.optim.Adam(parameters, **opt_args)
|
||||
elif opt_lower == 'bnbadam8bit':
|
||||
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnbadamw':
|
||||
optimizer = bnb.optim.AdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'bnbadamw8bit':
|
||||
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblamb':
|
||||
optimizer = bnb.optim.LAMB(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblamb8bit':
|
||||
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblars':
|
||||
optimizer = bnb.optim.LARS(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblarsb8bit':
|
||||
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblion':
|
||||
optimizer = bnb.optim.Lion(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblion8bit':
|
||||
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
|
||||
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
raise ValueError
|
||||
|
||||
if len(opt_split) > 1:
|
||||
if opt_split[0] == 'lookahead':
|
||||
optimizer = Lookahead(optimizer)
|
||||
|
||||
return optimizer
|
|
@ -9,10 +9,21 @@ from torch.optim.optimizer import Optimizer
|
|||
|
||||
class RAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)])
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
buffer=[[None, None, None] for _ in range(10)]
|
||||
)
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
|
@ -45,8 +45,18 @@ class RMSpropTF(Optimizer):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
|
||||
decoupled_decay=False, lr_in_momentum=True):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-2,
|
||||
alpha=0.9,
|
||||
eps=1e-10,
|
||||
weight_decay=0,
|
||||
momentum=0.,
|
||||
centered=False,
|
||||
decoupled_decay=False,
|
||||
lr_in_momentum=True,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
|
@ -59,8 +69,15 @@ class RMSpropTF(Optimizer):
|
|||
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||
|
||||
defaults = dict(
|
||||
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
alpha=alpha,
|
||||
eps=eps,
|
||||
centered=centered,
|
||||
weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay,
|
||||
lr_in_momentum=lr_in_momentum,
|
||||
)
|
||||
super(RMSpropTF, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
|
@ -17,11 +17,28 @@ from .adamp import projection
|
|||
|
||||
|
||||
class SGDP(Optimizer):
|
||||
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
||||
weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=required,
|
||||
momentum=0,
|
||||
dampening=0,
|
||||
weight_decay=0,
|
||||
nesterov=False,
|
||||
eps=1e-8,
|
||||
delta=0.1,
|
||||
wd_ratio=0.1
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
|
||||
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
eps=eps,
|
||||
delta=delta,
|
||||
wd_ratio=wd_ratio,
|
||||
)
|
||||
super(SGDP, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -35,10 +35,15 @@ class SGDW(Optimizer):
|
|||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr, momentum=momentum, dampening=dampening,
|
||||
weight_decay=weight_decay, nesterov=nesterov,
|
||||
maximize=maximize, foreach=foreach,
|
||||
differentiable=differentiable)
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
differentiable=differentiable,
|
||||
)
|
||||
if nesterov and (momentum <= 0 or dampening != 0):
|
||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||
super().__init__(params, defaults)
|
||||
|
|
8
train.py
8
train.py
|
@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
|||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import copy
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
|
@ -554,6 +555,13 @@ def main():
|
|||
**optimizer_kwargs(cfg=args),
|
||||
**args.opt_kwargs,
|
||||
)
|
||||
if utils.is_primary(args):
|
||||
defaults = copy.deepcopy(optimizer.defaults)
|
||||
defaults['weight_decay'] = args.weight_decay # this isn't stored in optimizer.defaults
|
||||
defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
|
||||
logging.info(
|
||||
f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
|
||||
)
|
||||
|
||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||
amp_autocast = suppress # do nothing
|
||||
|
|
Loading…
Reference in New Issue