mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
update pe to reuse timm layers
This commit is contained in:
parent
0f6e29019c
commit
8eafe2c21e
@ -16,7 +16,8 @@ from torch.utils.checkpoint import checkpoint
|
||||
### Import timm layers
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \
|
||||
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType, LayerScale
|
||||
get_act_layer, get_norm_layer, LayerType, LayerScale
|
||||
#from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
@ -302,7 +303,6 @@ class Rope2D:
|
||||
self.grid_size = (grid_h, grid_w)
|
||||
|
||||
self.rope = self.rope.to(device)
|
||||
|
||||
if self.use_cls_token:
|
||||
# +1 to leave space for the cls token to be (0, 0)
|
||||
grid_y_range = torch.arange(grid_h, device=device) + 1
|
||||
@ -310,9 +310,8 @@ class Rope2D:
|
||||
else:
|
||||
grid_y_range = torch.arange(grid_h, device=device)
|
||||
grid_x_range = torch.arange(grid_w, device=device)
|
||||
|
||||
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
||||
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
||||
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
||||
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
||||
|
||||
if self.use_cls_token:
|
||||
@ -581,7 +580,6 @@ class Transformer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
#class VisionTransformer(nn.Module):
|
||||
class PE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user