Add num_classes assertion after reset_classifier

This commit is contained in:
Ryan 2025-01-22 01:43:32 +08:00 committed by Ross Wightman
parent 17eabaad17
commit bda46f8e6f

View File

@ -265,6 +265,7 @@ def test_model_default_cfgs(model_name, batch_size):
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}'
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 2
@ -339,6 +340,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}'
model.to(torch_device)
outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)):