mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Make NormMlpClassifier head reset args consistent with ClassifierHead
This commit is contained in:
parent
87fec3dc14
commit
d6c2cc91af
@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module):
|
|||||||
self.drop = nn.Dropout(drop_rate)
|
self.drop = nn.Dropout(drop_rate)
|
||||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def reset(self, num_classes, global_pool=None):
|
def reset(self, num_classes, pool_type=None):
|
||||||
if global_pool is not None:
|
if pool_type is not None:
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||||
self.use_conv = self.global_pool.is_identity()
|
self.use_conv = self.global_pool.is_identity()
|
||||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||||
if self.hidden_size:
|
if self.hidden_size:
|
||||||
|
@ -569,7 +569,7 @@ class DaVit(nn.Module):
|
|||||||
return self.head.fc
|
return self.head.fc
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool=None):
|
def reset_classifier(self, num_classes, global_pool=None):
|
||||||
self.head.reset(num_classes, global_pool=global_pool)
|
self.head.reset(num_classes, global_pool)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.stem(x)
|
x = self.stem(x)
|
||||||
|
@ -535,7 +535,7 @@ class TinyVit(nn.Module):
|
|||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool=None):
|
def reset_classifier(self, num_classes, global_pool=None):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head.reset(num_classes, global_pool=global_pool)
|
self.head.reset(num_classes, pool_type=global_pool)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user