mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
device agnostic testing
This commit is contained in:
parent
7c685a4ef3
commit
e628ed7e67
@ -3,6 +3,13 @@ import torch.nn as nn
|
||||
|
||||
from timm.layers import create_act_layer, set_layer_config
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
torch_backend = os.environ.get('TORCH_BACKEND')
|
||||
if torch_backend is not None:
|
||||
importlib.import_module(torch_backend)
|
||||
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, act_layer="relu", inplace=True):
|
||||
@ -30,6 +37,9 @@ def _run_act_layer_grad(act_type, inplace=True):
|
||||
l = (out - 0).pow(2).sum()
|
||||
return l
|
||||
|
||||
x = x.to(device=torch_device)
|
||||
m.to(device=torch_device)
|
||||
|
||||
out_me = _run(x)
|
||||
|
||||
with set_layer_config(scriptable=True):
|
||||
|
@ -30,6 +30,17 @@ from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_v
|
||||
from timm.layers import Format, get_spatial_dim, get_channel_dim
|
||||
from timm.models import get_notrace_modules, get_notrace_functions
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
torch_backend = os.environ.get('TORCH_BACKEND')
|
||||
if torch_backend is not None:
|
||||
importlib.import_module(torch_backend)
|
||||
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
|
||||
timeout = os.environ.get('TIMEOUT')
|
||||
timeout120 = int(timeout) if timeout else 120
|
||||
timeout300 = int(timeout) if timeout else 300
|
||||
|
||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# legacy executor is too slow to compile large models for unit tests
|
||||
# no need for the fusion performance here
|
||||
@ -100,7 +111,7 @@ def _get_input_size(model=None, model_name='', target=None):
|
||||
|
||||
|
||||
@pytest.mark.base
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.timeout(timeout120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward(model_name, batch_size):
|
||||
@ -112,6 +123,8 @@ def test_model_forward(model_name, batch_size):
|
||||
if max(input_size) > MAX_FWD_SIZE:
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
inputs = inputs.to(torch_device)
|
||||
model.to(torch_device)
|
||||
outputs = model(inputs)
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
@ -119,7 +132,7 @@ def test_model_forward(model_name, batch_size):
|
||||
|
||||
|
||||
@pytest.mark.base
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.timeout(timeout120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
def test_model_backward(model_name, batch_size):
|
||||
@ -133,6 +146,8 @@ def test_model_backward(model_name, batch_size):
|
||||
model.train()
|
||||
|
||||
inputs = torch.randn((batch_size, *input_size))
|
||||
inputs = inputs.to(torch_device)
|
||||
model.to(torch_device)
|
||||
outputs = model(inputs)
|
||||
if isinstance(outputs, tuple):
|
||||
outputs = torch.cat(outputs)
|
||||
@ -147,7 +162,7 @@ def test_model_backward(model_name, batch_size):
|
||||
|
||||
|
||||
@pytest.mark.cfg
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.timeout(timeout300)
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -155,6 +170,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
state_dict = model.state_dict()
|
||||
cfg = model.default_cfg
|
||||
|
||||
@ -169,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
||||
# output sizes only checked if default res <= 448 * 448 to keep resource down
|
||||
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
|
||||
input_tensor = torch.randn((batch_size, *input_size))
|
||||
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
|
||||
|
||||
# test forward_features (always unpooled)
|
||||
outputs = model.forward_features(input_tensor)
|
||||
@ -180,12 +196,14 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 2
|
||||
assert outputs.shape[1] == model.num_features
|
||||
|
||||
# test model forward without pooling and classifier
|
||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
|
||||
@ -195,6 +213,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
if 'pruned' not in model_name: # FIXME better pruned model handling
|
||||
# test classifier + global pool deletion via __init__
|
||||
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 4
|
||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
|
||||
@ -218,13 +237,14 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
|
||||
|
||||
@pytest.mark.cfg
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.timeout(timeout300)
|
||||
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
state_dict = model.state_dict()
|
||||
cfg = model.default_cfg
|
||||
|
||||
@ -232,7 +252,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
if max(input_size) > 320: # FIXME const
|
||||
pytest.skip("Fixed input size model > limit.")
|
||||
|
||||
input_tensor = torch.randn((batch_size, *input_size))
|
||||
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
|
||||
feat_dim = getattr(model, 'feature_dim', None)
|
||||
|
||||
outputs = model.forward_features(input_tensor)
|
||||
@ -246,6 +266,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
outputs = outputs[0]
|
||||
@ -254,6 +275,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
|
||||
|
||||
model = create_model(model_name, pretrained=False, num_classes=0).eval()
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
outputs = outputs[0]
|
||||
@ -297,7 +319,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
||||
|
||||
|
||||
@pytest.mark.torchscript
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.timeout(timeout120)
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -312,6 +334,7 @@ def test_model_forward_torchscript(model_name, batch_size):
|
||||
model.eval()
|
||||
|
||||
model = torch.jit.script(model)
|
||||
model.to(torch_device)
|
||||
outputs = model(torch.randn((batch_size, *input_size)))
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
|
@ -15,6 +15,13 @@ from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
torch_backend = os.environ.get('TORCH_BACKEND')
|
||||
if torch_backend is not None:
|
||||
importlib.import_module(torch_backend)
|
||||
torch_device = os.environ.get('TORCH_DEVICE', 'cuda')
|
||||
|
||||
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
|
||||
torch_tc = TestCase()
|
||||
@ -61,7 +68,7 @@ def _test_state_dict(weight, bias, input, constructor):
|
||||
|
||||
def fn_base(optimizer, weight, bias):
|
||||
optimizer.zero_grad()
|
||||
i = input_cuda if weight.is_cuda else input
|
||||
i = input_device if weight.device.type != 'cpu' else input
|
||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||
loss.backward()
|
||||
return loss
|
||||
@ -97,28 +104,28 @@ def _test_state_dict(weight, bias, input, constructor):
|
||||
|
||||
# Check that state dict can be loaded even when we cast parameters
|
||||
# to a different type and move to a different device.
|
||||
if not torch.cuda.is_available():
|
||||
if torch_device == 'cpu':
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
input_cuda = Parameter(input.clone().detach().float().cuda())
|
||||
weight_cuda = Parameter(weight.clone().detach().cuda())
|
||||
bias_cuda = Parameter(bias.clone().detach().cuda())
|
||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
||||
input_device = Parameter(input.clone().detach().float().to(torch_device))
|
||||
weight_device = Parameter(weight.clone().detach().to(torch_device))
|
||||
bias_device = Parameter(bias.clone().detach().to(torch_device))
|
||||
optimizer_device = constructor(weight_device, bias_device)
|
||||
fn_device = functools.partial(fn_base, optimizer_device, weight_device, bias_device)
|
||||
|
||||
state_dict = deepcopy(optimizer.state_dict())
|
||||
state_dict_c = deepcopy(optimizer.state_dict())
|
||||
optimizer_cuda.load_state_dict(state_dict_c)
|
||||
optimizer_device.load_state_dict(state_dict_c)
|
||||
|
||||
# Make sure state dict wasn't modified
|
||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||
|
||||
for _i in range(20):
|
||||
optimizer.step(fn)
|
||||
optimizer_cuda.step(fn_cuda)
|
||||
torch_tc.assertEqual(weight, weight_cuda)
|
||||
torch_tc.assertEqual(bias, bias_cuda)
|
||||
optimizer_device.step(fn_device)
|
||||
torch_tc.assertEqual(weight, weight_device)
|
||||
torch_tc.assertEqual(bias, bias_device)
|
||||
|
||||
# validate deepcopy() copies all public attributes
|
||||
def getPublicAttr(obj):
|
||||
@ -152,12 +159,12 @@ def _test_basic_cases(constructor, scheduler_constructors=None):
|
||||
scheduler_constructors
|
||||
)
|
||||
# CUDA
|
||||
if not torch.cuda.is_available():
|
||||
if torch_device == 'cpu':
|
||||
return
|
||||
_test_basic_cases_template(
|
||||
torch.randn(10, 5).cuda(),
|
||||
torch.randn(10).cuda(),
|
||||
torch.randn(5).cuda(),
|
||||
torch.randn(10, 5).to(torch_device),
|
||||
torch.randn(10).to(torch_device),
|
||||
torch.randn(5).to(torch_device),
|
||||
constructor,
|
||||
scheduler_constructors
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user