diff --git a/tests/test_optim.py b/tests/test_optim.py index 58cd40e7..3db4b248 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -9,6 +9,7 @@ import functools from copy import deepcopy import torch +from IPython.testing.decorators import skip_win32 from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter @@ -299,27 +300,39 @@ def test_optim_factory(optimizer): opt_info = get_optimizer_info(optimizer) assert isinstance(opt_info, OptimInfo) - if not opt_info.second_order: # basic tests don't support second order right now - # test basic cases that don't need specific tuning via factory test - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), optimizer) - ) + lr = (1e-3, 1e-2, 1e-2, 1e-2) + if optimizer in ('mars',): + lr = (1e-3, 1e-3, 1e-3, 1e-3) + + try: + if not opt_info.second_order: # basic tests don't support second order right now + # test basic cases that don't need specific tuning via factory test + _test_basic_cases( + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0]) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict(weight, bias, lr=lr[1]), + optimizer, + lr=lr[1] / 10) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=lr[2]), + optimizer, + lr=lr[2] / 10) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=lr[3]), + optimizer) + ) + except TypeError as e: + if 'radamw' in optimizer: + pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.") + else: + raise e + #@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])