mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
device agnostic testing
This commit is contained in:
parent
7c685a4ef3
commit
e628ed7e67
@ -3,6 +3,13 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.layers import create_act_layer, set_layer_config
|
from timm.layers import create_act_layer, set_layer_config
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
torch_backend = os.environ.get('TORCH_BACKEND')
|
||||||
|
if torch_backend is not None:
|
||||||
|
importlib.import_module(torch_backend)
|
||||||
|
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, act_layer="relu", inplace=True):
|
def __init__(self, act_layer="relu", inplace=True):
|
||||||
@ -30,6 +37,9 @@ def _run_act_layer_grad(act_type, inplace=True):
|
|||||||
l = (out - 0).pow(2).sum()
|
l = (out - 0).pow(2).sum()
|
||||||
return l
|
return l
|
||||||
|
|
||||||
|
x = x.to(device=torch_device)
|
||||||
|
m.to(device=torch_device)
|
||||||
|
|
||||||
out_me = _run(x)
|
out_me = _run(x)
|
||||||
|
|
||||||
with set_layer_config(scriptable=True):
|
with set_layer_config(scriptable=True):
|
||||||
|
@ -30,6 +30,17 @@ from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_v
|
|||||||
from timm.layers import Format, get_spatial_dim, get_channel_dim
|
from timm.layers import Format, get_spatial_dim, get_channel_dim
|
||||||
from timm.models import get_notrace_modules, get_notrace_functions
|
from timm.models import get_notrace_modules, get_notrace_functions
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
torch_backend = os.environ.get('TORCH_BACKEND')
|
||||||
|
if torch_backend is not None:
|
||||||
|
importlib.import_module(torch_backend)
|
||||||
|
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
|
||||||
|
timeout = os.environ.get('TIMEOUT')
|
||||||
|
timeout120 = int(timeout) if timeout else 120
|
||||||
|
timeout300 = int(timeout) if timeout else 300
|
||||||
|
|
||||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||||
# legacy executor is too slow to compile large models for unit tests
|
# legacy executor is too slow to compile large models for unit tests
|
||||||
# no need for the fusion performance here
|
# no need for the fusion performance here
|
||||||
@ -100,7 +111,7 @@ def _get_input_size(model=None, model_name='', target=None):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.base
|
@pytest.mark.base
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(timeout120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_forward(model_name, batch_size):
|
def test_model_forward(model_name, batch_size):
|
||||||
@ -112,6 +123,8 @@ def test_model_forward(model_name, batch_size):
|
|||||||
if max(input_size) > MAX_FWD_SIZE:
|
if max(input_size) > MAX_FWD_SIZE:
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
inputs = torch.randn((batch_size, *input_size))
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
|
inputs = inputs.to(torch_device)
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
assert outputs.shape[0] == batch_size
|
||||||
@ -119,7 +132,7 @@ def test_model_forward(model_name, batch_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.base
|
@pytest.mark.base
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(timeout120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [2])
|
@pytest.mark.parametrize('batch_size', [2])
|
||||||
def test_model_backward(model_name, batch_size):
|
def test_model_backward(model_name, batch_size):
|
||||||
@ -133,6 +146,8 @@ def test_model_backward(model_name, batch_size):
|
|||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
inputs = torch.randn((batch_size, *input_size))
|
inputs = torch.randn((batch_size, *input_size))
|
||||||
|
inputs = inputs.to(torch_device)
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
outputs = torch.cat(outputs)
|
outputs = torch.cat(outputs)
|
||||||
@ -147,7 +162,7 @@ def test_model_backward(model_name, batch_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.cfg
|
@pytest.mark.cfg
|
||||||
@pytest.mark.timeout(300)
|
@pytest.mark.timeout(timeout300)
|
||||||
@pytest.mark.parametrize('model_name', list_models(
|
@pytest.mark.parametrize('model_name', list_models(
|
||||||
exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True))
|
exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -155,6 +170,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
"""Run a single forward pass with each model"""
|
"""Run a single forward pass with each model"""
|
||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
model.to(torch_device)
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
cfg = model.default_cfg
|
cfg = model.default_cfg
|
||||||
|
|
||||||
@ -169,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
|
||||||
# output sizes only checked if default res <= 448 * 448 to keep resource down
|
# output sizes only checked if default res <= 448 * 448 to keep resource down
|
||||||
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
|
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
|
||||||
input_tensor = torch.randn((batch_size, *input_size))
|
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
|
||||||
|
|
||||||
# test forward_features (always unpooled)
|
# test forward_features (always unpooled)
|
||||||
outputs = model.forward_features(input_tensor)
|
outputs = model.forward_features(input_tensor)
|
||||||
@ -180,12 +196,14 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
|
|
||||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||||
model.reset_classifier(0)
|
model.reset_classifier(0)
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
assert len(outputs.shape) == 2
|
assert len(outputs.shape) == 2
|
||||||
assert outputs.shape[1] == model.num_features
|
assert outputs.shape[1] == model.num_features
|
||||||
|
|
||||||
# test model forward without pooling and classifier
|
# test model forward without pooling and classifier
|
||||||
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
assert len(outputs.shape) == 4
|
assert len(outputs.shape) == 4
|
||||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
|
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
|
||||||
@ -195,6 +213,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
if 'pruned' not in model_name: # FIXME better pruned model handling
|
if 'pruned' not in model_name: # FIXME better pruned model handling
|
||||||
# test classifier + global pool deletion via __init__
|
# test classifier + global pool deletion via __init__
|
||||||
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
|
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
assert len(outputs.shape) == 4
|
assert len(outputs.shape) == 4
|
||||||
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
|
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
|
||||||
@ -218,13 +237,14 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.cfg
|
@pytest.mark.cfg
|
||||||
@pytest.mark.timeout(300)
|
@pytest.mark.timeout(timeout300)
|
||||||
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
|
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
def test_model_default_cfgs_non_std(model_name, batch_size):
|
def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||||
"""Run a single forward pass with each model"""
|
"""Run a single forward pass with each model"""
|
||||||
model = create_model(model_name, pretrained=False)
|
model = create_model(model_name, pretrained=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
model.to(torch_device)
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
cfg = model.default_cfg
|
cfg = model.default_cfg
|
||||||
|
|
||||||
@ -232,7 +252,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
|||||||
if max(input_size) > 320: # FIXME const
|
if max(input_size) > 320: # FIXME const
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
input_tensor = torch.randn((batch_size, *input_size))
|
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
|
||||||
feat_dim = getattr(model, 'feature_dim', None)
|
feat_dim = getattr(model, 'feature_dim', None)
|
||||||
|
|
||||||
outputs = model.forward_features(input_tensor)
|
outputs = model.forward_features(input_tensor)
|
||||||
@ -246,6 +266,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
|||||||
|
|
||||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||||
model.reset_classifier(0)
|
model.reset_classifier(0)
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
if isinstance(outputs, (tuple, list)):
|
if isinstance(outputs, (tuple, list)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
@ -254,6 +275,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
|||||||
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
|
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
|
||||||
|
|
||||||
model = create_model(model_name, pretrained=False, num_classes=0).eval()
|
model = create_model(model_name, pretrained=False, num_classes=0).eval()
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
if isinstance(outputs, (tuple, list)):
|
if isinstance(outputs, (tuple, list)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
@ -297,7 +319,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.torchscript
|
@pytest.mark.torchscript
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(timeout120)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -312,6 +334,7 @@ def test_model_forward_torchscript(model_name, batch_size):
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
model.to(torch_device)
|
||||||
outputs = model(torch.randn((batch_size, *input_size)))
|
outputs = model(torch.randn((batch_size, *input_size)))
|
||||||
|
|
||||||
assert outputs.shape[0] == batch_size
|
assert outputs.shape[0] == batch_size
|
||||||
|
@ -15,6 +15,13 @@ from timm.scheduler import PlateauLRScheduler
|
|||||||
|
|
||||||
from timm.optim import create_optimizer_v2
|
from timm.optim import create_optimizer_v2
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
torch_backend = os.environ.get('TORCH_BACKEND')
|
||||||
|
if torch_backend is not None:
|
||||||
|
importlib.import_module(torch_backend)
|
||||||
|
torch_device = os.environ.get('TORCH_DEVICE', 'cuda')
|
||||||
|
|
||||||
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
|
# HACK relying on internal PyTorch test functionality for comparisons that I don't want to write
|
||||||
torch_tc = TestCase()
|
torch_tc = TestCase()
|
||||||
@ -61,7 +68,7 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||||||
|
|
||||||
def fn_base(optimizer, weight, bias):
|
def fn_base(optimizer, weight, bias):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
i = input_cuda if weight.is_cuda else input
|
i = input_device if weight.device.type != 'cpu' else input
|
||||||
loss = (weight.mv(i) + bias).pow(2).sum()
|
loss = (weight.mv(i) + bias).pow(2).sum()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
return loss
|
return loss
|
||||||
@ -97,28 +104,28 @@ def _test_state_dict(weight, bias, input, constructor):
|
|||||||
|
|
||||||
# Check that state dict can be loaded even when we cast parameters
|
# Check that state dict can be loaded even when we cast parameters
|
||||||
# to a different type and move to a different device.
|
# to a different type and move to a different device.
|
||||||
if not torch.cuda.is_available():
|
if torch_device == 'cpu':
|
||||||
return
|
return
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
input_cuda = Parameter(input.clone().detach().float().cuda())
|
input_device = Parameter(input.clone().detach().float().to(torch_device))
|
||||||
weight_cuda = Parameter(weight.clone().detach().cuda())
|
weight_device = Parameter(weight.clone().detach().to(torch_device))
|
||||||
bias_cuda = Parameter(bias.clone().detach().cuda())
|
bias_device = Parameter(bias.clone().detach().to(torch_device))
|
||||||
optimizer_cuda = constructor(weight_cuda, bias_cuda)
|
optimizer_device = constructor(weight_device, bias_device)
|
||||||
fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda)
|
fn_device = functools.partial(fn_base, optimizer_device, weight_device, bias_device)
|
||||||
|
|
||||||
state_dict = deepcopy(optimizer.state_dict())
|
state_dict = deepcopy(optimizer.state_dict())
|
||||||
state_dict_c = deepcopy(optimizer.state_dict())
|
state_dict_c = deepcopy(optimizer.state_dict())
|
||||||
optimizer_cuda.load_state_dict(state_dict_c)
|
optimizer_device.load_state_dict(state_dict_c)
|
||||||
|
|
||||||
# Make sure state dict wasn't modified
|
# Make sure state dict wasn't modified
|
||||||
torch_tc.assertEqual(state_dict, state_dict_c)
|
torch_tc.assertEqual(state_dict, state_dict_c)
|
||||||
|
|
||||||
for _i in range(20):
|
for _i in range(20):
|
||||||
optimizer.step(fn)
|
optimizer.step(fn)
|
||||||
optimizer_cuda.step(fn_cuda)
|
optimizer_device.step(fn_device)
|
||||||
torch_tc.assertEqual(weight, weight_cuda)
|
torch_tc.assertEqual(weight, weight_device)
|
||||||
torch_tc.assertEqual(bias, bias_cuda)
|
torch_tc.assertEqual(bias, bias_device)
|
||||||
|
|
||||||
# validate deepcopy() copies all public attributes
|
# validate deepcopy() copies all public attributes
|
||||||
def getPublicAttr(obj):
|
def getPublicAttr(obj):
|
||||||
@ -152,12 +159,12 @@ def _test_basic_cases(constructor, scheduler_constructors=None):
|
|||||||
scheduler_constructors
|
scheduler_constructors
|
||||||
)
|
)
|
||||||
# CUDA
|
# CUDA
|
||||||
if not torch.cuda.is_available():
|
if torch_device == 'cpu':
|
||||||
return
|
return
|
||||||
_test_basic_cases_template(
|
_test_basic_cases_template(
|
||||||
torch.randn(10, 5).cuda(),
|
torch.randn(10, 5).to(torch_device),
|
||||||
torch.randn(10).cuda(),
|
torch.randn(10).to(torch_device),
|
||||||
torch.randn(5).cuda(),
|
torch.randn(5).to(torch_device),
|
||||||
constructor,
|
constructor,
|
||||||
scheduler_constructors
|
scheduler_constructors
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user