mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add convnext_base CLIP image tower weights for fine-tuning / features
This commit is contained in:
parent
65aea97067
commit
42bd8f7bcb
@ -205,6 +205,7 @@ class ConvNeXt(nn.Module):
|
||||
use_grn=False,
|
||||
act_layer='gelu',
|
||||
norm_layer=None,
|
||||
norm_eps=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
):
|
||||
@ -236,10 +237,15 @@ class ConvNeXt(nn.Module):
|
||||
if norm_layer is None:
|
||||
norm_layer = LayerNorm2d
|
||||
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
||||
if norm_eps is not None:
|
||||
norm_layer = partial(norm_layer, eps=norm_eps)
|
||||
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
||||
else:
|
||||
assert conv_mlp,\
|
||||
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
||||
norm_layer_cl = norm_layer
|
||||
if norm_eps is not None:
|
||||
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
@ -250,7 +256,7 @@ class ConvNeXt(nn.Module):
|
||||
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
|
||||
norm_layer(dims[0])
|
||||
norm_layer(dims[0]),
|
||||
)
|
||||
stem_stride = patch_size
|
||||
else:
|
||||
@ -376,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
return state_dict # non-FB checkpoint
|
||||
if 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
|
||||
out_dict = {}
|
||||
if 'visual.trunk.stem.0.weight' in state_dict:
|
||||
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
|
||||
if 'visual.head.proj.weight' in state_dict:
|
||||
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
|
||||
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
|
||||
return out_dict
|
||||
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('downsample_layers.0.', 'stem.')
|
||||
@ -395,6 +409,7 @@ def checkpoint_filter_fn(state_dict, model):
|
||||
model_shape = model.state_dict()[k].shape
|
||||
v = v.reshape(model_shape)
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
@ -685,6 +700,28 @@ default_cfgs = generate_default_cfgs({
|
||||
num_classes=0),
|
||||
|
||||
'convnextv2_small.untrained': _cfg(),
|
||||
|
||||
# CLIP based weights, original image tower weights and fine-tunes
|
||||
'convnext_base.clip_laion2b': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laion2b_augreg': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laiona': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laiona_320': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
|
||||
'convnext_base.clip_laiona_augreg_320': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
|
||||
})
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user