mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
update pe to reuse timm layers
This commit is contained in:
parent
0f6e29019c
commit
8eafe2c21e
@ -17,6 +17,7 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
||||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||||
get_act_layer, get_norm_layer, LayerType, LayerScale
|
get_act_layer, get_norm_layer, LayerType, LayerScale
|
||||||
|
#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -302,7 +303,6 @@ class Rope2D:
|
|||||||
self.grid_size = (grid_h, grid_w)
|
self.grid_size = (grid_h, grid_w)
|
||||||
|
|
||||||
self.rope = self.rope.to(device)
|
self.rope = self.rope.to(device)
|
||||||
|
|
||||||
if self.use_cls_token:
|
if self.use_cls_token:
|
||||||
# +1 to leave space for the cls token to be (0, 0)
|
# +1 to leave space for the cls token to be (0, 0)
|
||||||
grid_y_range = torch.arange(grid_h, device=device) + 1
|
grid_y_range = torch.arange(grid_h, device=device) + 1
|
||||||
@ -310,7 +310,6 @@ class Rope2D:
|
|||||||
else:
|
else:
|
||||||
grid_y_range = torch.arange(grid_h, device=device)
|
grid_y_range = torch.arange(grid_h, device=device)
|
||||||
grid_x_range = torch.arange(grid_w, device=device)
|
grid_x_range = torch.arange(grid_w, device=device)
|
||||||
|
|
||||||
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
||||||
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
||||||
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
||||||
@ -581,7 +580,6 @@ class Transformer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
#class VisionTransformer(nn.Module):
|
|
||||||
class PE(nn.Module):
|
class PE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user