From 73d10ab48236eca7158deae95e01678ef1a09ce7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 26 Nov 2024 12:13:21 -0800 Subject: [PATCH] Update tests, need handling for radamw with older PyTorch, need to back-off basic test LR in mars? --- tests/test_optim.py | 55 ++++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 58cd40e7..3db4b248 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -9,6 +9,7 @@ import functools from copy import deepcopy import torch +from IPython.testing.decorators import skip_win32 from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter @@ -299,27 +300,39 @@ def test_optim_factory(optimizer): opt_info = get_optimizer_info(optimizer) assert isinstance(opt_info, OptimInfo) - if not opt_info.second_order: # basic tests don't support second order right now - # test basic cases that don't need specific tuning via factory test - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), optimizer) - ) + lr = (1e-3, 1e-2, 1e-2, 1e-2) + if optimizer in ('mars',): + lr = (1e-3, 1e-3, 1e-3, 1e-3) + + try: + if not opt_info.second_order: # basic tests don't support second order right now + # test basic cases that don't need specific tuning via factory test + _test_basic_cases( + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0]) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict(weight, bias, lr=lr[1]), + optimizer, + lr=lr[1] / 10) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=lr[2]), + optimizer, + lr=lr[2] / 10) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=lr[3]), + optimizer) + ) + except TypeError as e: + if 'radamw' in optimizer: + pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.") + else: + raise e + #@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])