mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
DFN CLIP ViT support
This commit is contained in:
parent
d5f1525334
commit
c55bc41a42
@ -904,16 +904,17 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
|
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_clip(state_dict, model):
|
def _convert_openai_clip(state_dict, model, prefix='visual.'):
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
swaps = [
|
swaps = [
|
||||||
('visual.', ''), ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
|
('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'),
|
||||||
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
|
('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'),
|
||||||
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
|
('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'),
|
||||||
]
|
]
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if not k.startswith('visual.'):
|
if not k.startswith(prefix):
|
||||||
continue
|
continue
|
||||||
|
k = k.replace(prefix, '')
|
||||||
for sp in swaps:
|
for sp in swaps:
|
||||||
k = k.replace(sp[0], sp[1])
|
k = k.replace(sp[0], sp[1])
|
||||||
|
|
||||||
@ -974,6 +975,8 @@ def checkpoint_filter_fn(
|
|||||||
|
|
||||||
if 'visual.class_embedding' in state_dict:
|
if 'visual.class_embedding' in state_dict:
|
||||||
return _convert_openai_clip(state_dict, model)
|
return _convert_openai_clip(state_dict, model)
|
||||||
|
elif 'module.visual.class_embedding' in state_dict:
|
||||||
|
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
|
||||||
|
|
||||||
if "mask_token" in state_dict:
|
if "mask_token" in state_dict:
|
||||||
state_dict = _convert_dinov2(state_dict, model)
|
state_dict = _convert_dinov2(state_dict, model)
|
||||||
@ -1416,6 +1419,10 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
|
hf_hub_id='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||||
|
'vit_base_patch16_clip_224.dfn2b': _cfg(
|
||||||
|
hf_hub_id='apple/DFN2B-CLIP-ViT-B-16',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||||
'vit_large_patch14_clip_224.laion2b': _cfg(
|
'vit_large_patch14_clip_224.laion2b': _cfg(
|
||||||
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
|
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
@ -1424,10 +1431,22 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
|
hf_hub_id='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||||
|
'vit_large_patch14_clip_224.dfn2b': _cfg(
|
||||||
|
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||||
'vit_huge_patch14_clip_224.laion2b': _cfg(
|
'vit_huge_patch14_clip_224.laion2b': _cfg(
|
||||||
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||||
|
'vit_huge_patch14_clip_224.dfn5b': _cfg(
|
||||||
|
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||||
|
'vit_huge_patch14_clip_378.dfn5b': _cfg(
|
||||||
|
hf_hub_id='apple/DFN5B-CLIP-ViT-H-14-378',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
|
||||||
'vit_giant_patch14_clip_224.laion2b': _cfg(
|
'vit_giant_patch14_clip_224.laion2b': _cfg(
|
||||||
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
|
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
|
||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
@ -2026,6 +2045,16 @@ def vit_huge_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_huge_patch14_clip_378(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378
|
||||||
|
"""
|
||||||
|
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_giant_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
def vit_giant_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||||
|
Loading…
x
Reference in New Issue
Block a user