pull/2275/head
Ross Wightman 2024-09-02 13:20:05 -07:00
parent ebbe530ee4
commit a50713ce6e
1 changed files with 1 additions and 1 deletions

View File

@ -629,7 +629,7 @@ class VisionTransformer(nn.Module):
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
elif global_pool != 'map' and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()