From 8c9696c9df93d54ac17e0afadf5ef687f329fb8f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Aug 2022 17:40:31 -0700 Subject: [PATCH] More model and test fixes --- tests/test_models.py | 4 +++- timm/models/gcvit.py | 31 +++++++++++++++------------ timm/models/layers/create_norm_act.py | 3 +++ timm/models/layers/norm_act.py | 1 + timm/models/mvitv2.py | 7 +++++- timm/models/pvt_v2.py | 2 +- 6 files changed, 31 insertions(+), 17 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0f9b8c0b..5daee76d 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'coatne?t_*', 'max?vit_*', +] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index e7eccea8..bad40bd6 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -43,7 +43,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head.fc', + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', 'fixed_input_size': True, **kwargs } @@ -106,7 +106,7 @@ class Downsample2d(nn.Module): dim_out=None, reduction='conv', act_layer=nn.GELU, - norm_layer=LayerNorm2d, + norm_layer=LayerNorm2d, # NOTE in NCHW ): super().__init__() dim_out = dim_out or dim @@ -163,12 +163,10 @@ class Stem(nn.Module): self, in_chs: int = 3, out_chs: int = 96, - act_layer: str = 'gelu', - norm_layer: str = 'layernorm2d', # NOTE norm for NCHW + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW ): super().__init__() - act_layer = get_act_layer(act_layer) - norm_layer = get_norm_layer(norm_layer) self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1) self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer) @@ -333,15 +331,11 @@ class GlobalContextVitStage(nn.Module): proj_drop: float = 0., attn_drop: float = 0., drop_path: Union[List[float], float] = 0.0, - act_layer: str = 'gelu', - norm_layer: str = 'layernorm2d', - norm_layer_cl: str = 'layernorm', + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + norm_layer_cl: Callable = LayerNorm2d, ): super().__init__() - act_layer = get_act_layer(act_layer) - norm_layer = get_norm_layer(norm_layer) - norm_layer_cl = get_norm_layer(norm_layer_cl) - if downsample: self.downsample = Downsample2d( dim=dim, @@ -421,8 +415,13 @@ class GlobalContextVit(nn.Module): act_layer: str = 'gelu', norm_layer: str = 'layernorm2d', norm_layer_cl: str = 'layernorm', + norm_eps: float = 1e-5, ): super().__init__() + act_layer = get_act_layer(act_layer) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) + img_size = to_2tuple(img_size) feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4 self.global_pool = global_pool @@ -432,7 +431,11 @@ class GlobalContextVit(nn.Module): self.num_features = int(embed_dim * 2 ** (num_stages - 1)) self.stem = Stem( - in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer) + in_chs=in_chans, + out_chs=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer + ) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] stages = [] diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index cd15c2f8..78dd9a51 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, + groupnorm1=functools.partial(GroupNormAct, num_groups=1), layernorm=LayerNormAct, layernorm2d=LayerNormAct2d, evonormb0=EvoNorm2dB0, @@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None): norm_act_layer = BatchNormAct2d elif type_name.startswith('groupnorm'): norm_act_layer = GroupNormAct + elif type_name.startswith('groupnorm1'): + norm_act_layer = functools.partial(GroupNormAct, num_groups=1) elif type_name.startswith('layernorm2d'): norm_act_layer = LayerNormAct2d elif type_name.startswith('layernorm'): diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index be1edead..dc077160 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -226,6 +226,7 @@ class LayerNormAct2d(nn.LayerNorm): self.act = act_layer(**act_args) else: self.act = nn.Identity() + self._fast_norm = is_fast_norm() def forward(self, x): x = x.permute(0, 2, 3, 1) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index fc29f113..002225c6 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint as checkpoint from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function from .helpers import build_model_with_cfg from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple from .registry import register_model @@ -35,7 +36,8 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'fixed_input_size': True, **kwargs } @@ -169,6 +171,7 @@ class PatchEmbed(nn.Module): return x.flatten(2).transpose(1, 2), x.shape[-2:] +@register_notrace_function def reshape_pre_pool( x, feat_size: List[int], @@ -183,6 +186,7 @@ def reshape_pre_pool( return x, cls_tok +@register_notrace_function def reshape_post_pool( x, num_heads: int, @@ -196,6 +200,7 @@ def reshape_post_pool( return x, feat_size +@register_notrace_function def cal_rel_pos_type( attn: torch.Tensor, q: torch.Tensor, diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index ce4cbf56..dd3cf690 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -36,7 +36,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False, **kwargs }