diff --git a/tests/test_models.py b/tests/test_models.py index d4c39b39..3ba3615d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -265,6 +265,7 @@ def test_model_default_cfgs(model_name, batch_size): # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) + assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}' model.to(torch_device) outputs = model.forward(input_tensor) assert len(outputs.shape) == 2 @@ -339,6 +340,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) + assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}' model.to(torch_device) outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)):