Fix in_features for linear layer in reset_classifier.

This commit is contained in:
Thorsten Hempel 2023-09-13 10:38:13 +02:00 committed by Ross Wightman
parent 730b907b4d
commit 7eb7d13845

View File

@ -276,7 +276,7 @@ class GhostNet(nn.Module):
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)