mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Further pos embed tweaks, rejig model defs for testing
This commit is contained in:
parent
3dc90ed7a7
commit
ee27b73da4
@ -280,41 +280,51 @@ class FlexEmbeds(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
#@torch.compiler.disable()
|
||||||
def _apply_learned_naflex_pos_embed(
|
def _apply_learned_naflex_pos_embed(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
naflex_grid_sizes: List[Tuple[int, int]],
|
naflex_grid_sizes: List[Tuple[int, int]],
|
||||||
):
|
):
|
||||||
orig_h, orig_w = self.pos_embed.shape[1:3]
|
|
||||||
|
|
||||||
# Determine unique grid sizes
|
|
||||||
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
|
||||||
for bi, (h, w) in enumerate(naflex_grid_sizes):
|
|
||||||
#k = h << 16 | w # FIXME can get jit compat with this
|
|
||||||
k = (h, w)
|
|
||||||
if not k in size_to_indices:
|
|
||||||
size_to_indices[k] = [bi]
|
|
||||||
else:
|
|
||||||
size_to_indices[k].append(bi)
|
|
||||||
|
|
||||||
# Handle each batch element separately with its own grid size
|
# Handle each batch element separately with its own grid size
|
||||||
|
orig_h, orig_w = self.pos_embed.shape[1:3]
|
||||||
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
|
pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
|
||||||
for k, batch_indices in size_to_indices.items():
|
|
||||||
h, w = k
|
def _interp(_size):
|
||||||
#h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
if (_size[0] == orig_h) and (_size[1] == orig_w):
|
||||||
# Interpolate only once for this (h, w)
|
|
||||||
if (h == orig_h) and (w == orig_w):
|
|
||||||
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
|
||||||
else:
|
else:
|
||||||
pos_embed_flat = F.interpolate(
|
pos_embed_flat = F.interpolate(
|
||||||
pos_embed_nchw,
|
pos_embed_nchw,
|
||||||
size=(h, w),
|
size=_size,
|
||||||
mode=self.pos_embed_interp_mode,
|
mode=self.pos_embed_interp_mode,
|
||||||
align_corners=False,
|
align_corners=False,
|
||||||
antialias=True,
|
antialias=True,
|
||||||
).flatten(2).transpose(1, 2)
|
).flatten(2).transpose(1, 2)
|
||||||
pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
|
return pos_embed_flat.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
# FIXME leaving alternative code commented here for now for comparisons
|
||||||
|
# pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
||||||
|
# for i, s in enumerate(naflex_grid_sizes):
|
||||||
|
# if s in pos_embed_cache:
|
||||||
|
# pos_embed_flat = pos_embed_cache[s]
|
||||||
|
# else:
|
||||||
|
# pos_embed_flat = _interp(s)
|
||||||
|
# pos_embed_cache[s] = pos_embed_flat
|
||||||
|
#
|
||||||
|
# seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||||
|
# x[i, :seq_len] += pos_embed_flat[0, :seq_len]
|
||||||
|
|
||||||
|
# Determine unique grid sizes
|
||||||
|
size_to_indices: Dict[Tuple[int, int], List[int]] = {}
|
||||||
|
for bi, k in enumerate(naflex_grid_sizes):
|
||||||
|
# k = h << 16 | w # FIXME can get jit compat with this
|
||||||
|
size_to_indices.setdefault(k, []).append(bi)
|
||||||
|
|
||||||
|
for k, batch_indices in size_to_indices.items():
|
||||||
|
# h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
|
||||||
|
# Interpolate only once for this (h, w)
|
||||||
|
pos_embed_flat = _interp(k)
|
||||||
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
seq_len = min(x.shape[1], pos_embed_flat.shape[1])
|
||||||
x[:, :seq_len].index_add_(
|
x[:, :seq_len].index_add_(
|
||||||
0,
|
0,
|
||||||
@ -1015,7 +1025,6 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
'vit_naflex_base_patch16': _cfg(),
|
|
||||||
'vit_naflex_base_patch16_gap': _cfg(),
|
'vit_naflex_base_patch16_gap': _cfg(),
|
||||||
'vit_naflex_base_patch16_map': _cfg(),
|
'vit_naflex_base_patch16_map': _cfg(),
|
||||||
|
|
||||||
@ -1050,43 +1059,15 @@ def _create_vision_transformer_flex(variant, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_naflex_mediumd_patch16_reg4_gap(pretrained=False, **kwargs):
|
|
||||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
|
||||||
"""
|
|
||||||
model_args = dict(
|
|
||||||
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
|
|
||||||
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
|
|
||||||
model = _create_vision_transformer_flex(
|
|
||||||
'vit_naflex_mediumd_patch16_reg4_gap', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_naflex_base_patch16(pretrained=False, **kwargs):
|
|
||||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
|
||||||
|
|
||||||
This model supports:
|
|
||||||
1. Variable aspect ratios and resolutions via patch coordinates
|
|
||||||
2. Position embedding interpolation for arbitrary grid sizes
|
|
||||||
3. Explicit patch coordinates and valid token masking
|
|
||||||
"""
|
|
||||||
model_args = dict(
|
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
||||||
model = _create_vision_transformer_flex(
|
|
||||||
'vit_naflex_base_patch16', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
|
def vit_naflex_base_patch16_gap(pretrained=False, **kwargs):
|
||||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
||||||
global_pool='avg', class_token=False, reg_tokens=4, **kwargs)
|
global_pool='avg', class_token=False, reg_tokens=4, fc_norm=True, **kwargs)
|
||||||
model = _create_vision_transformer_flex(
|
model = _create_vision_transformer_flex(
|
||||||
'vit_naflex_base_patch16_gap', pretrained=pretrained, **model_args)
|
'vit_naflex_base_patch16_gap', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1095,9 +1076,10 @@ def vit_naflex_base_patch16_map(pretrained=False, **kwargs):
|
|||||||
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
"""ViT-New with NaFlex functionality for variable aspect ratios and resolutions.
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='map', **kwargs)
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
||||||
|
global_pool='map', reg_tokens=1)
|
||||||
model = _create_vision_transformer_flex(
|
model = _create_vision_transformer_flex(
|
||||||
'vit_naflex_base_patch16_map', pretrained=pretrained, **model_args)
|
'vit_naflex_base_patch16_map', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1112,9 +1094,9 @@ def vit_naflex_so150m2_patch16_reg1_gap(pretrained=False, **kwargs):
|
|||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
|
patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
|
||||||
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True, **kwargs)
|
qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', fc_norm=True)
|
||||||
model = _create_vision_transformer_flex(
|
model = _create_vision_transformer_flex(
|
||||||
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **model_args)
|
'vit_naflex_so150m2_patch16_reg1_gap', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1123,6 +1105,8 @@ def vit_naflex_base_patch16(pretrained: bool = False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||||||
|
global_pool='token', class_token=True, pos_embed_grid_size=(14, 14))
|
||||||
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
model = _create_vision_transformer_flex('vit_naflex_base_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user