mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix bug in ci for efficientvits
This commit is contained in:
parent
a56e2bbf19
commit
00f670fa69
@ -41,7 +41,7 @@ NON_STD_FILTERS = [
|
|||||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
|
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*'
|
||||||
]
|
]
|
||||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||||
|
|
||||||
|
@ -661,14 +661,14 @@ def efficientvit_b0(pretrained=False, **kwargs):
|
|||||||
@register_model
|
@register_model
|
||||||
def efficientvit_b1(pretrained=False, **kwargs):
|
def efficientvit_b1(pretrained=False, **kwargs):
|
||||||
model_args = dict(width_list=[16, 32, 64, 128, 256], depth_list=[1, 2, 3, 3, 4], dim=16, head_width_list=[1536, 1600])
|
model_args = dict(width_list=[16, 32, 64, 128, 256], depth_list=[1, 2, 3, 3, 4], dim=16, head_width_list=[1536, 1600])
|
||||||
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def efficientvit_b2(pretrained=False, **kwargs):
|
def efficientvit_b2(pretrained=False, **kwargs):
|
||||||
model_args = dict(width_list=[24, 48, 96, 192, 384], depth_list=[1, 3, 4, 4, 6], dim=32, head_width_list=[2304, 2560])
|
model_args = dict(width_list=[24, 48, 96, 192, 384], depth_list=[1, 3, 4, 4, 6], dim=32, head_width_list=[2304, 2560])
|
||||||
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def efficientvit_b3(pretrained=False, **kwargs):
|
def efficientvit_b3(pretrained=False, **kwargs):
|
||||||
model_args = dict(width_list=[32, 64, 128, 256, 512], depth_list=[1, 4, 6, 6, 9], dim=32, head_width_list=[2304, 2560])
|
model_args = dict(width_list=[32, 64, 128, 256, 512], depth_list=[1, 4, 6, 6, 9], dim=32, head_width_list=[2304, 2560])
|
||||||
return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
|
return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
@ -373,8 +373,8 @@ class EfficientViTMSRA(nn.Module):
|
|||||||
self.stages = nn.Sequential(*stages)
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||||||
self.out_dims = embed_dim[-1]
|
self.num_features = embed_dim[-1]
|
||||||
self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity()
|
self.head = BNLinear(self.num_features, num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse=False):
|
def group_matcher(self, coarse=False):
|
||||||
@ -396,7 +396,7 @@ class EfficientViTMSRA(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
if global_pool is not None:
|
if global_pool is not None:
|
||||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW')
|
||||||
self.head = BNLinear(self.out_dims, num_classes) if num_classes > 0 else torch.nn.Identity()
|
self.head = BNLinear(self.num_features, num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user