Merge pull request #2121 from huggingface/cleanup_vit_convert

Improve vit conversions. OpenAI convert pass through main convert
pull/2126/head
Ross Wightman 2024-03-21 13:13:53 -07:00 committed by GitHub
commit 67b0b3d7c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 38 deletions

View File

@ -771,28 +771,20 @@ def resize_pos_embed(
antialias: bool = False,
) -> torch.Tensor:
""" Rescale the grid of position embeddings when loading from state_dict.
*DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
Adapted from:
https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
*DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
"""
ntok_new = posemb_new.shape[1]
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
ntok_new = posemb_new.shape[1] - num_prefix_tokens
ntok_old = posemb.shape[1] - num_prefix_tokens
gs_old = [int(math.sqrt(ntok_old))] * 2
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).')
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb
return resample_abs_pos_embed(
posemb, gs_new, gs_old,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
@torch.no_grad()
@ -962,16 +954,6 @@ def _convert_openai_clip(
v = v.unsqueeze(0).unsqueeze(1)
elif k == 'pos_embed':
v = v.unsqueeze(0)
if v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
else getattr(model, 'num_prefix_tokens', 1)
v = resample_abs_pos_embed(
v,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
verbose=True,
)
out_dict[k] = v
return out_dict
@ -1014,19 +996,17 @@ def checkpoint_filter_fn(
prefix = ''
if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)
state_dict = _convert_openai_clip(state_dict, model)
elif 'module.visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
if "mask_token" in state_dict:
state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
elif "mask_token" in state_dict:
state_dict = _convert_dinov2(state_dict, model)
if "encoder" in state_dict:
elif "encoder" in state_dict:
# IJEPA, vit in an 'encoder' submodule
state_dict = state_dict['encoder']
prefix = 'module.'
if 'visual.trunk.pos_embed' in state_dict:
# convert an OpenCLIP model with timm vision encoder
elif 'visual.trunk.pos_embed' in state_dict:
# OpenCLIP model with timm vision encoder
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
prefix = 'visual.trunk.'