diff --git a/tests/test_models.py b/tests/test_models.py index 3ba3615d..5b23e86b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -53,13 +53,14 @@ FEAT_INTER_FILTERS = [ 'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos', 'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2', 'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet', - 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*' + 'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'swiftformer', + 'starnet', 'shvit', ] # transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output. NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', + 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'swiftformer_*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', ] diff --git a/timm/models/shvit.py b/timm/models/shvit.py index 4cbaff69..541c8729 100644 --- a/timm/models/shvit.py +++ b/timm/models/shvit.py @@ -235,7 +235,7 @@ class StageBlock(nn.Module): PatchMerging(prev_dim, dim, act_layer), Residule(Conv2d_BN(dim, dim, 3, 1, 1, groups=dim)), Residule(FFN(dim, int(dim * 2), act_layer)), - ) if prev_dim is not None else nn.Identity() + ) if prev_dim != dim else nn.Identity() self.block = nn.Sequential(*[ BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth) @@ -269,19 +269,20 @@ class SHViT(nn.Module): self.feature_info = [] # Patch embedding + stem_chs = embed_dim[0] self.patch_embed = nn.Sequential( - Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), + Conv2d_BN(in_chans, stem_chs // 8, 3, 2, 1), act_layer(), - Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), + Conv2d_BN(stem_chs // 8, stem_chs // 4, 3, 2, 1), act_layer(), - Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), + Conv2d_BN(stem_chs // 4, stem_chs // 2, 3, 2, 1), act_layer(), - Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1) + Conv2d_BN(stem_chs // 2, stem_chs, 3, 2, 1) ) # Build SHViT blocks blocks = [] - prev_chs = None + prev_chs = stem_chs for i in range(len(embed_dim)): blocks.append(StageBlock( prev_dim=prev_chs, diff --git a/timm/models/swiftformer.py b/timm/models/swiftformer.py index 748569b4..f6eadc2a 100644 --- a/timm/models/swiftformer.py +++ b/timm/models/swiftformer.py @@ -497,11 +497,7 @@ class SwiftFormer(nn.Module): x = self.norm(x) return x - def forward_head( - self, - x: torch.Tensor, - pre_logits: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + def forward_head(self, x: torch.Tensor, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean(dim=(2, 3)) x = self.head_drop(x) @@ -515,7 +511,7 @@ class SwiftFormer(nn.Module): # during standard train/finetune, inference average the classifier predictions return (x + x_dist) / 2 - def forward(self, x: torch.Tensor) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + def forward(self, x: torch.Tensor): x = self.forward_features(x) x = self.forward_head(x) return x