add more pooling mechanisms
parent
cd8d9d9ff3
commit
82f75f4cbb
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GeneralizedMP(nn.Module):
|
||||
"""
|
||||
Implements Generalized Max Pooling (GMP), a global pooling operation that
|
||||
generalizes the concept of max pooling to capture more complex and discriminative
|
||||
features from the input tensor.
|
||||
|
||||
The class operates by computing a linear kernel based on the input tensor,
|
||||
then solving a linear system to obtain the pooling coefficients. These coefficients
|
||||
are used to weigh and aggregate the input features, resulting in a pooled feature vector.
|
||||
|
||||
Parameters:
|
||||
lamb (float, optional): A regularization parameter used in the linear system
|
||||
to ensure numerical stability. Default value is 1e3.
|
||||
|
||||
Note:
|
||||
- The input tensor is expected to be in the format (B, D, H, W), where B is batch size,
|
||||
D is depth (channels), H is height, and W is width.
|
||||
- The implementation uses PyTorch's linear algebra functions to solve the linear system.
|
||||
"""
|
||||
def __init__(self, lamb = 1e3):
|
||||
super().__init__()
|
||||
self.lamb = nn.Parameter(lamb * torch.ones(1))
|
||||
#self.inv_lamb = nn.Parameter((1./lamb) * torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
B, D, H, W = x.shape
|
||||
N = H * W
|
||||
identity = torch.eye(N).cuda()
|
||||
# reshape x, s.t. we can use the gmp formulation as a global pooling operation
|
||||
x = x.view(B, D, N)
|
||||
x = x.permute(0, 2, 1)
|
||||
# compute the linear kernel
|
||||
K = torch.bmm(x, x.permute(0, 2, 1))
|
||||
# solve the linear system (K + lambda * I) * alpha = ones
|
||||
A = K + self.lamb * identity
|
||||
o = torch.ones(B, N, 1).cuda()
|
||||
#alphas, _ = torch.gesv(o, A) # tested using pytorch 1.0.1
|
||||
alphas = torch.linalg.solve(A,o) # TODO check it again
|
||||
alphas = alphas.view(B, 1, -1)
|
||||
xi = torch.bmm(alphas, x)
|
||||
xi = xi.view(B, -1)
|
||||
return xi
|
|
@ -0,0 +1,97 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class HOWPooling(nn.Module):
|
||||
"""
|
||||
Implements HOW, as described in the paper
|
||||
'Learning and Aggregating Deep Local Descriptors for Instance-Level Recognition'.
|
||||
This pooling method focuses on aggregating deep local descriptors
|
||||
for enhanced instance-level recognition.
|
||||
|
||||
The class includes functions for L2-based attention, smoothing average pooling,
|
||||
L2 normalization (l2n), and a forward method that integrates these components.
|
||||
It applies dimensionality reduction to the input features before the pooling operation.
|
||||
|
||||
Parameters:
|
||||
input_dim (int): Dimension of the input features.
|
||||
dim_reduction (int): Target dimension after reduction.
|
||||
kernel_size (int): Size of the kernel used in smoothing average pooling.
|
||||
"""
|
||||
def __init__(self, input_dim = 512, dim_reduction = 128, kernel_size = 3):
|
||||
super(HOWPooling, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.dimreduction = ConvDimReduction(input_dim, dim_reduction)
|
||||
|
||||
def L2Attention(self, x):
|
||||
return (x.pow(2.0).sum(1) + 1e-10).sqrt().squeeze(0)
|
||||
|
||||
def smoothing_avg_pooling(self, feats):
|
||||
"""Smoothing average pooling
|
||||
:param torch.Tensor feats: Feature map
|
||||
:param int kernel_size: kernel size of pooling
|
||||
:return torch.Tensor: Smoothend feature map
|
||||
"""
|
||||
pad = self.kernel_size // 2
|
||||
return F.avg_pool2d(feats, (self.kernel_size, self.kernel_size), stride=1, padding=pad,
|
||||
count_include_pad=False)
|
||||
|
||||
def l2n(self, x, eps=1e-6):
|
||||
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
weights = self.L2Attention(x)
|
||||
x = self.smoothing_avg_pooling(x)
|
||||
x = self.dimreduction(x)
|
||||
x = (x * weights.unsqueeze(1)).sum((-2, -1))
|
||||
return self.l2n(x)
|
||||
|
||||
class ConvDimReduction(nn.Conv2d):
|
||||
"""
|
||||
Implements dimensionality reduction using a convolutional layer. This layer is
|
||||
designed for reducing the dimensions of input features, particularly for use in
|
||||
aggregation and pooling operations like in the HOWPooling class.
|
||||
|
||||
The class also includes methods for learning and applying PCA whitening with shrinkage,
|
||||
which is a technique to reduce dimensionality while preserving important feature variations.
|
||||
|
||||
Parameters:
|
||||
input_dim (int): The input dimension (number of channels) of the network.
|
||||
dim (int): The target output dimension for the whitening process.
|
||||
"""
|
||||
def __init__(self, input_dim, dim):
|
||||
super().__init__(input_dim, dim, (1, 1), padding=0, bias=True)
|
||||
|
||||
def pcawhitenlearn_shrinkage(X, s=1.0):
|
||||
"""Learn PCA whitening with shrinkage from given descriptors"""
|
||||
N = X.shape[0]
|
||||
|
||||
# Learning PCA w/o annotations
|
||||
m = X.mean(axis=0, keepdims=True)
|
||||
Xc = X - m
|
||||
Xcov = np.dot(Xc.T, Xc)
|
||||
Xcov = (Xcov + Xcov.T) / (2*N)
|
||||
eigval, eigvec = np.linalg.eig(Xcov)
|
||||
order = eigval.argsort()[::-1]
|
||||
eigval = eigval[order]
|
||||
eigvec = eigvec[:, order]
|
||||
|
||||
eigval = np.clip(eigval, a_min=1e-14, a_max=None)
|
||||
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5*s))), eigvec.T)
|
||||
|
||||
return m, P.T
|
||||
|
||||
def initialize_pca_whitening(self, des):
|
||||
"""Initialize PCA whitening from given descriptors. Return tuple of shift and projection."""
|
||||
m, P = self.pcawhitenlearn_shrinkage(des)
|
||||
m, P = m.T, P.T
|
||||
|
||||
projection = torch.Tensor(P[:self.weight.shape[0], :]).unsqueeze(-1).unsqueeze(-1)
|
||||
self.weight.data = projection.to(self.weight.device)
|
||||
|
||||
projected_shift = -torch.mm(torch.FloatTensor(P), torch.FloatTensor(m)).squeeze()
|
||||
self.bias.data = projected_shift[:self.weight.shape[0]].to(self.bias.device)
|
||||
return m.T, P.T
|
|
@ -0,0 +1,36 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LSEPool(nn.Module):
|
||||
"""
|
||||
Implements LogSumExp (LSE) pooling, an advanced pooling technique that provides
|
||||
a smooth approximation to the max pooling operation. This pooling method is useful
|
||||
for capturing the global distribution of features across spatial dimensions (height and width)
|
||||
of the input tensor, while still maintaining differentiability.
|
||||
|
||||
The class supports learnable pooling behavior with an optional learnable parameter 'r'.
|
||||
When 'r' is large, LSE pooling closely approximates max pooling, and when 'r' is small,
|
||||
it behaves more like average pooling. The 'r' parameter can either be a fixed value or
|
||||
learned during training.
|
||||
|
||||
Parameters:
|
||||
r (float, optional): The initial value of the pooling parameter. Default is 10.
|
||||
learnable (bool, optional): If True, 'r' is a learnable parameter. Default is True.
|
||||
"""
|
||||
|
||||
def __init__(self, r=10, learnable=True):
|
||||
super(LSEPool, self).__init__()
|
||||
if learnable:
|
||||
self.r = nn.Parameter(torch.ones(1) * r)
|
||||
else:
|
||||
self.r = r
|
||||
|
||||
def forward(self, x):
|
||||
s = (x.size(2) * x.size(3))
|
||||
x_max = F.adaptive_max_pool2d(x, 1)
|
||||
exp = torch.exp(self.r * (x - x_max))
|
||||
sumexp = 1 / s * torch.sum(exp, dim=(2, 3))
|
||||
sumexp = sumexp.view(sumexp.size(0), -1, 1, 1)
|
||||
logsumexp = x_max + 1 / self.r * torch.log(sumexp)
|
||||
return logsumexp
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class SimPool(nn.Module):
|
||||
"""
|
||||
Implements SimPool as described in the ICCV 2023 paper
|
||||
"Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?".
|
||||
This class is designed to provide an efficient and effective pooling strategy
|
||||
for both Transformer and CNN architectures.
|
||||
|
||||
SimPool applies a global average pooling (GAP) operation as an initial step
|
||||
and then utilizes a simple but powerful attention mechanism to refine the pooled features.
|
||||
The attention mechanism uses linear transformations for queries and keys, followed by
|
||||
softmax normalization to compute attention scores.
|
||||
|
||||
Parameters:
|
||||
dim (int): Dimension of the input features.
|
||||
num_heads (int, optional): Number of attention heads. Default is 1.
|
||||
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, value projections. Default is False.
|
||||
qk_scale (float, optional): Scaling factor for query-key dot product. Default is None, which uses the inverse square root of head dimensions.
|
||||
gamma (float, optional): Scaling parameter for value vectors, used if not None. Default is None.
|
||||
use_beta (bool, optional): If True, adds a learnable translation to the value vectors after applying gamma. Default is False.
|
||||
"""
|
||||
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.norm_patches = nn.LayerNorm(dim, eps=1e-6)
|
||||
|
||||
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
||||
if gamma is not None:
|
||||
self.gamma = torch.tensor([gamma], device='cuda')
|
||||
if use_beta:
|
||||
self.beta = nn.Parameter(torch.tensor([0.0], device='cuda'))
|
||||
self.eps = torch.tensor([1e-6], device='cuda')
|
||||
|
||||
self.gamma = gamma
|
||||
self.use_beta = use_beta
|
||||
|
||||
def prepare_input(self, x):
|
||||
if len(x.shape) == 3: # Transformer
|
||||
# Input tensor dimensions:
|
||||
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
|
||||
B, N, d = x.shape
|
||||
gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
|
||||
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
|
||||
return gap_cls, x
|
||||
if len(x.shape) == 4: # CNN
|
||||
# Input tensor dimensions:
|
||||
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
|
||||
B, d, H, W = x.shape
|
||||
gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
|
||||
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
|
||||
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
|
||||
return gap_cls, x
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
|
||||
|
||||
def forward(self, x):
|
||||
# Prepare input tensor and perform GAP as initialization
|
||||
gap_cls, x = self.prepare_input(x)
|
||||
|
||||
# Prepare queries (q), keys (k), and values (v)
|
||||
q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)
|
||||
|
||||
# Extract dimensions after normalization
|
||||
Bq, Nq, dq = q.shape
|
||||
Bk, Nk, dk = k.shape
|
||||
Bv, Nv, dv = v.shape
|
||||
|
||||
# Check dimension consistency across batches and channels
|
||||
assert Bq == Bk == Bv
|
||||
assert dq == dk == dv
|
||||
|
||||
# Apply linear transformation for queries and keys then reshape
|
||||
qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1, 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
|
||||
kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1, 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)
|
||||
|
||||
vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1, 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)
|
||||
|
||||
# Compute attention scores
|
||||
attn = (qq @ kk.transpose(-2, -1)) * self.scale
|
||||
# Apply softmax for normalization
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
# If gamma scaling is used
|
||||
if self.gamma is not None:
|
||||
# Apply gamma scaling on values and compute the weighted sum using attention scores
|
||||
x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma), 1/self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
|
||||
# If use_beta, add a learnable translation
|
||||
if self.use_beta:
|
||||
x = x + self.beta
|
||||
else:
|
||||
# Compute the weighted sum using attention scores
|
||||
x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)
|
||||
|
||||
return x.squeeze()
|
|
@ -0,0 +1,119 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
def prepare_input(self, x):
|
||||
"""
|
||||
Prepares the input tensor for different neural network architectures (Transformers and CNNs).
|
||||
|
||||
This function adjusts the shape of the input tensor based on its dimensionality.
|
||||
It supports input tensors for Transformers (3D) and CNNs (4D), ensuring they are
|
||||
correctly formatted for these architectures.
|
||||
|
||||
For a Transformer, it expects a tensor of shape (B, N, d), where B is the batch size,
|
||||
N are patch tokens, and d is the depth (channels). The tensor is returned as is.
|
||||
|
||||
For a CNN, it expects a tensor of shape (B, d, H, W), where B is the batch size,
|
||||
d is the depth (channels), H is the height, and W is the width. The tensor is reshaped
|
||||
and permuted to the shape (B, H*W, d) to match CNN input requirements.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): The input tensor to be preprocessed.
|
||||
"""
|
||||
if len(x.shape) == 3: # Transformer
|
||||
# Input tensor dimensions:
|
||||
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
|
||||
B, N, d = x.shape
|
||||
return x
|
||||
if len(x.shape) == 4: # CNN
|
||||
# Input tensor dimensions:
|
||||
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
|
||||
B, d, H, W = x.shape
|
||||
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
|
||||
|
||||
class SlotPooling(nn.Module):
|
||||
"""
|
||||
This class implements the Slot Attention module as described in the paper
|
||||
"Object-Centric Learning with Slot Attention".
|
||||
|
||||
The module is designed for object-centric learning tasks and utilizes the concept of
|
||||
'slots' to represent distinct object features within an input.
|
||||
It iteratively refines these slots through a pooling mechanism to capture
|
||||
complex object representations.
|
||||
|
||||
Parameters:
|
||||
num_slots (int): Number of slots to be used.
|
||||
dim (int): Dimensionality of the input features.
|
||||
iters (int, optional): Number of iterations for slot refinement. Default is 3.
|
||||
eps (float, optional): A small epsilon value to avoid division by zero. Default is 1e-8.
|
||||
hidden_dim (int, optional): Dimensionality of the hidden layer within the module. Default is 128.
|
||||
"""
|
||||
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
|
||||
super().__init__()
|
||||
self.num_slots = num_slots
|
||||
self.iters = iters
|
||||
self.eps = eps
|
||||
self.scale = dim ** -0.5
|
||||
|
||||
self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
|
||||
|
||||
self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
init.xavier_uniform_(self.slots_logsigma)
|
||||
|
||||
self.to_q = nn.Linear(dim, dim)
|
||||
self.to_k = nn.Linear(dim, dim)
|
||||
self.to_v = nn.Linear(dim, dim)
|
||||
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
hidden_dim = max(dim, hidden_dim)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, hidden_dim),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Linear(hidden_dim, dim)
|
||||
)
|
||||
|
||||
self.norm_input = nn.LayerNorm(dim)
|
||||
self.norm_slots = nn.LayerNorm(dim)
|
||||
self.norm_pre_ff = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, inputs, num_slots = None):
|
||||
inputs = prepare_input(inputs)
|
||||
b, n, d, device, dtype = *inputs.shape, inputs.device, inputs.dtype
|
||||
n_s = num_slots if num_slots is not None else self.num_slots
|
||||
|
||||
mu = self.slots_mu.expand(b, n_s, -1)
|
||||
sigma = self.slots_logsigma.exp().expand(b, n_s, -1)
|
||||
|
||||
slots = mu + sigma * torch.randn(mu.shape, device = device, dtype = dtype)
|
||||
|
||||
inputs = self.norm_input(inputs)
|
||||
k, v = self.to_k(inputs), self.to_v(inputs)
|
||||
|
||||
for _ in range(self.iters):
|
||||
slots_prev = slots
|
||||
|
||||
slots = self.norm_slots(slots)
|
||||
q = self.to_q(slots)
|
||||
|
||||
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
|
||||
attn = dots.softmax(dim=1) + self.eps
|
||||
|
||||
attn = attn / attn.sum(dim=-1, keepdim=True)
|
||||
|
||||
updates = torch.einsum('bjd,bij->bid', v, attn)
|
||||
|
||||
slots = self.gru(
|
||||
updates.reshape(-1, d),
|
||||
slots_prev.reshape(-1, d)
|
||||
)
|
||||
|
||||
slots = slots.reshape(b, -1, d)
|
||||
slots = slots + self.mlp(self.norm_pre_ff(slots))
|
||||
slots = slots.max(dim=1)[0]
|
||||
|
||||
return slots
|
|
@ -0,0 +1,142 @@
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def prepare_input(self, x):
|
||||
"""
|
||||
Prepares the input tensor for different neural network architectures (Transformers and CNNs).
|
||||
|
||||
This function adjusts the shape of the input tensor based on its dimensionality.
|
||||
It supports input tensors for Transformers (3D) and CNNs (4D), ensuring they are
|
||||
correctly formatted for these architectures.
|
||||
|
||||
For a Transformer, it expects a tensor of shape (B, N, d), where B is the batch size,
|
||||
N are patch tokens, and d is the depth (channels). The tensor is returned as is.
|
||||
|
||||
For a CNN, it expects a tensor of shape (B, d, H, W), where B is the batch size,
|
||||
d is the depth (channels), H is the height, and W is the width. The tensor is reshaped
|
||||
and permuted to the shape (B, H*W, d) to match CNN input requirements.
|
||||
|
||||
Parameters:
|
||||
x (torch.Tensor): The input tensor to be preprocessed.
|
||||
"""
|
||||
if len(x.shape) == 3: # Transformer
|
||||
# Input tensor dimensions:
|
||||
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
|
||||
B, N, d = x.shape
|
||||
return x
|
||||
if len(x.shape) == 4: # CNN
|
||||
# Input tensor dimensions:
|
||||
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
|
||||
B, d, H, W = x.shape
|
||||
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
"""
|
||||
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
|
||||
Compute the dot products of the query with all keys, divide each by sqrt(dim),
|
||||
and apply a softmax function to obtain the weights on the values
|
||||
Args: dim, mask
|
||||
dim (int): dimention of attention
|
||||
mask (torch.Tensor): tensor containing indices to be masked
|
||||
Inputs: query, key, value, mask
|
||||
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
|
||||
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
|
||||
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
|
||||
- **mask** (-): tensor containing indices to be masked
|
||||
Returns: context, attn
|
||||
- **context**: tensor containing the context vector from attention mechanism.
|
||||
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
|
||||
"""
|
||||
def __init__(self, dim: int):
|
||||
super(ScaledDotProductAttention, self).__init__()
|
||||
self.sqrt_dim = np.sqrt(dim)
|
||||
|
||||
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
||||
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
|
||||
|
||||
if mask is not None:
|
||||
score.masked_fill_(mask.view(score.size()), -float('Inf'))
|
||||
|
||||
attn = F.softmax(score, -1)
|
||||
context = torch.bmm(attn, value)
|
||||
return context, attn
|
||||
|
||||
class ViTPooling(nn.Module):
|
||||
"""
|
||||
Multi-Head Attention proposed in "Attention Is All You Need"
|
||||
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
|
||||
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
|
||||
These are concatenated and once again projected, resulting in the final values.
|
||||
Multi-head attention allows the model to jointly attend to information from different representation
|
||||
subspaces at different positions.
|
||||
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
|
||||
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
|
||||
Args:
|
||||
d_model (int): The dimension of keys / values / quries (default: 512)
|
||||
num_heads (int): The number of attention heads. (default: 8)
|
||||
Inputs: query, key, value, mask
|
||||
- **query** (batch, q_len, d_model): In transformer, three different ways:
|
||||
Case 1: come from previoys decoder layer
|
||||
Case 2: come from the input embedding
|
||||
Case 3: come from the output embedding (masked)
|
||||
- **key** (batch, k_len, d_model): In transformer, three different ways:
|
||||
Case 1: come from the output of the encoder
|
||||
Case 2: come from the input embeddings
|
||||
Case 3: come from the output embedding (masked)
|
||||
- **value** (batch, v_len, d_model): In transformer, three different ways:
|
||||
Case 1: come from the output of the encoder
|
||||
Case 2: come from the input embeddings
|
||||
Case 3: come from the output embedding (masked)
|
||||
- **mask** (-): tensor containing indices to be masked
|
||||
Returns: output, attn
|
||||
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
|
||||
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
|
||||
"""
|
||||
def __init__(self, d_model: int = 512, num_heads: int = 8):
|
||||
super(ViTPooling, self).__init__()
|
||||
|
||||
assert d_model % num_heads == 0, "d_model % num_heads should be zero."
|
||||
|
||||
self.d_head = int(d_model / num_heads)
|
||||
self.num_heads = num_heads
|
||||
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
|
||||
self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
|
||||
self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
|
||||
self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
mask: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
x = prepare_input(x)
|
||||
cls_token = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
|
||||
query = self.query_proj(x).view(B, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
|
||||
key = self.key_proj(x).view(B, -1, self.num_heads, self.d_head) # BxK_LENxNxD
|
||||
value = self.value_proj(x).view(B, -1, self.num_heads, self.d_head) # BxV_LENxNxD
|
||||
|
||||
query = query.permute(2, 0, 1, 3).contiguous().view(B * self.num_heads, -1, self.d_head) # BNxQ_LENxD
|
||||
key = key.permute(2, 0, 1, 3).contiguous().view(B * self.num_heads, -1, self.d_head) # BNxK_LENxD
|
||||
value = value.permute(2, 0, 1, 3).contiguous().view(B * self.num_heads, -1, self.d_head) # BNxV_LENxD
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
|
||||
|
||||
context, attn = self.scaled_dot_attn(query, key, value, mask)
|
||||
|
||||
context = context.view(self.num_heads, B, -1, self.d_head)
|
||||
context = context.permute(1, 2, 0, 3).contiguous().view(B, -1, self.num_heads * self.d_head) # BxTxND
|
||||
|
||||
return context[:, 0]
|
Loading…
Reference in New Issue