A bit of an optimizer overhaul, added an improved factory, list_optimizers, class helper and add info classes with descriptions, arg configs
parent
c1cf8c52b9
commit
ee5f6e76bb
|
@ -15,7 +15,7 @@ from torch.nn import Parameter
|
||||||
from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
|
from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
|
||||||
from timm.scheduler import PlateauLRScheduler
|
from timm.scheduler import PlateauLRScheduler
|
||||||
|
|
||||||
from timm.optim import create_optimizer_v2
|
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
@ -293,10 +293,11 @@ def _build_params_dict_single(weight, bias, **kwargs):
|
||||||
return [dict(params=bias, **kwargs)]
|
return [dict(params=bias, **kwargs)]
|
||||||
|
|
||||||
|
|
||||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
|
||||||
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
|
def test_optim_factory(optimizer):
|
||||||
@pytest.mark.parametrize('optimizer', ['sgd'])
|
get_optimizer_class(optimizer)
|
||||||
def test_sgd(optimizer):
|
|
||||||
|
# test basic cases that don't need specific tuning via factory test
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -316,6 +317,12 @@ def test_sgd(optimizer):
|
||||||
lambda weight, bias: create_optimizer_v2(
|
lambda weight, bias: create_optimizer_v2(
|
||||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||||
|
# FIXME momentum variant frequently fails in GitHub runner, but never local after many attempts
|
||||||
|
@pytest.mark.parametrize('optimizer', ['sgd'])
|
||||||
|
def test_sgd(optimizer):
|
||||||
# _test_basic_cases(
|
# _test_basic_cases(
|
||||||
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
# lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3),
|
||||||
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
|
# [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
|
||||||
|
@ -358,21 +365,6 @@ def test_sgd(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw'])
|
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw'])
|
||||||
def test_adam(optimizer):
|
def test_adam(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||||
)
|
)
|
||||||
|
@ -381,21 +373,6 @@ def test_adam(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
|
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
|
||||||
def test_adopt(optimizer):
|
def test_adopt(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
# FIXME rosenbrock is not passing for ADOPT
|
# FIXME rosenbrock is not passing for ADOPT
|
||||||
# _test_rosenbrock(
|
# _test_rosenbrock(
|
||||||
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
|
@ -405,25 +382,6 @@ def test_adopt(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adabelief'])
|
@pytest.mark.parametrize('optimizer', ['adabelief'])
|
||||||
def test_adabelief(optimizer):
|
def test_adabelief(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||||
)
|
)
|
||||||
|
@ -435,21 +393,6 @@ def test_adabelief(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
|
@pytest.mark.parametrize('optimizer', ['radam', 'radabelief'])
|
||||||
def test_rectified(optimizer):
|
def test_rectified(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -458,25 +401,6 @@ def test_rectified(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
|
@pytest.mark.parametrize('optimizer', ['adadelta', 'adagrad'])
|
||||||
def test_adaother(optimizer):
|
def test_adaother(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||||
)
|
)
|
||||||
|
@ -488,24 +412,6 @@ def test_adaother(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv'])
|
@pytest.mark.parametrize('optimizer', ['adafactor', 'adafactorbv'])
|
||||||
def test_adafactor(optimizer):
|
def test_adafactor(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(_build_params_dict_single(weight, bias), optimizer)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, weight_decay=1)
|
||||||
)
|
)
|
||||||
|
@ -517,25 +423,6 @@ def test_adafactor(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
|
@pytest.mark.parametrize('optimizer', ['lamb', 'lambc'])
|
||||||
def test_lamb(optimizer):
|
def test_lamb(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=1e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=1e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -544,25 +431,6 @@ def test_lamb(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
||||||
def test_lars(optimizer):
|
def test_lars(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=1e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=1e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -571,25 +439,6 @@ def test_lars(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||||
def test_madgrad(optimizer):
|
def test_madgrad(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||||
)
|
)
|
||||||
|
@ -598,25 +447,6 @@ def test_madgrad(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['novograd'])
|
@pytest.mark.parametrize('optimizer', ['novograd'])
|
||||||
def test_novograd(optimizer):
|
def test_novograd(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -625,25 +455,6 @@ def test_novograd(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
|
@pytest.mark.parametrize('optimizer', ['rmsprop', 'rmsproptf'])
|
||||||
def test_rmsprop(optimizer):
|
def test_rmsprop(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
|
||||||
)
|
)
|
||||||
|
@ -652,25 +463,6 @@ def test_rmsprop(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['adamp'])
|
@pytest.mark.parametrize('optimizer', ['adamp'])
|
||||||
def test_adamp(optimizer):
|
def test_adamp(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||||
)
|
)
|
||||||
|
@ -679,25 +471,6 @@ def test_adamp(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['sgdp'])
|
@pytest.mark.parametrize('optimizer', ['sgdp'])
|
||||||
def test_sgdp(optimizer):
|
def test_sgdp(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -706,25 +479,6 @@ def test_sgdp(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
|
@pytest.mark.parametrize('optimizer', ['lookahead_sgd', 'lookahead_momentum'])
|
||||||
def test_lookahead_sgd(optimizer):
|
def test_lookahead_sgd(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
)
|
)
|
||||||
|
@ -732,25 +486,6 @@ def test_lookahead_sgd(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
|
@pytest.mark.parametrize('optimizer', ['lookahead_adamw', 'lookahead_adam'])
|
||||||
def test_lookahead_adam(optimizer):
|
def test_lookahead_adam(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)
|
||||||
)
|
)
|
||||||
|
@ -758,25 +493,6 @@ def test_lookahead_adam(optimizer):
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
|
@pytest.mark.parametrize('optimizer', ['lookahead_radam'])
|
||||||
def test_lookahead_radam(optimizer):
|
def test_lookahead_radam(optimizer):
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3),
|
|
||||||
optimizer,
|
|
||||||
lr=1e-3)
|
|
||||||
)
|
|
||||||
_test_basic_cases(
|
|
||||||
lambda weight, bias: create_optimizer_v2(
|
|
||||||
_build_params_dict_single(weight, bias, lr=3e-3), optimizer)
|
|
||||||
)
|
|
||||||
_test_rosenbrock(
|
_test_rosenbrock(
|
||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,4 +17,5 @@ from .radam import RAdam
|
||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
|
|
||||||
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
|
||||||
|
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
|
|
@ -0,0 +1,798 @@
|
||||||
|
""" Optimizer Factory w/ custom Weight Decay & Layer Decay support
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
|
||||||
|
from .adabelief import AdaBelief
|
||||||
|
from .adafactor import Adafactor
|
||||||
|
from .adafactor_bv import AdafactorBigVision
|
||||||
|
from .adamp import AdamP
|
||||||
|
from .adan import Adan
|
||||||
|
from .adopt import Adopt
|
||||||
|
from .lamb import Lamb
|
||||||
|
from .lars import Lars
|
||||||
|
from .lion import Lion
|
||||||
|
from .lookahead import Lookahead
|
||||||
|
from .madgrad import MADGRAD
|
||||||
|
from .nadam import Nadam
|
||||||
|
from .nadamw import NAdamW
|
||||||
|
from .nvnovograd import NvNovoGrad
|
||||||
|
from .radam import RAdam
|
||||||
|
from .rmsprop_tf import RMSpropTF
|
||||||
|
from .sgdp import SGDP
|
||||||
|
from .sgdw import SGDW
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type variables
|
||||||
|
T = TypeVar('T')
|
||||||
|
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
|
||||||
|
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
|
||||||
|
|
||||||
|
|
||||||
|
def _import_class(class_string: str) -> Type:
|
||||||
|
"""Dynamically import a class from a string."""
|
||||||
|
try:
|
||||||
|
module_name, class_name = class_string.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
return getattr(module, class_name)
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise ImportError(f"Could not import {class_string}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerCallable(Protocol):
|
||||||
|
"""Protocol for optimizer constructor signatures."""
|
||||||
|
|
||||||
|
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class OptimInfo:
|
||||||
|
"""Immutable configuration for an optimizer.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Unique identifier for the optimizer
|
||||||
|
opt_class: The optimizer class
|
||||||
|
description: Brief description of the optimizer's characteristics and behavior
|
||||||
|
has_eps: Whether the optimizer accepts epsilon parameter
|
||||||
|
has_momentum: Whether the optimizer accepts momentum parameter
|
||||||
|
has_betas: Whether the optimizer accepts a tuple of beta parameters
|
||||||
|
num_betas: number of betas in tuple (valid IFF has_betas = True)
|
||||||
|
defaults: Optional default parameters for the optimizer
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
opt_class: Union[str, Type[optim.Optimizer]]
|
||||||
|
description: str = ''
|
||||||
|
has_eps: bool = True
|
||||||
|
has_momentum: bool = False
|
||||||
|
has_betas: bool = False
|
||||||
|
num_betas: int = 2
|
||||||
|
defaults: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerRegistry:
|
||||||
|
"""Registry managing optimizer configurations and instantiation.
|
||||||
|
|
||||||
|
This class provides a central registry for optimizer configurations and handles
|
||||||
|
their instantiation with appropriate parameter groups and settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._optimizers: Dict[str, OptimInfo] = {}
|
||||||
|
self._foreach_defaults: Set[str] = {'lion'}
|
||||||
|
|
||||||
|
def register(self, info: OptimInfo) -> None:
|
||||||
|
"""Register an optimizer configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info: The OptimInfo configuration containing name, type and description
|
||||||
|
"""
|
||||||
|
name = info.name.lower()
|
||||||
|
if name in self._optimizers:
|
||||||
|
_logger.warning(f'Optimizer {name} already registered, overwriting')
|
||||||
|
self._optimizers[name] = info
|
||||||
|
|
||||||
|
def register_alias(self, alias: str, target: str) -> None:
|
||||||
|
"""Register an alias for an existing optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alias: The alias name
|
||||||
|
target: The target optimizer name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If target optimizer doesn't exist
|
||||||
|
"""
|
||||||
|
target = target.lower()
|
||||||
|
if target not in self._optimizers:
|
||||||
|
raise KeyError(f'Cannot create alias for non-existent optimizer {target}')
|
||||||
|
self._optimizers[alias.lower()] = self._optimizers[target]
|
||||||
|
|
||||||
|
def register_foreach_default(self, name: str) -> None:
|
||||||
|
"""Register an optimizer as defaulting to foreach=True."""
|
||||||
|
self._foreach_defaults.add(name.lower())
|
||||||
|
|
||||||
|
def list_optimizers(
|
||||||
|
self,
|
||||||
|
filter: str = '',
|
||||||
|
exclude_filters: Optional[List[str]] = None,
|
||||||
|
with_description: bool = False
|
||||||
|
) -> List[Union[str, Tuple[str, str]]]:
|
||||||
|
"""List available optimizer names, optionally filtered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter: Wildcard style filter string (e.g., 'adam*')
|
||||||
|
exclude_filters: Optional list of wildcard patterns to exclude
|
||||||
|
with_description: If True, return tuples of (name, description)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of either optimizer names or (name, description) tuples
|
||||||
|
"""
|
||||||
|
names = sorted(self._optimizers.keys())
|
||||||
|
|
||||||
|
if filter:
|
||||||
|
names = [n for n in names if fnmatch(n, filter)]
|
||||||
|
|
||||||
|
if exclude_filters:
|
||||||
|
for exclude_filter in exclude_filters:
|
||||||
|
names = [n for n in names if not fnmatch(n, exclude_filter)]
|
||||||
|
|
||||||
|
if with_description:
|
||||||
|
return [(name, self._optimizers[name].description) for name in names]
|
||||||
|
return names
|
||||||
|
|
||||||
|
def get_optimizer_info(self, name: str) -> OptimInfo:
|
||||||
|
"""Get the OptimInfo for an optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the optimizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OptimInfo configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If optimizer is not found
|
||||||
|
"""
|
||||||
|
name = name.lower()
|
||||||
|
if name not in self._optimizers:
|
||||||
|
raise ValueError(f'Optimizer {name} not found in registry')
|
||||||
|
return self._optimizers[name]
|
||||||
|
|
||||||
|
def get_optimizer_class(
|
||||||
|
self,
|
||||||
|
name_or_info: Union[str, OptimInfo],
|
||||||
|
bind_defaults: bool = True,
|
||||||
|
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||||
|
"""Get the optimizer class with any default arguments applied.
|
||||||
|
|
||||||
|
This allows direct instantiation of optimizers with their default configs
|
||||||
|
without going through the full factory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_or_info: Name of the optimizer
|
||||||
|
bind_defaults: Bind default arguments to optimizer class via `partial` before returning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimizer class or partial with defaults applied
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If optimizer not found
|
||||||
|
"""
|
||||||
|
if isinstance(name_or_info, str):
|
||||||
|
opt_info = self.get_optimizer_info(name_or_info)
|
||||||
|
else:
|
||||||
|
assert isinstance(name_or_info, OptimInfo)
|
||||||
|
opt_info = name_or_info
|
||||||
|
|
||||||
|
if isinstance(opt_info.opt_class, str):
|
||||||
|
# Special handling for APEX and BNB optimizers
|
||||||
|
if opt_info.opt_class.startswith('apex.'):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA required for APEX optimizers'
|
||||||
|
try:
|
||||||
|
opt_class = _import_class(opt_info.opt_class)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError('APEX optimizers require apex to be installed') from e
|
||||||
|
elif opt_info.opt_class.startswith('bitsandbytes.'):
|
||||||
|
assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers'
|
||||||
|
try:
|
||||||
|
opt_class = _import_class(opt_info.opt_class)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e
|
||||||
|
else:
|
||||||
|
opt_class = _import_class(opt_info.opt_class)
|
||||||
|
else:
|
||||||
|
opt_class = opt_info.opt_class
|
||||||
|
|
||||||
|
# Return class or partial with defaults
|
||||||
|
if bind_defaults and opt_info.defaults:
|
||||||
|
opt_class = partial(opt_class, **opt_info.defaults)
|
||||||
|
|
||||||
|
return opt_class
|
||||||
|
|
||||||
|
def create_optimizer(
|
||||||
|
self,
|
||||||
|
model_or_params: Union[nn.Module, Params],
|
||||||
|
opt: str,
|
||||||
|
lr: Optional[float] = None,
|
||||||
|
weight_decay: float = 0.,
|
||||||
|
momentum: float = 0.9,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
weight_decay_exclude_1d: bool = True,
|
||||||
|
layer_decay: Optional[float] = None,
|
||||||
|
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> optim.Optimizer:
|
||||||
|
"""Create an optimizer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_or_params: Model or parameters to optimize
|
||||||
|
opt: Name of optimizer to create
|
||||||
|
lr: Learning rate
|
||||||
|
weight_decay: Weight decay factor
|
||||||
|
momentum: Momentum factor for applicable optimizers
|
||||||
|
foreach: Enable/disable foreach operation
|
||||||
|
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
|
||||||
|
layer_decay: Layer-wise learning rate decay
|
||||||
|
param_group_fn: Optional custom parameter grouping function
|
||||||
|
**kwargs: Additional optimizer-specific arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured optimizer instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If optimizer not found or configuration invalid
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get parameters to optimize
|
||||||
|
if isinstance(model_or_params, nn.Module):
|
||||||
|
# Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied
|
||||||
|
no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())()
|
||||||
|
|
||||||
|
if param_group_fn:
|
||||||
|
# run custom fn to generate param groups from nn.Module
|
||||||
|
parameters = param_group_fn(model_or_params)
|
||||||
|
elif layer_decay is not None:
|
||||||
|
parameters = param_groups_layer_decay(
|
||||||
|
model_or_params,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
layer_decay=layer_decay,
|
||||||
|
no_weight_decay_list=no_weight_decay,
|
||||||
|
weight_decay_exclude_1d=weight_decay_exclude_1d,
|
||||||
|
)
|
||||||
|
weight_decay = 0.
|
||||||
|
elif weight_decay and weight_decay_exclude_1d:
|
||||||
|
parameters = param_groups_weight_decay(
|
||||||
|
model_or_params,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
no_weight_decay_list=no_weight_decay,
|
||||||
|
)
|
||||||
|
weight_decay = 0.
|
||||||
|
else:
|
||||||
|
parameters = model_or_params.parameters()
|
||||||
|
else:
|
||||||
|
# pass parameters / parameter groups through to optimizer
|
||||||
|
parameters = model_or_params
|
||||||
|
|
||||||
|
# Parse optimizer name
|
||||||
|
opt_split = opt.lower().split('_')
|
||||||
|
opt_name = opt_split[-1]
|
||||||
|
use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False
|
||||||
|
|
||||||
|
opt_info = self.get_optimizer_info(opt_name)
|
||||||
|
|
||||||
|
# Build optimizer arguments
|
||||||
|
opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs}
|
||||||
|
|
||||||
|
# Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None.
|
||||||
|
if lr is not None:
|
||||||
|
opt_args['lr'] = lr
|
||||||
|
|
||||||
|
# Apply optimizer-specific settings
|
||||||
|
if opt_info.defaults:
|
||||||
|
for k, v in opt_info.defaults.items():
|
||||||
|
opt_args.setdefault(k, v)
|
||||||
|
|
||||||
|
# timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat.
|
||||||
|
if opt_info.has_momentum:
|
||||||
|
opt_args.setdefault('momentum', momentum)
|
||||||
|
|
||||||
|
# Remove commonly used kwargs that aren't always supported
|
||||||
|
if not opt_info.has_eps:
|
||||||
|
opt_args.pop('eps', None)
|
||||||
|
if not opt_info.has_betas:
|
||||||
|
opt_args.pop('betas', None)
|
||||||
|
|
||||||
|
if foreach is not None:
|
||||||
|
# Explicitly activate or deactivate multi-tensor foreach impl.
|
||||||
|
# Not all optimizers support this, and those that do usually default to using
|
||||||
|
# multi-tensor impl if foreach is left as default 'None' and can be enabled.
|
||||||
|
opt_args.setdefault('foreach', foreach)
|
||||||
|
|
||||||
|
# Create optimizer
|
||||||
|
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
|
||||||
|
optimizer = opt_class(parameters, **opt_args)
|
||||||
|
|
||||||
|
# Apply Lookahead if requested
|
||||||
|
if use_lookahead:
|
||||||
|
optimizer = Lookahead(optimizer)
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
||||||
|
"""Register SGD-based optimizers"""
|
||||||
|
sgd_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='sgd',
|
||||||
|
opt_class=optim.SGD,
|
||||||
|
description='Stochastic Gradient Descent with Nesterov momentum (default)',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='momentum',
|
||||||
|
opt_class=optim.SGD,
|
||||||
|
description='Stochastic Gradient Descent with classical momentum',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': False}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='sgdp',
|
||||||
|
opt_class=SGDP,
|
||||||
|
description='SGD with built-in projection to unit norm sphere',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='sgdw',
|
||||||
|
opt_class=SGDW,
|
||||||
|
description='SGD with decoupled weight decay and Nesterov momentum',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in sgd_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
||||||
|
"""Register Adam-based optimizers"""
|
||||||
|
adam_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='adam',
|
||||||
|
opt_class=optim.Adam,
|
||||||
|
description='torch.optim Adam (Adaptive Moment Estimation)',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adamw',
|
||||||
|
opt_class=optim.AdamW,
|
||||||
|
description='torch.optim Adam with decoupled weight decay regularization',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adamp',
|
||||||
|
opt_class=AdamP,
|
||||||
|
description='Adam with built-in projection to unit norm sphere',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'wd_ratio': 0.01, 'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='nadam',
|
||||||
|
opt_class=Nadam,
|
||||||
|
description='Adam with Nesterov momentum',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='nadamw',
|
||||||
|
opt_class=NAdamW,
|
||||||
|
description='Adam with Nesterov momentum and decoupled weight decay',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='radam',
|
||||||
|
opt_class=RAdam,
|
||||||
|
description='Rectified Adam with variance adaptation',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adamax',
|
||||||
|
opt_class=optim.Adamax,
|
||||||
|
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adafactor',
|
||||||
|
opt_class=Adafactor,
|
||||||
|
description='Memory-efficient implementation of Adam with factored gradients',
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adafactorbv',
|
||||||
|
opt_class=AdafactorBigVision,
|
||||||
|
description='Big Vision variant of Adafactor with factored gradients, half precision momentum.',
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adopt',
|
||||||
|
opt_class=Adopt,
|
||||||
|
description='Memory-efficient implementation of Adam with factored gradients',
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adoptw',
|
||||||
|
opt_class=Adopt,
|
||||||
|
description='Memory-efficient implementation of Adam with factored gradients',
|
||||||
|
defaults={'decoupled': True}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in adam_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_lamb_lars(registry: OptimizerRegistry) -> None:
|
||||||
|
"""Register LAMB and LARS variants"""
|
||||||
|
lamb_lars_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='lamb',
|
||||||
|
opt_class=Lamb,
|
||||||
|
description='Layer-wise Adaptive Moments for batch optimization',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='lambc',
|
||||||
|
opt_class=Lamb,
|
||||||
|
description='LAMB with trust ratio clipping for stability',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'trust_clip': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='lars',
|
||||||
|
opt_class=Lars,
|
||||||
|
description='Layer-wise Adaptive Rate Scaling',
|
||||||
|
has_momentum=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='larc',
|
||||||
|
opt_class=Lars,
|
||||||
|
description='LARS with trust ratio clipping for stability',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'trust_clip': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='nlars',
|
||||||
|
opt_class=Lars,
|
||||||
|
description='LARS with Nesterov momentum',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='nlarc',
|
||||||
|
opt_class=Lars,
|
||||||
|
description='LARS with Nesterov momentum & trust ratio clipping',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True, 'trust_clip': True}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in lamb_lars_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
||||||
|
"""Register miscellaneous optimizers"""
|
||||||
|
other_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='adabelief',
|
||||||
|
opt_class=AdaBelief,
|
||||||
|
description='Adapts learning rate based on gradient prediction error',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'rectify': False}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='radabelief',
|
||||||
|
opt_class=AdaBelief,
|
||||||
|
description='Rectified AdaBelief with variance adaptation',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'rectify': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adadelta',
|
||||||
|
opt_class=optim.Adadelta,
|
||||||
|
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adagrad',
|
||||||
|
opt_class=optim.Adagrad,
|
||||||
|
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
|
||||||
|
defaults={'eps': 1e-8}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adan',
|
||||||
|
opt_class=Adan,
|
||||||
|
description='Adaptive Nesterov Momentum Algorithm',
|
||||||
|
defaults={'no_prox': False},
|
||||||
|
has_betas=True,
|
||||||
|
num_betas=3
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adanw',
|
||||||
|
opt_class=Adan,
|
||||||
|
description='Adaptive Nesterov Momentum with decoupled weight decay',
|
||||||
|
defaults={'no_prox': True},
|
||||||
|
has_betas=True,
|
||||||
|
num_betas=3
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='lion',
|
||||||
|
opt_class=Lion,
|
||||||
|
description='Evolved Sign Momentum optimizer for improved convergence',
|
||||||
|
has_eps=False,
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='madgrad',
|
||||||
|
opt_class=MADGRAD,
|
||||||
|
description='Momentum-based Adaptive gradient method',
|
||||||
|
has_momentum=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='madgradw',
|
||||||
|
opt_class=MADGRAD,
|
||||||
|
description='MADGRAD with decoupled weight decay',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'decoupled_decay': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='novograd',
|
||||||
|
opt_class=NvNovoGrad,
|
||||||
|
description='Normalized Adam with L2 norm gradient normalization',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='rmsprop',
|
||||||
|
opt_class=optim.RMSprop,
|
||||||
|
description='torch.optim RMSprop, Root Mean Square Propagation',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'alpha': 0.9}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='rmsproptf',
|
||||||
|
opt_class=RMSpropTF,
|
||||||
|
description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation',
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'alpha': 0.9}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in other_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
registry.register_foreach_default('lion')
|
||||||
|
|
||||||
|
|
||||||
|
def _register_apex_optimizers(registry: OptimizerRegistry) -> None:
|
||||||
|
"""Register APEX optimizers (lazy import)"""
|
||||||
|
apex_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='fusedsgd',
|
||||||
|
opt_class='apex.optimizers.FusedSGD',
|
||||||
|
description='NVIDIA APEX fused SGD implementation for faster training',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='fusedadam',
|
||||||
|
opt_class='apex.optimizers.FusedAdam',
|
||||||
|
description='NVIDIA APEX fused Adam implementation',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'adam_w_mode': False}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='fusedadamw',
|
||||||
|
opt_class='apex.optimizers.FusedAdam',
|
||||||
|
description='NVIDIA APEX fused AdamW implementation',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'adam_w_mode': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='fusedlamb',
|
||||||
|
opt_class='apex.optimizers.FusedLAMB',
|
||||||
|
description='NVIDIA APEX fused LAMB implementation',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='fusednovograd',
|
||||||
|
opt_class='apex.optimizers.FusedNovoGrad',
|
||||||
|
description='NVIDIA APEX fused NovoGrad implementation',
|
||||||
|
has_betas=True,
|
||||||
|
defaults={'betas': (0.95, 0.98)}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in apex_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
|
||||||
|
|
||||||
|
def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
|
||||||
|
"""Register bitsandbytes optimizers (lazy import)"""
|
||||||
|
bnb_optimizers = [
|
||||||
|
OptimInfo(
|
||||||
|
name='bnbsgd',
|
||||||
|
opt_class='bitsandbytes.optim.SGD',
|
||||||
|
description='bitsandbytes SGD',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='bnbsgd8bit',
|
||||||
|
opt_class='bitsandbytes.optim.SGD8bit',
|
||||||
|
description='bitsandbytes 8-bit SGD with dynamic quantization',
|
||||||
|
has_eps=False,
|
||||||
|
has_momentum=True,
|
||||||
|
defaults={'nesterov': True}
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='bnbadam',
|
||||||
|
opt_class='bitsandbytes.optim.Adam',
|
||||||
|
description='bitsandbytes Adam',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='bnbadam8bit',
|
||||||
|
opt_class='bitsandbytes.optim.Adam',
|
||||||
|
description='bitsandbytes 8-bit Adam with dynamic quantization',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='bnbadamw',
|
||||||
|
opt_class='bitsandbytes.optim.AdamW',
|
||||||
|
description='bitsandbytes AdamW',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='bnbadamw8bit',
|
||||||
|
opt_class='bitsandbytes.optim.AdamW',
|
||||||
|
description='bitsandbytes 8-bit AdamW with dynamic quantization',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
'bnblion',
|
||||||
|
'bitsandbytes.optim.Lion',
|
||||||
|
description='bitsandbytes Lion',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
'bnblion8bit',
|
||||||
|
'bitsandbytes.optim.Lion8bit',
|
||||||
|
description='bitsandbytes 8-bit Lion with dynamic quantization',
|
||||||
|
has_betas=True
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
'bnbademamix',
|
||||||
|
'bitsandbytes.optim.AdEMAMix',
|
||||||
|
description='bitsandbytes AdEMAMix',
|
||||||
|
has_betas=True,
|
||||||
|
num_betas=3,
|
||||||
|
),
|
||||||
|
OptimInfo(
|
||||||
|
'bnbademamix8bit',
|
||||||
|
'bitsandbytes.optim.AdEMAMix8bit',
|
||||||
|
description='bitsandbytes 8-bit AdEMAMix with dynamic quantization',
|
||||||
|
has_betas=True,
|
||||||
|
num_betas=3,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for opt in bnb_optimizers:
|
||||||
|
registry.register(opt)
|
||||||
|
|
||||||
|
|
||||||
|
default_registry = OptimizerRegistry()
|
||||||
|
|
||||||
|
def _register_default_optimizers() -> None:
|
||||||
|
"""Register all default optimizers to the global registry."""
|
||||||
|
# Register all optimizer groups
|
||||||
|
_register_sgd_variants(default_registry)
|
||||||
|
_register_adam_variants(default_registry)
|
||||||
|
_register_lamb_lars(default_registry)
|
||||||
|
_register_other_optimizers(default_registry)
|
||||||
|
_register_apex_optimizers(default_registry)
|
||||||
|
_register_bnb_optimizers(default_registry)
|
||||||
|
|
||||||
|
# Register aliases
|
||||||
|
default_registry.register_alias('nesterov', 'sgd')
|
||||||
|
default_registry.register_alias('nesterovw', 'sgdw')
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize default registry
|
||||||
|
_register_default_optimizers()
|
||||||
|
|
||||||
|
# Public API
|
||||||
|
|
||||||
|
def list_optimizers(
|
||||||
|
filter: str = '',
|
||||||
|
exclude_filters: Optional[List[str]] = None,
|
||||||
|
with_description: bool = False,
|
||||||
|
) -> List[Union[str, Tuple[str, str]]]:
|
||||||
|
"""List available optimizer names, optionally filtered.
|
||||||
|
"""
|
||||||
|
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_class(
|
||||||
|
name: str,
|
||||||
|
bind_defaults: bool = False,
|
||||||
|
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||||||
|
"""Get optimizer class by name with any defaults applied.
|
||||||
|
"""
|
||||||
|
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimizer_v2(
|
||||||
|
model_or_params: Union[nn.Module, Params],
|
||||||
|
opt: str = 'sgd',
|
||||||
|
lr: Optional[float] = None,
|
||||||
|
weight_decay: float = 0.,
|
||||||
|
momentum: float = 0.9,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
filter_bias_and_bn: bool = True,
|
||||||
|
layer_decay: Optional[float] = None,
|
||||||
|
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> optim.Optimizer:
|
||||||
|
"""Create an optimizer instance using the default registry."""
|
||||||
|
return default_registry.create_optimizer(
|
||||||
|
model_or_params,
|
||||||
|
opt=opt,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
momentum=momentum,
|
||||||
|
foreach=foreach,
|
||||||
|
weight_decay_exclude_1d=filter_bias_and_bn,
|
||||||
|
layer_decay=layer_decay,
|
||||||
|
param_group_fn=param_group_fn,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer_kwargs(cfg):
|
||||||
|
""" cfg/argparse to kwargs helper
|
||||||
|
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
||||||
|
"""
|
||||||
|
kwargs = dict(
|
||||||
|
opt=cfg.opt,
|
||||||
|
lr=cfg.lr,
|
||||||
|
weight_decay=cfg.weight_decay,
|
||||||
|
momentum=cfg.momentum,
|
||||||
|
)
|
||||||
|
if getattr(cfg, 'opt_eps', None) is not None:
|
||||||
|
kwargs['eps'] = cfg.opt_eps
|
||||||
|
if getattr(cfg, 'opt_betas', None) is not None:
|
||||||
|
kwargs['betas'] = cfg.opt_betas
|
||||||
|
if getattr(cfg, 'layer_decay', None) is not None:
|
||||||
|
kwargs['layer_decay'] = cfg.layer_decay
|
||||||
|
if getattr(cfg, 'opt_args', None) is not None:
|
||||||
|
kwargs.update(cfg.opt_args)
|
||||||
|
if getattr(cfg, 'opt_foreach', None) is not None:
|
||||||
|
kwargs['foreach'] = cfg.opt_foreach
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||||
|
""" Legacy optimizer factory for backwards compatibility.
|
||||||
|
NOTE: Use create_optimizer_v2 for new code.
|
||||||
|
"""
|
||||||
|
return create_optimizer_v2(
|
||||||
|
model,
|
||||||
|
**optimizer_kwargs(cfg=args),
|
||||||
|
filter_bias_and_bn=filter_bias_and_bn,
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,129 @@
|
||||||
|
import logging
|
||||||
|
from itertools import islice
|
||||||
|
from typing import Collection, Optional, Tuple
|
||||||
|
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from timm.models import group_parameters
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def param_groups_weight_decay(
|
||||||
|
model: nn.Module,
|
||||||
|
weight_decay: float = 1e-5,
|
||||||
|
no_weight_decay_list: Collection[str] = (),
|
||||||
|
):
|
||||||
|
no_weight_decay_list = set(no_weight_decay_list)
|
||||||
|
decay = []
|
||||||
|
no_decay = []
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
||||||
|
no_decay.append(param)
|
||||||
|
else:
|
||||||
|
decay.append(param)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{'params': no_decay, 'weight_decay': 0.},
|
||||||
|
{'params': decay, 'weight_decay': weight_decay}]
|
||||||
|
|
||||||
|
|
||||||
|
def _group(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def _layer_map(model, layers_per_group=12, num_groups=None):
|
||||||
|
def _in_head(n, hp):
|
||||||
|
if not hp:
|
||||||
|
return True
|
||||||
|
elif isinstance(hp, (tuple, list)):
|
||||||
|
return any([n.startswith(hpi) for hpi in hp])
|
||||||
|
else:
|
||||||
|
return n.startswith(hp)
|
||||||
|
|
||||||
|
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
|
||||||
|
names_trunk = []
|
||||||
|
names_head = []
|
||||||
|
for n, _ in model.named_parameters():
|
||||||
|
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
|
||||||
|
|
||||||
|
# group non-head layers
|
||||||
|
num_trunk_layers = len(names_trunk)
|
||||||
|
if num_groups is not None:
|
||||||
|
layers_per_group = -(num_trunk_layers // -num_groups)
|
||||||
|
names_trunk = list(_group(names_trunk, layers_per_group))
|
||||||
|
|
||||||
|
num_trunk_groups = len(names_trunk)
|
||||||
|
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
|
||||||
|
layer_map.update({n: num_trunk_groups for n in names_head})
|
||||||
|
return layer_map
|
||||||
|
|
||||||
|
|
||||||
|
def param_groups_layer_decay(
|
||||||
|
model: nn.Module,
|
||||||
|
weight_decay: float = 0.05,
|
||||||
|
no_weight_decay_list: Collection[str] = (),
|
||||||
|
weight_decay_exclude_1d: bool = True,
|
||||||
|
layer_decay: float = .75,
|
||||||
|
end_layer_decay: Optional[float] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parameter groups for layer-wise lr decay & weight decay
|
||||||
|
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
||||||
|
"""
|
||||||
|
no_weight_decay_list = set(no_weight_decay_list)
|
||||||
|
param_group_names = {} # NOTE for debugging
|
||||||
|
param_groups = {}
|
||||||
|
|
||||||
|
if hasattr(model, 'group_matcher'):
|
||||||
|
# FIXME interface needs more work
|
||||||
|
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
||||||
|
else:
|
||||||
|
# fallback
|
||||||
|
layer_map = _layer_map(model)
|
||||||
|
num_layers = max(layer_map.values()) + 1
|
||||||
|
layer_max = num_layers - 1
|
||||||
|
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# no decay: all 1D parameters and model specific ones
|
||||||
|
if (weight_decay_exclude_1d and param.ndim <= 1) or name in no_weight_decay_list:
|
||||||
|
g_decay = "no_decay"
|
||||||
|
this_decay = 0.
|
||||||
|
else:
|
||||||
|
g_decay = "decay"
|
||||||
|
this_decay = weight_decay
|
||||||
|
|
||||||
|
layer_id = layer_map.get(name, layer_max)
|
||||||
|
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
||||||
|
|
||||||
|
if group_name not in param_groups:
|
||||||
|
this_scale = layer_scales[layer_id]
|
||||||
|
param_group_names[group_name] = {
|
||||||
|
"lr_scale": this_scale,
|
||||||
|
"weight_decay": this_decay,
|
||||||
|
"param_names": [],
|
||||||
|
}
|
||||||
|
param_groups[group_name] = {
|
||||||
|
"lr_scale": this_scale,
|
||||||
|
"weight_decay": this_decay,
|
||||||
|
"params": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
param_group_names[group_name]["param_names"].append(name)
|
||||||
|
param_groups[group_name]["params"].append(param)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
import json
|
||||||
|
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
||||||
|
|
||||||
|
return list(param_groups.values())
|
|
@ -40,9 +40,18 @@ class AdaBelief(Optimizer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
|
self,
|
||||||
decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
|
params,
|
||||||
|
lr=1e-3,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-16,
|
||||||
|
weight_decay=0,
|
||||||
|
amsgrad=False,
|
||||||
|
decoupled_decay=True,
|
||||||
|
fixed_decay=False,
|
||||||
|
rectify=True,
|
||||||
|
degenerated_to_sgd=True,
|
||||||
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -58,9 +67,17 @@ class AdaBelief(Optimizer):
|
||||||
param['buffer'] = [[None, None, None] for _ in range(10)]
|
param['buffer'] = [[None, None, None] for _ in range(10)]
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
|
lr=lr,
|
||||||
degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
|
betas=betas,
|
||||||
fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
amsgrad=amsgrad,
|
||||||
|
degenerated_to_sgd=degenerated_to_sgd,
|
||||||
|
decoupled_decay=decoupled_decay,
|
||||||
|
rectify=rectify,
|
||||||
|
fixed_decay=fixed_decay,
|
||||||
|
buffer=[[None, None, None] for _ in range(10)]
|
||||||
|
)
|
||||||
super(AdaBelief, self).__init__(params, defaults)
|
super(AdaBelief, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
@ -16,6 +16,7 @@ import math
|
||||||
|
|
||||||
class Adafactor(torch.optim.Optimizer):
|
class Adafactor(torch.optim.Optimizer):
|
||||||
"""Implements Adafactor algorithm.
|
"""Implements Adafactor algorithm.
|
||||||
|
|
||||||
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
|
||||||
(see https://arxiv.org/abs/1804.04235)
|
(see https://arxiv.org/abs/1804.04235)
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,18 @@ class Adahessian(torch.optim.Optimizer):
|
||||||
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
|
def __init__(
|
||||||
hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
|
self,
|
||||||
|
params,
|
||||||
|
lr=0.1,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0.0,
|
||||||
|
hessian_power=1.0,
|
||||||
|
update_each=1,
|
||||||
|
n_samples=1,
|
||||||
|
avg_conv_kernel=False,
|
||||||
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError(f"Invalid learning rate: {lr}")
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -44,7 +54,13 @@ class Adahessian(torch.optim.Optimizer):
|
||||||
self.seed = 2147483647
|
self.seed = 2147483647
|
||||||
self.generator = torch.Generator().manual_seed(self.seed)
|
self.generator = torch.Generator().manual_seed(self.seed)
|
||||||
|
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
hessian_power=hessian_power,
|
||||||
|
)
|
||||||
super(Adahessian, self).__init__(params, defaults)
|
super(Adahessian, self).__init__(params, defaults)
|
||||||
|
|
||||||
for p in self.get_params():
|
for p in self.get_params():
|
||||||
|
|
|
@ -41,11 +41,26 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
|
||||||
|
|
||||||
|
|
||||||
class AdamP(Optimizer):
|
class AdamP(Optimizer):
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
def __init__(
|
||||||
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-3,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0,
|
||||||
|
delta=0.1,
|
||||||
|
wd_ratio=0.1,
|
||||||
|
nesterov=False,
|
||||||
|
):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
lr=lr,
|
||||||
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
delta=delta,
|
||||||
|
wd_ratio=wd_ratio,
|
||||||
|
nesterov=nesterov,
|
||||||
|
)
|
||||||
super(AdamP, self).__init__(params, defaults)
|
super(AdamP, self).__init__(params, defaults)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -36,8 +36,16 @@ class AdamW(Optimizer):
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
def __init__(
|
||||||
weight_decay=1e-2, amsgrad=False):
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-3,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=1e-2,
|
||||||
|
amsgrad=False,
|
||||||
|
):
|
||||||
|
# NOTE: deprecated in favour of builtin torch.optim.AdamW
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -46,8 +54,13 @@ class AdamW(Optimizer):
|
||||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
if not 0.0 <= betas[1] < 1.0:
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
defaults = dict(
|
||||||
weight_decay=weight_decay, amsgrad=amsgrad)
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
amsgrad=amsgrad,
|
||||||
|
)
|
||||||
super(AdamW, self).__init__(params, defaults)
|
super(AdamW, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
@ -137,7 +137,7 @@ def lion(
|
||||||
"""
|
"""
|
||||||
if foreach is None:
|
if foreach is None:
|
||||||
# Placeholder for more complex foreach logic to be added when value is not set
|
# Placeholder for more complex foreach logic to be added when value is not set
|
||||||
foreach = False
|
foreach = True
|
||||||
|
|
||||||
if foreach and torch.jit.is_scripting():
|
if foreach and torch.jit.is_scripting():
|
||||||
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
||||||
|
|
|
@ -71,7 +71,12 @@ class MADGRAD(torch.optim.Optimizer):
|
||||||
raise ValueError(f"Eps must be non-negative")
|
raise ValueError(f"Eps must be non-negative")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
lr=lr,
|
||||||
|
eps=eps,
|
||||||
|
momentum=momentum,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled_decay=decoupled_decay,
|
||||||
|
)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -27,8 +27,15 @@ class Nadam(Optimizer):
|
||||||
NOTE: Has potential issues but does work well on some problems.
|
NOTE: Has potential issues but does work well on some problems.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
|
def __init__(
|
||||||
weight_decay=0, schedule_decay=4e-3):
|
self,
|
||||||
|
params,
|
||||||
|
lr=2e-3,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0,
|
||||||
|
schedule_decay=4e-3,
|
||||||
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
|
|
|
@ -29,8 +29,16 @@ class NvNovoGrad(Optimizer):
|
||||||
(default: False)
|
(default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
|
def __init__(
|
||||||
weight_decay=0, grad_averaging=False, amsgrad=False):
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-3,
|
||||||
|
betas=(0.95, 0.98),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0,
|
||||||
|
grad_averaging=False,
|
||||||
|
amsgrad=False,
|
||||||
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -39,10 +47,14 @@ class NvNovoGrad(Optimizer):
|
||||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||||
if not 0.0 <= betas[1] < 1.0:
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
defaults = dict(
|
||||||
weight_decay=weight_decay,
|
lr=lr,
|
||||||
grad_averaging=grad_averaging,
|
betas=betas,
|
||||||
amsgrad=amsgrad)
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
grad_averaging=grad_averaging,
|
||||||
|
amsgrad=amsgrad,
|
||||||
|
)
|
||||||
|
|
||||||
super(NvNovoGrad, self).__init__(params, defaults)
|
super(NvNovoGrad, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
|
|
@ -1,431 +0,0 @@
|
||||||
""" Optimizer Factory w/ Custom Weight Decay
|
|
||||||
Hacked together by / Copyright 2021 Ross Wightman
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from itertools import islice
|
|
||||||
from typing import Optional, Callable, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
|
|
||||||
from timm.models import group_parameters
|
|
||||||
from . import AdafactorBigVision
|
|
||||||
|
|
||||||
from .adabelief import AdaBelief
|
|
||||||
from .adafactor import Adafactor
|
|
||||||
from .adahessian import Adahessian
|
|
||||||
from .adamp import AdamP
|
|
||||||
from .adan import Adan
|
|
||||||
from .adopt import Adopt
|
|
||||||
from .lamb import Lamb
|
|
||||||
from .lars import Lars
|
|
||||||
from .lion import Lion
|
|
||||||
from .lookahead import Lookahead
|
|
||||||
from .madgrad import MADGRAD
|
|
||||||
from .nadam import Nadam
|
|
||||||
from .nadamw import NAdamW
|
|
||||||
from .nvnovograd import NvNovoGrad
|
|
||||||
from .radam import RAdam
|
|
||||||
from .rmsprop_tf import RMSpropTF
|
|
||||||
from .sgdp import SGDP
|
|
||||||
from .sgdw import SGDW
|
|
||||||
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# optimizers to default to multi-tensor
|
|
||||||
_DEFAULT_FOREACH = {
|
|
||||||
'lion',
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def param_groups_weight_decay(
|
|
||||||
model: nn.Module,
|
|
||||||
weight_decay=1e-5,
|
|
||||||
no_weight_decay_list=()
|
|
||||||
):
|
|
||||||
no_weight_decay_list = set(no_weight_decay_list)
|
|
||||||
decay = []
|
|
||||||
no_decay = []
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
|
||||||
no_decay.append(param)
|
|
||||||
else:
|
|
||||||
decay.append(param)
|
|
||||||
|
|
||||||
return [
|
|
||||||
{'params': no_decay, 'weight_decay': 0.},
|
|
||||||
{'params': decay, 'weight_decay': weight_decay}]
|
|
||||||
|
|
||||||
|
|
||||||
def _group(it, size):
|
|
||||||
it = iter(it)
|
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
|
||||||
|
|
||||||
|
|
||||||
def _layer_map(model, layers_per_group=12, num_groups=None):
|
|
||||||
def _in_head(n, hp):
|
|
||||||
if not hp:
|
|
||||||
return True
|
|
||||||
elif isinstance(hp, (tuple, list)):
|
|
||||||
return any([n.startswith(hpi) for hpi in hp])
|
|
||||||
else:
|
|
||||||
return n.startswith(hp)
|
|
||||||
|
|
||||||
head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
|
|
||||||
names_trunk = []
|
|
||||||
names_head = []
|
|
||||||
for n, _ in model.named_parameters():
|
|
||||||
names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
|
|
||||||
|
|
||||||
# group non-head layers
|
|
||||||
num_trunk_layers = len(names_trunk)
|
|
||||||
if num_groups is not None:
|
|
||||||
layers_per_group = -(num_trunk_layers // -num_groups)
|
|
||||||
names_trunk = list(_group(names_trunk, layers_per_group))
|
|
||||||
|
|
||||||
num_trunk_groups = len(names_trunk)
|
|
||||||
layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
|
|
||||||
layer_map.update({n: num_trunk_groups for n in names_head})
|
|
||||||
return layer_map
|
|
||||||
|
|
||||||
|
|
||||||
def param_groups_layer_decay(
|
|
||||||
model: nn.Module,
|
|
||||||
weight_decay: float = 0.05,
|
|
||||||
no_weight_decay_list: Tuple[str] = (),
|
|
||||||
layer_decay: float = .75,
|
|
||||||
end_layer_decay: Optional[float] = None,
|
|
||||||
verbose: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parameter groups for layer-wise lr decay & weight decay
|
|
||||||
Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
|
||||||
"""
|
|
||||||
no_weight_decay_list = set(no_weight_decay_list)
|
|
||||||
param_group_names = {} # NOTE for debugging
|
|
||||||
param_groups = {}
|
|
||||||
|
|
||||||
if hasattr(model, 'group_matcher'):
|
|
||||||
# FIXME interface needs more work
|
|
||||||
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
|
||||||
else:
|
|
||||||
# fallback
|
|
||||||
layer_map = _layer_map(model)
|
|
||||||
num_layers = max(layer_map.values()) + 1
|
|
||||||
layer_max = num_layers - 1
|
|
||||||
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# no decay: all 1D parameters and model specific ones
|
|
||||||
if param.ndim == 1 or name in no_weight_decay_list:
|
|
||||||
g_decay = "no_decay"
|
|
||||||
this_decay = 0.
|
|
||||||
else:
|
|
||||||
g_decay = "decay"
|
|
||||||
this_decay = weight_decay
|
|
||||||
|
|
||||||
layer_id = layer_map.get(name, layer_max)
|
|
||||||
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
|
||||||
|
|
||||||
if group_name not in param_groups:
|
|
||||||
this_scale = layer_scales[layer_id]
|
|
||||||
param_group_names[group_name] = {
|
|
||||||
"lr_scale": this_scale,
|
|
||||||
"weight_decay": this_decay,
|
|
||||||
"param_names": [],
|
|
||||||
}
|
|
||||||
param_groups[group_name] = {
|
|
||||||
"lr_scale": this_scale,
|
|
||||||
"weight_decay": this_decay,
|
|
||||||
"params": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
param_group_names[group_name]["param_names"].append(name)
|
|
||||||
param_groups[group_name]["params"].append(param)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
import json
|
|
||||||
_logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
|
||||||
|
|
||||||
return list(param_groups.values())
|
|
||||||
|
|
||||||
|
|
||||||
def optimizer_kwargs(cfg):
|
|
||||||
""" cfg/argparse to kwargs helper
|
|
||||||
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
|
||||||
"""
|
|
||||||
kwargs = dict(
|
|
||||||
opt=cfg.opt,
|
|
||||||
lr=cfg.lr,
|
|
||||||
weight_decay=cfg.weight_decay,
|
|
||||||
momentum=cfg.momentum,
|
|
||||||
)
|
|
||||||
if getattr(cfg, 'opt_eps', None) is not None:
|
|
||||||
kwargs['eps'] = cfg.opt_eps
|
|
||||||
if getattr(cfg, 'opt_betas', None) is not None:
|
|
||||||
kwargs['betas'] = cfg.opt_betas
|
|
||||||
if getattr(cfg, 'layer_decay', None) is not None:
|
|
||||||
kwargs['layer_decay'] = cfg.layer_decay
|
|
||||||
if getattr(cfg, 'opt_args', None) is not None:
|
|
||||||
kwargs.update(cfg.opt_args)
|
|
||||||
if getattr(cfg, 'opt_foreach', None) is not None:
|
|
||||||
kwargs['foreach'] = cfg.opt_foreach
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
|
||||||
""" Legacy optimizer factory for backwards compatibility.
|
|
||||||
NOTE: Use create_optimizer_v2 for new code.
|
|
||||||
"""
|
|
||||||
return create_optimizer_v2(
|
|
||||||
model,
|
|
||||||
**optimizer_kwargs(cfg=args),
|
|
||||||
filter_bias_and_bn=filter_bias_and_bn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer_v2(
|
|
||||||
model_or_params,
|
|
||||||
opt: str = 'sgd',
|
|
||||||
lr: Optional[float] = None,
|
|
||||||
weight_decay: float = 0.,
|
|
||||||
momentum: float = 0.9,
|
|
||||||
foreach: Optional[bool] = None,
|
|
||||||
filter_bias_and_bn: bool = True,
|
|
||||||
layer_decay: Optional[float] = None,
|
|
||||||
param_group_fn: Optional[Callable] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
""" Create an optimizer.
|
|
||||||
|
|
||||||
TODO currently the model is passed in and all parameters are selected for optimization.
|
|
||||||
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
|
|
||||||
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
|
|
||||||
* expose the parameters interface and leave it up to caller
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_or_params (nn.Module): model containing parameters to optimize
|
|
||||||
opt: name of optimizer to create
|
|
||||||
lr: initial learning rate
|
|
||||||
weight_decay: weight decay to apply in optimizer
|
|
||||||
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
|
|
||||||
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
|
|
||||||
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
|
|
||||||
**kwargs: extra optimizer specific kwargs to pass through
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optimizer
|
|
||||||
"""
|
|
||||||
if isinstance(model_or_params, nn.Module):
|
|
||||||
# a model was passed in, extract parameters and add weight decays to appropriate layers
|
|
||||||
no_weight_decay = {}
|
|
||||||
if hasattr(model_or_params, 'no_weight_decay'):
|
|
||||||
no_weight_decay = model_or_params.no_weight_decay()
|
|
||||||
|
|
||||||
if param_group_fn:
|
|
||||||
parameters = param_group_fn(model_or_params)
|
|
||||||
elif layer_decay is not None:
|
|
||||||
parameters = param_groups_layer_decay(
|
|
||||||
model_or_params,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
layer_decay=layer_decay,
|
|
||||||
no_weight_decay_list=no_weight_decay,
|
|
||||||
)
|
|
||||||
weight_decay = 0.
|
|
||||||
elif weight_decay and filter_bias_and_bn:
|
|
||||||
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
|
|
||||||
weight_decay = 0.
|
|
||||||
else:
|
|
||||||
parameters = model_or_params.parameters()
|
|
||||||
else:
|
|
||||||
# iterable of parameters or param groups passed in
|
|
||||||
parameters = model_or_params
|
|
||||||
|
|
||||||
opt_lower = opt.lower()
|
|
||||||
opt_split = opt_lower.split('_')
|
|
||||||
opt_lower = opt_split[-1]
|
|
||||||
|
|
||||||
if opt_lower.startswith('fused'):
|
|
||||||
try:
|
|
||||||
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
|
||||||
has_apex = True
|
|
||||||
except ImportError:
|
|
||||||
has_apex = False
|
|
||||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
|
||||||
|
|
||||||
if opt_lower.startswith('bnb'):
|
|
||||||
try:
|
|
||||||
import bitsandbytes as bnb
|
|
||||||
has_bnb = True
|
|
||||||
except ImportError:
|
|
||||||
has_bnb = False
|
|
||||||
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
|
|
||||||
|
|
||||||
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
|
||||||
|
|
||||||
if lr is not None:
|
|
||||||
opt_args.setdefault('lr', lr)
|
|
||||||
|
|
||||||
if foreach is None:
|
|
||||||
if opt in _DEFAULT_FOREACH:
|
|
||||||
opt_args.setdefault('foreach', True)
|
|
||||||
else:
|
|
||||||
opt_args['foreach'] = foreach
|
|
||||||
|
|
||||||
# basic SGD & related
|
|
||||||
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
|
||||||
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'momentum':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
|
||||||
elif opt_lower == 'sgdp':
|
|
||||||
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'sgdw' or opt_lower == 'nesterovw':
|
|
||||||
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'momentumw':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args)
|
|
||||||
|
|
||||||
# adaptive
|
|
||||||
elif opt_lower == 'adam':
|
|
||||||
optimizer = optim.Adam(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adamw':
|
|
||||||
optimizer = optim.AdamW(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adamp':
|
|
||||||
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'nadam':
|
|
||||||
try:
|
|
||||||
# NOTE PyTorch >= 1.10 should have native NAdam
|
|
||||||
optimizer = optim.Nadam(parameters, **opt_args)
|
|
||||||
except AttributeError:
|
|
||||||
optimizer = Nadam(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'nadamw':
|
|
||||||
optimizer = NAdamW(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'radam':
|
|
||||||
optimizer = RAdam(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adamax':
|
|
||||||
optimizer = optim.Adamax(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adabelief':
|
|
||||||
optimizer = AdaBelief(parameters, rectify=False, **opt_args)
|
|
||||||
elif opt_lower == 'radabelief':
|
|
||||||
optimizer = AdaBelief(parameters, rectify=True, **opt_args)
|
|
||||||
elif opt_lower == 'adadelta':
|
|
||||||
optimizer = optim.Adadelta(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adagrad':
|
|
||||||
opt_args.setdefault('eps', 1e-8)
|
|
||||||
optimizer = optim.Adagrad(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adafactor':
|
|
||||||
optimizer = Adafactor(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adanp':
|
|
||||||
optimizer = Adan(parameters, no_prox=False, **opt_args)
|
|
||||||
elif opt_lower == 'adanw':
|
|
||||||
optimizer = Adan(parameters, no_prox=True, **opt_args)
|
|
||||||
elif opt_lower == 'lamb':
|
|
||||||
optimizer = Lamb(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'lambc':
|
|
||||||
optimizer = Lamb(parameters, trust_clip=True, **opt_args)
|
|
||||||
elif opt_lower == 'larc':
|
|
||||||
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
|
|
||||||
elif opt_lower == 'lars':
|
|
||||||
optimizer = Lars(parameters, momentum=momentum, **opt_args)
|
|
||||||
elif opt_lower == 'nlarc':
|
|
||||||
optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'nlars':
|
|
||||||
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'madgrad':
|
|
||||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
|
||||||
elif opt_lower == 'madgradw':
|
|
||||||
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
|
|
||||||
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
|
||||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'rmsprop':
|
|
||||||
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
|
||||||
elif opt_lower == 'rmsproptf':
|
|
||||||
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
|
||||||
elif opt_lower == 'lion':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = Lion(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adafactorbv':
|
|
||||||
optimizer = AdafactorBigVision(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adopt':
|
|
||||||
optimizer = Adopt(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'adoptw':
|
|
||||||
optimizer = Adopt(parameters, decoupled=True, **opt_args)
|
|
||||||
|
|
||||||
# second order
|
|
||||||
elif opt_lower == 'adahessian':
|
|
||||||
optimizer = Adahessian(parameters, **opt_args)
|
|
||||||
|
|
||||||
# NVIDIA fused optimizers, require APEX to be installed
|
|
||||||
elif opt_lower == 'fusedsgd':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'fusedmomentum':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
|
||||||
elif opt_lower == 'fusedadam':
|
|
||||||
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
|
||||||
elif opt_lower == 'fusedadamw':
|
|
||||||
optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
|
|
||||||
elif opt_lower == 'fusedlamb':
|
|
||||||
optimizer = FusedLAMB(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'fusednovograd':
|
|
||||||
opt_args.setdefault('betas', (0.95, 0.98))
|
|
||||||
optimizer = FusedNovoGrad(parameters, **opt_args)
|
|
||||||
|
|
||||||
# bitsandbytes optimizers, require bitsandbytes to be installed
|
|
||||||
elif opt_lower == 'bnbsgd':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'bnbsgd8bit':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
|
|
||||||
elif opt_lower == 'bnbmomentum':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
|
|
||||||
elif opt_lower == 'bnbmomentum8bit':
|
|
||||||
opt_args.pop('eps', None)
|
|
||||||
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
|
|
||||||
elif opt_lower == 'bnbadam':
|
|
||||||
optimizer = bnb.optim.Adam(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnbadam8bit':
|
|
||||||
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnbadamw':
|
|
||||||
optimizer = bnb.optim.AdamW(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnbadamw8bit':
|
|
||||||
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnblamb':
|
|
||||||
optimizer = bnb.optim.LAMB(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnblamb8bit':
|
|
||||||
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnblars':
|
|
||||||
optimizer = bnb.optim.LARS(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnblarsb8bit':
|
|
||||||
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnblion':
|
|
||||||
optimizer = bnb.optim.Lion(parameters, **opt_args)
|
|
||||||
elif opt_lower == 'bnblion8bit':
|
|
||||||
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert False and "Invalid optimizer"
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
if len(opt_split) > 1:
|
|
||||||
if opt_split[0] == 'lookahead':
|
|
||||||
optimizer = Lookahead(optimizer)
|
|
||||||
|
|
||||||
return optimizer
|
|
|
@ -9,10 +9,21 @@ from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
class RAdam(Optimizer):
|
class RAdam(Optimizer):
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-3,
|
||||||
|
betas=(0.9, 0.999),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0,
|
||||||
|
):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
lr=lr,
|
||||||
buffer=[[None, None, None] for _ in range(10)])
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
buffer=[[None, None, None] for _ in range(10)]
|
||||||
|
)
|
||||||
super(RAdam, self).__init__(params, defaults)
|
super(RAdam, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
@ -45,8 +45,18 @@ class RMSpropTF(Optimizer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
|
def __init__(
|
||||||
decoupled_decay=False, lr_in_momentum=True):
|
self,
|
||||||
|
params,
|
||||||
|
lr=1e-2,
|
||||||
|
alpha=0.9,
|
||||||
|
eps=1e-10,
|
||||||
|
weight_decay=0,
|
||||||
|
momentum=0.,
|
||||||
|
centered=False,
|
||||||
|
decoupled_decay=False,
|
||||||
|
lr_in_momentum=True,
|
||||||
|
):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -59,8 +69,15 @@ class RMSpropTF(Optimizer):
|
||||||
raise ValueError("Invalid alpha value: {}".format(alpha))
|
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
|
lr=lr,
|
||||||
decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
|
momentum=momentum,
|
||||||
|
alpha=alpha,
|
||||||
|
eps=eps,
|
||||||
|
centered=centered,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled_decay=decoupled_decay,
|
||||||
|
lr_in_momentum=lr_in_momentum,
|
||||||
|
)
|
||||||
super(RMSpropTF, self).__init__(params, defaults)
|
super(RMSpropTF, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
@ -17,11 +17,28 @@ from .adamp import projection
|
||||||
|
|
||||||
|
|
||||||
class SGDP(Optimizer):
|
class SGDP(Optimizer):
|
||||||
def __init__(self, params, lr=required, momentum=0, dampening=0,
|
def __init__(
|
||||||
weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
|
self,
|
||||||
|
params,
|
||||||
|
lr=required,
|
||||||
|
momentum=0,
|
||||||
|
dampening=0,
|
||||||
|
weight_decay=0,
|
||||||
|
nesterov=False,
|
||||||
|
eps=1e-8,
|
||||||
|
delta=0.1,
|
||||||
|
wd_ratio=0.1
|
||||||
|
):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
|
lr=lr,
|
||||||
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
momentum=momentum,
|
||||||
|
dampening=dampening,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
nesterov=nesterov,
|
||||||
|
eps=eps,
|
||||||
|
delta=delta,
|
||||||
|
wd_ratio=wd_ratio,
|
||||||
|
)
|
||||||
super(SGDP, self).__init__(params, defaults)
|
super(SGDP, self).__init__(params, defaults)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -35,10 +35,15 @@ class SGDW(Optimizer):
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, momentum=momentum, dampening=dampening,
|
lr=lr,
|
||||||
weight_decay=weight_decay, nesterov=nesterov,
|
momentum=momentum,
|
||||||
maximize=maximize, foreach=foreach,
|
dampening=dampening,
|
||||||
differentiable=differentiable)
|
weight_decay=weight_decay,
|
||||||
|
nesterov=nesterov,
|
||||||
|
maximize=maximize,
|
||||||
|
foreach=foreach,
|
||||||
|
differentiable=differentiable,
|
||||||
|
)
|
||||||
if nesterov and (momentum <= 0 or dampening != 0):
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
8
train.py
8
train.py
|
@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
||||||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -554,6 +555,13 @@ def main():
|
||||||
**optimizer_kwargs(cfg=args),
|
**optimizer_kwargs(cfg=args),
|
||||||
**args.opt_kwargs,
|
**args.opt_kwargs,
|
||||||
)
|
)
|
||||||
|
if utils.is_primary(args):
|
||||||
|
defaults = copy.deepcopy(optimizer.defaults)
|
||||||
|
defaults['weight_decay'] = args.weight_decay # this isn't stored in optimizer.defaults
|
||||||
|
defaults = ', '.join([f'{k}: {v}' for k, v in defaults.items()])
|
||||||
|
logging.info(
|
||||||
|
f'Created {type(optimizer).__name__} ({args.opt}) optimizer: {defaults}'
|
||||||
|
)
|
||||||
|
|
||||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||||
amp_autocast = suppress # do nothing
|
amp_autocast = suppress # do nothing
|
||||||
|
|
Loading…
Reference in New Issue