mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix reset head in hgnet
This commit is contained in:
parent
6862c9850a
commit
56ae8b906d
@ -470,8 +470,8 @@ class PPHGNet(nn.Module):
|
||||
class_expand=self.class_expand,
|
||||
use_lab=self.use_lab)
|
||||
else:
|
||||
if self.global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
|
||||
if global_pool == 'avg':
|
||||
self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
else:
|
||||
self.head = nn.Identity()
|
||||
|
||||
@ -480,7 +480,7 @@ class PPHGNet(nn.Module):
|
||||
return self.stages(x)
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
return self.head(x, pre_logits=pre_logits)
|
||||
return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user