mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix unit test for samvit
This commit is contained in:
parent
ea1f52df3e
commit
15de561f2c
@ -41,7 +41,7 @@ NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
|
||||
'eva_*', 'flexivit*', 'eva02*'
|
||||
'eva_*', 'flexivit*', 'eva02*', 'samvit_*'
|
||||
]
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
|
@ -303,7 +303,7 @@ def add_decomposed_rel_pos(
|
||||
|
||||
|
||||
class VisionTransformerSAM(nn.Module):
|
||||
""" Vision Transformer for vitsam or SAM
|
||||
""" Vision Transformer for Segment-Anything Model(SAM)
|
||||
|
||||
A PyTorch impl of : `Exploring Plain Vision Transformer Backbones for Object Detection` or `Segment Anything Model (SAM)`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
@ -533,19 +533,19 @@ def _cfg(url='', **kwargs):
|
||||
default_cfgs = generate_default_cfgs({
|
||||
|
||||
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
|
||||
'vitsam_base_patch16.sa1b': _cfg(
|
||||
'samvit_base_patch16.sa1b': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
||||
hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||
'vitsam_large_patch16.sa1b': _cfg(
|
||||
'samvit_large_patch16.sa1b': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
||||
hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||
'vitsam_huge_patch16.sa1b': _cfg(
|
||||
'samvit_huge_patch16.sa1b': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
|
||||
hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
@ -569,7 +569,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
def samvit_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
""" ViT-B/16 for Segment-Anything
|
||||
"""
|
||||
model_args = dict(
|
||||
@ -577,12 +577,12 @@ def vitsam_base_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
window_size=14, use_rel_pos=True, img_size=1024,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vitsam_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'samvit_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
def samvit_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
""" ViT-L/16 for Segment-Anything
|
||||
"""
|
||||
model_args = dict(
|
||||
@ -590,12 +590,12 @@ def vitsam_large_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
window_size=14, use_rel_pos=True, img_size=1024,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vitsam_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'samvit_large_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
""" ViT-H/16 for Segment-Anything
|
||||
"""
|
||||
model_args = dict(
|
||||
@ -603,7 +603,7 @@ def vitsam_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
|
||||
window_size=14, use_rel_pos=True, img_size=1024,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vitsam_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
# TODO:
|
||||
|
Loading…
x
Reference in New Issue
Block a user