fix reset head in hgnet

This commit is contained in:
SeeFun 2023-12-27 20:11:29 +08:00 committed by GitHub
parent 6862c9850a
commit 56ae8b906d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -470,8 +470,8 @@ class PPHGNet(nn.Module):
class_expand=self.class_expand,
use_lab=self.use_lab)
else:
if self.global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
if global_pool == 'avg':
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
else:
self.head = nn.Identity()
@ -480,7 +480,7 @@ class PPHGNet(nn.Module):
return self.stages(x)
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)