fix loading pretrained weight for samvit

This commit is contained in:
方曦 2023-05-18 08:49:29 +08:00
parent 15de561f2c
commit c1c6eeb909

View File

@ -512,11 +512,11 @@ def checkpoint_filter_fn(
""" Remap SAM checkpoints -> timm """ """ Remap SAM checkpoints -> timm """
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'image_encoder.' in k: if 'image_encoder.' in k:
new_k = k.replace('image_encoder.', '') new_k = k.replace('image_encoder.', '')
new_k = new_k.replace('mlp.lin', 'mlp.fc') new_k = new_k.replace('mlp.lin', 'mlp.fc')
out_dict[new_k] = v out_dict[new_k] = v
return state_dict return out_dict
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -535,19 +535,19 @@ default_cfgs = generate_default_cfgs({
# Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only) # Segment-Anyhing Model (SAM) pretrained - https://github.com/facebookresearch/segment-anything (no classifier head, for fine-tune/features only)
'samvit_base_patch16.sa1b': _cfg( 'samvit_base_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth', url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
hf_hub_id='timm/', # hf_hub_id='timm/',
license='apache-2.0', license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0), input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_large_patch16.sa1b': _cfg( 'samvit_large_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth', url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
hf_hub_id='timm/', # hf_hub_id='timm/',
license='apache-2.0', license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0), input_size=(3, 1024, 1024), crop_pct=1.0),
'samvit_huge_patch16.sa1b': _cfg( 'samvit_huge_patch16.sa1b': _cfg(
url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', url='https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
hf_hub_id='timm/', # hf_hub_id='timm/',
license='apache-2.0', license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 1024, 1024), crop_pct=1.0), input_size=(3, 1024, 1024), crop_pct=1.0),