Merge pull request #2121 from huggingface/cleanup_vit_convert
Improve vit conversions. OpenAI convert pass through main convertpull/2126/head
commit
67b0b3d7c7
|
@ -771,28 +771,20 @@ def resize_pos_embed(
|
|||
antialias: bool = False,
|
||||
) -> torch.Tensor:
|
||||
""" Rescale the grid of position embeddings when loading from state_dict.
|
||||
|
||||
*DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
|
||||
|
||||
Adapted from:
|
||||
https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
*DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
|
||||
"""
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_prefix_tokens:
|
||||
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
||||
ntok_new -= num_prefix_tokens
|
||||
else:
|
||||
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
ntok_new = posemb_new.shape[1] - num_prefix_tokens
|
||||
ntok_old = posemb.shape[1] - num_prefix_tokens
|
||||
gs_old = [int(math.sqrt(ntok_old))] * 2
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
assert len(gs_new) >= 2
|
||||
_logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
||||
return posemb
|
||||
return resample_abs_pos_embed(
|
||||
posemb, gs_new, gs_old,
|
||||
num_prefix_tokens=num_prefix_tokens,
|
||||
interpolation=interpolation,
|
||||
antialias=antialias,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -962,16 +954,6 @@ def _convert_openai_clip(
|
|||
v = v.unsqueeze(0).unsqueeze(1)
|
||||
elif k == 'pos_embed':
|
||||
v = v.unsqueeze(0)
|
||||
if v.shape[1] != model.pos_embed.shape[1]:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
|
||||
else getattr(model, 'num_prefix_tokens', 1)
|
||||
v = resample_abs_pos_embed(
|
||||
v,
|
||||
new_size=model.patch_embed.grid_size,
|
||||
num_prefix_tokens=num_prefix_tokens,
|
||||
verbose=True,
|
||||
)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
@ -1014,19 +996,17 @@ def checkpoint_filter_fn(
|
|||
prefix = ''
|
||||
|
||||
if 'visual.class_embedding' in state_dict:
|
||||
return _convert_openai_clip(state_dict, model)
|
||||
state_dict = _convert_openai_clip(state_dict, model)
|
||||
elif 'module.visual.class_embedding' in state_dict:
|
||||
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
|
||||
|
||||
if "mask_token" in state_dict:
|
||||
state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
|
||||
elif "mask_token" in state_dict:
|
||||
state_dict = _convert_dinov2(state_dict, model)
|
||||
|
||||
if "encoder" in state_dict:
|
||||
elif "encoder" in state_dict:
|
||||
# IJEPA, vit in an 'encoder' submodule
|
||||
state_dict = state_dict['encoder']
|
||||
prefix = 'module.'
|
||||
|
||||
if 'visual.trunk.pos_embed' in state_dict:
|
||||
# convert an OpenCLIP model with timm vision encoder
|
||||
elif 'visual.trunk.pos_embed' in state_dict:
|
||||
# OpenCLIP model with timm vision encoder
|
||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
prefix = 'visual.trunk.'
|
||||
|
||||
|
|
Loading…
Reference in New Issue