mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add num_classes assertion after reset_classifier
This commit is contained in:
parent
17eabaad17
commit
bda46f8e6f
@ -265,6 +265,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}'
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
assert len(outputs.shape) == 2
|
||||
@ -339,6 +340,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||
|
||||
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
|
||||
model.reset_classifier(0)
|
||||
assert model.num_classes == 0, f'Expected num_classes to be 0 after reset_classifier(0), but got {model.num_classes}'
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_tensor)
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
|
Loading…
x
Reference in New Issue
Block a user