Switch hf hub entries for new aimv2 / dfn weights to point to timm locations. Undo forced device for SDR linspace, part of another change.

pull/2045/merge
Ross Wightman 2024-12-30 16:59:55 -08:00 committed by Ross Wightman
parent cc7fd34015
commit b0068ba5d0
1 changed files with 17 additions and 28 deletions

View File

@ -556,7 +556,7 @@ class VisionTransformer(nn.Module):
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
@ -1158,22 +1158,12 @@ def _convert_aimv2(
k = k.replace('preprocessor.pos_embed', 'pos_embed')
k = k.replace('trunk.', '')
k = k.replace('post_trunk_norm.', 'norm.')
# packed ver, FIXME to delete
# if 'mlp.fc1' in k:
# if k in out_dict:
# v = torch.cat([v, out_dict[k]], dim=0)
# elif 'mlp.fc3' in k:
# k = k.replace('mlp.fc3', 'mlp.fc1')
# if k in out_dict:
# v = torch.cat([out_dict[k], v], dim=0)
k = k.replace('mlp.fc1', 'mlp.fc1_g')
k = k.replace('mlp.fc3', 'mlp.fc1_x')
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(
state_dict: Dict[str, torch.Tensor],
model: VisionTransformer,
@ -1688,8 +1678,7 @@ default_cfgs = {
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
#hf_hub_id='timm/',
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
hf_hub_id='timm/',
license='apple-ascl',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
'vit_large_patch14_clip_224.dfn2b': _cfg(
@ -2177,59 +2166,59 @@ default_cfgs = {
),
'aimv2_large_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_224.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-224',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_336.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-336',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
'aimv2_large_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-large-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_huge_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-huge-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_1b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-1b-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
'aimv2_3b_patch14_448.apple_pt': _cfg(
hf_hub_id='apple/aimv2-3b-patch14-448',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),