Merge pull request #2121 from huggingface/cleanup_vit_convert
Improve vit conversions. OpenAI convert pass through main convertpull/2126/head
commit
67b0b3d7c7
|
@ -771,28 +771,20 @@ def resize_pos_embed(
|
||||||
antialias: bool = False,
|
antialias: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
""" Rescale the grid of position embeddings when loading from state_dict.
|
""" Rescale the grid of position embeddings when loading from state_dict.
|
||||||
|
*DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
|
||||||
*DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
|
|
||||||
|
|
||||||
Adapted from:
|
|
||||||
https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
||||||
"""
|
"""
|
||||||
ntok_new = posemb_new.shape[1]
|
ntok_new = posemb_new.shape[1] - num_prefix_tokens
|
||||||
if num_prefix_tokens:
|
ntok_old = posemb.shape[1] - num_prefix_tokens
|
||||||
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
gs_old = [int(math.sqrt(ntok_old))] * 2
|
||||||
ntok_new -= num_prefix_tokens
|
|
||||||
else:
|
|
||||||
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
|
||||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
||||||
if not len(gs_new): # backwards compatibility
|
if not len(gs_new): # backwards compatibility
|
||||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||||
assert len(gs_new) >= 2
|
return resample_abs_pos_embed(
|
||||||
_logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
|
posemb, gs_new, gs_old,
|
||||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
num_prefix_tokens=num_prefix_tokens,
|
||||||
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
|
interpolation=interpolation,
|
||||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
antialias=antialias,
|
||||||
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
verbose=True,
|
||||||
return posemb
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -962,16 +954,6 @@ def _convert_openai_clip(
|
||||||
v = v.unsqueeze(0).unsqueeze(1)
|
v = v.unsqueeze(0).unsqueeze(1)
|
||||||
elif k == 'pos_embed':
|
elif k == 'pos_embed':
|
||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
if v.shape[1] != model.pos_embed.shape[1]:
|
|
||||||
# To resize pos embedding when using model at different size from pretrained weights
|
|
||||||
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
|
|
||||||
else getattr(model, 'num_prefix_tokens', 1)
|
|
||||||
v = resample_abs_pos_embed(
|
|
||||||
v,
|
|
||||||
new_size=model.patch_embed.grid_size,
|
|
||||||
num_prefix_tokens=num_prefix_tokens,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
@ -1014,19 +996,17 @@ def checkpoint_filter_fn(
|
||||||
prefix = ''
|
prefix = ''
|
||||||
|
|
||||||
if 'visual.class_embedding' in state_dict:
|
if 'visual.class_embedding' in state_dict:
|
||||||
return _convert_openai_clip(state_dict, model)
|
state_dict = _convert_openai_clip(state_dict, model)
|
||||||
elif 'module.visual.class_embedding' in state_dict:
|
elif 'module.visual.class_embedding' in state_dict:
|
||||||
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
|
state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
|
||||||
|
elif "mask_token" in state_dict:
|
||||||
if "mask_token" in state_dict:
|
|
||||||
state_dict = _convert_dinov2(state_dict, model)
|
state_dict = _convert_dinov2(state_dict, model)
|
||||||
|
elif "encoder" in state_dict:
|
||||||
if "encoder" in state_dict:
|
# IJEPA, vit in an 'encoder' submodule
|
||||||
state_dict = state_dict['encoder']
|
state_dict = state_dict['encoder']
|
||||||
prefix = 'module.'
|
prefix = 'module.'
|
||||||
|
elif 'visual.trunk.pos_embed' in state_dict:
|
||||||
if 'visual.trunk.pos_embed' in state_dict:
|
# OpenCLIP model with timm vision encoder
|
||||||
# convert an OpenCLIP model with timm vision encoder
|
|
||||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||||
prefix = 'visual.trunk.'
|
prefix = 'visual.trunk.'
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue