update pe to reuse timm layers

This commit is contained in:
berniebear 2025-04-25 06:50:29 +00:00
parent 0f6e29019c
commit 8eafe2c21e

View File

@ -16,7 +16,8 @@ from torch.utils.checkpoint import checkpoint
### Import timm layers
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType, LayerScale
get_act_layer, get_norm_layer, LayerType, LayerScale
#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
@ -302,7 +303,6 @@ class Rope2D:
self.grid_size = (grid_h, grid_w)
self.rope = self.rope.to(device)
if self.use_cls_token:
# +1 to leave space for the cls token to be (0, 0)
grid_y_range = torch.arange(grid_h, device=device) + 1
@ -310,9 +310,8 @@ class Rope2D:
else:
grid_y_range = torch.arange(grid_h, device=device)
grid_x_range = torch.arange(grid_w, device=device)
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
if self.use_cls_token:
@ -581,7 +580,6 @@ class Transformer(nn.Module):
return x
#class VisionTransformer(nn.Module):
class PE(nn.Module):
def __init__(
self,