From 8db20dc2401871b206280b4fd8de0a3ac94bf93f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 20 Mar 2023 09:37:05 -0700 Subject: [PATCH] Fix #1726, dropout not used in NormMlpClassifierHead. Make dropout more consistent across both classifier heads (nn.Dropout) --- timm/layers/classifier.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index a95a4dfe..78adbf9a 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -88,18 +88,20 @@ class ClassifierHead(nn.Module): drop_rate: Pre-classifier dropout rate. """ super(ClassifierHead, self).__init__() - self.drop_rate = drop_rate self.in_features = in_features self.use_conv = use_conv self.input_fmt = input_fmt - self.global_pool, self.fc = create_classifier( + global_pool, fc = create_classifier( in_features, num_classes, pool_type, use_conv=use_conv, input_fmt=input_fmt, ) + self.global_pool = global_pool + self.drop = nn.Dropout(drop_rate) + self.fc = fc self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() def reset(self, num_classes, pool_type=None): @@ -122,8 +124,7 @@ class ClassifierHead(nn.Module): def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) - if self.drop_rate: - x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.drop(x) if pre_logits: return self.flatten(x) x = self.fc(x) @@ -153,7 +154,6 @@ class NormMlpClassifierHead(nn.Module): act_layer: MLP activation layer type (only used if hidden_size is not None). """ super().__init__() - self.drop_rate = drop_rate self.in_features = in_features self.hidden_size = hidden_size self.num_features = in_features @@ -173,7 +173,7 @@ class NormMlpClassifierHead(nn.Module): self.num_features = hidden_size else: self.pre_logits = nn.Identity() - self.drop = nn.Dropout(self.drop_rate) + self.drop = nn.Dropout(drop_rate) self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def reset(self, num_classes, global_pool=None): @@ -197,6 +197,7 @@ class NormMlpClassifierHead(nn.Module): x = self.norm(x) x = self.flatten(x) x = self.pre_logits(x) + x = self.drop(x) if pre_logits: return x x = self.fc(x)