Merge c9c973b3b5
into c8c4f256b8
commit
8e7c4a5bf1
timm/models
|
@ -107,6 +107,82 @@ class Attention(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class DiffAttention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RmsNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads // 2
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.lambda_init = 0.8
|
||||
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
|
||||
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
|
||||
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
|
||||
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1))
|
||||
|
||||
self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5)
|
||||
|
||||
def _set_lambda_init(self, depth: int):
|
||||
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
q, k, v = self.qkv(x).chunk(3, dim=2)
|
||||
q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
|
||||
k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
|
||||
q1, q2 = q.unbind(2)
|
||||
k1, k2 = k.unbind(2)
|
||||
attn1 = F.scaled_dot_product_attention(q1, k1, v)
|
||||
attn2 = F.scaled_dot_product_attention(q2, k2, v)
|
||||
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
|
||||
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
x = attn1 - lambda_full * attn2
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
|
||||
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
|
||||
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
||||
attn = attn.view(B, self.num_heads, 2, N, N)
|
||||
attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
|
||||
x = attn @ v
|
||||
|
||||
x = self.sub_norm(x)
|
||||
x = x * (1 - self.lambda_init)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue