mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update tests, need handling for radamw with older PyTorch, need to back-off basic test LR in mars?
This commit is contained in:
parent
7d3146b97b
commit
bc7d2247bf
@ -9,6 +9,7 @@ import functools
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from IPython.testing.decorators import skip_win32
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.nn import Parameter
|
||||
|
||||
@ -299,27 +300,39 @@ def test_optim_factory(optimizer):
|
||||
opt_info = get_optimizer_info(optimizer)
|
||||
assert isinstance(opt_info, OptimInfo)
|
||||
|
||||
if not opt_info.second_order: # basic tests don't support second order right now
|
||||
# test basic cases that don't need specific tuning via factory test
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||
)
|
||||
lr = (1e-3, 1e-2, 1e-2, 1e-2)
|
||||
if optimizer in ('mars',):
|
||||
lr = (1e-3, 1e-3, 1e-3, 1e-3)
|
||||
|
||||
try:
|
||||
if not opt_info.second_order: # basic tests don't support second order right now
|
||||
# test basic cases that don't need specific tuning via factory test
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0])
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=lr[1]),
|
||||
optimizer,
|
||||
lr=lr[1] / 10)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=lr[2]),
|
||||
optimizer,
|
||||
lr=lr[2] / 10)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=lr[3]),
|
||||
optimizer)
|
||||
)
|
||||
except TypeError as e:
|
||||
if 'radamw' in optimizer:
|
||||
pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.")
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user