diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index ce2786bf..3d25e8c8 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -470,8 +470,8 @@ class PPHGNet(nn.Module): class_expand=self.class_expand, use_lab=self.use_lab) else: - if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + if global_pool == 'avg': + self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) else: self.head = nn.Identity() @@ -480,7 +480,7 @@ class PPHGNet(nn.Module): return self.stages(x) def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x)