mmpretrain/mmcls/models/utils/attention.py

371 lines
14 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.registry import DROPOUT_LAYERS
from mmcv.cnn.bricks.transformer import build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule
from ..builder import ATTENTION
from .helpers import to_2tuple
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
super(WindowMSA, self).init_weights()
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
Wh*Ww), value should be between (-inf, 0].
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
input_resolution (Tuple[int, int]): The resolution of the input feature
map.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults to dict(type='DropPath', drop_prob=0.).
auto_pad (bool, optional): Auto pad the feature map to be divisible by
window_size, Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
input_resolution,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
auto_pad=False,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.input_resolution = input_resolution
self.shift_size = shift_size
self.window_size = window_size
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, don't partition
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.w_msa = WindowMSA(embed_dims, to_2tuple(self.window_size),
num_heads, qkv_bias, qk_scale, attn_drop,
proj_drop)
self.drop = build_dropout(dropout_layer)
H, W = self.input_resolution
# Handle auto padding
self.auto_pad = auto_pad
if self.auto_pad:
self.pad_r = (self.window_size -
W % self.window_size) % self.window_size
self.pad_b = (self.window_size -
H % self.window_size) % self.window_size
self.H_pad = H + self.pad_b
self.W_pad = W + self.pad_r
else:
H_pad, W_pad = self.input_resolution
assert H_pad % self.window_size + W_pad % self.window_size == 0,\
f'input_resolution({self.input_resolution}) is not divisible '\
f'by window_size({self.window_size}). Please check feature '\
f'map shape or set `auto_pad=True`.'
self.H_pad, self.W_pad = H_pad, W_pad
self.pad_r, self.pad_b = 0, 0
if self.shift_size > 0:
# calculate attention mask for SW-MSA
img_mask = torch.zeros((1, self.H_pad, self.W_pad, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = self.window_partition(img_mask)
mask_windows = mask_windows.view(
-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
def forward(self, query):
H, W = self.input_resolution
B, L, C = query.shape
assert L == H * W, 'input feature has wrong size'
query = query.view(B, H, W, C)
if self.pad_r or self.pad_b:
query = F.pad(query, (0, 0, 0, self.pad_r, 0, self.pad_b))
# cyclic shift
if self.shift_size > 0:
shifted_query = torch.roll(
query,
shifts=(-self.shift_size, -self.shift_size),
dims=(1, 2))
else:
shifted_query = query
# nW*B, window_size, window_size, C
query_windows = self.window_partition(shifted_query)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, self.window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=self.attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, self.H_pad, self.W_pad)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
if self.pad_r or self.pad_b:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
def window_reverse(self, windows, H, W):
window_size = self.window_size
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
def window_partition(self, x):
B, H, W, C = x.shape
window_size = self.window_size
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
class MultiheadAttention(BaseModule):
"""Multi-head Attention Module.
This module implements multi-head attention that supports different input
dims and embed dims. And it also supports a shortcut from ``value``, which
is useful if input dims is not the same with embed dims.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
input_dims=None,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
qkv_bias=True,
qk_scale=None,
proj_bias=True,
v_shortcut=False,
init_cfg=None):
super(MultiheadAttention, self).__init__(init_cfg=init_cfg)
self.input_dims = input_dims or embed_dims
self.embed_dims = embed_dims
self.num_heads = num_heads
self.v_shortcut = v_shortcut
self.head_dims = embed_dims // num_heads
self.scale = qk_scale or self.head_dims**-0.5
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.out_drop = DROPOUT_LAYERS.build(dropout_layer)
def forward(self, x):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
x = self.out_drop(self.proj_drop(x))
if self.v_shortcut:
x = v.squeeze(1) + x
return x