mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update metaformers.py
This commit is contained in:
parent
61e8414ad0
commit
199b443884
@ -26,7 +26,7 @@ from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import trunc_normal_, DropPath
|
||||
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d
|
||||
from timm.layers.helpers import to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import FeatureInfo
|
||||
@ -640,6 +640,7 @@ class MetaFormer(nn.Module):
|
||||
res_scale_init_values=[None, None, 1.0, 1.0],
|
||||
output_norm=partial(nn.LayerNorm, eps=1e-6),
|
||||
head_fn=nn.Linear,
|
||||
global_pool = 'avg',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@ -705,7 +706,7 @@ class MetaFormer(nn.Module):
|
||||
self.stages = nn.Sequential(*stages)
|
||||
self.norm = output_norm(dims[-1])
|
||||
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
|
||||
if head_dropout > 0.0:
|
||||
self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout)
|
||||
@ -731,7 +732,7 @@ class MetaFormer(nn.Module):
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
|
||||
if num_classes == 0:
|
||||
self.head= nn.Identity()
|
||||
self.head = nn.Identity()
|
||||
self.norm = nn.Identity()
|
||||
else:
|
||||
if self.head_dropout > 0.0:
|
||||
@ -743,7 +744,7 @@ class MetaFormer(nn.Module):
|
||||
if pre_logits:
|
||||
return x
|
||||
|
||||
x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean
|
||||
x = self.global_pool(x)
|
||||
x = self.norm(x)
|
||||
# (B, H, W, C) -> (B, C)
|
||||
x = self.head(x)
|
||||
|
Loading…
x
Reference in New Issue
Block a user