mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add option to include relative pos embedding in the attention scaling as per references. See discussion #912
This commit is contained in:
parent
2c33ca6d8c
commit
02daf2ab94
@ -61,9 +61,8 @@ class PosEmbedRel(nn.Module):
|
||||
super().__init__()
|
||||
self.height, self.width = to_2tuple(feat_size)
|
||||
self.dim_head = dim_head
|
||||
self.scale = scale
|
||||
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
|
||||
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
|
||||
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale)
|
||||
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale)
|
||||
|
||||
def forward(self, q):
|
||||
B, HW, _ = q.shape
|
||||
@ -101,10 +100,11 @@ class BottleneckAttn(nn.Module):
|
||||
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
|
||||
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
||||
qkv_bias (bool): add bias to q, k, and v projections
|
||||
scale_pos_embed (bool): scale the position embedding as well as Q @ K
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
|
||||
qk_ratio=1.0, qkv_bias=False):
|
||||
qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False):
|
||||
super().__init__()
|
||||
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
|
||||
dim_out = dim_out or dim
|
||||
@ -115,6 +115,7 @@ class BottleneckAttn(nn.Module):
|
||||
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||
self.dim_out_v = num_heads * self.dim_head_v
|
||||
self.scale = self.dim_head_qk ** -0.5
|
||||
self.scale_pos_embed = scale_pos_embed
|
||||
|
||||
self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
||||
|
||||
@ -144,8 +145,10 @@ class BottleneckAttn(nn.Module):
|
||||
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
|
||||
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
|
||||
|
||||
attn = (q @ k) * self.scale
|
||||
attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W
|
||||
if self.scale_pos_embed:
|
||||
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
|
||||
else:
|
||||
attn = (q @ k) * self.scale + self.pos_embed(q)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
||||
|
@ -74,9 +74,8 @@ class PosEmbedRel(nn.Module):
|
||||
super().__init__()
|
||||
self.block_size = block_size
|
||||
self.dim_head = dim_head
|
||||
self.scale = scale
|
||||
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
||||
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
||||
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
|
||||
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
|
||||
|
||||
def forward(self, q):
|
||||
B, BB, HW, _ = q.shape
|
||||
@ -120,11 +119,11 @@ class HaloAttn(nn.Module):
|
||||
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
||||
qkv_bias (bool) : add bias to q, k, and v projections
|
||||
avg_down (bool): use average pool downsample instead of strided query blocks
|
||||
|
||||
scale_pos_embed (bool): scale the position embedding as well as Q @ K
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
|
||||
qk_ratio=1.0, qkv_bias=False, avg_down=False):
|
||||
qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
assert dim_out % num_heads == 0
|
||||
@ -135,6 +134,7 @@ class HaloAttn(nn.Module):
|
||||
self.dim_out_qk = num_heads * self.dim_head_qk
|
||||
self.dim_out_v = num_heads * self.dim_head_v
|
||||
self.scale = self.dim_head_qk ** -0.5
|
||||
self.scale_pos_embed = scale_pos_embed
|
||||
self.block_size = self.block_size_ds = block_size
|
||||
self.halo_size = halo_size
|
||||
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
||||
@ -190,8 +190,11 @@ class HaloAttn(nn.Module):
|
||||
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
|
||||
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
|
||||
|
||||
attn = (q @ k.transpose(-1, -2)) * self.scale
|
||||
attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
|
||||
if self.scale_pos_embed:
|
||||
attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
|
||||
else:
|
||||
attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
|
||||
# B * num_heads, num_blocks, block_size ** 2, win_size ** 2
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
|
||||
|
Loading…
x
Reference in New Issue
Block a user