From b82f9e869cb88cfa64089e64e2a33b3f19b6b8df Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 25 Apr 2025 21:26:34 +0000 Subject: [PATCH] renaming models and update checkpoint to vit-only --- timm/models/pe.py | 352 ++++++++++++++++++++-------------------------- 1 file changed, 150 insertions(+), 202 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index 953a9906..62ebfa7b 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -14,10 +14,27 @@ from torch.amp import autocast from torch.utils.checkpoint import checkpoint ### Import timm layers -from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ - trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType, LayerScale -#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible +from timm.layers import ( + PatchEmbed, + Mlp, + DropPath, + AttentionPoolLatent, + RmsNorm, + PatchDropout, + SwiGLUPacked, + SwiGLU, + trunc_normal_, + lecun_normal_, + resample_patch_embed, + resample_abs_pos_embed, + use_fused_attn, + get_act_layer, + get_norm_layer, + LayerType, + LayerScale, +) + +# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from ._builder import build_model_with_cfg @@ -28,20 +45,22 @@ from ._registry import generate_default_cfgs, register_model, register_model_dep __all__ = ['PE'] -####### PE's Rope ######## - +######## PE's Rope ######## def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") + @autocast("cuda", enabled=False) def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): dtype = t.dtype @@ -73,9 +92,7 @@ class RotaryEmbedding(Module): self, dim, custom_freqs: Optional[Tensor] = None, - freqs_for: Union[ - Literal["lang"], Literal["pixel"], Literal["constant"] - ] = "lang", + freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang", theta=10000, max_freq=10, num_freqs=1, @@ -99,9 +116,7 @@ class RotaryEmbedding(Module): if exists(custom_freqs): freqs = custom_freqs elif freqs_for == "lang": - freqs = 1.0 / ( - theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) - ) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) elif freqs_for == "pixel": freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi elif freqs_for == "constant": @@ -154,9 +169,7 @@ class RotaryEmbedding(Module): self.register_buffer(key, value, persistent=False) def get_seq_pos(self, seq_len, device, dtype, offset=0): - return ( - torch.arange(seq_len, device=device, dtype=dtype) + offset - ) / self.interpolate_factor + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): seq_dim = default(seq_dim, self.default_seq_dim) @@ -184,9 +197,7 @@ class RotaryEmbedding(Module): q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] assert q_len <= k_len - rotated_q = self.rotate_queries_or_keys( - q, seq_dim=seq_dim, offset=k_len - q_len + offset - ) + rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset) rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) rotated_q = rotated_q.type(q.dtype) @@ -222,11 +233,7 @@ class RotaryEmbedding(Module): should_cache = self.cache_if_possible and exists(seq_len) - if ( - should_cache - and exists(self.cached_scales) - and (seq_len + offset) <= self.cached_scales.shape[0] - ): + if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]: return self.cached_scales[offset : (offset + seq_len)] scale = 1.0 @@ -264,17 +271,10 @@ class RotaryEmbedding(Module): @autocast("cuda", enabled=False) def forward(self, t: Tensor, seq_len=None, offset=0): should_cache = ( - self.cache_if_possible - and not self.learned_freq - and exists(seq_len) - and self.freqs_for != "pixel" + self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" ) - if ( - should_cache - and exists(self.cached_freqs) - and (offset + seq_len) <= self.cached_freqs.shape[0] - ): + if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]: return self.cached_freqs[offset : (offset + seq_len)].detach() freqs = self.freqs @@ -289,7 +289,7 @@ class RotaryEmbedding(Module): class Rope2D: - """ Helper class to apply RoPE2D as well as interpolate on the fly. """ + """Helper class to apply RoPE2D as well as interpolate on the fly.""" def __init__(self, dim, use_cls_token=False): self.dim = dim @@ -313,13 +313,11 @@ class Rope2D: grid_y_range = torch.arange(grid_h, device=device) grid_x_range = torch.arange(grid_w, device=device) freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) - freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) + freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) if self.use_cls_token: - freq = torch.cat( - [freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0 - ) + freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0) self.freq = freq[None, ...] @@ -332,8 +330,8 @@ class Rope2D: return q, k -####### PE's Modules ######## +######## PE Modules ######## class AttentionPooling(nn.Module): def __init__( self, @@ -349,14 +347,10 @@ class AttentionPooling(nn.Module): self.embed_dim = embed_dim self.num_heads = num_heads - assert ( - self.embed_dim % num_heads == 0 - ), "embed_dim must be divisible by num_heads" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim)) - self.attn = nn.MultiheadAttention( - self.embed_dim, self.num_heads, batch_first=True - ) + self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True) self.layernorm = norm_layer(embed_dim) self.mlp_width = int(embed_dim * mlp_ratio) @@ -396,9 +390,7 @@ class SelfAttention(nn.Module): self.num_heads = num_heads self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" # To make this compatibile with nn.MultiHeadAttention self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) @@ -418,13 +410,7 @@ class SelfAttention(nn.Module): proj = F.linear(x, self.in_proj_weight, self.in_proj_bias) # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk() - proj = ( - proj.unflatten(-1, (3, embed_dim)) - .unsqueeze(0) - .transpose(0, -2) - .squeeze(-2) - .contiguous() - ) + proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() q, k, v = proj[0], proj[1], proj[2] # Use "q_" so that we don't accidentally quit in pdb :) @@ -435,9 +421,7 @@ class SelfAttention(nn.Module): if self.rope: q, k = self.rope(q, k) - attn = F.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale - ) + attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale) attn = rearrange(attn, "b h s d -> b s (h d)") return F.linear(attn, self.out_proj.weight, self.out_proj.bias) @@ -462,16 +446,8 @@ class ResidualAttentionBlock(nn.Module): else: self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True) - self.ls_1 = ( - LayerScale(d_model, ls_init_value) - if ls_init_value is not None - else nn.Identity() - ) - self.ls_2 = ( - LayerScale(d_model, ls_init_value) - if ls_init_value is not None - else nn.Identity() - ) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() self.ln_1 = norm_layer(d_model) self.ln_2 = norm_layer(d_model) @@ -511,9 +487,7 @@ class ResidualAttentionBlock(nn.Module): x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ): - x = x + self.drop_path1( - self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)) - ) + x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x)))) return x @@ -558,9 +532,9 @@ class Transformer(nn.Module): @torch.jit.ignore def truncate(self, layer_idx: int): - """ Delete layers so the last layer is the given layer index. """ + """Delete layers so the last layer is the given layer index.""" self.layers = ((self.layers + layer_idx) % self.layers) + 1 - self.resblocks = nn.ModuleList(self.resblocks[:self.layers]) + self.resblocks = nn.ModuleList(self.resblocks[: self.layers]) def forward( self, @@ -576,7 +550,7 @@ class Transformer(nn.Module): x = checkpoint(r, x, None, None, attn_mask) else: x = r(x, attn_mask=attn_mask) - + if i == stop_idx: break @@ -604,7 +578,7 @@ class PE(nn.Module): output_dim: Optional[int] = 1280, attn_pooler_heads: int = 8, pool_type: Literal["attn", "tok", "avg", "none"] = "attn", - num_classes: int = 1000, # no use for now + num_classes: int = 1000, # no use for now in_chans: int = 3, ): super().__init__() @@ -666,12 +640,11 @@ class PE(nn.Module): self.init_tensors() - def init_tensors(self): def init_submodule_tensors(module): for name, child in module.named_children(): if hasattr(child, "init_tensors"): - #logger.debug(f"Initializing tensors for submodule: {name}") + # logger.debug(f"Initializing tensors for submodule: {name}") child.init_tensors() init_submodule_tensors(child) @@ -687,19 +660,14 @@ class PE(nn.Module): if self.use_abs_posemb: self.posemb_grid_size = self.image_size // self.patch_size self.positional_embedding = nn.Parameter( - init_scale - * torch.randn( - int(self.use_cls_token) + self.posemb_grid_size**2, self.width - ) + init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width) ) if self.proj_dim is not None: - self.proj = nn.Parameter( - init_scale * torch.randn(self.width, self.proj_dim) - ) + self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim)) def truncate(self, layer_idx: int): - """ Delete layers so the last layer is the given layer index. """ + """Delete layers so the last layer is the given layer index.""" self.transformer.truncate(layer_idx) self.layers = self.transformer.layers @@ -717,13 +685,9 @@ class PE(nn.Module): cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:] pos_embed = ( - pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1) - .permute(0, 3, 1, 2) - .contiguous() - ) - pos_embed = F.interpolate( - pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False + pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous() ) + pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False) pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous() if self.use_cls_token: @@ -743,13 +707,7 @@ class PE(nn.Module): else: raise NotImplementedError - def forward_features( - self, - x: torch.Tensor, - norm: bool = False, - layer_idx: int = -1, - strip_cls_token: bool = False - ): + def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False): batch, _, h, w = x.shape grid_h, grid_w = h // self.patch_size, w // self.patch_size @@ -790,14 +748,8 @@ class PE(nn.Module): def checkpoint_filter_fn( - state_dict: Dict[str, torch.Tensor], - model: PE, - adapt_layer_scale: bool = False, - interpolation: str = 'bicubic', - antialias: bool = True, + state_dict: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - """ convert patch embedding weight from manual patchify + linear proj to conv""" - import re state_dict = state_dict.get('model', state_dict) state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} if any(k.startswith("visual.") for k in state_dict): @@ -805,42 +757,33 @@ def checkpoint_filter_fn( return state_dict +######## PE Config ######## def _cfg(url='', **kwargs): return { 'license': 'apache-2.0', 'num_classes': 0, 'interpolation': 'bilinear', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - **kwargs + **kwargs, } -default_cfgs = generate_default_cfgs({ - 'pe_core_b16_224': _cfg( - hf_hub_id='timm/', - input_size=(3, 224, 224)), - 'pe_core_l14_336': _cfg( - hf_hub_id='timm/', - input_size=(3, 336, 336)), - 'pe_core_G14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), - 'pe_lang_l14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), - 'pe_lang_G14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), - 'pe_spatial_G14_448': _cfg( - hf_hub_id='timm/', - input_size=(3, 448, 448)), -}) + +default_cfgs = generate_default_cfgs( + { + 'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)), + 'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)), + 'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + 'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + 'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + 'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)), + } +) def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: out_indices = kwargs.pop('out_indices', 3) - return build_model_with_cfg( PE, variant, @@ -851,99 +794,104 @@ def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE: **kwargs, ) -@register_model -def pe_core_b16_224(pretrained=False, **kwargs): - model_args = dict( - image_size = 224, - patch_size = 16, - width = 768, - layers = 12, - heads = 12, - mlp_ratio = 4.0, - output_dim = 1024, - use_cls_token = True, - pool_type = 'attn', - ) - return _create_pe('pe_core_b16_224', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def pe_core_l14_336(pretrained=False, **kwargs): +def vit_pe_core_base_patch16_224(pretrained=False, **kwargs): model_args = dict( - image_size = 336, - patch_size = 14, - width = 1024, - layers = 24, - heads = 16, - mlp_ratio = 4.0, - output_dim = 1024, - use_cls_token = True, - pool_type = 'attn', + image_size=224, + patch_size=16, + width=768, + layers=12, + heads=12, + mlp_ratio=4.0, + output_dim=1024, + use_cls_token=True, + pool_type='attn', ) - return _create_pe('pe_core_l14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def pe_core_G14_448(pretrained=False, **kwargs): +def vit_pe_core_large_patch14_336(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1536, - layers = 50, - heads = 16, - mlp_ratio = 8960 / 1536, - output_dim = 1280, - use_cls_token = False, - pool_type = 'attn', + image_size=336, + patch_size=14, + width=1024, + layers=24, + heads=16, + mlp_ratio=4.0, + output_dim=1024, + use_cls_token=True, + pool_type='attn', ) - return _create_pe('pe_core_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_lang_G14_448(pretrained=False, **kwargs): +def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1536, - layers = 47, - heads = 16, - mlp_ratio = 8960 / 1536, - output_dim = None, - use_cls_token = False, - use_ln_post = False, - pool_type = 'none', - ls_init_value = 0.1, + image_size=448, + patch_size=14, + width=1536, + layers=50, + heads=16, + mlp_ratio=8960 / 1536, + output_dim=1280, + use_cls_token=False, + pool_type='attn', ) - return _create_pe('pe_lang_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_lang_l14_448(pretrained=False, **kwargs): +def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1024, - layers = 23, - heads = 16, - mlp_ratio = 4.0, - output_dim = None, - use_cls_token = True, - use_ln_post = False, - pool_type = 'none', - ls_init_value = 0.1, + image_size=448, + patch_size=14, + width=1024, + layers=23, + heads=16, + mlp_ratio=4.0, + output_dim=None, + use_cls_token=True, + use_ln_post=False, + pool_type='none', + ls_init_value=0.1, ) - return _create_pe('pe_lang_l14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + @register_model -def pe_spatial_G14_448(pretrained=False, **kwargs): +def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs): model_args = dict( - image_size = 448, - patch_size = 14, - width = 1536, - layers = 50, - heads = 16, - mlp_ratio = 8960 / 1536, - output_dim = None, - use_cls_token = False, - use_ln_post = False, - pool_type = 'none', - ls_init_value = 0.1, + image_size=448, + patch_size=14, + width=1536, + layers=47, + heads=16, + mlp_ratio=8960 / 1536, + output_dim=None, + use_cls_token=False, + use_ln_post=False, + pool_type='none', + ls_init_value=0.1, ) - return _create_pe('pe_spatial_G14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs): + model_args = dict( + image_size=448, + patch_size=14, + width=1536, + layers=50, + heads=16, + mlp_ratio=8960 / 1536, + output_dim=None, + use_cls_token=False, + use_ln_post=False, + pool_type='none', + ls_init_value=0.1, + ) + return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))