mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More model and test fixes
This commit is contained in:
parent
ca52108c2b
commit
8c9696c9df
@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*']
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'coatne?t_*', 'max?vit_*',
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
|
@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
||||
'fixed_input_size': True,
|
||||
**kwargs
|
||||
}
|
||||
@ -106,7 +106,7 @@ class Downsample2d(nn.Module):
|
||||
dim_out=None,
|
||||
reduction='conv',
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=LayerNorm2d,
|
||||
norm_layer=LayerNorm2d, # NOTE in NCHW
|
||||
):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
@ -163,12 +163,10 @@ class Stem(nn.Module):
|
||||
self,
|
||||
in_chs: int = 3,
|
||||
out_chs: int = 96,
|
||||
act_layer: str = 'gelu',
|
||||
norm_layer: str = 'layernorm2d', # NOTE norm for NCHW
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
|
||||
):
|
||||
super().__init__()
|
||||
act_layer = get_act_layer(act_layer)
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
|
||||
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
|
||||
|
||||
@ -333,15 +331,11 @@ class GlobalContextVitStage(nn.Module):
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.0,
|
||||
act_layer: str = 'gelu',
|
||||
norm_layer: str = 'layernorm2d',
|
||||
norm_layer_cl: str = 'layernorm',
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
norm_layer_cl: Callable = LayerNorm2d,
|
||||
):
|
||||
super().__init__()
|
||||
act_layer = get_act_layer(act_layer)
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
norm_layer_cl = get_norm_layer(norm_layer_cl)
|
||||
|
||||
if downsample:
|
||||
self.downsample = Downsample2d(
|
||||
dim=dim,
|
||||
@ -421,8 +415,13 @@ class GlobalContextVit(nn.Module):
|
||||
act_layer: str = 'gelu',
|
||||
norm_layer: str = 'layernorm2d',
|
||||
norm_layer_cl: str = 'layernorm',
|
||||
norm_eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
act_layer = get_act_layer(act_layer)
|
||||
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
||||
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
||||
|
||||
img_size = to_2tuple(img_size)
|
||||
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
||||
self.global_pool = global_pool
|
||||
@ -432,7 +431,11 @@ class GlobalContextVit(nn.Module):
|
||||
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
||||
|
||||
self.stem = Stem(
|
||||
in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer)
|
||||
in_chs=in_chans,
|
||||
out_chs=embed_dim,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer
|
||||
)
|
||||
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
stages = []
|
||||
|
@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict(
|
||||
batchnorm=BatchNormAct2d,
|
||||
batchnorm2d=BatchNormAct2d,
|
||||
groupnorm=GroupNormAct,
|
||||
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
|
||||
layernorm=LayerNormAct,
|
||||
layernorm2d=LayerNormAct2d,
|
||||
evonormb0=EvoNorm2dB0,
|
||||
@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
|
||||
norm_act_layer = BatchNormAct2d
|
||||
elif type_name.startswith('groupnorm'):
|
||||
norm_act_layer = GroupNormAct
|
||||
elif type_name.startswith('groupnorm1'):
|
||||
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
|
||||
elif type_name.startswith('layernorm2d'):
|
||||
norm_act_layer = LayerNormAct2d
|
||||
elif type_name.startswith('layernorm'):
|
||||
|
@ -226,6 +226,7 @@ class LayerNormAct2d(nn.LayerNorm):
|
||||
self.act = act_layer(**act_args)
|
||||
else:
|
||||
self.act = nn.Identity()
|
||||
self._fast_norm = is_fast_norm()
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
|
@ -24,6 +24,7 @@ import torch.utils.checkpoint as checkpoint
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
||||
from .registry import register_model
|
||||
@ -35,7 +36,8 @@ def _cfg(url='', **kwargs):
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
'fixed_input_size': True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@ -169,6 +171,7 @@ class PatchEmbed(nn.Module):
|
||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def reshape_pre_pool(
|
||||
x,
|
||||
feat_size: List[int],
|
||||
@ -183,6 +186,7 @@ def reshape_pre_pool(
|
||||
return x, cls_tok
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def reshape_post_pool(
|
||||
x,
|
||||
num_heads: int,
|
||||
@ -196,6 +200,7 @@ def reshape_post_pool(
|
||||
return x, feat_size
|
||||
|
||||
|
||||
@register_notrace_function
|
||||
def cal_rel_pos_type(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
|
@ -36,7 +36,7 @@ def _cfg(url='', **kwargs):
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user