diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 9a03f7df..289c83f3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -131,7 +131,7 @@ class EfficientNet(nn.Module): else: self.conv_head = nn.Identity() self.bn2 = nn.Identity() - self.num_features = head_chs + self.num_features = self.head_hidden_size = head_chs self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool)