mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add mlp head support for convnext_large, add laion2b CLIP weights, prep fine-tuned weight tags
This commit is contained in:
parent
6f28b562c6
commit
316bdf8955
@ -397,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
if 'visual.head.proj.weight' in state_dict:
|
||||
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
|
||||
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||
elif 'visual.head.mlp.fc1.weight' in state_dict:
|
||||
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
|
||||
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
|
||||
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
|
||||
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
|
||||
return out_dict
|
||||
|
||||
import re
|
||||
@ -716,6 +721,16 @@ default_cfgs = generate_default_cfgs({
|
||||
|
||||
'convnextv2_small.untrained': _cfg(),
|
||||
|
||||
# CLIP weights, fine-tuned on in1k or in12k + in1k
|
||||
'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
|
||||
'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
|
||||
# CLIP based weights, original image tower weights and fine-tunes
|
||||
'convnext_base.clip_laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
|
||||
@ -742,6 +757,11 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
|
||||
'convnext_large_mlp.clip_laion2b_augreg': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
|
||||
})
|
||||
|
||||
|
||||
@ -854,6 +874,13 @@ def convnext_large(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_large_mlp(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
|
||||
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_xlarge(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user