From 8eafe2c21e700646ffb04c566c26854f82a331ad Mon Sep 17 00:00:00 2001 From: berniebear Date: Fri, 25 Apr 2025 06:50:29 +0000 Subject: [PATCH] update pe to reuse timm layers --- timm/models/pe.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/timm/models/pe.py b/timm/models/pe.py index a102f83e..147cecee 100644 --- a/timm/models/pe.py +++ b/timm/models/pe.py @@ -16,7 +16,8 @@ from torch.utils.checkpoint import checkpoint ### Import timm layers from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ - get_act_layer, get_norm_layer, LayerType, LayerScale + get_act_layer, get_norm_layer, LayerType, LayerScale +#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -302,7 +303,6 @@ class Rope2D: self.grid_size = (grid_h, grid_w) self.rope = self.rope.to(device) - if self.use_cls_token: # +1 to leave space for the cls token to be (0, 0) grid_y_range = torch.arange(grid_h, device=device) + 1 @@ -310,9 +310,8 @@ class Rope2D: else: grid_y_range = torch.arange(grid_h, device=device) grid_x_range = torch.arange(grid_w, device=device) - freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1) - freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) + freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1) freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1) if self.use_cls_token: @@ -581,7 +580,6 @@ class Transformer(nn.Module): return x -#class VisionTransformer(nn.Module): class PE(nn.Module): def __init__( self,