Fix tests for rank-4 output where feature channels dim is -1 (3) and not 1
parent
d79f3d9d1e
commit
39b725e1c9
|
@ -202,12 +202,14 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||||
pytest.skip("Fixed input size model > limit.")
|
pytest.skip("Fixed input size model > limit.")
|
||||||
|
|
||||||
input_tensor = torch.randn((batch_size, *input_size))
|
input_tensor = torch.randn((batch_size, *input_size))
|
||||||
|
feat_dim = getattr(model, 'feature_dim', None)
|
||||||
|
|
||||||
outputs = model.forward_features(input_tensor)
|
outputs = model.forward_features(input_tensor)
|
||||||
if isinstance(outputs, (tuple, list)):
|
if isinstance(outputs, (tuple, list)):
|
||||||
# cannot currently verify multi-tensor output.
|
# cannot currently verify multi-tensor output.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
if feat_dim is None:
|
||||||
feat_dim = -1 if outputs.ndim == 3 else 1
|
feat_dim = -1 if outputs.ndim == 3 else 1
|
||||||
assert outputs.shape[feat_dim] == model.num_features
|
assert outputs.shape[feat_dim] == model.num_features
|
||||||
|
|
||||||
|
@ -216,6 +218,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
if isinstance(outputs, (tuple, list)):
|
if isinstance(outputs, (tuple, list)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
|
if feat_dim is None:
|
||||||
feat_dim = -1 if outputs.ndim == 3 else 1
|
feat_dim = -1 if outputs.ndim == 3 else 1
|
||||||
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
|
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
|
||||||
|
|
||||||
|
@ -223,6 +226,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
|
||||||
outputs = model.forward(input_tensor)
|
outputs = model.forward(input_tensor)
|
||||||
if isinstance(outputs, (tuple, list)):
|
if isinstance(outputs, (tuple, list)):
|
||||||
outputs = outputs[0]
|
outputs = outputs[0]
|
||||||
|
if feat_dim is None:
|
||||||
feat_dim = -1 if outputs.ndim == 3 else 1
|
feat_dim = -1 if outputs.ndim == 3 else 1
|
||||||
assert outputs.shape[feat_dim] == model.num_features
|
assert outputs.shape[feat_dim] == model.num_features
|
||||||
|
|
||||||
|
|
|
@ -288,6 +288,7 @@ class Sequencer2D(nn.Module):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = embed_dims[-1] # num_features for consistency with other models
|
self.num_features = embed_dims[-1] # num_features for consistency with other models
|
||||||
|
self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
|
||||||
self.embed_dims = embed_dims
|
self.embed_dims = embed_dims
|
||||||
self.stem = PatchEmbed(
|
self.stem = PatchEmbed(
|
||||||
img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans,
|
img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans,
|
||||||
|
@ -333,7 +334,7 @@ class Sequencer2D(nn.Module):
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool=None):
|
def reset_classifier(self, num_classes, global_pool=None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if self.global_pool is not None:
|
if global_pool is not None:
|
||||||
assert global_pool in ('', 'avg')
|
assert global_pool in ('', 'avg')
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
Loading…
Reference in New Issue