Update optim test to remove Variable/.data and fix _state_dict optim test for PyTorch 2.1 (#1988)
* Update optim test to remove Variable/.data and fix _state_dict optim test * Attempt to run python 3.11 w/ 2.1 * Try factoring out testmarker to common var * More fiddling * Abandon attempt to reduce redunancy * Another tryvit_siglip_and_reg
parent
7ce65a83a2
commit
68b2824e49
|
@ -16,10 +16,12 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python: ['3.10']
|
||||
torch: ['1.13.0']
|
||||
torchvision: ['0.14.0']
|
||||
python: ['3.10', '3.11']
|
||||
torch: [{base: '1.13.0', vision: '0.14.0'}, {base: '2.1.0', vision: '0.16.0'}]
|
||||
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
|
||||
exclude:
|
||||
- python: '3.11'
|
||||
torch: {base: '1.13.0', vision: '0.14.0'}
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
|
@ -34,17 +36,17 @@ jobs:
|
|||
pip install -r requirements-dev.txt
|
||||
- name: Install torch on mac
|
||||
if: startsWith(matrix.os, 'macOS')
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
|
||||
- name: Install torch on Windows
|
||||
if: startsWith(matrix.os, 'windows')
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch.base }} torchvision==${{ matrix.torch.vision }}
|
||||
- name: Install torch on ubuntu
|
||||
if: startsWith(matrix.os, 'ubuntu')
|
||||
run: |
|
||||
sudo sed -i 's/azure\.//' /etc/apt/sources.list
|
||||
sudo apt update
|
||||
sudo apt install -y google-perftools
|
||||
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install --no-cache-dir torch==${{ matrix.torch.base }}+cpu torchvision==${{ matrix.torch.vision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install requirements
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
|
|
|
@ -10,7 +10,7 @@ from copy import deepcopy
|
|||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.autograd import Variable
|
||||
from torch.nn import Parameter
|
||||
from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
|
@ -21,9 +21,9 @@ torch_tc = TestCase()
|
|||
|
||||
|
||||
def _test_basic_cases_template(weight, bias, input, constructor, scheduler_constructors):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
weight = Parameter(weight)
|
||||
bias = Parameter(bias)
|
||||
input = Parameter(input)
|
||||
optimizer = constructor(weight, bias)
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
|
@ -55,9 +55,9 @@ def _test_basic_cases_template(weight, bias, input, constructor, scheduler_const
|
|||
|
||||
|
||||
def _test_state_dict(weight, bias, input, constructor):
|
||||
weight = Variable(weight, requires_grad=True)
|
||||
bias = Variable(bias, requires_grad=True)
|
||||
input = Variable(input)
|
||||
weight = Parameter(weight)
|
||||
bias = Parameter(bias)
|
||||
input = Parameter(input)
|
||||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
|
@ -73,8 +73,9 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
# Clone the weights and construct new optimizer for them
|
||||
weight_c = Variable(weight.data.clone(), requires_grad=True)
|
||||
bias_c = Variable(bias.data.clone(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
weight_c = Parameter(weight.clone().detach())
|
||||
bias_c = Parameter(bias.clone().detach())
|
||||
optimizer_c = constructor(weight_c, bias_c)
|
||||
fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c)
|
||||
# Load state dict
|
||||
|
@ -86,12 +87,8 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_c.step(fn_c)
|
||||
#assert torch.equal(weight, weight_c)
|
||||
#assert torch.equal(bias, bias_c)
|
||||
torch_tc.assertEqual(weight, weight_c)
|
||||
torch_tc.assertEqual(bias, bias_c)
|
||||
# Make sure state dict wasn't modified
|
||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||
# Make sure state dict is deterministic with equal but not identical parameters
|
||||
torch_tc.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
|
||||
# Make sure repeated parameters have identical representation in state dict
|
||||
|
@ -103,9 +100,10 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
input_cuda = Variable(input.data.float().cuda())
|
||||
weight_cuda = Variable(weight.data.float().cuda(), requires_grad=True)
|
||||
bias_cuda = Variable(bias.data.float().cuda(), requires_grad=True)
|
||||
with torch.no_grad():
|
||||
input_cuda = Parameter(input.clone().detach().float().cuda())
|
||||
weight_cuda = Parameter(weight.clone().detach().cuda())
|
||||
bias_cuda = Parameter(bias.clone().detach().cuda())
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
||||
|
||||
|
@ -216,21 +214,21 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
|
|||
scheduler_constructors = []
|
||||
params_t = torch.tensor([1.5, 1.5])
|
||||
|
||||
params = Variable(params_t, requires_grad=True)
|
||||
params = Parameter(params_t)
|
||||
optimizer = constructor([params])
|
||||
schedulers = []
|
||||
for scheduler_constructor in scheduler_constructors:
|
||||
schedulers.append(scheduler_constructor(optimizer))
|
||||
|
||||
solution = torch.tensor([1, 1])
|
||||
initial_dist = params.data.dist(solution)
|
||||
initial_dist = params.clone().detach().dist(solution)
|
||||
|
||||
def eval(params, w):
|
||||
# Depending on w, provide only the x or y gradient
|
||||
optimizer.zero_grad()
|
||||
loss = rosenbrock(params)
|
||||
loss.backward()
|
||||
grad = drosenbrock(params.data)
|
||||
grad = drosenbrock(params.clone().detach())
|
||||
# NB: We torture test the optimizer by returning an
|
||||
# uncoalesced sparse tensor
|
||||
if w:
|
||||
|
@ -256,7 +254,7 @@ def _test_rosenbrock(constructor, scheduler_constructors=None):
|
|||
else:
|
||||
scheduler.step()
|
||||
|
||||
torch_tc.assertLessEqual(params.data.dist(solution), initial_dist)
|
||||
torch_tc.assertLessEqual(params.clone().detach().dist(solution), initial_dist)
|
||||
|
||||
|
||||
def _build_params_dict(weight, bias, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue