111 lines
3.6 KiB
Python
111 lines
3.6 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .config import use_fused_attn
|
|
from .mlp import Mlp
|
|
from .weight_init import trunc_normal_tf_
|
|
|
|
|
|
class AttentionPoolLatent(nn.Module):
|
|
""" Attention pooling w/ latent query
|
|
"""
|
|
fused_attn: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int = None,
|
|
embed_dim: int = None,
|
|
num_heads: int = 8,
|
|
feat_size: Optional[int] = None,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
qk_norm: bool = False,
|
|
latent_len: int = 1,
|
|
latent_dim: int = None,
|
|
pos_embed: str = '',
|
|
pool_type: str = 'token',
|
|
norm_layer: Optional[nn.Module] = None,
|
|
act_layer: Optional[nn.Module] = nn.GELU,
|
|
drop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
embed_dim = embed_dim or in_features
|
|
out_features = out_features or in_features
|
|
assert embed_dim % num_heads == 0
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.feat_size = feat_size
|
|
self.scale = self.head_dim ** -0.5
|
|
self.pool = pool_type
|
|
self.fused_attn = use_fused_attn()
|
|
|
|
if pos_embed == 'abs':
|
|
assert feat_size is not None
|
|
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
|
|
else:
|
|
self.pos_embed = None
|
|
|
|
self.latent_dim = latent_dim or embed_dim
|
|
self.latent_len = latent_len
|
|
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
|
|
|
|
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
|
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
|
if qk_norm:
|
|
qk_norm_layer = norm_layer or nn.LayerNorm
|
|
self.q_norm = qk_norm_layer(self.head_dim)
|
|
self.k_norm = qk_norm_layer(self.head_dim)
|
|
else:
|
|
self.q_norm = nn.Identity()
|
|
self.k_norm = nn.Identity()
|
|
self.proj = nn.Linear(embed_dim, embed_dim)
|
|
self.proj_drop = nn.Dropout(drop)
|
|
|
|
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
|
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer)
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
if self.pos_embed is not None:
|
|
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
|
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
|
|
if self.pos_embed is not None:
|
|
# FIXME interpolate
|
|
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
|
|
|
q_latent = self.latent.expand(B, -1, -1)
|
|
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
k, v = kv.unbind(0)
|
|
|
|
q, k = self.q_norm(q), self.k_norm(k)
|
|
|
|
if self.fused_attn:
|
|
x = F.scaled_dot_product_attention(q, k, v)
|
|
else:
|
|
q = q * self.scale
|
|
attn = q @ k.transpose(-2, -1)
|
|
attn = attn.softmax(dim=-1)
|
|
x = attn @ v
|
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
|
|
x = x + self.mlp(self.norm(x))
|
|
|
|
# optional pool if latent seq_len > 1 and pooled output is desired
|
|
if self.pool == 'token':
|
|
x = x[:, 0]
|
|
elif self.pool == 'avg':
|
|
x = x.mean(1)
|
|
return x |