fix unit test for samvit

This commit is contained in:
方曦 2023-05-17 12:51:12 +08:00
parent ea1f52df3e
commit 15de561f2c
2 changed files with 11 additions and 11 deletions

View File

@ -41,7 +41,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'eva_*', 'flexivit*', 'eva02*'
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

View File

@ -303,7 +303,7 @@ def add_decomposed_rel_pos(
class VisionTransformerSAM(nn.Module):
""" Vision Transformer for vitsam or SAM
""" Vision Transformer for Segment-Anything Model(SAM)
A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)`
- https://arxiv.org/abs/2010.11929
@ -533,19 +533,19 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
'vitsam_base_patch16.sa1b': _cfg(
'samvit_base_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'vitsam_large_patch16.sa1b': _cfg(
'samvit_large_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0),
'vitsam_huge_patch16.sa1b': _cfg(
'samvit_huge_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
hf_hub_id='timm/',
license='apache-2.0',
@ -569,7 +569,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
@register_model
def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
def samvit_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
""" ViT-B/16 for Segment-Anything
"""
model_args = dict(
@ -577,12 +577,12 @@ def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
window_size=14, use_rel_pos=True, img_size=1024,
)
model = _create_vision_transformer(
'vitsam_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
'samvit_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
def samvit_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
""" ViT-L/16 for Segment-Anything
"""
model_args = dict(
@ -590,12 +590,12 @@ def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
window_size=14, use_rel_pos=True, img_size=1024,
)
model = _create_vision_transformer(
'vitsam_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
'samvit_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
""" ViT-H/16 for Segment-Anything
"""
model_args = dict(
@ -603,7 +603,7 @@ def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
window_size=14, use_rel_pos=True, img_size=1024,
)
model = _create_vision_transformer(
'vitsam_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
# TODO: