mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More model and test fixes
This commit is contained in:
parent
ca52108c2b
commit
8c9696c9df
@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||||||
NON_STD_FILTERS = [
|
NON_STD_FILTERS = [
|
||||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
|
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
|
||||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*']
|
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||||
|
'coatne?t_*', 'max?vit_*',
|
||||||
|
]
|
||||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||||
|
|
||||||
# exclude models that cause specific test failures
|
# exclude models that cause specific test failures
|
||||||
|
@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
||||||
'fixed_input_size': True,
|
'fixed_input_size': True,
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ class Downsample2d(nn.Module):
|
|||||||
dim_out=None,
|
dim_out=None,
|
||||||
reduction='conv',
|
reduction='conv',
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=LayerNorm2d,
|
norm_layer=LayerNorm2d, # NOTE in NCHW
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
dim_out = dim_out or dim
|
dim_out = dim_out or dim
|
||||||
@ -163,12 +163,10 @@ class Stem(nn.Module):
|
|||||||
self,
|
self,
|
||||||
in_chs: int = 3,
|
in_chs: int = 3,
|
||||||
out_chs: int = 96,
|
out_chs: int = 96,
|
||||||
act_layer: str = 'gelu',
|
act_layer: Callable = nn.GELU,
|
||||||
norm_layer: str = 'layernorm2d', # NOTE norm for NCHW
|
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act_layer = get_act_layer(act_layer)
|
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
|
||||||
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
|
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
|
||||||
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
|
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
|
||||||
|
|
||||||
@ -333,15 +331,11 @@ class GlobalContextVitStage(nn.Module):
|
|||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
drop_path: Union[List[float], float] = 0.0,
|
drop_path: Union[List[float], float] = 0.0,
|
||||||
act_layer: str = 'gelu',
|
act_layer: Callable = nn.GELU,
|
||||||
norm_layer: str = 'layernorm2d',
|
norm_layer: Callable = nn.LayerNorm,
|
||||||
norm_layer_cl: str = 'layernorm',
|
norm_layer_cl: Callable = LayerNorm2d,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act_layer = get_act_layer(act_layer)
|
|
||||||
norm_layer = get_norm_layer(norm_layer)
|
|
||||||
norm_layer_cl = get_norm_layer(norm_layer_cl)
|
|
||||||
|
|
||||||
if downsample:
|
if downsample:
|
||||||
self.downsample = Downsample2d(
|
self.downsample = Downsample2d(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
@ -421,8 +415,13 @@ class GlobalContextVit(nn.Module):
|
|||||||
act_layer: str = 'gelu',
|
act_layer: str = 'gelu',
|
||||||
norm_layer: str = 'layernorm2d',
|
norm_layer: str = 'layernorm2d',
|
||||||
norm_layer_cl: str = 'layernorm',
|
norm_layer_cl: str = 'layernorm',
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
||||||
|
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
||||||
|
|
||||||
img_size = to_2tuple(img_size)
|
img_size = to_2tuple(img_size)
|
||||||
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
@ -432,7 +431,11 @@ class GlobalContextVit(nn.Module):
|
|||||||
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
||||||
|
|
||||||
self.stem = Stem(
|
self.stem = Stem(
|
||||||
in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer)
|
in_chs=in_chans,
|
||||||
|
out_chs=embed_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
norm_layer=norm_layer
|
||||||
|
)
|
||||||
|
|
||||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
stages = []
|
stages = []
|
||||||
|
@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict(
|
|||||||
batchnorm=BatchNormAct2d,
|
batchnorm=BatchNormAct2d,
|
||||||
batchnorm2d=BatchNormAct2d,
|
batchnorm2d=BatchNormAct2d,
|
||||||
groupnorm=GroupNormAct,
|
groupnorm=GroupNormAct,
|
||||||
|
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
|
||||||
layernorm=LayerNormAct,
|
layernorm=LayerNormAct,
|
||||||
layernorm2d=LayerNormAct2d,
|
layernorm2d=LayerNormAct2d,
|
||||||
evonormb0=EvoNorm2dB0,
|
evonormb0=EvoNorm2dB0,
|
||||||
@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
|
|||||||
norm_act_layer = BatchNormAct2d
|
norm_act_layer = BatchNormAct2d
|
||||||
elif type_name.startswith('groupnorm'):
|
elif type_name.startswith('groupnorm'):
|
||||||
norm_act_layer = GroupNormAct
|
norm_act_layer = GroupNormAct
|
||||||
|
elif type_name.startswith('groupnorm1'):
|
||||||
|
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
|
||||||
elif type_name.startswith('layernorm2d'):
|
elif type_name.startswith('layernorm2d'):
|
||||||
norm_act_layer = LayerNormAct2d
|
norm_act_layer = LayerNormAct2d
|
||||||
elif type_name.startswith('layernorm'):
|
elif type_name.startswith('layernorm'):
|
||||||
|
@ -226,6 +226,7 @@ class LayerNormAct2d(nn.LayerNorm):
|
|||||||
self.act = act_layer(**act_args)
|
self.act = act_layer(**act_args)
|
||||||
else:
|
else:
|
||||||
self.act = nn.Identity()
|
self.act = nn.Identity()
|
||||||
|
self._fast_norm = is_fast_norm()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.permute(0, 2, 3, 1)
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
@ -24,6 +24,7 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .fx_features import register_notrace_function
|
||||||
from .helpers import build_model_with_cfg
|
from .helpers import build_model_with_cfg
|
||||||
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
@ -35,7 +36,8 @@ def _cfg(url='', **kwargs):
|
|||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True,
|
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||||
|
'fixed_input_size': True,
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -169,6 +171,7 @@ class PatchEmbed(nn.Module):
|
|||||||
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
return x.flatten(2).transpose(1, 2), x.shape[-2:]
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
def reshape_pre_pool(
|
def reshape_pre_pool(
|
||||||
x,
|
x,
|
||||||
feat_size: List[int],
|
feat_size: List[int],
|
||||||
@ -183,6 +186,7 @@ def reshape_pre_pool(
|
|||||||
return x, cls_tok
|
return x, cls_tok
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
def reshape_post_pool(
|
def reshape_post_pool(
|
||||||
x,
|
x,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -196,6 +200,7 @@ def reshape_post_pool(
|
|||||||
return x, feat_size
|
return x, feat_size
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function
|
||||||
def cal_rel_pos_type(
|
def cal_rel_pos_type(
|
||||||
attn: torch.Tensor,
|
attn: torch.Tensor,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
@ -36,7 +36,7 @@ def _cfg(url='', **kwargs):
|
|||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False,
|
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user