mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix loading pretrained weight for samvit
This commit is contained in:
parent
15de561f2c
commit
c1c6eeb909
@ -512,11 +512,11 @@ def checkpoint_filter_fn(
|
||||
""" Remap SAM checkpoints -> timm """
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'image_encoder.' in k:
|
||||
new_k = k.replace('image_encoder.', '')
|
||||
new_k = new_k.replace('mlp.lin', 'mlp.fc')
|
||||
out_dict[new_k] = v
|
||||
return state_dict
|
||||
if 'image_encoder.' in k:
|
||||
new_k = k.replace('image_encoder.', '')
|
||||
new_k = new_k.replace('mlp.lin', 'mlp.fc')
|
||||
out_dict[new_k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
@ -535,19 +535,19 @@ default_cfgs = generate_default_cfgs({
|
||||
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
|
||||
'samvit_base_patch16.sa1b': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
||||
hf_hub_id='timm/',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||
'samvit_large_patch16.sa1b': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
||||
hf_hub_id='timm/',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||
'samvit_huge_patch16.sa1b': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
|
||||
hf_hub_id='timm/',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 1024, 1024), crop_pct=1.0),
|
||||
|
Loading…
x
Reference in New Issue
Block a user