mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Further pos embed tweaks, rejig model defs for testing
This commit is contained in:
parent
3dc90ed7a7
commit
ee27b73da4
@ -280,41 +280,51 @@ class FlexEmbeds(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
#@torch.compiler.disable()
|
||||
def _apply_learned_naflex_pos_embed(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
naflex_grid_sizes: List[Tuple[int, int]],
|
||||
):
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
|
||||
# Determine unique grid sizes
|
||||
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
||||
for bi, (h, w) in enumerate(naflex_grid_sizes):
|
||||
#k = h << 16 | w # FIXME can get jit compat with this
|
||||
k = (h, w)
|
||||
if not k in size_to_indices:
|
||||
size_to_indices[k] = [bi]
|
||||
else:
|
||||
size_to_indices[k].append(bi)
|
||||
|
||||
# Handle each batch element separately with its own grid size
|
||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
|
||||
for k, batch_indices in size_to_indices.items():
|
||||
h, w = k
|
||||
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||
# Interpolate only once for this (h, w)
|
||||
if (h == orig_h) and (w == orig_w):
|
||||
|
||||
def _interp(_size):
|
||||
if (_size[0] == orig_h) and (_size[1] == orig_w):
|
||||
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||
else:
|
||||
pos_embed_flat = F.interpolate(
|
||||
pos_embed_nchw,
|
||||
size=(h, w),
|
||||
size=_size,
|
||||
mode=self.pos_embed_interp_mode,
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
).flatten(2).transpose(1, 2)
|
||||
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
|
||||
return pos_embed_flat.to(dtype=x.dtype)
|
||||
|
||||
# FIXME leaving alternative code commented here for now for comparisons
|
||||
# pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
||||
# for i, s in enumerate(naflex_grid_sizes):
|
||||
# if s in pos_embed_cache:
|
||||
# pos_embed_flat = pos_embed_cache[s]
|
||||
# else:
|
||||
# pos_embed_flat = _interp(s)
|
||||
# pos_embed_cache[s] = pos_embed_flat
|
||||
#
|
||||
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
|
||||
|
||||
# Determine unique grid sizes
|
||||
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
||||
for bi, k in enumerate(naflex_grid_sizes):
|
||||
# k = h << 16 | w # FIXME can get jit compat with this
|
||||
size_to_indices.setdefault(k, []).append(bi)
|
||||
|
||||
for k, batch_indices in size_to_indices.items():
|
||||
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||
# Interpolate only once for this (h, w)
|
||||
pos_embed_flat = _interp(k)
|
||||
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||
x[:, :seq_len].index_add_(
|
||||
0,
|
||||
@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'vit_naflex_base_patch16': _cfg(),
|
||||
'vit_naflex_base_patch16_gap': _cfg(),
|
||||
'vit_naflex_base_patch16_map': _cfg(),
|
||||
|
||||
@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
|
||||
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_base_patch16(pretrained=False, **kwargs):
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
|
||||
This model supports:
|
||||
1. Variable aspect ratios and resolutions via patch coordinates
|
||||
2. Position embedding interpolation for arbitrary grid sizes
|
||||
3. Explicit patch coordinates and valid token masking
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_base_patch16', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
global_pool='avg', class_token=False, reg_tokens=4, **kwargs)
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
||||
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_base_patch16_gap', pretrained=pretrained, **model_args)
|
||||
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
|
||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='map', **kwargs)
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
||||
global_pool='map', reg_tokens=1)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_base_patch16_map', pretrained=pretrained, **model_args)
|
||||
'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
|
||||
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
|
||||
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
|
||||
model = _create_vision_transformer_flex(
|
||||
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
|
||||
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
||||
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user