Switch hf hub entries for new aimv2 / dfn weights to point to timm locations. Undo forced device for SDR linspace, part of another change.
parent
cc7fd34015
commit
b0068ba5d0
|
@ -556,7 +556,7 @@ class VisionTransformer(nn.Module):
|
|||
self.patch_drop = nn.Identity()
|
||||
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth, device='cpu')] # stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
|
@ -1158,22 +1158,12 @@ def _convert_aimv2(
|
|||
k = k.replace('preprocessor.pos_embed', 'pos_embed')
|
||||
k = k.replace('trunk.', '')
|
||||
k = k.replace('post_trunk_norm.', 'norm.')
|
||||
|
||||
# packed ver, FIXME to delete
|
||||
# if 'mlp.fc1' in k:
|
||||
# if k in out_dict:
|
||||
# v = torch.cat([v, out_dict[k]], dim=0)
|
||||
# elif 'mlp.fc3' in k:
|
||||
# k = k.replace('mlp.fc3', 'mlp.fc1')
|
||||
# if k in out_dict:
|
||||
# v = torch.cat([out_dict[k], v], dim=0)
|
||||
k = k.replace('mlp.fc1', 'mlp.fc1_g')
|
||||
k = k.replace('mlp.fc3', 'mlp.fc1_x')
|
||||
|
||||
out_dict[k] = v
|
||||
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
model: VisionTransformer,
|
||||
|
@ -1688,8 +1678,7 @@ default_cfgs = {
|
|||
license='apple-ascl',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
|
||||
'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
hf_hub_id='apple/DFN2B-CLIP-ViT-L-14-39B', hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
hf_hub_id='timm/',
|
||||
license='apple-ascl',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_large_patch14_clip_224.dfn2b': _cfg(
|
||||
|
@ -2177,59 +2166,59 @@ default_cfgs = {
|
|||
),
|
||||
|
||||
'aimv2_large_patch14_224.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-large-patch14-224',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
crop_pct=1.0, num_classes=0),
|
||||
'aimv2_large_patch14_224.apple_pt_dist': _cfg(
|
||||
hf_hub_id='apple/aimv2-large-patch14-224-distilled',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
crop_pct=1.0, num_classes=0),
|
||||
'aimv2_huge_patch14_224.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-huge-patch14-224',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
crop_pct=1.0, num_classes=0),
|
||||
'aimv2_1b_patch14_224.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-1b-patch14-224',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
crop_pct=1.0, num_classes=0),
|
||||
'aimv2_3b_patch14_224.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-3b-patch14-224',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
crop_pct=1.0, num_classes=0),
|
||||
'aimv2_large_patch14_336.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-large-patch14-336',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_large_patch14_336.apple_pt_dist': _cfg(
|
||||
hf_hub_id='apple/aimv2-large-patch14-336-distilled',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_huge_patch14_336.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-huge-patch14-336',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_1b_patch14_336.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-1b-patch14-336',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_3b_patch14_336.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-3b-patch14-336',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_large_patch14_448.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-large-patch14-448',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_huge_patch14_448.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-huge-patch14-448',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_1b_patch14_448.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-1b-patch14-448',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
|
||||
'aimv2_3b_patch14_448.apple_pt': _cfg(
|
||||
hf_hub_id='apple/aimv2-3b-patch14-448',
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
|
||||
input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
|
||||
|
||||
|
|
Loading…
Reference in New Issue