1108 lines
44 KiB
Python
1108 lines
44 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
import itertools
|
||
from functools import partial
|
||
from typing import List, Optional, Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from mmcv.cnn.bricks.drop import build_dropout
|
||
from mmengine.model import BaseModule
|
||
from mmengine.model.weight_init import trunc_normal_
|
||
from mmengine.utils import digit_version
|
||
|
||
from mmpretrain.registry import MODELS
|
||
from .helpers import to_2tuple
|
||
from .layer_scale import LayerScale
|
||
|
||
# After pytorch v1.10.0, use torch.meshgrid without indexing
|
||
# will raise extra warning. For more details,
|
||
# refers to https://github.com/pytorch/pytorch/issues/50276
|
||
if digit_version(torch.__version__) >= digit_version('1.10.0'):
|
||
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
|
||
else:
|
||
torch_meshgrid = torch.meshgrid
|
||
|
||
|
||
def scaled_dot_product_attention_pyimpl(query,
|
||
key,
|
||
value,
|
||
attn_mask=None,
|
||
dropout_p=0.,
|
||
scale=None,
|
||
is_causal=False):
|
||
scale = scale or query.size(-1)**0.5
|
||
if is_causal and attn_mask is not None:
|
||
attn_mask = torch.ones(
|
||
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
|
||
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
||
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
|
||
|
||
attn_weight = query @ key.transpose(-2, -1) / scale
|
||
if attn_mask is not None:
|
||
attn_weight += attn_mask
|
||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
||
return attn_weight @ value
|
||
|
||
|
||
if digit_version(torch.__version__) >= digit_version('2.0.0'):
|
||
scaled_dot_product_attention = F.scaled_dot_product_attention
|
||
else:
|
||
scaled_dot_product_attention = scaled_dot_product_attention_pyimpl
|
||
|
||
|
||
class WindowMSA(BaseModule):
|
||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||
position bias.
|
||
|
||
Args:
|
||
embed_dims (int): Number of input channels.
|
||
window_size (tuple[int]): The height and width of the window.
|
||
num_heads (int): Number of attention heads.
|
||
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
||
Defaults to True.
|
||
qk_scale (float, optional): Override default qk scale of
|
||
``head_dim ** -0.5`` if set. Defaults to None.
|
||
attn_drop (float, optional): Dropout ratio of attention weight.
|
||
Defaults to 0.
|
||
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
||
init_cfg (dict, optional): The extra config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
window_size,
|
||
num_heads,
|
||
qkv_bias=True,
|
||
qk_scale=None,
|
||
attn_drop=0.,
|
||
proj_drop=0.,
|
||
init_cfg=None):
|
||
|
||
super().__init__(init_cfg)
|
||
self.embed_dims = embed_dims
|
||
self.window_size = window_size # Wh, Ww
|
||
self.num_heads = num_heads
|
||
head_embed_dims = embed_dims // num_heads
|
||
self.scale = qk_scale or head_embed_dims**-0.5
|
||
|
||
# define a parameter table of relative position bias
|
||
self.relative_position_bias_table = nn.Parameter(
|
||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||
|
||
# About 2x faster than original impl
|
||
Wh, Ww = self.window_size
|
||
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
||
rel_position_index = rel_index_coords + rel_index_coords.T
|
||
rel_position_index = rel_position_index.flip(1).contiguous()
|
||
self.register_buffer('relative_position_index', rel_position_index)
|
||
|
||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||
self.attn_drop = nn.Dropout(attn_drop)
|
||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
||
self.softmax = nn.Softmax(dim=-1)
|
||
|
||
def init_weights(self):
|
||
super(WindowMSA, self).init_weights()
|
||
|
||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||
|
||
def forward(self, x, mask=None):
|
||
"""
|
||
Args:
|
||
|
||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
|
||
Wh*Ww), value should be between (-inf, 0].
|
||
"""
|
||
B_, N, C = x.shape
|
||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||
q, k, v = qkv[0], qkv[1], qkv[
|
||
2] # make torchscript happy (cannot use tensor as tuple)
|
||
|
||
q = q * self.scale
|
||
attn = (q @ k.transpose(-2, -1))
|
||
|
||
relative_position_bias = self.relative_position_bias_table[
|
||
self.relative_position_index.view(-1)].view(
|
||
self.window_size[0] * self.window_size[1],
|
||
self.window_size[0] * self.window_size[1],
|
||
-1) # Wh*Ww,Wh*Ww,nH
|
||
relative_position_bias = relative_position_bias.permute(
|
||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||
attn = attn + relative_position_bias.unsqueeze(0)
|
||
|
||
if mask is not None:
|
||
nW = mask.shape[0]
|
||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||
attn = attn.view(-1, self.num_heads, N, N)
|
||
attn = self.softmax(attn)
|
||
else:
|
||
attn = self.softmax(attn)
|
||
|
||
attn = self.attn_drop(attn)
|
||
|
||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||
x = self.proj(x)
|
||
x = self.proj_drop(x)
|
||
return x
|
||
|
||
@staticmethod
|
||
def double_step_seq(step1, len1, step2, len2):
|
||
seq1 = torch.arange(0, step1 * len1, step1)
|
||
seq2 = torch.arange(0, step2 * len2, step2)
|
||
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
||
|
||
|
||
class WindowMSAV2(BaseModule):
|
||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||
position bias.
|
||
|
||
Based on implementation on Swin Transformer V2 original repo. Refers to
|
||
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py
|
||
for more details.
|
||
|
||
Args:
|
||
embed_dims (int): Number of input channels.
|
||
window_size (tuple[int]): The height and width of the window.
|
||
num_heads (int): Number of attention heads.
|
||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||
Defaults to True.
|
||
attn_drop (float): Dropout ratio of attention weight.
|
||
Defaults to 0.
|
||
proj_drop (float): Dropout ratio of output. Defaults to 0.
|
||
cpb_mlp_hidden_dims (int): The hidden dimensions of the continuous
|
||
relative position bias network. Defaults to 512.
|
||
pretrained_window_size (tuple(int)): The height and width of the window
|
||
in pre-training. Defaults to (0, 0), which means not load
|
||
pretrained model.
|
||
init_cfg (dict, optional): The extra config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
window_size,
|
||
num_heads,
|
||
qkv_bias=True,
|
||
attn_drop=0.,
|
||
proj_drop=0.,
|
||
cpb_mlp_hidden_dims=512,
|
||
pretrained_window_size=(0, 0),
|
||
init_cfg=None):
|
||
|
||
super().__init__(init_cfg)
|
||
self.embed_dims = embed_dims
|
||
self.window_size = window_size # Wh, Ww
|
||
self.num_heads = num_heads
|
||
|
||
# Use small network for continuous relative position bias
|
||
self.cpb_mlp = nn.Sequential(
|
||
nn.Linear(
|
||
in_features=2, out_features=cpb_mlp_hidden_dims, bias=True),
|
||
nn.ReLU(inplace=True),
|
||
nn.Linear(
|
||
in_features=cpb_mlp_hidden_dims,
|
||
out_features=num_heads,
|
||
bias=False))
|
||
|
||
# Add learnable scalar for cosine attention
|
||
self.logit_scale = nn.Parameter(
|
||
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
||
|
||
# get relative_coords_table
|
||
relative_coords_h = torch.arange(
|
||
-(self.window_size[0] - 1),
|
||
self.window_size[0],
|
||
dtype=torch.float32)
|
||
relative_coords_w = torch.arange(
|
||
-(self.window_size[1] - 1),
|
||
self.window_size[1],
|
||
dtype=torch.float32)
|
||
relative_coords_table = torch.stack(
|
||
torch_meshgrid([relative_coords_h, relative_coords_w])).permute(
|
||
1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
||
if pretrained_window_size[0] > 0:
|
||
relative_coords_table[:, :, :, 0] /= (
|
||
pretrained_window_size[0] - 1)
|
||
relative_coords_table[:, :, :, 1] /= (
|
||
pretrained_window_size[1] - 1)
|
||
else:
|
||
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
||
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
||
relative_coords_table *= 8 # normalize to -8, 8
|
||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
||
self.register_buffer('relative_coords_table', relative_coords_table)
|
||
|
||
# get pair-wise relative position index
|
||
# for each token inside the window
|
||
indexes_h = torch.arange(self.window_size[0])
|
||
indexes_w = torch.arange(self.window_size[1])
|
||
coordinates = torch.stack(
|
||
torch_meshgrid([indexes_h, indexes_w]), dim=0) # 2, Wh, Ww
|
||
coordinates = torch.flatten(coordinates, start_dim=1) # 2, Wh*Ww
|
||
# 2, Wh*Ww, Wh*Ww
|
||
relative_coordinates = coordinates[:, :, None] - coordinates[:,
|
||
None, :]
|
||
relative_coordinates = relative_coordinates.permute(
|
||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||
|
||
relative_coordinates[:, :, 0] += self.window_size[
|
||
0] - 1 # shift to start from 0
|
||
relative_coordinates[:, :, 1] += self.window_size[1] - 1
|
||
relative_coordinates[:, :, 0] *= 2 * self.window_size[1] - 1
|
||
relative_position_index = relative_coordinates.sum(-1) # Wh*Ww, Wh*Ww
|
||
self.register_buffer('relative_position_index',
|
||
relative_position_index)
|
||
|
||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False)
|
||
if qkv_bias:
|
||
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
|
||
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
|
||
else:
|
||
self.q_bias = None
|
||
self.v_bias = None
|
||
self.attn_drop = nn.Dropout(attn_drop)
|
||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
||
self.softmax = nn.Softmax(dim=-1)
|
||
|
||
def forward(self, x, mask=None):
|
||
"""
|
||
Args:
|
||
|
||
x (tensor): input features with shape of (num_windows*B, N, C)
|
||
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
|
||
Wh*Ww), value should be between (-inf, 0].
|
||
"""
|
||
B_, N, C = x.shape
|
||
qkv_bias = None
|
||
if self.q_bias is not None:
|
||
qkv_bias = torch.cat(
|
||
(self.q_bias,
|
||
torch.zeros_like(self.v_bias,
|
||
requires_grad=False), self.v_bias))
|
||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||
qkv = qkv.reshape(B_, N, 3, self.num_heads,
|
||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||
q, k, v = qkv[0], qkv[1], qkv[
|
||
2] # make torchscript happy (cannot use tensor as tuple)
|
||
|
||
# cosine attention
|
||
attn = (
|
||
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||
logit_scale = torch.clamp(
|
||
self.logit_scale, max=np.log(1. / 0.01)).exp()
|
||
attn = attn * logit_scale
|
||
|
||
relative_position_bias_table = self.cpb_mlp(
|
||
self.relative_coords_table).view(-1, self.num_heads)
|
||
relative_position_bias = relative_position_bias_table[
|
||
self.relative_position_index.view(-1)].view(
|
||
self.window_size[0] * self.window_size[1],
|
||
self.window_size[0] * self.window_size[1],
|
||
-1) # Wh*Ww,Wh*Ww,nH
|
||
relative_position_bias = relative_position_bias.permute(
|
||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
||
attn = attn + relative_position_bias.unsqueeze(0)
|
||
|
||
if mask is not None:
|
||
nW = mask.shape[0]
|
||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||
attn = attn.view(-1, self.num_heads, N, N)
|
||
attn = self.softmax(attn)
|
||
else:
|
||
attn = self.softmax(attn)
|
||
|
||
attn = self.attn_drop(attn)
|
||
|
||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||
x = self.proj(x)
|
||
x = self.proj_drop(x)
|
||
return x
|
||
|
||
|
||
@MODELS.register_module()
|
||
class ShiftWindowMSA(BaseModule):
|
||
"""Shift Window Multihead Self-Attention Module.
|
||
|
||
Args:
|
||
embed_dims (int): Number of input channels.
|
||
num_heads (int): Number of attention heads.
|
||
window_size (int): The height and width of the window.
|
||
shift_size (int, optional): The shift step of each window towards
|
||
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
||
dropout_layer (dict, optional): The dropout_layer used before output.
|
||
Defaults to dict(type='DropPath', drop_prob=0.).
|
||
pad_small_map (bool): If True, pad the small feature map to the window
|
||
size, which is common used in detection and segmentation. If False,
|
||
avoid shifting window and shrink the window size to the size of
|
||
feature map, which is common used in classification.
|
||
Defaults to False.
|
||
window_msa (Callable): To build a window multi-head attention module.
|
||
Defaults to :class:`WindowMSA`.
|
||
init_cfg (dict, optional): The extra config for initialization.
|
||
Defaults to None.
|
||
**kwargs: Other keyword arguments to build the window multi-head
|
||
attention module.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
window_size,
|
||
shift_size=0,
|
||
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
||
pad_small_map=False,
|
||
window_msa=WindowMSA,
|
||
init_cfg=None,
|
||
**kwargs):
|
||
super().__init__(init_cfg)
|
||
|
||
self.shift_size = shift_size
|
||
self.window_size = window_size
|
||
assert 0 <= self.shift_size < self.window_size
|
||
|
||
self.w_msa = window_msa(
|
||
embed_dims=embed_dims,
|
||
num_heads=num_heads,
|
||
window_size=to_2tuple(self.window_size),
|
||
**kwargs,
|
||
)
|
||
|
||
self.drop = build_dropout(dropout_layer)
|
||
self.pad_small_map = pad_small_map
|
||
|
||
def forward(self, query, hw_shape):
|
||
B, L, C = query.shape
|
||
H, W = hw_shape
|
||
assert L == H * W, f"The query length {L} doesn't match the input "\
|
||
f'shape ({H}, {W}).'
|
||
query = query.view(B, H, W, C)
|
||
|
||
window_size = self.window_size
|
||
shift_size = self.shift_size
|
||
|
||
if min(H, W) == window_size:
|
||
# If not pad small feature map, avoid shifting when the window size
|
||
# is equal to the size of feature map. It's to align with the
|
||
# behavior of the original implementation.
|
||
shift_size = shift_size if self.pad_small_map else 0
|
||
elif min(H, W) < window_size:
|
||
# In the original implementation, the window size will be shrunk
|
||
# to the size of feature map. The behavior is different with
|
||
# swin-transformer for downstream tasks. To support dynamic input
|
||
# shape, we don't allow this feature.
|
||
assert self.pad_small_map, \
|
||
f'The input shape ({H}, {W}) is smaller than the window ' \
|
||
f'size ({window_size}). Please set `pad_small_map=True`, or ' \
|
||
'decrease the `window_size`.'
|
||
|
||
pad_r = (window_size - W % window_size) % window_size
|
||
pad_b = (window_size - H % window_size) % window_size
|
||
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
||
|
||
H_pad, W_pad = query.shape[1], query.shape[2]
|
||
|
||
# cyclic shift
|
||
if shift_size > 0:
|
||
query = torch.roll(
|
||
query, shifts=(-shift_size, -shift_size), dims=(1, 2))
|
||
|
||
attn_mask = self.get_attn_mask((H_pad, W_pad),
|
||
window_size=window_size,
|
||
shift_size=shift_size,
|
||
device=query.device)
|
||
|
||
# nW*B, window_size, window_size, C
|
||
query_windows = self.window_partition(query, window_size)
|
||
# nW*B, window_size*window_size, C
|
||
query_windows = query_windows.view(-1, window_size**2, C)
|
||
|
||
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
||
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
||
|
||
# merge windows
|
||
attn_windows = attn_windows.view(-1, window_size, window_size, C)
|
||
|
||
# B H' W' C
|
||
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad,
|
||
window_size)
|
||
# reverse cyclic shift
|
||
if self.shift_size > 0:
|
||
x = torch.roll(
|
||
shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
|
||
else:
|
||
x = shifted_x
|
||
|
||
if H != H_pad or W != W_pad:
|
||
x = x[:, :H, :W, :].contiguous()
|
||
|
||
x = x.view(B, H * W, C)
|
||
|
||
x = self.drop(x)
|
||
|
||
return x
|
||
|
||
@staticmethod
|
||
def window_reverse(windows, H, W, window_size):
|
||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||
window_size, -1)
|
||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||
return x
|
||
|
||
@staticmethod
|
||
def window_partition(x, window_size):
|
||
B, H, W, C = x.shape
|
||
x = x.view(B, H // window_size, window_size, W // window_size,
|
||
window_size, C)
|
||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||
windows = windows.view(-1, window_size, window_size, C)
|
||
return windows
|
||
|
||
@staticmethod
|
||
def get_attn_mask(hw_shape, window_size, shift_size, device=None):
|
||
if shift_size > 0:
|
||
img_mask = torch.zeros(1, *hw_shape, 1, device=device)
|
||
h_slices = (slice(0, -window_size), slice(-window_size,
|
||
-shift_size),
|
||
slice(-shift_size, None))
|
||
w_slices = (slice(0, -window_size), slice(-window_size,
|
||
-shift_size),
|
||
slice(-shift_size, None))
|
||
cnt = 0
|
||
for h in h_slices:
|
||
for w in w_slices:
|
||
img_mask[:, h, w, :] = cnt
|
||
cnt += 1
|
||
|
||
# nW, window_size, window_size, 1
|
||
mask_windows = ShiftWindowMSA.window_partition(
|
||
img_mask, window_size)
|
||
mask_windows = mask_windows.view(-1, window_size * window_size)
|
||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0)
|
||
attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0)
|
||
else:
|
||
attn_mask = None
|
||
return attn_mask
|
||
|
||
|
||
class MultiheadAttention(BaseModule):
|
||
"""Multi-head Attention Module.
|
||
|
||
This module implements multi-head attention that supports different input
|
||
dims and embed dims. And it also supports a shortcut from ``value``, which
|
||
is useful if input dims is not the same with embed dims.
|
||
|
||
Args:
|
||
embed_dims (int): The embedding dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
input_dims (int, optional): The input dimension, and if None,
|
||
use ``embed_dims``. Defaults to None.
|
||
attn_drop (float): Dropout rate of the dropout layer after the
|
||
attention calculation of query and key. Defaults to 0.
|
||
proj_drop (float): Dropout rate of the dropout layer after the
|
||
output projection. Defaults to 0.
|
||
dropout_layer (dict): The dropout config before adding the shortcut.
|
||
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||
Defaults to True.
|
||
qk_scale (float, optional): Override default qk scale of
|
||
``head_dim ** -0.5`` if set. Defaults to None.
|
||
proj_bias (bool) If True, add a learnable bias to output projection.
|
||
Defaults to True.
|
||
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
||
used if ``input_dims`` is different from ``embed_dims``.
|
||
Defaults to False.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
input_dims=None,
|
||
attn_drop=0.,
|
||
proj_drop=0.,
|
||
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
||
qkv_bias=True,
|
||
qk_scale=None,
|
||
proj_bias=True,
|
||
v_shortcut=False,
|
||
use_layer_scale=False,
|
||
init_cfg=None):
|
||
super(MultiheadAttention, self).__init__(init_cfg=init_cfg)
|
||
|
||
self.input_dims = input_dims or embed_dims
|
||
self.embed_dims = embed_dims
|
||
self.num_heads = num_heads
|
||
self.v_shortcut = v_shortcut
|
||
|
||
self.head_dims = embed_dims // num_heads
|
||
if qk_scale is not None:
|
||
self.scaled_dot_product_attention = partial(
|
||
scaled_dot_product_attention_pyimpl,
|
||
scale=self.head_dims**-0.5)
|
||
else:
|
||
self.scaled_dot_product_attention = scaled_dot_product_attention
|
||
|
||
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
|
||
self.attn_drop = attn_drop
|
||
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
||
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
||
self.out_drop = build_dropout(dropout_layer)
|
||
|
||
if use_layer_scale:
|
||
self.gamma1 = LayerScale(embed_dims)
|
||
else:
|
||
self.gamma1 = nn.Identity()
|
||
|
||
def forward(self, x):
|
||
B, N, _ = x.shape
|
||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||
|
||
attn_drop = self.attn_drop if self.training else 0.
|
||
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
||
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
||
|
||
x = self.proj(x)
|
||
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
||
|
||
if self.v_shortcut:
|
||
x = v.squeeze(1) + x
|
||
return x
|
||
|
||
|
||
class BEiTAttention(BaseModule):
|
||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||
position bias.
|
||
|
||
The initial implementation is in MMSegmentation.
|
||
|
||
Args:
|
||
embed_dims (int): Number of input channels.
|
||
num_heads (int): Number of attention heads.
|
||
window_size (tuple[int, int]): The height and width of the window.
|
||
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
||
if False, use shared relative position bias defined in backbone.
|
||
bias (str): The option to add leanable bias for q, k, v. If bias is
|
||
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
||
add leanable bias for q, v. If bias is False, it will not add bias
|
||
for q, k, v. Default to 'qv_bias'.
|
||
qk_scale (float | None, optional): Override default qk scale of
|
||
head_dim ** -0.5 if set. Default: None.
|
||
attn_drop_rate (float): Dropout ratio of attention weight.
|
||
Default: 0.0
|
||
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
||
init_cfg (dict | None, optional): The Config for initialization.
|
||
Default: None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
window_size,
|
||
use_rel_pos_bias,
|
||
bias='qv_bias',
|
||
qk_scale=None,
|
||
attn_drop_rate=0.,
|
||
proj_drop_rate=0.,
|
||
init_cfg=None,
|
||
**kwargs):
|
||
super().__init__(init_cfg=init_cfg)
|
||
self.embed_dims = embed_dims
|
||
self.num_heads = num_heads
|
||
head_embed_dims = embed_dims // num_heads
|
||
self.bias = bias
|
||
self.scale = qk_scale or head_embed_dims**-0.5
|
||
|
||
qkv_bias = bias
|
||
if bias == 'qv_bias':
|
||
self._init_qv_bias()
|
||
qkv_bias = False
|
||
|
||
if window_size is None:
|
||
assert not use_rel_pos_bias
|
||
else:
|
||
assert isinstance(window_size, tuple)
|
||
self.window_size = window_size
|
||
self.use_rel_pos_bias = use_rel_pos_bias
|
||
self._init_rel_pos_embedding()
|
||
|
||
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
||
self.attn_drop = nn.Dropout(attn_drop_rate)
|
||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||
self.proj_drop = nn.Dropout(proj_drop_rate)
|
||
|
||
def _init_qv_bias(self):
|
||
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
||
|
||
def _init_rel_pos_embedding(self):
|
||
if self.use_rel_pos_bias:
|
||
Wh, Ww = self.window_size
|
||
# cls to token & token 2 cls & cls to cls
|
||
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
||
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
||
self.relative_position_bias_table = nn.Parameter(
|
||
torch.zeros(self.num_relative_distance, self.num_heads))
|
||
|
||
# get pair-wise relative position index for
|
||
# each token inside the window
|
||
coords_h = torch.arange(Wh)
|
||
coords_w = torch.arange(Ww)
|
||
# coords shape is (2, Wh, Ww)
|
||
coords = torch.stack(torch_meshgrid([coords_h, coords_w]))
|
||
# coords_flatten shape is (2, Wh*Ww)
|
||
coords_flatten = torch.flatten(coords, 1)
|
||
relative_coords = (
|
||
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
||
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||
# shift to start from 0
|
||
relative_coords[:, :, 0] += Wh - 1
|
||
relative_coords[:, :, 1] += Ww - 1
|
||
relative_coords[:, :, 0] *= 2 * Ww - 1
|
||
relative_position_index = torch.zeros(
|
||
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
||
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
||
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||
|
||
self.register_buffer('relative_position_index',
|
||
relative_position_index)
|
||
else:
|
||
self.window_size = None
|
||
self.relative_position_bias_table = None
|
||
self.relative_position_index = None
|
||
|
||
def init_weights(self):
|
||
super().init_weights()
|
||
if self.use_rel_pos_bias:
|
||
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
||
|
||
def forward(self, x, rel_pos_bias=None):
|
||
"""
|
||
Args:
|
||
x (tensor): input features with shape of (num_windows*B, N, C).
|
||
rel_pos_bias (tensor): input relative position bias with shape of
|
||
(num_heads, N, N).
|
||
"""
|
||
B, N, C = x.shape
|
||
|
||
if self.bias == 'qv_bias':
|
||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||
else:
|
||
qkv = self.qkv(x)
|
||
|
||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||
q = q * self.scale
|
||
attn = (q @ k.transpose(-2, -1))
|
||
|
||
if self.relative_position_bias_table is not None:
|
||
Wh = self.window_size[0]
|
||
Ww = self.window_size[1]
|
||
relative_position_bias = self.relative_position_bias_table[
|
||
self.relative_position_index.view(-1)].view(
|
||
Wh * Ww + 1, Wh * Ww + 1, -1)
|
||
relative_position_bias = relative_position_bias.permute(
|
||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||
attn = attn + relative_position_bias.unsqueeze(0)
|
||
|
||
if rel_pos_bias is not None:
|
||
# use shared relative position bias
|
||
attn = attn + rel_pos_bias
|
||
|
||
attn = attn.softmax(dim=-1)
|
||
attn = self.attn_drop(attn)
|
||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||
x = self.proj(x)
|
||
x = self.proj_drop(x)
|
||
return x
|
||
|
||
|
||
class ChannelMultiheadAttention(BaseModule):
|
||
"""Channel Multihead Self-attention Module.
|
||
|
||
This module implements channel multi-head attention that supports different
|
||
input dims and embed dims.
|
||
Args:
|
||
embed_dims (int): The embedding dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
input_dims (int, optional): The input dimension, and if None,
|
||
use ``embed_dims``. Defaults to None.
|
||
attn_drop (float): Dropout rate of the dropout layer after the
|
||
attention calculation of query and key. Defaults to 0.
|
||
proj_drop (float): Dropout rate of the dropout layer after the
|
||
output projection. Defaults to 0.
|
||
dropout_layer (dict): The dropout config before adding the shoutcut.
|
||
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||
Defaults to False.
|
||
proj_bias (bool) If True, add a learnable bias to output projection.
|
||
Defaults to True.
|
||
qk_scale_type (str): The scale type of qk scale.
|
||
Defaults to 'learnable'. It can be 'learnable', 'fixed' or 'none'.
|
||
qk_scale (float, optional): If set qk_scale_type to 'none', this
|
||
should be specified with valid float number. Defaults to None.
|
||
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
||
used if ``input_dims`` is different from ``embed_dims``.
|
||
Defaults to False.
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads=8,
|
||
input_dims=None,
|
||
attn_drop=0.,
|
||
proj_drop=0.,
|
||
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
||
qkv_bias=False,
|
||
proj_bias=True,
|
||
qk_scale_type='learnable',
|
||
qk_scale=None,
|
||
v_shortcut=False,
|
||
init_cfg=None):
|
||
super().__init__(init_cfg)
|
||
|
||
self.input_dims = input_dims or embed_dims
|
||
self.embed_dims = embed_dims
|
||
self.num_heads = num_heads
|
||
self.v_shortcut = v_shortcut
|
||
|
||
self.head_dims = embed_dims // num_heads
|
||
if qk_scale_type == 'learnable':
|
||
self.scale = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||
elif qk_scale_type == 'fixed':
|
||
self.scale = self.head_dims**-0.5
|
||
elif qk_scale_type == 'none':
|
||
assert qk_scale is not None
|
||
self.scale = qk_scale
|
||
|
||
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
|
||
self.attn_drop = nn.Dropout(attn_drop)
|
||
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
||
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
||
self.out_drop = build_dropout(dropout_layer)
|
||
|
||
def forward(self, x):
|
||
B, N, _ = x.shape
|
||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||
|
||
q, k, v = [item.transpose(-2, -1) for item in [qkv[0], qkv[1], qkv[2]]]
|
||
|
||
q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)
|
||
|
||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||
attn = attn.softmax(dim=-1)
|
||
|
||
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, self.embed_dims)
|
||
x = self.proj(x)
|
||
x = self.out_drop(self.proj_drop(x))
|
||
|
||
if self.v_shortcut:
|
||
x = qkv[2].squeeze(1) + x
|
||
return x
|
||
|
||
|
||
class LeAttention(BaseModule):
|
||
"""LeViT Attention. Multi-head attention with attention bias, which is
|
||
proposed in `LeViT: a Vision Transformer in ConvNet’s Clothing for Faster
|
||
Inference<https://arxiv.org/abs/2104.01136>`_
|
||
|
||
Args:
|
||
dim (int): Number of input channels.
|
||
num_heads (int): Number of attention heads. Default: 8.
|
||
key_dim (int): Dimension of key. Default: None.
|
||
attn_ratio (int): Ratio of attention heads. Default: 8.
|
||
resolution (tuple[int]): Input resolution. Default: (16, 16).
|
||
init_cfg (dict, optional): The Config for initialization.
|
||
"""
|
||
|
||
def __init__(self,
|
||
dim,
|
||
key_dim,
|
||
num_heads=8,
|
||
attn_ratio=4,
|
||
resolution=(14, 14),
|
||
init_cfg=None):
|
||
super().__init__(init_cfg=init_cfg)
|
||
# (h, w)
|
||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||
self.num_heads = num_heads
|
||
self.scale = key_dim**-0.5
|
||
self.key_dim = key_dim
|
||
self.nh_kd = nh_kd = key_dim * num_heads
|
||
self.d = int(attn_ratio * key_dim)
|
||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||
self.attn_ratio = attn_ratio
|
||
h = self.dh + nh_kd * 2
|
||
|
||
self.norm = nn.LayerNorm(dim)
|
||
self.qkv = nn.Linear(dim, h)
|
||
self.proj = nn.Linear(self.dh, dim)
|
||
|
||
points = list(
|
||
itertools.product(range(resolution[0]), range(resolution[1])))
|
||
N = len(points)
|
||
attention_offsets = {}
|
||
idxs = []
|
||
for p1 in points:
|
||
for p2 in points:
|
||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||
if offset not in attention_offsets:
|
||
attention_offsets[offset] = len(attention_offsets)
|
||
idxs.append(attention_offsets[offset])
|
||
self.attention_biases = torch.nn.Parameter(
|
||
torch.zeros(num_heads, len(attention_offsets)))
|
||
self.register_buffer(
|
||
'attention_bias_idxs',
|
||
torch.LongTensor(idxs).view(N, N),
|
||
persistent=False)
|
||
|
||
@torch.no_grad()
|
||
def train(self, mode=True):
|
||
super().train(mode)
|
||
if mode and hasattr(self, 'ab'):
|
||
del self.ab
|
||
else:
|
||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||
|
||
def forward(self, x): # x (B,N,C)
|
||
B, N, _ = x.shape
|
||
|
||
# Normalization
|
||
x = self.norm(x)
|
||
|
||
qkv = self.qkv(x)
|
||
# (B, N, num_heads, d)
|
||
q, k, v = qkv.view(B, N, self.num_heads,
|
||
-1).split([self.key_dim, self.key_dim, self.d],
|
||
dim=3)
|
||
# (B, num_heads, N, d)
|
||
q = q.permute(0, 2, 1, 3)
|
||
k = k.permute(0, 2, 1, 3)
|
||
v = v.permute(0, 2, 1, 3)
|
||
|
||
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
||
(self.attention_biases[:, self.attention_bias_idxs]
|
||
if self.training else self.ab))
|
||
attn = attn.softmax(dim=-1)
|
||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||
x = self.proj(x)
|
||
return x
|
||
|
||
|
||
class CrossMultiheadAttention(BaseModule):
|
||
"""Cross attention between queries and the union of keys and values.
|
||
|
||
This module is different from ``MultiheadAttention``, for the attention
|
||
is computed between queries and the union of keys and values.
|
||
|
||
Args:
|
||
embed_dims (int): The embedding dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||
Defaults to True.
|
||
qk_scale (float, optional): Override default qk scale of
|
||
``head_dim ** -0.5`` if set. Defaults to None.
|
||
attn_drop (float): Dropout rate of the dropout layer after the
|
||
attention calculation of query and key. Defaults to 0.
|
||
proj_drop (float): Dropout rate of the dropout layer after the
|
||
output projection. Defaults to 0.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims: int,
|
||
num_heads: int = 8,
|
||
qkv_bias: bool = False,
|
||
qk_scale: float = None,
|
||
attn_drop: float = 0.,
|
||
proj_drop: float = 0.) -> None:
|
||
super().__init__()
|
||
self.num_heads = num_heads
|
||
head_dim = embed_dims // num_heads
|
||
self.scale = qk_scale or head_dim**-0.5
|
||
|
||
self.q = nn.Linear(embed_dims, embed_dims, bias=False)
|
||
self.k = nn.Linear(embed_dims, embed_dims, bias=False)
|
||
self.v = nn.Linear(embed_dims, embed_dims, bias=False)
|
||
|
||
if qkv_bias:
|
||
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
|
||
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
|
||
else:
|
||
self.q_bias = None
|
||
self.k_bias = None
|
||
self.v_bias = None
|
||
|
||
self.attn_drop = nn.Dropout(attn_drop)
|
||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||
self.proj_drop = nn.Dropout(proj_drop)
|
||
|
||
def forward(self,
|
||
x: torch.Tensor,
|
||
k: torch.Tensor = None,
|
||
v: torch.Tensor = None) -> None:
|
||
"""Forward function."""
|
||
B, N, _ = x.shape
|
||
|
||
N_k = k.shape[1]
|
||
N_v = v.shape[1]
|
||
|
||
q_bias, k_bias, v_bias = None, None, None
|
||
if self.q_bias is not None:
|
||
q_bias = self.q_bias
|
||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||
v_bias = self.v_bias
|
||
|
||
q = F.linear(
|
||
input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim)
|
||
k = F.linear(
|
||
input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim)
|
||
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
||
|
||
q = q.reshape(B, N, 1, self.num_heads,
|
||
-1).permute(2, 0, 3, 1,
|
||
4).squeeze(0) # (B, num_heads, N_q, dim)
|
||
k = k.reshape(B, N_k, 1, self.num_heads,
|
||
-1).permute(2, 0, 3, 1,
|
||
4).squeeze(0) # (B, num_heads, N_k, dim)
|
||
v = v.reshape(B, N_v, 1, self.num_heads,
|
||
-1).permute(2, 0, 3, 1,
|
||
4).squeeze(0) # (B, num_heads, N_v, dim)
|
||
|
||
q = q * self.scale
|
||
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
||
|
||
attn = attn.softmax(dim=-1)
|
||
attn = self.attn_drop(attn)
|
||
|
||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||
x = self.proj(x)
|
||
x = self.proj_drop(x)
|
||
|
||
return x
|
||
|
||
|
||
class PromptMultiheadAttention(MultiheadAttention):
|
||
"""Prompt Multihead Attention for MILAN.
|
||
|
||
This module is specific for the prompt encoder in MILAN. It will not update
|
||
the visible tokens from the encoder.
|
||
|
||
Args:
|
||
embed_dims (int): The embedding dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
input_dims (int, optional): The input dimension, and if None,
|
||
use ``embed_dims``. Defaults to None.
|
||
attn_drop (float): Dropout rate of the dropout layer after the
|
||
attention calculation of query and key. Defaults to 0.
|
||
proj_drop (float): Dropout rate of the dropout layer after the
|
||
output projection. Defaults to 0.
|
||
dropout_layer (dict): The dropout config before adding the shortcut.
|
||
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||
Defaults to True.
|
||
qk_scale (float, optional): Override default qk scale of
|
||
``head_dim ** -0.5`` if set. Defaults to None.
|
||
proj_bias (bool) If True, add a learnable bias to output projection.
|
||
Defaults to True.
|
||
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
||
used if ``input_dims`` is different from ``embed_dims``.
|
||
Defaults to False.
|
||
return_attention (bool): If True, return the attention map, computed by
|
||
the cross attention between the class token and all other tokens.
|
||
Defaults to False.
|
||
init_cfg (Union[List[dict], dict], optional): The Config for
|
||
initialization. Defaults to None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims: int,
|
||
num_heads: int,
|
||
input_dims: Optional[int] = None,
|
||
attn_drop: float = 0,
|
||
proj_drop: float = 0,
|
||
dropout_layer: dict = dict(type='Dropout', drop_prob=0.),
|
||
qkv_bias: bool = True,
|
||
qk_scale: Optional[float] = None,
|
||
proj_bias: bool = True,
|
||
v_shortcut: bool = False,
|
||
use_layer_scale: bool = False,
|
||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||
super().__init__(embed_dims, num_heads, input_dims, attn_drop,
|
||
proj_drop, dropout_layer, qkv_bias, qk_scale,
|
||
proj_bias, v_shortcut, use_layer_scale, init_cfg)
|
||
# no longer need qkv
|
||
del self.qkv
|
||
|
||
# to project the mask tokens
|
||
self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
|
||
# to project al the tokens
|
||
self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias)
|
||
|
||
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
|
||
ids_restore: torch.Tensor) -> torch.Tensor:
|
||
"""Forward function for `PromptMultiheadAttention`.
|
||
|
||
Args:
|
||
x (torch.Tensor): Mask token features with shape N x L_m x C.
|
||
visible_tokens (torch.Tensor): The visible tokens features from
|
||
encoder with shape N x L_v x C.
|
||
ids_restore (torch.Tensor): The ids of all tokens in the original
|
||
image with shape N x L.
|
||
|
||
Returns:
|
||
torch Tensor: Output features with shape N x L x C.
|
||
"""
|
||
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
|
||
assert x_.shape[1] == ids_restore.shape[1]
|
||
x_ = torch.gather(
|
||
x_,
|
||
dim=1,
|
||
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
||
x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
|
||
|
||
# full sequence shape
|
||
B, _, _ = x_.shape
|
||
q = self.q(x).reshape(B, x.shape[1], self.num_heads,
|
||
self.head_dims).permute(0, 2, 1, 3)
|
||
kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads,
|
||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||
k, v = kv[0], kv[1]
|
||
|
||
attn_drop = self.attn_drop if self.training else 0.
|
||
attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
||
x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
|
||
|
||
x = self.proj(x)
|
||
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
||
return x
|