Update optim test to remove Variable/.data and fix _state_dict optim test

This commit is contained in:
Ross Wightman 2023-10-11 11:58:04 -07:00
parent 7ce65a83a2
commit 52ca108fe6

View File

@ -10,7 +10,7 @@ from copy import deepcopy
import torch
from torch.testing._internal.common_utils import TestCase
from torch.autograd import Variable
from torch.nn import Parameter
from timm.scheduler import PlateauLRScheduler
from timm.optim import create_optimizer_v2
@ -21,9 +21,9 @@ torch_tc = TestCase()
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
input = Variable(input)
weight = Parameter(weight)
bias = Parameter(bias)
input = Parameter(input)
optimizer = constructor(weight, bias)
schedulers = []
for scheduler_constructor in scheduler_constructors:
@ -55,9 +55,9 @@ def _test_basic_cases_template(weight, bias, input, constructor, scheduler_const
def _test_state_dict(weight, bias, input, constructor):
weight = Variable(weight, requires_grad=True)
bias = Variable(bias, requires_grad=True)
input = Variable(input)
weight = Parameter(weight)
bias = Parameter(bias)
input = Parameter(input)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
@ -73,8 +73,9 @@ def _test_state_dict(weight, bias, input, constructor):
for _i in range(20):
optimizer.step(fn)
# Clone the weights and construct new optimizer for them
weight_c = Variable(weight.data.clone(), requires_grad=True)
bias_c = Variable(bias.data.clone(), requires_grad=True)
with torch.no_grad():
weight_c = Parameter(weight.clone().detach())
bias_c = Parameter(bias.clone().detach())
optimizer_c = constructor(weight_c, bias_c)
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
# Load state dict
@ -86,12 +87,8 @@ def _test_state_dict(weight, bias, input, constructor):
for _i in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
#assert torch.equal(weight, weight_c)
#assert torch.equal(bias, bias_c)
torch_tc.assertEqual(weight, weight_c)
torch_tc.assertEqual(bias, bias_c)
# Make sure state dict wasn't modified
torch_tc.assertEqual(state_dict, state_dict_c)
# Make sure state dict is deterministic with equal but not identical parameters
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
# Make sure repeated parameters have identical representation in state dict
@ -103,9 +100,10 @@ def _test_state_dict(weight, bias, input, constructor):
if not torch.cuda.is_available():
return
input_cuda = Variable(input.data.float().cuda())
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
with torch.no_grad():
input_cuda = Parameter(input.clone().detach().float().cuda())
weight_cuda = Parameter(weight.clone().detach().cuda())
bias_cuda = Parameter(bias.clone().detach().cuda())
optimizer_cuda = constructor(weight_cuda, bias_cuda)
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
@ -216,21 +214,21 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
scheduler_constructors = []
params_t = torch.tensor([1.5, 1.5])
params = Variable(params_t, requires_grad=True)
params = Parameter(params_t)
optimizer = constructor([params])
schedulers = []
for scheduler_constructor in scheduler_constructors:
schedulers.append(scheduler_constructor(optimizer))
solution = torch.tensor([1, 1])
initial_dist = params.data.dist(solution)
initial_dist = params.clone().detach().dist(solution)
def eval(params, w):
# Depending on w, provide only the x or y gradient
optimizer.zero_grad()
loss = rosenbrock(params)
loss.backward()
grad = drosenbrock(params.data)
grad = drosenbrock(params.clone().detach())
# NB: We torture test the optimizer by returning an
# uncoalesced sparse tensor
if w:
@ -256,7 +254,7 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
else:
scheduler.step()
torch_tc.assertLessEqual(params.data.dist(solution), initial_dist)
torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist)
def _build_params_dict(weight, bias, **kwargs):