Fix #2272
parent
ebbe530ee4
commit
a50713ce6e
timm/models
|
@ -629,7 +629,7 @@ class VisionTransformer(nn.Module):
|
|||
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
|
||||
if global_pool == 'map' and self.attn_pool is None:
|
||||
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
||||
elif global_pool != 'map ' and self.attn_pool is not None:
|
||||
elif global_pool != 'map' and self.attn_pool is not None:
|
||||
self.attn_pool = None # remove attention pooling
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
|
Loading…
Reference in New Issue