2021-08-17 19:52:42 +08:00
|
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-12-20 13:04:00 +08:00
|
|
|
|
import itertools
|
2023-05-05 16:59:37 +08:00
|
|
|
|
import warnings
|
2023-03-29 15:50:44 +08:00
|
|
|
|
from functools import partial
|
2023-02-28 10:05:00 +08:00
|
|
|
|
from typing import List, Optional, Union
|
2022-12-20 13:04:00 +08:00
|
|
|
|
|
2022-09-20 15:45:27 +08:00
|
|
|
|
import numpy as np
|
2021-07-01 09:30:42 +08:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
2022-07-12 16:10:59 +08:00
|
|
|
|
from mmcv.cnn.bricks.drop import build_dropout
|
|
|
|
|
from mmengine.model import BaseModule
|
2022-08-26 10:40:43 +08:00
|
|
|
|
from mmengine.model.weight_init import trunc_normal_
|
2022-09-20 15:45:27 +08:00
|
|
|
|
from mmengine.utils import digit_version
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
2023-02-17 11:31:08 +08:00
|
|
|
|
from mmpretrain.registry import MODELS
|
2021-07-01 09:30:42 +08:00
|
|
|
|
from .helpers import to_2tuple
|
2022-10-10 14:54:20 +08:00
|
|
|
|
from .layer_scale import LayerScale
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
2022-09-20 15:45:27 +08:00
|
|
|
|
# After pytorch v1.10.0, use torch.meshgrid without indexing
|
|
|
|
|
# will raise extra warning. For more details,
|
|
|
|
|
# refers to https://github.com/pytorch/pytorch/issues/50276
|
|
|
|
|
if digit_version(torch.__version__) >= digit_version('1.10.0'):
|
|
|
|
|
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
|
|
|
|
|
else:
|
|
|
|
|
torch_meshgrid = torch.meshgrid
|
|
|
|
|
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
2023-03-29 15:50:44 +08:00
|
|
|
|
def scaled_dot_product_attention_pyimpl(query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
attn_mask=None,
|
|
|
|
|
dropout_p=0.,
|
|
|
|
|
scale=None,
|
|
|
|
|
is_causal=False):
|
|
|
|
|
scale = scale or query.size(-1)**0.5
|
|
|
|
|
if is_causal and attn_mask is not None:
|
|
|
|
|
attn_mask = torch.ones(
|
|
|
|
|
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
|
|
|
|
|
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
|
|
|
|
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
|
|
|
|
|
|
|
|
|
|
attn_weight = query @ key.transpose(-2, -1) / scale
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
attn_weight += attn_mask
|
|
|
|
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
|
|
|
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
|
|
|
|
return attn_weight @ value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if digit_version(torch.__version__) >= digit_version('2.0.0'):
|
|
|
|
|
scaled_dot_product_attention = F.scaled_dot_product_attention
|
|
|
|
|
else:
|
|
|
|
|
scaled_dot_product_attention = scaled_dot_product_attention_pyimpl
|
|
|
|
|
|
|
|
|
|
|
2021-07-01 09:30:42 +08:00
|
|
|
|
class WindowMSA(BaseModule):
|
|
|
|
|
"""Window based multi-head self-attention (W-MSA) module with relative
|
|
|
|
|
position bias.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): Number of input channels.
|
|
|
|
|
window_size (tuple[int]): The height and width of the window.
|
|
|
|
|
num_heads (int): Number of attention heads.
|
|
|
|
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
|
|
|
|
Defaults to True.
|
2021-10-18 16:07:00 +08:00
|
|
|
|
qk_scale (float, optional): Override default qk scale of
|
|
|
|
|
``head_dim ** -0.5`` if set. Defaults to None.
|
2021-07-01 09:30:42 +08:00
|
|
|
|
attn_drop (float, optional): Dropout ratio of attention weight.
|
|
|
|
|
Defaults to 0.
|
|
|
|
|
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
|
|
|
|
|
init_cfg (dict, optional): The extra config for initialization.
|
|
|
|
|
Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
window_size,
|
|
|
|
|
num_heads,
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
qk_scale=None,
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
proj_drop=0.,
|
|
|
|
|
init_cfg=None):
|
|
|
|
|
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.window_size = window_size # Wh, Ww
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
head_embed_dims = embed_dims // num_heads
|
|
|
|
|
self.scale = qk_scale or head_embed_dims**-0.5
|
|
|
|
|
|
|
|
|
|
# define a parameter table of relative position bias
|
|
|
|
|
self.relative_position_bias_table = nn.Parameter(
|
|
|
|
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
|
|
|
|
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
|
|
|
|
|
|
|
|
|
# About 2x faster than original impl
|
|
|
|
|
Wh, Ww = self.window_size
|
|
|
|
|
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
|
|
|
|
rel_position_index = rel_index_coords + rel_index_coords.T
|
|
|
|
|
rel_position_index = rel_position_index.flip(1).contiguous()
|
|
|
|
|
self.register_buffer('relative_position_index', rel_position_index)
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(embed_dims, embed_dims)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
def init_weights(self):
|
|
|
|
|
super(WindowMSA, self).init_weights()
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
|
|
x (tensor): input features with shape of (num_windows*B, N, C)
|
|
|
|
|
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
|
|
|
|
|
Wh*Ww), value should be between (-inf, 0].
|
|
|
|
|
"""
|
|
|
|
|
B_, N, C = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
|
|
|
|
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[
|
|
|
|
|
2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
attn = (q @ k.transpose(-2, -1))
|
|
|
|
|
|
|
|
|
|
relative_position_bias = self.relative_position_bias_table[
|
|
|
|
|
self.relative_position_index.view(-1)].view(
|
|
|
|
|
self.window_size[0] * self.window_size[1],
|
|
|
|
|
self.window_size[0] * self.window_size[1],
|
|
|
|
|
-1) # Wh*Ww,Wh*Ww,nH
|
|
|
|
|
relative_position_bias = relative_position_bias.permute(
|
|
|
|
|
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
|
|
|
attn = attn + relative_position_bias.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
nW = mask.shape[0]
|
|
|
|
|
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
|
|
|
|
N) + mask.unsqueeze(1).unsqueeze(0)
|
|
|
|
|
attn = attn.view(-1, self.num_heads, N, N)
|
|
|
|
|
attn = self.softmax(attn)
|
|
|
|
|
else:
|
|
|
|
|
attn = self.softmax(attn)
|
|
|
|
|
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def double_step_seq(step1, len1, step2, len2):
|
|
|
|
|
seq1 = torch.arange(0, step1 * len1, step1)
|
|
|
|
|
seq2 = torch.arange(0, step2 * len2, step2)
|
|
|
|
|
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
|
|
|
|
|
|
|
|
|
|
2022-09-20 15:45:27 +08:00
|
|
|
|
class WindowMSAV2(BaseModule):
|
|
|
|
|
"""Window based multi-head self-attention (W-MSA) module with relative
|
|
|
|
|
position bias.
|
|
|
|
|
|
|
|
|
|
Based on implementation on Swin Transformer V2 original repo. Refers to
|
|
|
|
|
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py
|
|
|
|
|
for more details.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): Number of input channels.
|
|
|
|
|
window_size (tuple[int]): The height and width of the window.
|
|
|
|
|
num_heads (int): Number of attention heads.
|
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
attn_drop (float): Dropout ratio of attention weight.
|
|
|
|
|
Defaults to 0.
|
|
|
|
|
proj_drop (float): Dropout ratio of output. Defaults to 0.
|
|
|
|
|
cpb_mlp_hidden_dims (int): The hidden dimensions of the continuous
|
|
|
|
|
relative position bias network. Defaults to 512.
|
|
|
|
|
pretrained_window_size (tuple(int)): The height and width of the window
|
|
|
|
|
in pre-training. Defaults to (0, 0), which means not load
|
|
|
|
|
pretrained model.
|
|
|
|
|
init_cfg (dict, optional): The extra config for initialization.
|
|
|
|
|
Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
window_size,
|
|
|
|
|
num_heads,
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
proj_drop=0.,
|
|
|
|
|
cpb_mlp_hidden_dims=512,
|
|
|
|
|
pretrained_window_size=(0, 0),
|
|
|
|
|
init_cfg=None):
|
|
|
|
|
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.window_size = window_size # Wh, Ww
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
|
|
# Use small network for continuous relative position bias
|
|
|
|
|
self.cpb_mlp = nn.Sequential(
|
|
|
|
|
nn.Linear(
|
|
|
|
|
in_features=2, out_features=cpb_mlp_hidden_dims, bias=True),
|
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
|
nn.Linear(
|
|
|
|
|
in_features=cpb_mlp_hidden_dims,
|
|
|
|
|
out_features=num_heads,
|
|
|
|
|
bias=False))
|
|
|
|
|
|
|
|
|
|
# Add learnable scalar for cosine attention
|
|
|
|
|
self.logit_scale = nn.Parameter(
|
|
|
|
|
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
|
|
|
|
|
|
|
|
|
# get relative_coords_table
|
|
|
|
|
relative_coords_h = torch.arange(
|
|
|
|
|
-(self.window_size[0] - 1),
|
|
|
|
|
self.window_size[0],
|
|
|
|
|
dtype=torch.float32)
|
|
|
|
|
relative_coords_w = torch.arange(
|
|
|
|
|
-(self.window_size[1] - 1),
|
|
|
|
|
self.window_size[1],
|
|
|
|
|
dtype=torch.float32)
|
|
|
|
|
relative_coords_table = torch.stack(
|
|
|
|
|
torch_meshgrid([relative_coords_h, relative_coords_w])).permute(
|
|
|
|
|
1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
|
|
|
|
if pretrained_window_size[0] > 0:
|
|
|
|
|
relative_coords_table[:, :, :, 0] /= (
|
|
|
|
|
pretrained_window_size[0] - 1)
|
|
|
|
|
relative_coords_table[:, :, :, 1] /= (
|
|
|
|
|
pretrained_window_size[1] - 1)
|
|
|
|
|
else:
|
|
|
|
|
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
|
|
|
|
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
|
|
|
|
relative_coords_table *= 8 # normalize to -8, 8
|
|
|
|
|
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
|
|
|
|
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
|
|
|
|
self.register_buffer('relative_coords_table', relative_coords_table)
|
|
|
|
|
|
|
|
|
|
# get pair-wise relative position index
|
|
|
|
|
# for each token inside the window
|
|
|
|
|
indexes_h = torch.arange(self.window_size[0])
|
|
|
|
|
indexes_w = torch.arange(self.window_size[1])
|
|
|
|
|
coordinates = torch.stack(
|
|
|
|
|
torch_meshgrid([indexes_h, indexes_w]), dim=0) # 2, Wh, Ww
|
|
|
|
|
coordinates = torch.flatten(coordinates, start_dim=1) # 2, Wh*Ww
|
|
|
|
|
# 2, Wh*Ww, Wh*Ww
|
|
|
|
|
relative_coordinates = coordinates[:, :, None] - coordinates[:,
|
|
|
|
|
None, :]
|
|
|
|
|
relative_coordinates = relative_coordinates.permute(
|
|
|
|
|
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
|
|
|
|
|
|
|
|
relative_coordinates[:, :, 0] += self.window_size[
|
|
|
|
|
0] - 1 # shift to start from 0
|
|
|
|
|
relative_coordinates[:, :, 1] += self.window_size[1] - 1
|
|
|
|
|
relative_coordinates[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
|
|
|
relative_position_index = relative_coordinates.sum(-1) # Wh*Ww, Wh*Ww
|
|
|
|
|
self.register_buffer('relative_position_index',
|
|
|
|
|
relative_position_index)
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False)
|
|
|
|
|
if qkv_bias:
|
|
|
|
|
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
|
|
|
|
|
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
|
|
|
|
|
else:
|
|
|
|
|
self.q_bias = None
|
|
|
|
|
self.v_bias = None
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(embed_dims, embed_dims)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
|
|
x (tensor): input features with shape of (num_windows*B, N, C)
|
|
|
|
|
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
|
|
|
|
|
Wh*Ww), value should be between (-inf, 0].
|
|
|
|
|
"""
|
|
|
|
|
B_, N, C = x.shape
|
|
|
|
|
qkv_bias = None
|
|
|
|
|
if self.q_bias is not None:
|
|
|
|
|
qkv_bias = torch.cat(
|
|
|
|
|
(self.q_bias,
|
|
|
|
|
torch.zeros_like(self.v_bias,
|
|
|
|
|
requires_grad=False), self.v_bias))
|
|
|
|
|
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
|
|
|
|
qkv = qkv.reshape(B_, N, 3, self.num_heads,
|
|
|
|
|
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[
|
|
|
|
|
2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
|
|
# cosine attention
|
|
|
|
|
attn = (
|
|
|
|
|
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
|
|
|
|
logit_scale = torch.clamp(
|
|
|
|
|
self.logit_scale, max=np.log(1. / 0.01)).exp()
|
|
|
|
|
attn = attn * logit_scale
|
|
|
|
|
|
|
|
|
|
relative_position_bias_table = self.cpb_mlp(
|
|
|
|
|
self.relative_coords_table).view(-1, self.num_heads)
|
|
|
|
|
relative_position_bias = relative_position_bias_table[
|
|
|
|
|
self.relative_position_index.view(-1)].view(
|
|
|
|
|
self.window_size[0] * self.window_size[1],
|
|
|
|
|
self.window_size[0] * self.window_size[1],
|
|
|
|
|
-1) # Wh*Ww,Wh*Ww,nH
|
|
|
|
|
relative_position_bias = relative_position_bias.permute(
|
|
|
|
|
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
|
|
|
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
|
|
|
|
attn = attn + relative_position_bias.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
nW = mask.shape[0]
|
|
|
|
|
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
|
|
|
|
N) + mask.unsqueeze(1).unsqueeze(0)
|
|
|
|
|
attn = attn.view(-1, self.num_heads, N, N)
|
|
|
|
|
attn = self.softmax(attn)
|
|
|
|
|
else:
|
|
|
|
|
attn = self.softmax(attn)
|
|
|
|
|
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2022-05-10 17:45:10 +08:00
|
|
|
|
@MODELS.register_module()
|
2021-07-01 09:30:42 +08:00
|
|
|
|
class ShiftWindowMSA(BaseModule):
|
|
|
|
|
"""Shift Window Multihead Self-Attention Module.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): Number of input channels.
|
|
|
|
|
num_heads (int): Number of attention heads.
|
|
|
|
|
window_size (int): The height and width of the window.
|
|
|
|
|
shift_size (int, optional): The shift step of each window towards
|
|
|
|
|
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
|
|
|
|
dropout_layer (dict, optional): The dropout_layer used before output.
|
|
|
|
|
Defaults to dict(type='DropPath', drop_prob=0.).
|
2022-03-03 13:10:12 +08:00
|
|
|
|
pad_small_map (bool): If True, pad the small feature map to the window
|
|
|
|
|
size, which is common used in detection and segmentation. If False,
|
|
|
|
|
avoid shifting window and shrink the window size to the size of
|
|
|
|
|
feature map, which is common used in classification.
|
|
|
|
|
Defaults to False.
|
2022-09-20 15:45:27 +08:00
|
|
|
|
window_msa (Callable): To build a window multi-head attention module.
|
|
|
|
|
Defaults to :class:`WindowMSA`.
|
2021-07-01 09:30:42 +08:00
|
|
|
|
init_cfg (dict, optional): The extra config for initialization.
|
2022-03-03 13:10:12 +08:00
|
|
|
|
Defaults to None.
|
2022-09-20 15:45:27 +08:00
|
|
|
|
**kwargs: Other keyword arguments to build the window multi-head
|
|
|
|
|
attention module.
|
2021-07-01 09:30:42 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
num_heads,
|
|
|
|
|
window_size,
|
|
|
|
|
shift_size=0,
|
|
|
|
|
dropout_layer=dict(type='DropPath', drop_prob=0.),
|
2022-03-03 13:10:12 +08:00
|
|
|
|
pad_small_map=False,
|
2022-09-20 15:45:27 +08:00
|
|
|
|
window_msa=WindowMSA,
|
|
|
|
|
init_cfg=None,
|
|
|
|
|
**kwargs):
|
2021-07-01 09:30:42 +08:00
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
|
|
|
|
|
self.shift_size = shift_size
|
|
|
|
|
self.window_size = window_size
|
2022-03-03 13:10:12 +08:00
|
|
|
|
assert 0 <= self.shift_size < self.window_size
|
|
|
|
|
|
2022-09-20 15:45:27 +08:00
|
|
|
|
self.w_msa = window_msa(
|
2022-03-03 13:10:12 +08:00
|
|
|
|
embed_dims=embed_dims,
|
|
|
|
|
num_heads=num_heads,
|
2022-09-20 15:45:27 +08:00
|
|
|
|
window_size=to_2tuple(self.window_size),
|
|
|
|
|
**kwargs,
|
2022-03-03 13:10:12 +08:00
|
|
|
|
)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
|
|
|
|
self.drop = build_dropout(dropout_layer)
|
2022-03-03 13:10:12 +08:00
|
|
|
|
self.pad_small_map = pad_small_map
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
2022-03-03 13:10:12 +08:00
|
|
|
|
def forward(self, query, hw_shape):
|
2021-07-01 09:30:42 +08:00
|
|
|
|
B, L, C = query.shape
|
2022-03-03 13:10:12 +08:00
|
|
|
|
H, W = hw_shape
|
|
|
|
|
assert L == H * W, f"The query length {L} doesn't match the input "\
|
|
|
|
|
f'shape ({H}, {W}).'
|
2021-07-01 09:30:42 +08:00
|
|
|
|
query = query.view(B, H, W, C)
|
|
|
|
|
|
2022-03-03 13:10:12 +08:00
|
|
|
|
window_size = self.window_size
|
|
|
|
|
shift_size = self.shift_size
|
|
|
|
|
|
|
|
|
|
if min(H, W) == window_size:
|
|
|
|
|
# If not pad small feature map, avoid shifting when the window size
|
|
|
|
|
# is equal to the size of feature map. It's to align with the
|
|
|
|
|
# behavior of the original implementation.
|
|
|
|
|
shift_size = shift_size if self.pad_small_map else 0
|
|
|
|
|
elif min(H, W) < window_size:
|
|
|
|
|
# In the original implementation, the window size will be shrunk
|
|
|
|
|
# to the size of feature map. The behavior is different with
|
|
|
|
|
# swin-transformer for downstream tasks. To support dynamic input
|
|
|
|
|
# shape, we don't allow this feature.
|
|
|
|
|
assert self.pad_small_map, \
|
|
|
|
|
f'The input shape ({H}, {W}) is smaller than the window ' \
|
|
|
|
|
f'size ({window_size}). Please set `pad_small_map=True`, or ' \
|
|
|
|
|
'decrease the `window_size`.'
|
|
|
|
|
|
|
|
|
|
pad_r = (window_size - W % window_size) % window_size
|
|
|
|
|
pad_b = (window_size - H % window_size) % window_size
|
|
|
|
|
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
|
|
|
|
|
|
|
|
|
H_pad, W_pad = query.shape[1], query.shape[2]
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
|
|
|
|
# cyclic shift
|
2022-03-03 13:10:12 +08:00
|
|
|
|
if shift_size > 0:
|
|
|
|
|
query = torch.roll(
|
|
|
|
|
query, shifts=(-shift_size, -shift_size), dims=(1, 2))
|
|
|
|
|
|
|
|
|
|
attn_mask = self.get_attn_mask((H_pad, W_pad),
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
shift_size=shift_size,
|
|
|
|
|
device=query.device)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
|
|
|
|
# nW*B, window_size, window_size, C
|
2022-03-03 13:10:12 +08:00
|
|
|
|
query_windows = self.window_partition(query, window_size)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
# nW*B, window_size*window_size, C
|
2022-03-03 13:10:12 +08:00
|
|
|
|
query_windows = query_windows.view(-1, window_size**2, C)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
|
|
|
|
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
2022-03-03 13:10:12 +08:00
|
|
|
|
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
|
|
|
|
# merge windows
|
2022-03-03 13:10:12 +08:00
|
|
|
|
attn_windows = attn_windows.view(-1, window_size, window_size, C)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
|
|
|
|
|
# B H' W' C
|
2022-03-03 13:10:12 +08:00
|
|
|
|
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad,
|
|
|
|
|
window_size)
|
2021-07-01 09:30:42 +08:00
|
|
|
|
# reverse cyclic shift
|
|
|
|
|
if self.shift_size > 0:
|
|
|
|
|
x = torch.roll(
|
2022-03-03 13:10:12 +08:00
|
|
|
|
shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
|
2021-07-01 09:30:42 +08:00
|
|
|
|
else:
|
|
|
|
|
x = shifted_x
|
|
|
|
|
|
2022-03-03 13:10:12 +08:00
|
|
|
|
if H != H_pad or W != W_pad:
|
2021-07-01 09:30:42 +08:00
|
|
|
|
x = x[:, :H, :W, :].contiguous()
|
|
|
|
|
|
|
|
|
|
x = x.view(B, H * W, C)
|
|
|
|
|
|
|
|
|
|
x = self.drop(x)
|
2022-03-03 13:10:12 +08:00
|
|
|
|
|
2021-07-01 09:30:42 +08:00
|
|
|
|
return x
|
|
|
|
|
|
2022-03-03 13:10:12 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
def window_reverse(windows, H, W, window_size):
|
2021-07-01 09:30:42 +08:00
|
|
|
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
|
|
|
x = windows.view(B, H // window_size, W // window_size, window_size,
|
|
|
|
|
window_size, -1)
|
|
|
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
|
|
|
return x
|
|
|
|
|
|
2022-03-03 13:10:12 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
def window_partition(x, window_size):
|
2021-07-01 09:30:42 +08:00
|
|
|
|
B, H, W, C = x.shape
|
|
|
|
|
x = x.view(B, H // window_size, window_size, W // window_size,
|
|
|
|
|
window_size, C)
|
|
|
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
|
|
|
windows = windows.view(-1, window_size, window_size, C)
|
|
|
|
|
return windows
|
2021-10-18 16:07:00 +08:00
|
|
|
|
|
2022-03-03 13:10:12 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
def get_attn_mask(hw_shape, window_size, shift_size, device=None):
|
|
|
|
|
if shift_size > 0:
|
|
|
|
|
img_mask = torch.zeros(1, *hw_shape, 1, device=device)
|
|
|
|
|
h_slices = (slice(0, -window_size), slice(-window_size,
|
|
|
|
|
-shift_size),
|
|
|
|
|
slice(-shift_size, None))
|
|
|
|
|
w_slices = (slice(0, -window_size), slice(-window_size,
|
|
|
|
|
-shift_size),
|
|
|
|
|
slice(-shift_size, None))
|
|
|
|
|
cnt = 0
|
|
|
|
|
for h in h_slices:
|
|
|
|
|
for w in w_slices:
|
|
|
|
|
img_mask[:, h, w, :] = cnt
|
|
|
|
|
cnt += 1
|
|
|
|
|
|
|
|
|
|
# nW, window_size, window_size, 1
|
|
|
|
|
mask_windows = ShiftWindowMSA.window_partition(
|
|
|
|
|
img_mask, window_size)
|
|
|
|
|
mask_windows = mask_windows.view(-1, window_size * window_size)
|
|
|
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
|
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0)
|
|
|
|
|
attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0)
|
|
|
|
|
else:
|
|
|
|
|
attn_mask = None
|
|
|
|
|
return attn_mask
|
|
|
|
|
|
2021-10-18 16:07:00 +08:00
|
|
|
|
|
|
|
|
|
class MultiheadAttention(BaseModule):
|
|
|
|
|
"""Multi-head Attention Module.
|
|
|
|
|
|
|
|
|
|
This module implements multi-head attention that supports different input
|
|
|
|
|
dims and embed dims. And it also supports a shortcut from ``value``, which
|
|
|
|
|
is useful if input dims is not the same with embed dims.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): The embedding dimension.
|
|
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
|
|
input_dims (int, optional): The input dimension, and if None,
|
|
|
|
|
use ``embed_dims``. Defaults to None.
|
|
|
|
|
attn_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
attention calculation of query and key. Defaults to 0.
|
|
|
|
|
proj_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
output projection. Defaults to 0.
|
|
|
|
|
dropout_layer (dict): The dropout config before adding the shortcut.
|
|
|
|
|
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
qk_scale (float, optional): Override default qk scale of
|
|
|
|
|
``head_dim ** -0.5`` if set. Defaults to None.
|
|
|
|
|
proj_bias (bool) If True, add a learnable bias to output projection.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
|
|
|
|
used if ``input_dims`` is different from ``embed_dims``.
|
|
|
|
|
Defaults to False.
|
2023-05-05 16:59:37 +08:00
|
|
|
|
use_layer_scale (bool): Whether to use layer scale. Defaults to False.
|
|
|
|
|
layer_scale_init_value (float or torch.Tensor): Init value of layer
|
|
|
|
|
scale. Defaults to 0.
|
2021-10-18 16:07:00 +08:00
|
|
|
|
init_cfg (dict, optional): The Config for initialization.
|
|
|
|
|
Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
num_heads,
|
|
|
|
|
input_dims=None,
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
proj_drop=0.,
|
|
|
|
|
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
qk_scale=None,
|
|
|
|
|
proj_bias=True,
|
|
|
|
|
v_shortcut=False,
|
2022-10-10 14:54:20 +08:00
|
|
|
|
use_layer_scale=False,
|
2023-05-05 16:59:37 +08:00
|
|
|
|
layer_scale_init_value=0.,
|
2021-10-18 16:07:00 +08:00
|
|
|
|
init_cfg=None):
|
|
|
|
|
super(MultiheadAttention, self).__init__(init_cfg=init_cfg)
|
|
|
|
|
|
|
|
|
|
self.input_dims = input_dims or embed_dims
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.v_shortcut = v_shortcut
|
|
|
|
|
|
|
|
|
|
self.head_dims = embed_dims // num_heads
|
2023-03-29 15:50:44 +08:00
|
|
|
|
if qk_scale is not None:
|
|
|
|
|
self.scaled_dot_product_attention = partial(
|
|
|
|
|
scaled_dot_product_attention_pyimpl,
|
|
|
|
|
scale=self.head_dims**-0.5)
|
|
|
|
|
else:
|
|
|
|
|
self.scaled_dot_product_attention = scaled_dot_product_attention
|
2021-10-18 16:07:00 +08:00
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
|
2023-03-29 15:50:44 +08:00
|
|
|
|
self.attn_drop = attn_drop
|
2021-10-18 16:07:00 +08:00
|
|
|
|
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
2022-07-12 16:10:59 +08:00
|
|
|
|
self.out_drop = build_dropout(dropout_layer)
|
2021-10-18 16:07:00 +08:00
|
|
|
|
|
2022-10-10 14:54:20 +08:00
|
|
|
|
if use_layer_scale:
|
2023-05-05 16:59:37 +08:00
|
|
|
|
warnings.warn('The `use_layer_scale` in `MultiheadAttention` will '
|
|
|
|
|
'be deprecated. Please use `layer_scale_init_value` '
|
|
|
|
|
'to control whether using layer scale or not.')
|
|
|
|
|
|
|
|
|
|
if use_layer_scale or (layer_scale_init_value > 0):
|
|
|
|
|
layer_scale_init_value = layer_scale_init_value or 1e-5
|
|
|
|
|
self.gamma1 = LayerScale(
|
|
|
|
|
embed_dims, layer_scale_init_value=layer_scale_init_value)
|
2022-10-10 14:54:20 +08:00
|
|
|
|
else:
|
|
|
|
|
self.gamma1 = nn.Identity()
|
|
|
|
|
|
2021-10-18 16:07:00 +08:00
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, N, _ = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
|
|
|
|
self.head_dims).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
|
2023-03-29 15:50:44 +08:00
|
|
|
|
attn_drop = self.attn_drop if self.training else 0.
|
|
|
|
|
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
2021-10-18 16:07:00 +08:00
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
2022-10-10 14:54:20 +08:00
|
|
|
|
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
2021-10-18 16:07:00 +08:00
|
|
|
|
|
|
|
|
|
if self.v_shortcut:
|
|
|
|
|
x = v.squeeze(1) + x
|
|
|
|
|
return x
|
2022-08-17 00:07:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BEiTAttention(BaseModule):
|
|
|
|
|
"""Window based multi-head self-attention (W-MSA) module with relative
|
|
|
|
|
position bias.
|
|
|
|
|
|
|
|
|
|
The initial implementation is in MMSegmentation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): Number of input channels.
|
|
|
|
|
num_heads (int): Number of attention heads.
|
2023-02-28 15:59:17 +08:00
|
|
|
|
window_size (tuple[int, int]): The height and width of the window.
|
2022-11-29 12:56:33 +08:00
|
|
|
|
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
|
|
|
|
if False, use shared relative position bias defined in backbone.
|
2022-08-17 00:07:06 +08:00
|
|
|
|
bias (str): The option to add leanable bias for q, k, v. If bias is
|
|
|
|
|
True, it will add leanable bias. If bias is 'qv_bias', it will only
|
|
|
|
|
add leanable bias for q, v. If bias is False, it will not add bias
|
|
|
|
|
for q, k, v. Default to 'qv_bias'.
|
|
|
|
|
qk_scale (float | None, optional): Override default qk scale of
|
|
|
|
|
head_dim ** -0.5 if set. Default: None.
|
|
|
|
|
attn_drop_rate (float): Dropout ratio of attention weight.
|
|
|
|
|
Default: 0.0
|
|
|
|
|
proj_drop_rate (float): Dropout ratio of output. Default: 0.
|
|
|
|
|
init_cfg (dict | None, optional): The Config for initialization.
|
|
|
|
|
Default: None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
num_heads,
|
|
|
|
|
window_size,
|
2022-11-29 12:56:33 +08:00
|
|
|
|
use_rel_pos_bias,
|
2022-08-17 00:07:06 +08:00
|
|
|
|
bias='qv_bias',
|
|
|
|
|
qk_scale=None,
|
|
|
|
|
attn_drop_rate=0.,
|
|
|
|
|
proj_drop_rate=0.,
|
|
|
|
|
init_cfg=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
head_embed_dims = embed_dims // num_heads
|
|
|
|
|
self.bias = bias
|
|
|
|
|
self.scale = qk_scale or head_embed_dims**-0.5
|
|
|
|
|
|
|
|
|
|
qkv_bias = bias
|
|
|
|
|
if bias == 'qv_bias':
|
|
|
|
|
self._init_qv_bias()
|
|
|
|
|
qkv_bias = False
|
|
|
|
|
|
2023-02-28 15:59:17 +08:00
|
|
|
|
if window_size is None:
|
|
|
|
|
assert not use_rel_pos_bias
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(window_size, tuple)
|
2022-08-17 00:07:06 +08:00
|
|
|
|
self.window_size = window_size
|
2022-11-29 12:56:33 +08:00
|
|
|
|
self.use_rel_pos_bias = use_rel_pos_bias
|
2022-08-17 00:07:06 +08:00
|
|
|
|
self._init_rel_pos_embedding()
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop_rate)
|
|
|
|
|
self.proj = nn.Linear(embed_dims, embed_dims)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop_rate)
|
|
|
|
|
|
|
|
|
|
def _init_qv_bias(self):
|
|
|
|
|
self.q_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
|
|
|
|
self.v_bias = nn.Parameter(torch.zeros(self.embed_dims))
|
|
|
|
|
|
|
|
|
|
def _init_rel_pos_embedding(self):
|
2022-11-29 12:56:33 +08:00
|
|
|
|
if self.use_rel_pos_bias:
|
|
|
|
|
Wh, Ww = self.window_size
|
|
|
|
|
# cls to token & token 2 cls & cls to cls
|
|
|
|
|
self.num_relative_distance = (2 * Wh - 1) * (2 * Ww - 1) + 3
|
|
|
|
|
# relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
|
|
|
|
|
self.relative_position_bias_table = nn.Parameter(
|
|
|
|
|
torch.zeros(self.num_relative_distance, self.num_heads))
|
|
|
|
|
|
|
|
|
|
# get pair-wise relative position index for
|
|
|
|
|
# each token inside the window
|
|
|
|
|
coords_h = torch.arange(Wh)
|
|
|
|
|
coords_w = torch.arange(Ww)
|
|
|
|
|
# coords shape is (2, Wh, Ww)
|
|
|
|
|
coords = torch.stack(torch_meshgrid([coords_h, coords_w]))
|
|
|
|
|
# coords_flatten shape is (2, Wh*Ww)
|
|
|
|
|
coords_flatten = torch.flatten(coords, 1)
|
|
|
|
|
relative_coords = (
|
|
|
|
|
coords_flatten[:, :, None] - coords_flatten[:, None, :])
|
|
|
|
|
# relative_coords shape is (Wh*Ww, Wh*Ww, 2)
|
|
|
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
|
|
|
# shift to start from 0
|
|
|
|
|
relative_coords[:, :, 0] += Wh - 1
|
|
|
|
|
relative_coords[:, :, 1] += Ww - 1
|
|
|
|
|
relative_coords[:, :, 0] *= 2 * Ww - 1
|
|
|
|
|
relative_position_index = torch.zeros(
|
|
|
|
|
size=(Wh * Ww + 1, ) * 2, dtype=relative_coords.dtype)
|
|
|
|
|
# relative_position_index shape is (Wh*Ww, Wh*Ww)
|
|
|
|
|
relative_position_index[1:, 1:] = relative_coords.sum(-1)
|
|
|
|
|
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
|
|
|
|
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
|
|
|
|
relative_position_index[0, 0] = self.num_relative_distance - 1
|
|
|
|
|
|
|
|
|
|
self.register_buffer('relative_position_index',
|
|
|
|
|
relative_position_index)
|
|
|
|
|
else:
|
|
|
|
|
self.window_size = None
|
|
|
|
|
self.relative_position_bias_table = None
|
|
|
|
|
self.relative_position_index = None
|
2022-08-17 00:07:06 +08:00
|
|
|
|
|
|
|
|
|
def init_weights(self):
|
|
|
|
|
super().init_weights()
|
2022-11-29 12:56:33 +08:00
|
|
|
|
if self.use_rel_pos_bias:
|
|
|
|
|
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
2022-08-17 00:07:06 +08:00
|
|
|
|
|
2022-11-29 12:56:33 +08:00
|
|
|
|
def forward(self, x, rel_pos_bias=None):
|
2022-08-17 00:07:06 +08:00
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x (tensor): input features with shape of (num_windows*B, N, C).
|
2022-11-29 12:56:33 +08:00
|
|
|
|
rel_pos_bias (tensor): input relative position bias with shape of
|
|
|
|
|
(num_heads, N, N).
|
2022-08-17 00:07:06 +08:00
|
|
|
|
"""
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
if self.bias == 'qv_bias':
|
|
|
|
|
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
|
|
|
|
qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
|
|
|
|
|
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
|
|
|
|
else:
|
|
|
|
|
qkv = self.qkv(x)
|
|
|
|
|
|
|
|
|
|
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
attn = (q @ k.transpose(-2, -1))
|
2023-02-28 10:05:00 +08:00
|
|
|
|
|
2022-08-17 00:07:06 +08:00
|
|
|
|
if self.relative_position_bias_table is not None:
|
|
|
|
|
Wh = self.window_size[0]
|
|
|
|
|
Ww = self.window_size[1]
|
|
|
|
|
relative_position_bias = self.relative_position_bias_table[
|
|
|
|
|
self.relative_position_index.view(-1)].view(
|
|
|
|
|
Wh * Ww + 1, Wh * Ww + 1, -1)
|
|
|
|
|
relative_position_bias = relative_position_bias.permute(
|
|
|
|
|
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
|
|
|
attn = attn + relative_position_bias.unsqueeze(0)
|
2022-11-29 12:56:33 +08:00
|
|
|
|
|
|
|
|
|
if rel_pos_bias is not None:
|
|
|
|
|
# use shared relative position bias
|
|
|
|
|
attn = attn + rel_pos_bias
|
|
|
|
|
|
2022-08-17 00:07:06 +08:00
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
|
return x
|
2022-09-21 13:27:04 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelMultiheadAttention(BaseModule):
|
|
|
|
|
"""Channel Multihead Self-attention Module.
|
|
|
|
|
|
|
|
|
|
This module implements channel multi-head attention that supports different
|
|
|
|
|
input dims and embed dims.
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): The embedding dimension.
|
|
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
|
|
input_dims (int, optional): The input dimension, and if None,
|
|
|
|
|
use ``embed_dims``. Defaults to None.
|
|
|
|
|
attn_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
attention calculation of query and key. Defaults to 0.
|
|
|
|
|
proj_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
output projection. Defaults to 0.
|
|
|
|
|
dropout_layer (dict): The dropout config before adding the shoutcut.
|
|
|
|
|
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
|
|
|
|
Defaults to False.
|
|
|
|
|
proj_bias (bool) If True, add a learnable bias to output projection.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
qk_scale_type (str): The scale type of qk scale.
|
|
|
|
|
Defaults to 'learnable'. It can be 'learnable', 'fixed' or 'none'.
|
|
|
|
|
qk_scale (float, optional): If set qk_scale_type to 'none', this
|
|
|
|
|
should be specified with valid float number. Defaults to None.
|
|
|
|
|
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
|
|
|
|
used if ``input_dims`` is different from ``embed_dims``.
|
|
|
|
|
Defaults to False.
|
|
|
|
|
init_cfg (dict, optional): The Config for initialization.
|
|
|
|
|
Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
num_heads=8,
|
|
|
|
|
input_dims=None,
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
proj_drop=0.,
|
|
|
|
|
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
proj_bias=True,
|
|
|
|
|
qk_scale_type='learnable',
|
|
|
|
|
qk_scale=None,
|
|
|
|
|
v_shortcut=False,
|
|
|
|
|
init_cfg=None):
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
|
|
|
|
|
self.input_dims = input_dims or embed_dims
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.v_shortcut = v_shortcut
|
|
|
|
|
|
|
|
|
|
self.head_dims = embed_dims // num_heads
|
|
|
|
|
if qk_scale_type == 'learnable':
|
|
|
|
|
self.scale = nn.Parameter(torch.ones(num_heads, 1, 1))
|
|
|
|
|
elif qk_scale_type == 'fixed':
|
|
|
|
|
self.scale = self.head_dims**-0.5
|
|
|
|
|
elif qk_scale_type == 'none':
|
|
|
|
|
assert qk_scale is not None
|
|
|
|
|
self.scale = qk_scale
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
|
|
self.out_drop = build_dropout(dropout_layer)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, N, _ = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
|
|
|
|
self.head_dims).permute(2, 0, 3, 1, 4)
|
|
|
|
|
|
|
|
|
|
q, k, v = [item.transpose(-2, -1) for item in [qkv[0], qkv[1], qkv[2]]]
|
|
|
|
|
|
|
|
|
|
q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1)
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, self.embed_dims)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.out_drop(self.proj_drop(x))
|
|
|
|
|
|
|
|
|
|
if self.v_shortcut:
|
|
|
|
|
x = qkv[2].squeeze(1) + x
|
|
|
|
|
return x
|
2022-12-20 13:04:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LeAttention(BaseModule):
|
|
|
|
|
"""LeViT Attention. Multi-head attention with attention bias, which is
|
|
|
|
|
proposed in `LeViT: a Vision Transformer in ConvNet’s Clothing for Faster
|
|
|
|
|
Inference<https://arxiv.org/abs/2104.01136>`_
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dim (int): Number of input channels.
|
|
|
|
|
num_heads (int): Number of attention heads. Default: 8.
|
|
|
|
|
key_dim (int): Dimension of key. Default: None.
|
|
|
|
|
attn_ratio (int): Ratio of attention heads. Default: 8.
|
|
|
|
|
resolution (tuple[int]): Input resolution. Default: (16, 16).
|
|
|
|
|
init_cfg (dict, optional): The Config for initialization.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
dim,
|
|
|
|
|
key_dim,
|
|
|
|
|
num_heads=8,
|
|
|
|
|
attn_ratio=4,
|
|
|
|
|
resolution=(14, 14),
|
|
|
|
|
init_cfg=None):
|
|
|
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
|
# (h, w)
|
|
|
|
|
assert isinstance(resolution, tuple) and len(resolution) == 2
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.scale = key_dim**-0.5
|
|
|
|
|
self.key_dim = key_dim
|
|
|
|
|
self.nh_kd = nh_kd = key_dim * num_heads
|
|
|
|
|
self.d = int(attn_ratio * key_dim)
|
|
|
|
|
self.dh = int(attn_ratio * key_dim) * num_heads
|
|
|
|
|
self.attn_ratio = attn_ratio
|
|
|
|
|
h = self.dh + nh_kd * 2
|
|
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
|
self.qkv = nn.Linear(dim, h)
|
|
|
|
|
self.proj = nn.Linear(self.dh, dim)
|
|
|
|
|
|
|
|
|
|
points = list(
|
|
|
|
|
itertools.product(range(resolution[0]), range(resolution[1])))
|
|
|
|
|
N = len(points)
|
|
|
|
|
attention_offsets = {}
|
|
|
|
|
idxs = []
|
|
|
|
|
for p1 in points:
|
|
|
|
|
for p2 in points:
|
|
|
|
|
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
|
|
|
|
if offset not in attention_offsets:
|
|
|
|
|
attention_offsets[offset] = len(attention_offsets)
|
|
|
|
|
idxs.append(attention_offsets[offset])
|
|
|
|
|
self.attention_biases = torch.nn.Parameter(
|
|
|
|
|
torch.zeros(num_heads, len(attention_offsets)))
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
'attention_bias_idxs',
|
|
|
|
|
torch.LongTensor(idxs).view(N, N),
|
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def train(self, mode=True):
|
|
|
|
|
super().train(mode)
|
|
|
|
|
if mode and hasattr(self, 'ab'):
|
|
|
|
|
del self.ab
|
|
|
|
|
else:
|
|
|
|
|
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
|
|
|
|
|
def forward(self, x): # x (B,N,C)
|
|
|
|
|
B, N, _ = x.shape
|
|
|
|
|
|
|
|
|
|
# Normalization
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(x)
|
|
|
|
|
# (B, N, num_heads, d)
|
|
|
|
|
q, k, v = qkv.view(B, N, self.num_heads,
|
|
|
|
|
-1).split([self.key_dim, self.key_dim, self.d],
|
|
|
|
|
dim=3)
|
|
|
|
|
# (B, num_heads, N, d)
|
|
|
|
|
q = q.permute(0, 2, 1, 3)
|
|
|
|
|
k = k.permute(0, 2, 1, 3)
|
|
|
|
|
v = v.permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
|
|
attn = ((q @ k.transpose(-2, -1)) * self.scale +
|
|
|
|
|
(self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
if self.training else self.ab))
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
return x
|
2023-02-28 10:05:00 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossMultiheadAttention(BaseModule):
|
|
|
|
|
"""Cross attention between queries and the union of keys and values.
|
|
|
|
|
|
|
|
|
|
This module is different from ``MultiheadAttention``, for the attention
|
|
|
|
|
is computed between queries and the union of keys and values.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): The embedding dimension.
|
|
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
qk_scale (float, optional): Override default qk scale of
|
|
|
|
|
``head_dim ** -0.5`` if set. Defaults to None.
|
|
|
|
|
attn_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
attention calculation of query and key. Defaults to 0.
|
|
|
|
|
proj_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
output projection. Defaults to 0.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims: int,
|
|
|
|
|
num_heads: int = 8,
|
|
|
|
|
qkv_bias: bool = False,
|
|
|
|
|
qk_scale: float = None,
|
|
|
|
|
attn_drop: float = 0.,
|
|
|
|
|
proj_drop: float = 0.) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
head_dim = embed_dims // num_heads
|
|
|
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
|
|
|
|
|
|
self.q = nn.Linear(embed_dims, embed_dims, bias=False)
|
|
|
|
|
self.k = nn.Linear(embed_dims, embed_dims, bias=False)
|
|
|
|
|
self.v = nn.Linear(embed_dims, embed_dims, bias=False)
|
|
|
|
|
|
|
|
|
|
if qkv_bias:
|
|
|
|
|
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
|
|
|
|
|
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
|
|
|
|
|
else:
|
|
|
|
|
self.q_bias = None
|
|
|
|
|
self.k_bias = None
|
|
|
|
|
self.v_bias = None
|
|
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(embed_dims, embed_dims)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
k: torch.Tensor = None,
|
|
|
|
|
v: torch.Tensor = None) -> None:
|
|
|
|
|
"""Forward function."""
|
|
|
|
|
B, N, _ = x.shape
|
|
|
|
|
|
|
|
|
|
N_k = k.shape[1]
|
|
|
|
|
N_v = v.shape[1]
|
|
|
|
|
|
|
|
|
|
q_bias, k_bias, v_bias = None, None, None
|
|
|
|
|
if self.q_bias is not None:
|
|
|
|
|
q_bias = self.q_bias
|
|
|
|
|
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
|
|
|
|
v_bias = self.v_bias
|
|
|
|
|
|
|
|
|
|
q = F.linear(
|
|
|
|
|
input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim)
|
|
|
|
|
k = F.linear(
|
|
|
|
|
input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim)
|
|
|
|
|
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
|
|
|
|
|
|
|
|
|
q = q.reshape(B, N, 1, self.num_heads,
|
|
|
|
|
-1).permute(2, 0, 3, 1,
|
|
|
|
|
4).squeeze(0) # (B, num_heads, N_q, dim)
|
|
|
|
|
k = k.reshape(B, N_k, 1, self.num_heads,
|
|
|
|
|
-1).permute(2, 0, 3, 1,
|
|
|
|
|
4).squeeze(0) # (B, num_heads, N_k, dim)
|
|
|
|
|
v = v.reshape(B, N_v, 1, self.num_heads,
|
|
|
|
|
-1).permute(2, 0, 3, 1,
|
|
|
|
|
4).squeeze(0) # (B, num_heads, N_v, dim)
|
|
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
|
|
|
|
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptMultiheadAttention(MultiheadAttention):
|
|
|
|
|
"""Prompt Multihead Attention for MILAN.
|
|
|
|
|
|
|
|
|
|
This module is specific for the prompt encoder in MILAN. It will not update
|
|
|
|
|
the visible tokens from the encoder.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): The embedding dimension.
|
|
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
|
|
input_dims (int, optional): The input dimension, and if None,
|
|
|
|
|
use ``embed_dims``. Defaults to None.
|
|
|
|
|
attn_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
attention calculation of query and key. Defaults to 0.
|
|
|
|
|
proj_drop (float): Dropout rate of the dropout layer after the
|
|
|
|
|
output projection. Defaults to 0.
|
|
|
|
|
dropout_layer (dict): The dropout config before adding the shortcut.
|
|
|
|
|
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
qk_scale (float, optional): Override default qk scale of
|
|
|
|
|
``head_dim ** -0.5`` if set. Defaults to None.
|
|
|
|
|
proj_bias (bool) If True, add a learnable bias to output projection.
|
|
|
|
|
Defaults to True.
|
|
|
|
|
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
|
|
|
|
used if ``input_dims`` is different from ``embed_dims``.
|
|
|
|
|
Defaults to False.
|
|
|
|
|
return_attention (bool): If True, return the attention map, computed by
|
|
|
|
|
the cross attention between the class token and all other tokens.
|
|
|
|
|
Defaults to False.
|
|
|
|
|
init_cfg (Union[List[dict], dict], optional): The Config for
|
|
|
|
|
initialization. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims: int,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
input_dims: Optional[int] = None,
|
|
|
|
|
attn_drop: float = 0,
|
|
|
|
|
proj_drop: float = 0,
|
|
|
|
|
dropout_layer: dict = dict(type='Dropout', drop_prob=0.),
|
|
|
|
|
qkv_bias: bool = True,
|
|
|
|
|
qk_scale: Optional[float] = None,
|
|
|
|
|
proj_bias: bool = True,
|
|
|
|
|
v_shortcut: bool = False,
|
|
|
|
|
use_layer_scale: bool = False,
|
|
|
|
|
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
2023-05-05 16:59:37 +08:00
|
|
|
|
super().__init__(
|
|
|
|
|
embed_dims=embed_dims,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
input_dims=input_dims,
|
|
|
|
|
attn_drop=attn_drop,
|
|
|
|
|
proj_drop=proj_drop,
|
|
|
|
|
dropout_layer=dropout_layer,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
qk_scale=qk_scale,
|
|
|
|
|
proj_bias=proj_bias,
|
|
|
|
|
v_shortcut=v_shortcut,
|
|
|
|
|
use_layer_scale=use_layer_scale,
|
|
|
|
|
init_cfg=init_cfg)
|
2023-02-28 10:05:00 +08:00
|
|
|
|
# no longer need qkv
|
|
|
|
|
del self.qkv
|
|
|
|
|
|
|
|
|
|
# to project the mask tokens
|
|
|
|
|
self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
|
|
|
|
|
# to project al the tokens
|
|
|
|
|
self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
|
|
|
|
|
ids_restore: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Forward function for `PromptMultiheadAttention`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (torch.Tensor): Mask token features with shape N x L_m x C.
|
|
|
|
|
visible_tokens (torch.Tensor): The visible tokens features from
|
|
|
|
|
encoder with shape N x L_v x C.
|
|
|
|
|
ids_restore (torch.Tensor): The ids of all tokens in the original
|
|
|
|
|
image with shape N x L.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch Tensor: Output features with shape N x L x C.
|
|
|
|
|
"""
|
|
|
|
|
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
|
|
|
|
|
assert x_.shape[1] == ids_restore.shape[1]
|
|
|
|
|
x_ = torch.gather(
|
|
|
|
|
x_,
|
|
|
|
|
dim=1,
|
|
|
|
|
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
|
|
|
|
x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
|
|
|
|
|
|
|
|
|
|
# full sequence shape
|
|
|
|
|
B, _, _ = x_.shape
|
|
|
|
|
q = self.q(x).reshape(B, x.shape[1], self.num_heads,
|
|
|
|
|
self.head_dims).permute(0, 2, 1, 3)
|
|
|
|
|
kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads,
|
|
|
|
|
self.head_dims).permute(2, 0, 3, 1, 4)
|
|
|
|
|
k, v = kv[0], kv[1]
|
|
|
|
|
|
2023-03-29 15:50:44 +08:00
|
|
|
|
attn_drop = self.attn_drop if self.training else 0.
|
|
|
|
|
attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
|
|
|
|
x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
|
2023-02-28 10:05:00 +08:00
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
|
|
|
|
return x
|