diff --git a/tests/test_models.py b/tests/test_models.py index d8ac8d64..e75b17f9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -41,7 +41,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'eva_*', 'flexivit*', 'eva02*', 'samvit_*' + 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*' ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index d7c01a0f..674e960c 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -661,14 +661,14 @@ def efficientvit_b0(pretrained=False, **kwargs): @register_model def efficientvit_b1(pretrained=False, **kwargs): model_args = dict(width_list=[16, 32, 64, 128, 256], depth_list=[1, 2, 3, 3, 4], dim=16, head_width_list=[1536, 1600]) - return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def efficientvit_b2(pretrained=False, **kwargs): model_args = dict(width_list=[24, 48, 96, 192, 384], depth_list=[1, 3, 4, 4, 6], dim=32, head_width_list=[2304, 2560]) - return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def efficientvit_b3(pretrained=False, **kwargs): model_args = dict(width_list=[32, 64, 128, 256, 512], depth_list=[1, 4, 6, 6, 9], dim=32, head_width_list=[2304, 2560]) - return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index 57d5cf6d..ca93283e 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -373,8 +373,8 @@ class EfficientViTMSRA(nn.Module): self.stages = nn.Sequential(*stages) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') - self.out_dims = embed_dim[-1] - self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity() + self.num_features = embed_dim[-1] + self.head = BNLinear(self.num_features, num_classes) if num_classes > 0 else torch.nn.Identity() @torch.jit.ignore def group_matcher(self, coarse=False): @@ -396,7 +396,7 @@ class EfficientViTMSRA(nn.Module): self.num_classes = num_classes if global_pool is not None: self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') - self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity() + self.head = BNLinear(self.num_features, num_classes) if num_classes > 0 else torch.nn.Identity() def forward_features(self, x): x = self.patch_embed(x)