mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update tests, need handling for radamw with older PyTorch, need to back-off basic test LR in mars?
This commit is contained in:
parent
09bc21774e
commit
73d10ab482
@ -9,6 +9,7 @@ import functools
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from IPython.testing.decorators import skip_win32
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
@ -299,27 +300,39 @@ def test_optim_factory(optimizer):
|
|||||||
opt_info = get_optimizer_info(optimizer)
|
opt_info = get_optimizer_info(optimizer)
|
||||||
assert isinstance(opt_info, OptimInfo)
|
assert isinstance(opt_info, OptimInfo)
|
||||||
|
|
||||||
|
lr = (1e-3, 1e-2, 1e-2, 1e-2)
|
||||||
|
if optimizer in ('mars',):
|
||||||
|
lr = (1e-3, 1e-3, 1e-3, 1e-3)
|
||||||
|
|
||||||
|
try:
|
||||||
if not opt_info.second_order: # basic tests don't support second order right now
|
if not opt_info.second_order: # basic tests don't support second order right now
|
||||||
# test basic cases that don't need specific tuning via factory test
|
# test basic cases that don't need specific tuning via factory test
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=lr[0])
|
||||||
)
|
)
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2(
|
lambda weight, bias: create_optimizer_v2(
|
||||||
_build_params_dict(weight, bias, lr=1e-2),
|
_build_params_dict(weight, bias, lr=lr[1]),
|
||||||
optimizer,
|
optimizer,
|
||||||
lr=1e-3)
|
lr=lr[1] / 10)
|
||||||
)
|
)
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2(
|
lambda weight, bias: create_optimizer_v2(
|
||||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
_build_params_dict_single(weight, bias, lr=lr[2]),
|
||||||
optimizer,
|
optimizer,
|
||||||
lr=1e-3)
|
lr=lr[2] / 10)
|
||||||
)
|
)
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2(
|
lambda weight, bias: create_optimizer_v2(
|
||||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
_build_params_dict_single(weight, bias, lr=lr[3]),
|
||||||
|
optimizer)
|
||||||
)
|
)
|
||||||
|
except TypeError as e:
|
||||||
|
if 'radamw' in optimizer:
|
||||||
|
pytest.skip("Expected for 'radamw' (decoupled decay) to fail in older PyTorch versions.")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user