Fix #1726, dropout not used in NormMlpClassifierHead. Make dropout more consistent across both classifier heads (nn.Dropout)

pull/1741/head
Ross Wightman 2023-03-20 09:37:05 -07:00
parent 041de79f9e
commit 8db20dc240
1 changed files with 7 additions and 6 deletions

View File

@ -88,18 +88,20 @@ class ClassifierHead(nn.Module):
drop_rate: Pre-classifier dropout rate. drop_rate: Pre-classifier dropout rate.
""" """
super(ClassifierHead, self).__init__() super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.in_features = in_features self.in_features = in_features
self.use_conv = use_conv self.use_conv = use_conv
self.input_fmt = input_fmt self.input_fmt = input_fmt
self.global_pool, self.fc = create_classifier( global_pool, fc = create_classifier(
in_features, in_features,
num_classes, num_classes,
pool_type, pool_type,
use_conv=use_conv, use_conv=use_conv,
input_fmt=input_fmt, input_fmt=input_fmt,
) )
self.global_pool = global_pool
self.drop = nn.Dropout(drop_rate)
self.fc = fc
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, pool_type=None): def reset(self, num_classes, pool_type=None):
@ -122,8 +124,7 @@ class ClassifierHead(nn.Module):
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: x = self.drop(x)
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
if pre_logits: if pre_logits:
return self.flatten(x) return self.flatten(x)
x = self.fc(x) x = self.fc(x)
@ -153,7 +154,6 @@ class NormMlpClassifierHead(nn.Module):
act_layer: MLP activation layer type (only used if hidden_size is not None). act_layer: MLP activation layer type (only used if hidden_size is not None).
""" """
super().__init__() super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features self.in_features = in_features
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_features = in_features self.num_features = in_features
@ -173,7 +173,7 @@ class NormMlpClassifierHead(nn.Module):
self.num_features = hidden_size self.num_features = hidden_size
else: else:
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate) self.drop = nn.Dropout(drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None): def reset(self, num_classes, global_pool=None):
@ -197,6 +197,7 @@ class NormMlpClassifierHead(nn.Module):
x = self.norm(x) x = self.norm(x)
x = self.flatten(x) x = self.flatten(x)
x = self.pre_logits(x) x = self.pre_logits(x)
x = self.drop(x)
if pre_logits: if pre_logits:
return x return x
x = self.fc(x) x = self.fc(x)