Make NormMlpClassifier head reset args consistent with ClassifierHead

This commit is contained in:
Ross Wightman 2024-02-10 16:25:33 -08:00
parent 87fec3dc14
commit d6c2cc91af
3 changed files with 6 additions and 6 deletions

View File

@ -180,10 +180,10 @@ class NormMlpClassifierHead(nn.Module):
self.drop = nn.Dropout(drop_rate) self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None): def reset(self, num_classes, pool_type=None):
if global_pool is not None: if pool_type is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
self.use_conv = self.global_pool.is_identity() self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size: if self.hidden_size:

View File

@ -569,7 +569,7 @@ class DaVit(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
self.head.reset(num_classes, global_pool=global_pool) self.head.reset(num_classes, global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)

View File

@ -535,7 +535,7 @@ class TinyVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes self.num_classes = num_classes
self.head.reset(num_classes, global_pool=global_pool) self.head.reset(num_classes, pool_type=global_pool)
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)