diff --git a/tests/test_layers.py b/tests/test_layers.py index da061870..92f6b683 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -3,6 +3,13 @@ import torch.nn as nn from timm.layers import create_act_layer, set_layer_config +import importlib +import os + +torch_backend = os.environ.get('TORCH_BACKEND') +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get('TORCH_DEVICE', 'cpu') class MLP(nn.Module): def __init__(self, act_layer="relu", inplace=True): @@ -30,6 +37,9 @@ def _run_act_layer_grad(act_type, inplace=True): l = (out - 0).pow(2).sum() return l + x = x.to(device=torch_device) + m.to(device=torch_device) + out_me = _run(x) with set_layer_config(scriptable=True): diff --git a/tests/test_models.py b/tests/test_models.py index b1b2bf19..a6411a78 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -30,6 +30,17 @@ from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_v from timm.layers import Format, get_spatial_dim, get_channel_dim from timm.models import get_notrace_modules, get_notrace_functions +import importlib +import os + +torch_backend = os.environ.get('TORCH_BACKEND') +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get('TORCH_DEVICE', 'cpu') +timeout = os.environ.get('TIMEOUT') +timeout120 = int(timeout) if timeout else 120 +timeout300 = int(timeout) if timeout else 300 + if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests # no need for the fusion performance here @@ -100,7 +111,7 @@ def _get_input_size(model=None, model_name='', target=None): @pytest.mark.base -@pytest.mark.timeout(120) +@pytest.mark.timeout(timeout120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): @@ -112,6 +123,8 @@ def test_model_forward(model_name, batch_size): if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) + inputs = inputs.to(torch_device) + model.to(torch_device) outputs = model(inputs) assert outputs.shape[0] == batch_size @@ -119,7 +132,7 @@ def test_model_forward(model_name, batch_size): @pytest.mark.base -@pytest.mark.timeout(120) +@pytest.mark.timeout(timeout120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): @@ -133,6 +146,8 @@ def test_model_backward(model_name, batch_size): model.train() inputs = torch.randn((batch_size, *input_size)) + inputs = inputs.to(torch_device) + model.to(torch_device) outputs = model(inputs) if isinstance(outputs, tuple): outputs = torch.cat(outputs) @@ -147,7 +162,7 @@ def test_model_backward(model_name, batch_size): @pytest.mark.cfg -@pytest.mark.timeout(300) +@pytest.mark.timeout(timeout300) @pytest.mark.parametrize('model_name', list_models( exclude_filters=EXCLUDE_FILTERS + NON_STD_FILTERS, include_tags=True)) @pytest.mark.parametrize('batch_size', [1]) @@ -155,6 +170,7 @@ def test_model_default_cfgs(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False) model.eval() + model.to(torch_device) state_dict = model.state_dict() cfg = model.default_cfg @@ -169,7 +185,7 @@ def test_model_default_cfgs(model_name, batch_size): not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): # output sizes only checked if default res <= 448 * 448 to keep resource down input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size]) - input_tensor = torch.randn((batch_size, *input_size)) + input_tensor = torch.randn((batch_size, *input_size), device=torch_device) # test forward_features (always unpooled) outputs = model.forward_features(input_tensor) @@ -180,12 +196,14 @@ def test_model_default_cfgs(model_name, batch_size): # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) + model.to(torch_device) outputs = model.forward(input_tensor) assert len(outputs.shape) == 2 assert outputs.shape[1] == model.num_features # test model forward without pooling and classifier model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through + model.to(torch_device) outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): @@ -195,6 +213,7 @@ def test_model_default_cfgs(model_name, batch_size): if 'pruned' not in model_name: # FIXME better pruned model handling # test classifier + global pool deletion via __init__ model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval() + model.to(torch_device) outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): @@ -218,13 +237,14 @@ def test_model_default_cfgs(model_name, batch_size): @pytest.mark.cfg -@pytest.mark.timeout(300) +@pytest.mark.timeout(timeout300) @pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs_non_std(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False) model.eval() + model.to(torch_device) state_dict = model.state_dict() cfg = model.default_cfg @@ -232,7 +252,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): if max(input_size) > 320: # FIXME const pytest.skip("Fixed input size model > limit.") - input_tensor = torch.randn((batch_size, *input_size)) + input_tensor = torch.randn((batch_size, *input_size), device=torch_device) feat_dim = getattr(model, 'feature_dim', None) outputs = model.forward_features(input_tensor) @@ -246,6 +266,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) + model.to(torch_device) outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] @@ -254,6 +275,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' model = create_model(model_name, pretrained=False, num_classes=0).eval() + model.to(torch_device) outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] @@ -297,7 +319,7 @@ if 'GITHUB_ACTIONS' not in os.environ: @pytest.mark.torchscript -@pytest.mark.timeout(120) +@pytest.mark.timeout(timeout120) @pytest.mark.parametrize( 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) @@ -312,6 +334,7 @@ def test_model_forward_torchscript(model_name, batch_size): model.eval() model = torch.jit.script(model) + model.to(torch_device) outputs = model(torch.randn((batch_size, *input_size))) assert outputs.shape[0] == batch_size diff --git a/tests/test_optim.py b/tests/test_optim.py index 9bdfd682..b1e900c2 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -15,6 +15,13 @@ from timm.scheduler import PlateauLRScheduler from timm.optim import create_optimizer_v2 +import importlib +import os + +torch_backend = os.environ.get('TORCH_BACKEND') +if torch_backend is not None: + importlib.import_module(torch_backend) +torch_device = os.environ.get('TORCH_DEVICE', 'cuda') # HACK relying on internal PyTorch test functionality for comparisons that I don't want to write torch_tc = TestCase() @@ -61,7 +68,7 @@ def _test_state_dict(weight, bias, input, constructor): def fn_base(optimizer, weight, bias): optimizer.zero_grad() - i = input_cuda if weight.is_cuda else input + i = input_device if weight.device.type != 'cpu' else input loss = (weight.mv(i) + bias).pow(2).sum() loss.backward() return loss @@ -97,28 +104,28 @@ def _test_state_dict(weight, bias, input, constructor): # Check that state dict can be loaded even when we cast parameters # to a different type and move to a different device. - if not torch.cuda.is_available(): + if torch_device == 'cpu': return with torch.no_grad(): - input_cuda = Parameter(input.clone().detach().float().cuda()) - weight_cuda = Parameter(weight.clone().detach().cuda()) - bias_cuda = Parameter(bias.clone().detach().cuda()) - optimizer_cuda = constructor(weight_cuda, bias_cuda) - fn_cuda = functools.partial(fn_base, optimizer_cuda, weight_cuda, bias_cuda) + input_device = Parameter(input.clone().detach().float().to(torch_device)) + weight_device = Parameter(weight.clone().detach().to(torch_device)) + bias_device = Parameter(bias.clone().detach().to(torch_device)) + optimizer_device = constructor(weight_device, bias_device) + fn_device = functools.partial(fn_base, optimizer_device, weight_device, bias_device) state_dict = deepcopy(optimizer.state_dict()) state_dict_c = deepcopy(optimizer.state_dict()) - optimizer_cuda.load_state_dict(state_dict_c) + optimizer_device.load_state_dict(state_dict_c) # Make sure state dict wasn't modified torch_tc.assertEqual(state_dict, state_dict_c) for _i in range(20): optimizer.step(fn) - optimizer_cuda.step(fn_cuda) - torch_tc.assertEqual(weight, weight_cuda) - torch_tc.assertEqual(bias, bias_cuda) + optimizer_device.step(fn_device) + torch_tc.assertEqual(weight, weight_device) + torch_tc.assertEqual(bias, bias_device) # validate deepcopy() copies all public attributes def getPublicAttr(obj): @@ -152,12 +159,12 @@ def _test_basic_cases(constructor, scheduler_constructors=None): scheduler_constructors ) # CUDA - if not torch.cuda.is_available(): + if torch_device == 'cpu': return _test_basic_cases_template( - torch.randn(10, 5).cuda(), - torch.randn(10).cuda(), - torch.randn(5).cuda(), + torch.randn(10, 5).to(torch_device), + torch.randn(10).to(torch_device), + torch.randn(5).to(torch_device), constructor, scheduler_constructors )