mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support loading of paligemma weights into GAP variants of SigLIP ViT. Minor tweak to npz loading for packed transformer weights.
This commit is contained in:
parent
04462f554f
commit
7b3b11b63f
@ -10,7 +10,8 @@ from torch.hub import load_state_dict_from_url
|
|||||||
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
|
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
|
||||||
from timm.models._features_fx import FeatureGraphNet
|
from timm.models._features_fx import FeatureGraphNet
|
||||||
from timm.models._helpers import load_state_dict
|
from timm.models._helpers import load_state_dict
|
||||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
|
||||||
|
load_custom_from_hf
|
||||||
from timm.models._manipulate import adapt_input_conv
|
from timm.models._manipulate import adapt_input_conv
|
||||||
from timm.models._pretrained import PretrainedCfg
|
from timm.models._pretrained import PretrainedCfg
|
||||||
from timm.models._prune import adapt_model_from_file
|
from timm.models._prune import adapt_model_from_file
|
||||||
@ -185,7 +186,12 @@ def load_pretrained(
|
|||||||
elif load_from == 'hf-hub':
|
elif load_from == 'hf-hub':
|
||||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||||
if isinstance(pretrained_loc, (list, tuple)):
|
if isinstance(pretrained_loc, (list, tuple)):
|
||||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
custom_load = pretrained_cfg.get('custom_load', False)
|
||||||
|
if isinstance(custom_load, str) and custom_load == 'hf':
|
||||||
|
load_custom_from_hf(*pretrained_loc, model)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||||
else:
|
else:
|
||||||
|
@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
|||||||
return torch.load(cached_file, map_location='cpu')
|
return torch.load(cached_file, map_location='cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
|
||||||
|
assert has_hf_hub(True)
|
||||||
|
hf_model_id, hf_revision = hf_split(model_id)
|
||||||
|
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||||
|
return model.load_pretrained(cached_file)
|
||||||
|
|
||||||
|
|
||||||
def save_config_for_hf(
|
def save_config_for_hf(
|
||||||
model,
|
model,
|
||||||
config_path: str,
|
config_path: str,
|
||||||
|
@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def _n2p(w, t=True):
|
def _n2p(w, t=True, idx=None):
|
||||||
|
if idx is not None:
|
||||||
|
w = w[idx]
|
||||||
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
||||||
w = w.flatten()
|
w = w.flatten()
|
||||||
if t:
|
if t:
|
||||||
@ -955,21 +957,28 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
|
|
||||||
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
|
||||||
for i, block in enumerate(model.blocks.children()):
|
for i, block in enumerate(model.blocks.children()):
|
||||||
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
|
||||||
|
block_prefix = f'{prefix}Transformer/encoderblock/'
|
||||||
|
idx = i
|
||||||
|
else:
|
||||||
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
||||||
|
idx = None
|
||||||
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
|
||||||
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
|
||||||
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
|
||||||
block.attn.qkv.weight.copy_(torch.cat([
|
block.attn.qkv.weight.copy_(torch.cat([
|
||||||
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
|
||||||
block.attn.qkv.bias.copy_(torch.cat([
|
block.attn.qkv.bias.copy_(torch.cat([
|
||||||
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
|
||||||
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
|
||||||
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
|
||||||
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
|
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
|
||||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
|
||||||
for r in range(2):
|
for r in range(2):
|
||||||
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
|
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
|
||||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
|
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
|
||||||
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
|
||||||
|
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_clip(
|
def _convert_openai_clip(
|
||||||
@ -1769,6 +1778,44 @@ default_cfgs = {
|
|||||||
input_size=(3, 384, 384),
|
input_size=(3, 384, 384),
|
||||||
num_classes=0),
|
num_classes=0),
|
||||||
|
|
||||||
|
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
|
||||||
|
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
num_classes=0),
|
||||||
|
'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
|
||||||
|
hf_hub_id='google/paligemma-3b-mix-224-jax',
|
||||||
|
hf_hub_filename='paligemma-3b-mix-224.npz',
|
||||||
|
custom_load='hf',
|
||||||
|
num_classes=0),
|
||||||
|
'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
|
||||||
|
hf_hub_id='google/paligemma-3b-pt-224-jax',
|
||||||
|
hf_hub_filename='paligemma-3b-pt-224.npz',
|
||||||
|
custom_load='hf',
|
||||||
|
num_classes=0),
|
||||||
|
'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
|
||||||
|
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0,
|
||||||
|
num_classes=0),
|
||||||
|
'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
|
||||||
|
hf_hub_id='google/paligemma-3b-mix-448-jax',
|
||||||
|
hf_hub_filename='paligemma-3b-mix-448.npz',
|
||||||
|
custom_load='hf',
|
||||||
|
input_size=(3, 448, 448), crop_pct=1.0,
|
||||||
|
num_classes=0),
|
||||||
|
'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
|
||||||
|
hf_hub_id='google/paligemma-3b-pt-448-jax',
|
||||||
|
hf_hub_filename='paligemma-3b-pt-448.npz',
|
||||||
|
custom_load='hf',
|
||||||
|
input_size=(3, 448, 448), crop_pct=1.0,
|
||||||
|
num_classes=0),
|
||||||
|
'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
|
||||||
|
hf_hub_id='google/paligemma-3b-pt-896-jax',
|
||||||
|
hf_hub_filename='paligemma-3b-pt-896.npz',
|
||||||
|
custom_load='hf',
|
||||||
|
input_size=(3, 896, 896), crop_pct=1.0,
|
||||||
|
num_classes=0),
|
||||||
|
|
||||||
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
|
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
@ -2756,15 +2803,48 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
# @register_model
|
@register_model
|
||||||
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
# model_args = dict(
|
model_args = dict(
|
||||||
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
|
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||||
# no_embed_class=True, reg_tokens=4,
|
class_token=False, global_pool='avg', fc_norm=False,
|
||||||
# )
|
)
|
||||||
# model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
# return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||||
|
class_token=False, global_pool='avg', fc_norm=False,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||||
|
class_token=False, global_pool='avg', fc_norm=False,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
|
||||||
|
class_token=False, global_pool='avg', fc_norm=False,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user