From c559c3911f5de876d1060feb5b6230f45a67cf8d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 Mar 2024 10:00:43 -0700 Subject: [PATCH] Improve vit conversions. OpenAI convert pass through main convert for patch & pos resize. Fix #2120 --- timm/models/vision_transformer.py | 56 ++++++++++--------------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 701fcb84..ce65ee4a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -771,28 +771,20 @@ def resize_pos_embed( antialias: bool = False, ) -> torch.Tensor: """ Rescale the grid of position embeddings when loading from state_dict. - - *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed - - Adapted from: - https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed """ - ntok_new = posemb_new.shape[1] - if num_prefix_tokens: - posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] - ntok_new -= num_prefix_tokens - else: - posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] - gs_old = int(math.sqrt(len(posemb_grid))) + ntok_new = posemb_new.shape[1] - num_prefix_tokens + ntok_old = posemb.shape[1] - num_prefix_tokens + gs_old = [int(math.sqrt(ntok_old))] * 2 if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 - assert len(gs_new) >= 2 - _logger.info(f'Resized position embedding: {posemb.shape} ({[gs_old, gs_old]}) to {posemb_new.shape} ({gs_new}).') - posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode=interpolation, antialias=antialias, align_corners=False) - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) - return posemb + return resample_abs_pos_embed( + posemb, gs_new, gs_old, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) @torch.no_grad() @@ -962,16 +954,6 @@ def _convert_openai_clip( v = v.unsqueeze(0).unsqueeze(1) elif k == 'pos_embed': v = v.unsqueeze(0) - if v.shape[1] != model.pos_embed.shape[1]: - # To resize pos embedding when using model at different size from pretrained weights - num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \ - else getattr(model, 'num_prefix_tokens', 1) - v = resample_abs_pos_embed( - v, - new_size=model.patch_embed.grid_size, - num_prefix_tokens=num_prefix_tokens, - verbose=True, - ) out_dict[k] = v return out_dict @@ -1014,19 +996,17 @@ def checkpoint_filter_fn( prefix = '' if 'visual.class_embedding' in state_dict: - return _convert_openai_clip(state_dict, model) + state_dict = _convert_openai_clip(state_dict, model) elif 'module.visual.class_embedding' in state_dict: - return _convert_openai_clip(state_dict, model, prefix='module.visual.') - - if "mask_token" in state_dict: + state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') + elif "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) - - if "encoder" in state_dict: + elif "encoder" in state_dict: + # IJEPA, vit in an 'encoder' submodule state_dict = state_dict['encoder'] prefix = 'module.' - - if 'visual.trunk.pos_embed' in state_dict: - # convert an OpenCLIP model with timm vision encoder + elif 'visual.trunk.pos_embed' in state_dict: + # OpenCLIP model with timm vision encoder # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) prefix = 'visual.trunk.'