mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix unit test for samvit
This commit is contained in:
parent
ea1f52df3e
commit
15de561f2c
@ -41,7 +41,7 @@ NON_STD_FILTERS = [
|
|||||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||||
'eva_*', 'flexivit*', 'eva02*'
|
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
|
||||||
]
|
]
|
||||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ def add_decomposed_rel_pos(
|
|||||||
|
|
||||||
|
|
||||||
class VisionTransformerSAM(nn.Module):
|
class VisionTransformerSAM(nn.Module):
|
||||||
""" Vision Transformer for vitsam or SAM
|
""" Vision Transformer for Segment-Anything Model(SAM)
|
||||||
|
|
||||||
A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)`
|
A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)`
|
||||||
- https://arxiv.org/abs/2010.11929
|
- https://arxiv.org/abs/2010.11929
|
||||||
@ -533,19 +533,19 @@ def _cfg(url='', **kwargs):
|
|||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
|
|
||||||
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
|
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
|
||||||
'vitsam_base_patch16.sa1b': _cfg(
|
'samvit_base_patch16.sa1b': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
license='apache-2.0',
|
license='apache-2.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||||
'vitsam_large_patch16.sa1b': _cfg(
|
'samvit_large_patch16.sa1b': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
license='apache-2.0',
|
license='apache-2.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||||
'vitsam_huge_patch16.sa1b': _cfg(
|
'samvit_huge_patch16.sa1b': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
|
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
license='apache-2.0',
|
license='apache-2.0',
|
||||||
@ -569,7 +569,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
def samvit_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||||
""" ViT-B/16 for Segment-Anything
|
""" ViT-B/16 for Segment-Anything
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
@ -577,12 +577,12 @@ def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
|||||||
window_size=14, use_rel_pos=True, img_size=1024,
|
window_size=14, use_rel_pos=True, img_size=1024,
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vitsam_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
'samvit_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
def samvit_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||||
""" ViT-L/16 for Segment-Anything
|
""" ViT-L/16 for Segment-Anything
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
@ -590,12 +590,12 @@ def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
|||||||
window_size=14, use_rel_pos=True, img_size=1024,
|
window_size=14, use_rel_pos=True, img_size=1024,
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vitsam_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
'samvit_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||||
""" ViT-H/16 for Segment-Anything
|
""" ViT-H/16 for Segment-Anything
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
@ -603,7 +603,7 @@ def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
|||||||
window_size=14, use_rel_pos=True, img_size=1024,
|
window_size=14, use_rel_pos=True, img_size=1024,
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vitsam_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user