diff --git a/tests/test_optim.py b/tests/test_optim.py index 737674e5..9bdfd682 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -10,7 +10,7 @@ from copy import deepcopy import torch from torch.testing._internal.common_utils import TestCase -from torch.autograd import Variable +from torch.nn import Parameter from timm.scheduler import PlateauLRScheduler from timm.optim import create_optimizer_v2 @@ -21,9 +21,9 @@ torch_tc = TestCase() def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors): - weight = Variable(weight, requires_grad=True) - bias = Variable(bias, requires_grad=True) - input = Variable(input) + weight = Parameter(weight) + bias = Parameter(bias) + input = Parameter(input) optimizer = constructor(weight, bias) schedulers = [] for scheduler_constructor in scheduler_constructors: @@ -55,9 +55,9 @@ def _test_basic_cases_template(weight, bias, input, constructor, scheduler_const def _test_state_dict(weight, bias, input, constructor): - weight = Variable(weight, requires_grad=True) - bias = Variable(bias, requires_grad=True) - input = Variable(input) + weight = Parameter(weight) + bias = Parameter(bias) + input = Parameter(input) def fn_base(optimizer, weight, bias): optimizer.zero_grad() @@ -73,8 +73,9 @@ def _test_state_dict(weight, bias, input, constructor): for _i in range(20): optimizer.step(fn) # Clone the weights and construct new optimizer for them - weight_c = Variable(weight.data.clone(), requires_grad=True) - bias_c = Variable(bias.data.clone(), requires_grad=True) + with torch.no_grad(): + weight_c = Parameter(weight.clone().detach()) + bias_c = Parameter(bias.clone().detach()) optimizer_c = constructor(weight_c, bias_c) fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c) # Load state dict @@ -86,12 +87,8 @@ def _test_state_dict(weight, bias, input, constructor): for _i in range(20): optimizer.step(fn) optimizer_c.step(fn_c) - #assert torch.equal(weight, weight_c) - #assert torch.equal(bias, bias_c) torch_tc.assertEqual(weight, weight_c) torch_tc.assertEqual(bias, bias_c) - # Make sure state dict wasn't modified - torch_tc.assertEqual(state_dict, state_dict_c) # Make sure state dict is deterministic with equal but not identical parameters torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) # Make sure repeated parameters have identical representation in state dict @@ -103,9 +100,10 @@ def _test_state_dict(weight, bias, input, constructor): if not torch.cuda.is_available(): return - input_cuda = Variable(input.data.float().cuda()) - weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True) - bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True) + with torch.no_grad(): + input_cuda = Parameter(input.clone().detach().float().cuda()) + weight_cuda = Parameter(weight.clone().detach().cuda()) + bias_cuda = Parameter(bias.clone().detach().cuda()) optimizer_cuda = constructor(weight_cuda, bias_cuda) fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda) @@ -216,21 +214,21 @@ def _test_rosenbrock(constructor, scheduler_constructors=None): scheduler_constructors = [] params_t = torch.tensor([1.5, 1.5]) - params = Variable(params_t, requires_grad=True) + params = Parameter(params_t) optimizer = constructor([params]) schedulers = [] for scheduler_constructor in scheduler_constructors: schedulers.append(scheduler_constructor(optimizer)) solution = torch.tensor([1, 1]) - initial_dist = params.data.dist(solution) + initial_dist = params.clone().detach().dist(solution) def eval(params, w): # Depending on w, provide only the x or y gradient optimizer.zero_grad() loss = rosenbrock(params) loss.backward() - grad = drosenbrock(params.data) + grad = drosenbrock(params.clone().detach()) # NB: We torture test the optimizer by returning an # uncoalesced sparse tensor if w: @@ -256,7 +254,7 @@ def _test_rosenbrock(constructor, scheduler_constructors=None): else: scheduler.step() - torch_tc.assertLessEqual(params.data.dist(solution), initial_dist) + torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist) def _build_params_dict(weight, bias, **kwargs):