Update tests, need handling for radamw with older PyTorch, need to back-off basic test LR in mars?

This commit is contained in:
Ross Wightman 2024-11-26 12:13:21 -08:00 committed by Ross Wightman
parent 09bc21774e
commit 73d10ab482

View File

@ -9,6 +9,7 @@ import functools
from copy import deepcopy
import torch
from IPython.testing.decorators import skip_win32
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter
@ -299,27 +300,39 @@ def test_optim_factory(optimizer):
opt_info = get_optimizer_info(optimizer)
assert isinstance(opt_info, OptimInfo)
lr = (1e-3, 1e-2, 1e-2, 1e-2)
if optimizer in ('mars',):
lr = (1e-3, 1e-3, 1e-3, 1e-3)
try:
if not opt_info.second_order: # basic tests don't support second order right now
# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0])
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
_build_params_dict(weight, bias, lr=lr[1]),
optimizer,
lr=1e-3)
lr=lr[1] / 10)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
_build_params_dict_single(weight, bias, lr=lr[2]),
optimizer,
lr=1e-3)
lr=lr[2] / 10)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
_build_params_dict_single(weight, bias, lr=lr[3]),
optimizer)
)
except TypeError as e:
if 'radamw' in optimizer:
pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.")
else:
raise e
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])