diff --git a/timm/models/vision_transformer_flex.py b/timm/models/vision_transformer_flex.py index c8c52586..e3a8d8ef 100644 --- a/timm/models/vision_transformer_flex.py +++ b/timm/models/vision_transformer_flex.py @@ -280,41 +280,51 @@ class FlexEmbeds(nn.Module): return x + #@torch.compiler.disable() def _apply_learned_naflex_pos_embed( self, x: torch.Tensor, naflex_grid_sizes: List[Tuple[int, int]], ): - orig_h, orig_w = self.pos_embed.shape[1:3] - - # Determine unique grid sizes - size_to_indices: Dict[Tuple[int, int], List[int]] = {} - for bi, (h, w) in enumerate(naflex_grid_sizes): - #k = h << 16 | w # FIXME can get jit compat with this - k = (h, w) - if not k in size_to_indices: - size_to_indices[k] = [bi] - else: - size_to_indices[k].append(bi) - # Handle each batch element separately with its own grid size + orig_h, orig_w = self.pos_embed.shape[1:3] pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W - for k, batch_indices in size_to_indices.items(): - h, w = k - #h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this - # Interpolate only once for this (h, w) - if (h == orig_h) and (w == orig_w): + + def _interp(_size): + if (_size[0] == orig_h) and (_size[1] == orig_w): pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1) else: pos_embed_flat = F.interpolate( pos_embed_nchw, - size=(h, w), + size=_size, mode=self.pos_embed_interp_mode, align_corners=False, antialias=True, ).flatten(2).transpose(1, 2) - pos_embed_flat = pos_embed_flat.to(dtype=x.dtype) + return pos_embed_flat.to(dtype=x.dtype) + # FIXME leaving alternative code commented here for now for comparisons + # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {} + # for i, s in enumerate(naflex_grid_sizes): + # if s in pos_embed_cache: + # pos_embed_flat = pos_embed_cache[s] + # else: + # pos_embed_flat = _interp(s) + # pos_embed_cache[s] = pos_embed_flat + # + # seq_len = min(x.shape[1], pos_embed_flat.shape[1]) + # x[i, :seq_len] += pos_embed_flat[0, :seq_len] + + # Determine unique grid sizes + size_to_indices: Dict[Tuple[int, int], List[int]] = {} + for bi, k in enumerate(naflex_grid_sizes): + # k = h << 16 | w # FIXME can get jit compat with this + size_to_indices.setdefault(k, []).append(bi) + + for k, batch_indices in size_to_indices.items(): + # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this + # Interpolate only once for this (h, w) + pos_embed_flat = _interp(k) seq_len = min(x.shape[1], pos_embed_flat.shape[1]) x[:, :seq_len].index_add_( 0, @@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: default_cfgs = generate_default_cfgs({ - 'vit_naflex_base_patch16': _cfg(), 'vit_naflex_base_patch16_gap': _cfg(), 'vit_naflex_base_patch16_map': _cfg(), @@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs): return model -@register_model -def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs): - """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. - """ - model_args = dict( - patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, - global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) - model = _create_vision_transformer_flex( - 'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args) - return model - - -@register_model -def vit_naflex_base_patch16(pretrained=False, **kwargs): - """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. - - This model supports: - 1. Variable aspect ratios and resolutions via patch coordinates - 2. Position embedding interpolation for arbitrary grid sizes - 3. Explicit patch coordinates and valid token masking - """ - model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer_flex( - 'vit_naflex_base_patch16', pretrained=pretrained, **model_args) - return model - - @register_model def vit_naflex_base_patch16_gap(pretrained=False, **kwargs): """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, - global_pool='avg', class_token=False, reg_tokens=4, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs) model = _create_vision_transformer_flex( - 'vit_naflex_base_patch16_gap', pretrained=pretrained, **model_args) + 'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs): """ViT-New with NaFlex functionality for variable aspect ratios and resolutions. """ model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='map', **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + global_pool='map', reg_tokens=1) model = _create_vision_transformer_flex( - 'vit_naflex_base_patch16_map', pretrained=pretrained, **model_args) + 'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs): """ model_args = dict( patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, - qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs) + qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True) model = _create_vision_transformer_flex( - 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args) + 'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + global_pool='token', class_token=True, pos_embed_grid_size=(14, 14)) model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model