mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
add test_models
This commit is contained in:
parent
77788f4f92
commit
91e6e1737e
@ -53,13 +53,14 @@ FEAT_INTER_FILTERS = [
|
|||||||
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
|
||||||
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
|
||||||
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
|
||||||
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*'
|
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'swiftformer',
|
||||||
|
'starnet', 'shvit',
|
||||||
]
|
]
|
||||||
|
|
||||||
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
|
||||||
NON_STD_FILTERS = [
|
NON_STD_FILTERS = [
|
||||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*',
|
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'swiftformer_*',
|
||||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
|
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
|
||||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
|
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
|
||||||
]
|
]
|
||||||
|
@ -235,7 +235,7 @@ class StageBlock(nn.Module):
|
|||||||
PatchMerging(prev_dim, dim, act_layer),
|
PatchMerging(prev_dim, dim, act_layer),
|
||||||
Residule(Conv2d_BN(dim, dim, 3, 1, 1, groups=dim)),
|
Residule(Conv2d_BN(dim, dim, 3, 1, 1, groups=dim)),
|
||||||
Residule(FFN(dim, int(dim * 2), act_layer)),
|
Residule(FFN(dim, int(dim * 2), act_layer)),
|
||||||
) if prev_dim is not None else nn.Identity()
|
) if prev_dim != dim else nn.Identity()
|
||||||
|
|
||||||
self.block = nn.Sequential(*[
|
self.block = nn.Sequential(*[
|
||||||
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
|
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
|
||||||
@ -269,19 +269,20 @@ class SHViT(nn.Module):
|
|||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
|
|
||||||
# Patch embedding
|
# Patch embedding
|
||||||
|
stem_chs = embed_dim[0]
|
||||||
self.patch_embed = nn.Sequential(
|
self.patch_embed = nn.Sequential(
|
||||||
Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1),
|
Conv2d_BN(in_chans, stem_chs // 8, 3, 2, 1),
|
||||||
act_layer(),
|
act_layer(),
|
||||||
Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1),
|
Conv2d_BN(stem_chs // 8, stem_chs // 4, 3, 2, 1),
|
||||||
act_layer(),
|
act_layer(),
|
||||||
Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1),
|
Conv2d_BN(stem_chs // 4, stem_chs // 2, 3, 2, 1),
|
||||||
act_layer(),
|
act_layer(),
|
||||||
Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1)
|
Conv2d_BN(stem_chs // 2, stem_chs, 3, 2, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build SHViT blocks
|
# Build SHViT blocks
|
||||||
blocks = []
|
blocks = []
|
||||||
prev_chs = None
|
prev_chs = stem_chs
|
||||||
for i in range(len(embed_dim)):
|
for i in range(len(embed_dim)):
|
||||||
blocks.append(StageBlock(
|
blocks.append(StageBlock(
|
||||||
prev_dim=prev_chs,
|
prev_dim=prev_chs,
|
||||||
|
@ -497,11 +497,7 @@ class SwiftFormer(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_head(
|
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
pre_logits: bool = False,
|
|
||||||
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
||||||
if self.global_pool == 'avg':
|
if self.global_pool == 'avg':
|
||||||
x = x.mean(dim=(2, 3))
|
x = x.mean(dim=(2, 3))
|
||||||
x = self.head_drop(x)
|
x = self.head_drop(x)
|
||||||
@ -515,7 +511,7 @@ class SwiftFormer(nn.Module):
|
|||||||
# during standard train/finetune, inference average the classifier predictions
|
# during standard train/finetune, inference average the classifier predictions
|
||||||
return (x + x_dist) / 2
|
return (x + x_dist) / 2
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
def forward(self, x: torch.Tensor):
|
||||||
x = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
x = self.forward_head(x)
|
x = self.forward_head(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
x
Reference in New Issue
Block a user