Further pos embed tweaks, rejig model defs for testing

This commit is contained in:
Ross Wightman 2025-04-28 09:15:11 -07:00
parent 3dc90ed7a7
commit ee27b73da4

View File

@ -280,41 +280,51 @@ class FlexEmbeds(nn.Module):
return x
#@torch.compiler.disable()
def _apply_learned_naflex_pos_embed(
self,
x: torch.Tensor,
naflex_grid_sizes: List[Tuple[int, int]],
):
orig_h, orig_w = self.pos_embed.shape[1:3]
# Determine unique grid sizes
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, (h, w) in enumerate(naflex_grid_sizes):
#k = h << 16 | w # FIXME can get jit compat with this
k = (h, w)
if not k in size_to_indices:
size_to_indices[k] = [bi]
else:
size_to_indices[k].append(bi)
# Handle each batch element separately with its own grid size
orig_h, orig_w = self.pos_embed.shape[1:3]
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
for k, batch_indices in size_to_indices.items():
h, w = k
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
if (h == orig_h) and (w == orig_w):
def _interp(_size):
if (_size[0] == orig_h) and (_size[1] == orig_w):
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
else:
pos_embed_flat = F.interpolate(
pos_embed_nchw,
size=(h, w),
size=_size,
mode=self.pos_embed_interp_mode,
align_corners=False,
antialias=True,
).flatten(2).transpose(1, 2)
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
return pos_embed_flat.to(dtype=x.dtype)
# FIXME leaving alternative code commented here for now for comparisons
# pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
# for i, s in enumerate(naflex_grid_sizes):
# if s in pos_embed_cache:
# pos_embed_flat = pos_embed_cache[s]
# else:
# pos_embed_flat = _interp(s)
# pos_embed_cache[s] = pos_embed_flat
#
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
# Determine unique grid sizes
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
for bi, k in enumerate(naflex_grid_sizes):
# k = h << 16 | w # FIXME can get jit compat with this
size_to_indices.setdefault(k, []).append(bi)
for k, batch_indices in size_to_indices.items():
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
# Interpolate only once for this (h, w)
pos_embed_flat = _interp(k)
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
x[:, :seq_len].index_add_(
0,
@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
default_cfgs = generate_default_cfgs({
'vit_naflex_base_patch16': _cfg(),
'vit_naflex_base_patch16_gap': _cfg(),
'vit_naflex_base_patch16_map': _cfg(),
@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
return model
@register_model
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
return model
@register_model
def vit_naflex_base_patch16(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
This model supports:
1. Variable aspect ratios and resolutions via patch coordinates
2. Position embedding interpolation for arbitrary grid sizes
3. Explicit patch coordinates and valid token masking
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16', pretrained=pretrained, **model_args)
return model
@register_model
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
global_pool='avg', class_token=False, reg_tokens=4, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_gap', pretrained=pretrained, **model_args)
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='map', **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
global_pool='map', reg_tokens=1)
model = _create_vision_transformer_flex(
'vit_naflex_base_patch16_map', pretrained=pretrained, **model_args)
'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
"""
model_args = dict(
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
model = _create_vision_transformer_flex(
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model