diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index b7c0f5dd..d34b5485 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -276,7 +276,7 @@ class GhostNet(nn.Module): # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.conv_stem(x)