from typing import Final, Optional, Type import torch from torch import nn as nn from torch.nn import functional as F from .config import use_fused_attn from .pos_embed_sincos import apply_rot_embed_cat class Attention(nn.Module): """Standard Multi-head Self Attention module with QKV projection. This module implements the standard multi-head attention mechanism used in transformers. It supports both the fused attention implementation (scaled_dot_product_attention) for efficiency when available, and a manual implementation otherwise. The module includes options for QK normalization, attention dropout, and projection dropout. """ fused_attn: Final[bool] def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, proj_bias: bool = True, attn_drop: float = 0., proj_drop: float = 0., norm_layer: Type[nn.Module] = nn.LayerNorm, ) -> None: """Initialize the Attention module. Args: dim: Input dimension of the token embeddings num_heads: Number of attention heads qkv_bias: Whether to use bias in the query, key, value projections qk_norm: Whether to apply normalization to query and key vectors proj_bias: Whether to use bias in the output projection attn_drop: Dropout rate applied to the attention weights proj_drop: Dropout rate applied after the output projection norm_layer: Normalization layer constructor for QK normalization if enabled """ super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward( self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) if attn_mask is not None: attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class AttentionRope(nn.Module): """ A Self Attention module with ROPE support. Includes options for: * QK normalization option * Attention output (scale) normalization * Fused or unfused QKV projection support """ fused_attn: torch.jit.Final[bool] def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, qkv_fused: bool = True, num_prefix_tokens: int = 1, attn_drop: float = 0., proj_drop: float = 0., attn_head_dim: Optional[int] = None, norm_layer: Type[nn.Module] = None, qk_norm: bool = False, scale_norm: bool = False, ): """Initialize the Attention module. Args: dim: Input dimension of the token embeddings num_heads: Number of attention heads qkv_bias: Whether to add a bias term to the query, key, and value projections num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that should not have position embeddings applied attn_drop: Dropout rate for attention weights proj_drop: Dropout rate for the output projection attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads) norm_layer: Normalization layer constructor to use for QK and scale normalization qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer scale_norm: Enable normalization (scaling) of attention output with norm_layer """ super().__init__() if scale_norm or qk_norm: assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True' self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim attn_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() if qkv_fused: self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) self.q_proj = self.k_proj = self.v_proj = None else: self.qkv = None self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity() self.proj = nn.Linear(attn_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward( self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): """Forward pass for the attention module. Args: x: Input tensor of shape (batch_size, sequence_length, embedding_dim) rope: Rotary position embeddings tensor for position-aware attention attn_mask: Optional attention mask to apply during attention computation Returns: Tensor of shape (batch_size, sequence_length, embedding_dim) """ B, N, C = x.shape if self.qkv is not None: qkv = self.qkv(x) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim else: q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) q, k = self.q_norm(q), self.k_norm(k) if rope is not None: npt = self.num_prefix_tokens q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale attn = (q @ k.transpose(-2, -1)) if attn_mask is not None: attn_mask = attn_mask.to(torch.bool) attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.norm(x) x = self.proj(x) x = self.proj_drop(x) return x