mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update metaformers.py
This commit is contained in:
parent
f400e8a3c9
commit
1b1b1d83b4
@ -706,13 +706,16 @@ class MetaFormer(nn.Module):
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.norm = self.output_norm(self.num_features)
|
||||
|
||||
'''
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
|
||||
if head_dropout > 0.0:
|
||||
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
||||
else:
|
||||
self.head = self.head_fn(self.num_features, self.num_classes)
|
||||
|
||||
'''
|
||||
self.reset_classifier(self.num_classes, global_pool)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@ -742,9 +745,9 @@ class MetaFormer(nn.Module):
|
||||
else:
|
||||
self.norm = self.output_norm(self.num_features)
|
||||
if self.head_dropout > 0.0:
|
||||
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
||||
self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout)
|
||||
else:
|
||||
self.head = self.head_fn(self.num_features, self.num_classes)
|
||||
self.head = self.head_fn(self.num_features, num_classes)
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if pre_logits:
|
||||
|
Loading…
x
Reference in New Issue
Block a user