mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make NormMlpClassifier head reset args consistent with ClassifierHead
This commit is contained in:
parent
87fec3dc14
commit
d6c2cc91af
@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module):
|
||||
self.drop = nn.Dropout(drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
def reset(self, num_classes, pool_type=None):
|
||||
if pool_type is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
self.use_conv = self.global_pool.is_identity()
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
if self.hidden_size:
|
||||
|
@ -569,7 +569,7 @@ class DaVit(nn.Module):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
|
@ -535,7 +535,7 @@ class TinyVit(nn.Module):
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
self.head.reset(num_classes, global_pool=global_pool)
|
||||
self.head.reset(num_classes, pool_type=global_pool)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user