diff --git a/tests/test_models.py b/tests/test_models.py index 65a7ebb3..fd99bb46 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -66,5 +66,5 @@ def test_model_default_cfgs(model_name, batch_size): input_size = tuple([min(x, 448) for x in input_size]) outputs = model.forward_features(torch.randn((batch_size, *input_size))) assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] - assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params' - assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params' + assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params' + assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params'