2022-07-11 14:52:49 +08:00
|
|
|
|
import copy
|
2022-10-24 17:20:12 +08:00
|
|
|
|
import warnings
|
2022-07-11 14:52:49 +08:00
|
|
|
|
from typing import Optional
|
|
|
|
|
|
2022-08-31 15:18:11 +08:00
|
|
|
|
import torch
|
2022-07-11 14:52:49 +08:00
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
2022-10-24 17:20:12 +08:00
|
|
|
|
from mmcv.cnn.bricks import Linear
|
|
|
|
|
from mmcv.cnn.bricks.drop import build_dropout
|
|
|
|
|
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
|
|
|
|
from mmcv.utils import ConfigDict
|
2022-07-11 14:52:49 +08:00
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
2022-09-19 16:07:04 +08:00
|
|
|
|
from easycv.framework.errors import RuntimeError
|
2022-10-24 17:20:12 +08:00
|
|
|
|
from easycv.models.builder import (build_attention, build_feedforward_network,
|
|
|
|
|
build_transformer_layer)
|
|
|
|
|
from easycv.models.registry import (ATTENTION, FEEDFORWARD_NETWORK,
|
|
|
|
|
TRANSFORMER_LAYER,
|
|
|
|
|
TRANSFORMER_LAYER_SEQUENCE)
|
|
|
|
|
from easycv.models.utils.activation import build_activation_layer
|
|
|
|
|
from easycv.models.utils.norm import build_norm_layer
|
2022-09-19 16:07:04 +08:00
|
|
|
|
|
2022-07-11 14:52:49 +08:00
|
|
|
|
|
|
|
|
|
class MLP(nn.Module):
|
|
|
|
|
""" Very simple multi-layer perceptron (also called FFN)"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
h = [hidden_dim] * (num_layers - 1)
|
|
|
|
|
self.layers = nn.ModuleList(
|
|
|
|
|
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
|
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2022-08-31 15:18:11 +08:00
|
|
|
|
class Mlp(nn.Module):
|
|
|
|
|
""" Multilayer perceptron.
|
|
|
|
|
Parameters:
|
|
|
|
|
act_layer: Specify the activate function, default use nn.GELU.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_features,
|
|
|
|
|
hidden_features=None,
|
|
|
|
|
out_features=None,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
drop=0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
out_features = out_features or in_features
|
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
|
self.act = act_layer()
|
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2022-11-17 14:30:12 +08:00
|
|
|
|
class ConvMlp(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_features,
|
|
|
|
|
hidden_features=None,
|
|
|
|
|
out_features=None,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
drop=0.):
|
|
|
|
|
super().__init__()
|
|
|
|
|
out_features = out_features or in_features
|
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
|
|
|
|
self.act = act_layer()
|
|
|
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|
x = self.drop(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
2022-08-31 15:18:11 +08:00
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
|
|
|
|
if drop_prob == 0. or not training:
|
|
|
|
|
return x
|
|
|
|
|
keep_prob = 1 - drop_prob
|
|
|
|
|
shape = (x.shape[0], ) + (1, ) * (
|
|
|
|
|
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
|
|
|
random_tensor = keep_prob + torch.rand(
|
|
|
|
|
shape, dtype=x.dtype, device=x.device)
|
|
|
|
|
random_tensor.floor_() # binarize
|
|
|
|
|
output = x.div(keep_prob) * random_tensor
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DropPath(nn.Module):
|
|
|
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, drop_prob=None):
|
|
|
|
|
super(DropPath, self).__init__()
|
|
|
|
|
self.drop_prob = drop_prob
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
|
|
|
|
def extra_repr(self):
|
|
|
|
|
return 'p={}'.format(self.drop_prob)
|
|
|
|
|
|
|
|
|
|
|
2022-07-11 14:52:49 +08:00
|
|
|
|
class TransformerEncoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
encoder_layer,
|
|
|
|
|
num_layers,
|
|
|
|
|
norm=None,
|
|
|
|
|
d_model=256,
|
|
|
|
|
query_scale_type=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.layers = _get_clones(encoder_layer, num_layers)
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
self.query_scale_type = query_scale_type
|
|
|
|
|
if query_scale_type == 'cond_elewise':
|
|
|
|
|
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
|
|
|
|
self.norm = norm
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
src,
|
|
|
|
|
mask: Optional[Tensor] = None,
|
|
|
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
|
|
|
pos: Optional[Tensor] = None):
|
|
|
|
|
output = src
|
|
|
|
|
|
|
|
|
|
for layer_id, layer in enumerate(self.layers):
|
|
|
|
|
# rescale the content and pos sim
|
|
|
|
|
if self.query_scale_type == 'cond_elewise':
|
|
|
|
|
pos_scales = self.query_scale(output)
|
|
|
|
|
else:
|
|
|
|
|
pos_scales = 1
|
|
|
|
|
output = layer(
|
|
|
|
|
output,
|
|
|
|
|
src_mask=mask,
|
|
|
|
|
src_key_padding_mask=src_key_padding_mask,
|
|
|
|
|
pos=pos * pos_scales)
|
|
|
|
|
|
|
|
|
|
if self.norm is not None:
|
|
|
|
|
output = self.norm(output)
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
d_model,
|
|
|
|
|
nhead,
|
|
|
|
|
dim_feedforward=2048,
|
|
|
|
|
dropout=0.1,
|
|
|
|
|
activation='relu',
|
|
|
|
|
normalize_before=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
|
|
|
# Implementation of Feedforward model
|
|
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
|
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
|
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
|
|
self.normalize_before = normalize_before
|
|
|
|
|
|
|
|
|
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
|
|
|
|
return tensor if pos is None else tensor + pos
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
src,
|
|
|
|
|
src_mask: Optional[Tensor] = None,
|
|
|
|
|
src_key_padding_mask: Optional[Tensor] = None,
|
|
|
|
|
pos: Optional[Tensor] = None):
|
|
|
|
|
q = k = self.with_pos_embed(src, pos)
|
|
|
|
|
src2 = self.self_attn(
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
value=src,
|
|
|
|
|
attn_mask=src_mask,
|
|
|
|
|
key_padding_mask=src_key_padding_mask)[0]
|
|
|
|
|
src = src + self.dropout1(src2)
|
|
|
|
|
src = self.norm1(src)
|
|
|
|
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
|
|
|
src = src + self.dropout2(src2)
|
|
|
|
|
src = self.norm2(src)
|
|
|
|
|
return src
|
|
|
|
|
|
|
|
|
|
|
2022-08-31 15:18:11 +08:00
|
|
|
|
def _get_clones(module, N, layer_share=False):
|
|
|
|
|
if layer_share:
|
|
|
|
|
return nn.ModuleList([module for i in range(N)])
|
|
|
|
|
else:
|
|
|
|
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
2022-07-11 14:52:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_activation_fn(activation):
|
|
|
|
|
"""Return an activation function given a string"""
|
|
|
|
|
if activation == 'relu':
|
|
|
|
|
return F.relu
|
|
|
|
|
if activation == 'gelu':
|
|
|
|
|
return F.gelu
|
|
|
|
|
if activation == 'glu':
|
|
|
|
|
return F.glu
|
|
|
|
|
if activation == 'prelu':
|
|
|
|
|
return nn.PReLU()
|
|
|
|
|
if activation == 'selu':
|
|
|
|
|
return F.selu
|
|
|
|
|
raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
|
2022-10-24 17:20:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@FEEDFORWARD_NETWORK.register_module()
|
|
|
|
|
class FFN(BaseModule):
|
|
|
|
|
"""Implements feed-forward networks (FFNs) with identity connection.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): The feature dimension. Same as
|
|
|
|
|
`MultiheadAttention`. Defaults: 256.
|
|
|
|
|
feedforward_channels (int): The hidden dimension of FFNs.
|
|
|
|
|
Defaults: 1024.
|
|
|
|
|
num_fcs (int, optional): The number of fully-connected layers in
|
|
|
|
|
FFNs. Default: 2.
|
|
|
|
|
act_cfg (dict, optional): The activation config for FFNs.
|
|
|
|
|
Default: dict(type='ReLU')
|
|
|
|
|
ffn_drop (float, optional): Probability of an element to be
|
|
|
|
|
zeroed in FFN. Default 0.0.
|
|
|
|
|
add_identity (bool, optional): Whether to add the
|
|
|
|
|
identity connection. Default: `True`.
|
|
|
|
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
|
|
|
|
when adding the shortcut.
|
|
|
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
|
|
|
Default: None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims=256,
|
|
|
|
|
feedforward_channels=1024,
|
|
|
|
|
num_fcs=2,
|
|
|
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
|
|
|
ffn_drop=0.,
|
|
|
|
|
dropout_layer=None,
|
|
|
|
|
add_identity=True,
|
|
|
|
|
init_cfg=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
|
|
|
|
f'than 2. got {num_fcs}.'
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.feedforward_channels = feedforward_channels
|
|
|
|
|
self.num_fcs = num_fcs
|
|
|
|
|
self.act_cfg = act_cfg
|
|
|
|
|
self.activate = build_activation_layer(act_cfg)
|
|
|
|
|
|
|
|
|
|
layers = []
|
|
|
|
|
in_channels = embed_dims
|
|
|
|
|
for _ in range(num_fcs - 1):
|
|
|
|
|
layers.append(
|
|
|
|
|
Sequential(
|
|
|
|
|
Linear(in_channels, feedforward_channels), self.activate,
|
|
|
|
|
nn.Dropout(ffn_drop)))
|
|
|
|
|
in_channels = feedforward_channels
|
|
|
|
|
layers.append(Linear(feedforward_channels, embed_dims))
|
|
|
|
|
layers.append(nn.Dropout(ffn_drop))
|
|
|
|
|
self.layers = Sequential(*layers)
|
|
|
|
|
self.dropout_layer = build_dropout(
|
|
|
|
|
dropout_layer) if dropout_layer else torch.nn.Identity()
|
|
|
|
|
self.add_identity = add_identity
|
|
|
|
|
|
|
|
|
|
def forward(self, x, identity=None):
|
|
|
|
|
"""Forward function for `FFN`.
|
|
|
|
|
|
|
|
|
|
The function would add x to the output tensor if residue is None.
|
|
|
|
|
"""
|
|
|
|
|
out = self.layers(x)
|
|
|
|
|
if not self.add_identity:
|
|
|
|
|
return self.dropout_layer(out)
|
|
|
|
|
if identity is None:
|
|
|
|
|
identity = x
|
|
|
|
|
return identity + self.dropout_layer(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@TRANSFORMER_LAYER.register_module()
|
|
|
|
|
class BaseTransformerLayer(BaseModule):
|
|
|
|
|
"""Base `TransformerLayer` for vision transformer.
|
|
|
|
|
|
|
|
|
|
It can be built from `mmcv.ConfigDict` and support more flexible
|
|
|
|
|
customization, for example, using any number of `FFN or LN ` and
|
|
|
|
|
use different kinds of `attention` by specifying a list of `ConfigDict`
|
|
|
|
|
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
|
|
|
|
|
when you specifying `norm` as the first element of `operation_order`.
|
|
|
|
|
More details about the `prenorm`: `On Layer Normalization in the
|
|
|
|
|
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
|
|
|
|
Configs for `self_attention` or `cross_attention` modules,
|
|
|
|
|
The order of the configs in the list should be consistent with
|
|
|
|
|
corresponding attentions in operation_order.
|
|
|
|
|
If it is a dict, all of the attention modules in operation_order
|
|
|
|
|
will be built with this config. Default: None.
|
|
|
|
|
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
|
|
|
|
Configs for FFN, The order of the configs in the list should be
|
|
|
|
|
consistent with corresponding ffn in operation_order.
|
|
|
|
|
If it is a dict, all of the attention modules in operation_order
|
|
|
|
|
will be built with this config.
|
|
|
|
|
operation_order (tuple[str]): The execution order of operation
|
|
|
|
|
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
|
|
|
|
Support `prenorm` when you specifying first element as `norm`.
|
|
|
|
|
Default:None.
|
|
|
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
|
|
|
Default: dict(type='LN').
|
|
|
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
|
|
|
Default: None.
|
|
|
|
|
batch_first (bool): Key, Query and Value are shape
|
|
|
|
|
of (batch, n, embed_dim)
|
|
|
|
|
or (n, batch, embed_dim). Default to False.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
attn_cfgs=None,
|
|
|
|
|
ffn_cfgs=dict(
|
|
|
|
|
type='FFN',
|
|
|
|
|
embed_dims=256,
|
|
|
|
|
feedforward_channels=1024,
|
|
|
|
|
num_fcs=2,
|
|
|
|
|
ffn_drop=0.,
|
|
|
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
|
|
|
),
|
|
|
|
|
operation_order=None,
|
|
|
|
|
norm_cfg=dict(type='LN'),
|
|
|
|
|
init_cfg=None,
|
|
|
|
|
batch_first=False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
|
|
deprecated_args = dict(
|
|
|
|
|
feedforward_channels='feedforward_channels',
|
|
|
|
|
ffn_dropout='ffn_drop',
|
|
|
|
|
ffn_num_fcs='num_fcs')
|
|
|
|
|
for ori_name, new_name in deprecated_args.items():
|
|
|
|
|
if ori_name in kwargs:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
|
|
|
|
f'has been deprecated, now you should set `{new_name}` '
|
|
|
|
|
f'and other FFN related arguments '
|
|
|
|
|
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
|
|
|
|
ffn_cfgs[new_name] = kwargs[ori_name]
|
|
|
|
|
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
|
|
|
|
|
self.batch_first = batch_first
|
|
|
|
|
|
|
|
|
|
assert set(operation_order) & {
|
|
|
|
|
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
|
|
|
|
|
set(operation_order), f'The operation_order of' \
|
|
|
|
|
f' {self.__class__.__name__} should ' \
|
|
|
|
|
f'contains all four operation type ' \
|
|
|
|
|
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
|
|
|
|
|
|
|
|
|
|
num_attn = operation_order.count('self_attn') + operation_order.count(
|
|
|
|
|
'cross_attn')
|
|
|
|
|
if isinstance(attn_cfgs, dict):
|
|
|
|
|
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
|
|
|
|
|
else:
|
|
|
|
|
assert num_attn == len(attn_cfgs), f'The length ' \
|
|
|
|
|
f'of attn_cfg {num_attn} is ' \
|
|
|
|
|
f'not consistent with the number of attention' \
|
|
|
|
|
f'in operation_order {operation_order}.'
|
|
|
|
|
|
|
|
|
|
self.num_attn = num_attn
|
|
|
|
|
self.operation_order = operation_order
|
|
|
|
|
self.norm_cfg = norm_cfg
|
|
|
|
|
self.pre_norm = operation_order[0] == 'norm'
|
|
|
|
|
self.attentions = ModuleList()
|
|
|
|
|
|
|
|
|
|
index = 0
|
|
|
|
|
for operation_name in operation_order:
|
|
|
|
|
if operation_name in ['self_attn', 'cross_attn']:
|
|
|
|
|
if 'batch_first' in attn_cfgs[index]:
|
|
|
|
|
assert self.batch_first == attn_cfgs[index]['batch_first']
|
|
|
|
|
else:
|
|
|
|
|
attn_cfgs[index]['batch_first'] = self.batch_first
|
|
|
|
|
attention = build_attention(attn_cfgs[index])
|
|
|
|
|
# Some custom attentions used as `self_attn`
|
|
|
|
|
# or `cross_attn` can have different behavior.
|
|
|
|
|
attention.operation_name = operation_name
|
|
|
|
|
self.attentions.append(attention)
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
self.embed_dims = self.attentions[0].embed_dims
|
|
|
|
|
|
|
|
|
|
self.ffns = ModuleList()
|
|
|
|
|
num_ffns = operation_order.count('ffn')
|
|
|
|
|
if isinstance(ffn_cfgs, dict):
|
|
|
|
|
ffn_cfgs = ConfigDict(ffn_cfgs)
|
|
|
|
|
if isinstance(ffn_cfgs, dict):
|
|
|
|
|
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
|
|
|
|
|
assert len(ffn_cfgs) == num_ffns
|
|
|
|
|
for ffn_index in range(num_ffns):
|
|
|
|
|
if 'embed_dims' not in ffn_cfgs[ffn_index]:
|
|
|
|
|
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
|
|
|
|
|
else:
|
|
|
|
|
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
|
|
|
|
|
self.ffns.append(
|
|
|
|
|
build_feedforward_network(ffn_cfgs[ffn_index],
|
|
|
|
|
dict(type='FFN')))
|
|
|
|
|
|
|
|
|
|
self.norms = ModuleList()
|
|
|
|
|
num_norms = operation_order.count('norm')
|
|
|
|
|
for _ in range(num_norms):
|
|
|
|
|
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
query,
|
|
|
|
|
key=None,
|
|
|
|
|
value=None,
|
|
|
|
|
query_pos=None,
|
|
|
|
|
key_pos=None,
|
|
|
|
|
attn_masks=None,
|
|
|
|
|
query_key_padding_mask=None,
|
|
|
|
|
key_padding_mask=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
"""Forward function for `TransformerDecoderLayer`.
|
|
|
|
|
|
|
|
|
|
**kwargs contains some specific arguments of attentions.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query (Tensor): The input query with shape
|
|
|
|
|
[num_queries, bs, embed_dims] if
|
|
|
|
|
self.batch_first is False, else
|
|
|
|
|
[bs, num_queries embed_dims].
|
|
|
|
|
key (Tensor): The key tensor with shape [num_keys, bs,
|
|
|
|
|
embed_dims] if self.batch_first is False, else
|
|
|
|
|
[bs, num_keys, embed_dims] .
|
|
|
|
|
value (Tensor): The value tensor with same shape as `key`.
|
|
|
|
|
query_pos (Tensor): The positional encoding for `query`.
|
|
|
|
|
Default: None.
|
|
|
|
|
key_pos (Tensor): The positional encoding for `key`.
|
|
|
|
|
Default: None.
|
|
|
|
|
attn_masks (List[Tensor] | None): 2D Tensor used in
|
|
|
|
|
calculation of corresponding attention. The length of
|
|
|
|
|
it should equal to the number of `attention` in
|
|
|
|
|
`operation_order`. Default: None.
|
|
|
|
|
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
|
|
|
shape [bs, num_queries]. Only used in `self_attn` layer.
|
|
|
|
|
Defaults to None.
|
|
|
|
|
key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
|
|
|
shape [bs, num_keys]. Default: None.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
norm_index = 0
|
|
|
|
|
attn_index = 0
|
|
|
|
|
ffn_index = 0
|
|
|
|
|
identity = query
|
|
|
|
|
if attn_masks is None:
|
|
|
|
|
attn_masks = [None for _ in range(self.num_attn)]
|
|
|
|
|
elif isinstance(attn_masks, torch.Tensor):
|
|
|
|
|
attn_masks = [
|
|
|
|
|
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
|
|
|
|
]
|
|
|
|
|
warnings.warn(f'Use same attn_mask in all attentions in '
|
|
|
|
|
f'{self.__class__.__name__} ')
|
|
|
|
|
else:
|
|
|
|
|
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
|
|
|
|
f'attn_masks {len(attn_masks)} must be equal ' \
|
|
|
|
|
f'to the number of attention in ' \
|
|
|
|
|
f'operation_order {self.num_attn}'
|
|
|
|
|
|
|
|
|
|
for layer in self.operation_order:
|
|
|
|
|
if layer == 'self_attn':
|
|
|
|
|
temp_key = temp_value = query
|
|
|
|
|
query = self.attentions[attn_index](
|
|
|
|
|
query,
|
|
|
|
|
temp_key,
|
|
|
|
|
temp_value,
|
|
|
|
|
identity if self.pre_norm else None,
|
|
|
|
|
query_pos=query_pos,
|
|
|
|
|
key_pos=query_pos,
|
|
|
|
|
attn_mask=attn_masks[attn_index],
|
|
|
|
|
key_padding_mask=query_key_padding_mask,
|
|
|
|
|
**kwargs)
|
|
|
|
|
attn_index += 1
|
|
|
|
|
identity = query
|
|
|
|
|
|
|
|
|
|
elif layer == 'norm':
|
|
|
|
|
query = self.norms[norm_index](query)
|
|
|
|
|
norm_index += 1
|
|
|
|
|
|
|
|
|
|
elif layer == 'cross_attn':
|
|
|
|
|
query = self.attentions[attn_index](
|
2022-11-23 21:32:08 +08:00
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
identity=identity if self.pre_norm else None,
|
2022-10-24 17:20:12 +08:00
|
|
|
|
query_pos=query_pos,
|
|
|
|
|
key_pos=key_pos,
|
|
|
|
|
attn_mask=attn_masks[attn_index],
|
|
|
|
|
key_padding_mask=key_padding_mask,
|
|
|
|
|
**kwargs)
|
|
|
|
|
attn_index += 1
|
|
|
|
|
identity = query
|
|
|
|
|
|
|
|
|
|
elif layer == 'ffn':
|
|
|
|
|
query = self.ffns[ffn_index](
|
|
|
|
|
query, identity if self.pre_norm else None)
|
|
|
|
|
ffn_index += 1
|
|
|
|
|
|
|
|
|
|
return query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
|
|
|
|
class TransformerLayerSequence(BaseModule):
|
|
|
|
|
"""Base class for TransformerEncoder and TransformerDecoder in vision
|
|
|
|
|
transformer.
|
|
|
|
|
|
|
|
|
|
As base-class of Encoder and Decoder in vision transformer.
|
|
|
|
|
Support customization such as specifying different kind
|
|
|
|
|
of `transformer_layer` in `transformer_coder`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
transformerlayer (list[obj:`mmcv.ConfigDict`] |
|
|
|
|
|
obj:`mmcv.ConfigDict`): Config of transformerlayer
|
|
|
|
|
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
|
|
|
|
|
it would be repeated `num_layer` times to a
|
|
|
|
|
list[`mmcv.ConfigDict`]. Default: None.
|
|
|
|
|
num_layers (int): The number of `TransformerLayer`. Default: None.
|
|
|
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
|
|
|
Default: None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
if isinstance(transformerlayers, dict):
|
|
|
|
|
transformerlayers = [
|
|
|
|
|
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
assert isinstance(transformerlayers, list) and \
|
|
|
|
|
len(transformerlayers) == num_layers
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
self.layers = ModuleList()
|
|
|
|
|
for i in range(num_layers):
|
|
|
|
|
self.layers.append(build_transformer_layer(transformerlayers[i]))
|
|
|
|
|
self.embed_dims = self.layers[0].embed_dims
|
|
|
|
|
self.pre_norm = self.layers[0].pre_norm
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
query_pos=None,
|
|
|
|
|
key_pos=None,
|
|
|
|
|
attn_masks=None,
|
|
|
|
|
query_key_padding_mask=None,
|
|
|
|
|
key_padding_mask=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
"""Forward function for `TransformerCoder`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query (Tensor): Input query with shape
|
|
|
|
|
`(num_queries, bs, embed_dims)`.
|
|
|
|
|
key (Tensor): The key tensor with shape
|
|
|
|
|
`(num_keys, bs, embed_dims)`.
|
|
|
|
|
value (Tensor): The value tensor with shape
|
|
|
|
|
`(num_keys, bs, embed_dims)`.
|
|
|
|
|
query_pos (Tensor): The positional encoding for `query`.
|
|
|
|
|
Default: None.
|
|
|
|
|
key_pos (Tensor): The positional encoding for `key`.
|
|
|
|
|
Default: None.
|
|
|
|
|
attn_masks (List[Tensor], optional): Each element is 2D Tensor
|
|
|
|
|
which is used in calculation of corresponding attention in
|
|
|
|
|
operation_order. Default: None.
|
|
|
|
|
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
|
|
|
shape [bs, num_queries]. Only used in self-attention
|
|
|
|
|
Default: None.
|
|
|
|
|
key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
|
|
|
shape [bs, num_keys]. Default: None.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: results with shape [num_queries, bs, embed_dims].
|
|
|
|
|
"""
|
|
|
|
|
for layer in self.layers:
|
|
|
|
|
query = layer(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
query_pos=query_pos,
|
|
|
|
|
key_pos=key_pos,
|
|
|
|
|
attn_masks=attn_masks,
|
|
|
|
|
query_key_padding_mask=query_key_padding_mask,
|
|
|
|
|
key_padding_mask=key_padding_mask,
|
|
|
|
|
**kwargs)
|
|
|
|
|
return query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ATTENTION.register_module()
|
|
|
|
|
class MultiheadAttention(BaseModule):
|
|
|
|
|
"""A wrapper for ``torch.nn.MultiheadAttention``.
|
|
|
|
|
|
|
|
|
|
This module implements MultiheadAttention with identity connection,
|
|
|
|
|
and positional encoding is also passed as input.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
embed_dims (int): The embedding dimension.
|
|
|
|
|
num_heads (int): Parallel attention heads.
|
|
|
|
|
attn_drop (float): A Dropout layer on attn_output_weights.
|
|
|
|
|
Default: 0.0.
|
|
|
|
|
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
|
|
|
|
Default: 0.0.
|
|
|
|
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
|
|
|
|
when adding the shortcut.
|
|
|
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
|
|
|
Default: None.
|
|
|
|
|
batch_first (bool): When it is True, Key, Query and Value are shape of
|
|
|
|
|
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
|
|
|
|
|
Default to False.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
embed_dims,
|
|
|
|
|
num_heads,
|
|
|
|
|
attn_drop=0.,
|
|
|
|
|
proj_drop=0.,
|
|
|
|
|
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
|
|
|
|
init_cfg=None,
|
|
|
|
|
batch_first=False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__(init_cfg)
|
|
|
|
|
if 'dropout' in kwargs:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
'The arguments `dropout` in MultiheadAttention '
|
|
|
|
|
'has been deprecated, now you can separately '
|
|
|
|
|
'set `attn_drop`(float), proj_drop(float), '
|
|
|
|
|
'and `dropout_layer`(dict) ', DeprecationWarning)
|
|
|
|
|
attn_drop = kwargs['dropout']
|
|
|
|
|
dropout_layer['drop_prob'] = kwargs.pop('dropout')
|
|
|
|
|
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.batch_first = batch_first
|
|
|
|
|
|
|
|
|
|
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
self.dropout_layer = build_dropout(
|
|
|
|
|
dropout_layer) if dropout_layer else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
query,
|
|
|
|
|
key=None,
|
|
|
|
|
value=None,
|
|
|
|
|
identity=None,
|
|
|
|
|
query_pos=None,
|
|
|
|
|
key_pos=None,
|
|
|
|
|
attn_mask=None,
|
|
|
|
|
key_padding_mask=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
"""Forward function for `MultiheadAttention`.
|
|
|
|
|
|
|
|
|
|
**kwargs allow passing a more general data flow when combining
|
|
|
|
|
with other operations in `transformerlayer`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query (Tensor): The input query with shape [num_queries, bs,
|
|
|
|
|
embed_dims] if self.batch_first is False, else
|
|
|
|
|
[bs, num_queries embed_dims].
|
|
|
|
|
key (Tensor): The key tensor with shape [num_keys, bs,
|
|
|
|
|
embed_dims] if self.batch_first is False, else
|
|
|
|
|
[bs, num_keys, embed_dims] .
|
|
|
|
|
If None, the ``query`` will be used. Defaults to None.
|
|
|
|
|
value (Tensor): The value tensor with same shape as `key`.
|
|
|
|
|
Same in `nn.MultiheadAttention.forward`. Defaults to None.
|
|
|
|
|
If None, the `key` will be used.
|
|
|
|
|
identity (Tensor): This tensor, with the same shape as x,
|
|
|
|
|
will be used for the identity link.
|
|
|
|
|
If None, `x` will be used. Defaults to None.
|
|
|
|
|
query_pos (Tensor): The positional encoding for query, with
|
|
|
|
|
the same shape as `x`. If not None, it will
|
|
|
|
|
be added to `x` before forward function. Defaults to None.
|
|
|
|
|
key_pos (Tensor): The positional encoding for `key`, with the
|
|
|
|
|
same shape as `key`. Defaults to None. If not None, it will
|
|
|
|
|
be added to `key` before forward function. If None, and
|
|
|
|
|
`query_pos` has the same shape as `key`, then `query_pos`
|
|
|
|
|
will be used for `key_pos`. Defaults to None.
|
|
|
|
|
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
|
|
|
|
|
num_keys]. Same in `nn.MultiheadAttention.forward`.
|
|
|
|
|
Defaults to None.
|
|
|
|
|
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
|
|
|
|
|
Defaults to None.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: forwarded results with shape
|
|
|
|
|
[num_queries, bs, embed_dims]
|
|
|
|
|
if self.batch_first is False, else
|
|
|
|
|
[bs, num_queries embed_dims].
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if key is None:
|
|
|
|
|
key = query
|
|
|
|
|
if value is None:
|
|
|
|
|
value = key
|
|
|
|
|
if identity is None:
|
|
|
|
|
identity = query
|
|
|
|
|
if key_pos is None:
|
|
|
|
|
if query_pos is not None:
|
|
|
|
|
# use query_pos if key_pos is not available
|
|
|
|
|
if query_pos.shape == key.shape:
|
|
|
|
|
key_pos = query_pos
|
|
|
|
|
else:
|
|
|
|
|
warnings.warn(f'position encoding of key is'
|
|
|
|
|
f'missing in {self.__class__.__name__}.')
|
|
|
|
|
if query_pos is not None:
|
|
|
|
|
query = query + query_pos
|
|
|
|
|
if key_pos is not None:
|
|
|
|
|
key = key + key_pos
|
|
|
|
|
|
|
|
|
|
# Because the dataflow('key', 'query', 'value') of
|
|
|
|
|
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
|
|
|
|
# embed_dims), We should adjust the shape of dataflow from
|
|
|
|
|
# batch_first (batch, num_query, embed_dims) to num_query_first
|
|
|
|
|
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
|
|
|
|
# from num_query_first to batch_first.
|
|
|
|
|
if self.batch_first:
|
|
|
|
|
query = query.transpose(0, 1)
|
|
|
|
|
key = key.transpose(0, 1)
|
|
|
|
|
value = value.transpose(0, 1)
|
|
|
|
|
|
|
|
|
|
out = self.attn(
|
|
|
|
|
query=query,
|
|
|
|
|
key=key,
|
|
|
|
|
value=value,
|
|
|
|
|
attn_mask=attn_mask,
|
|
|
|
|
key_padding_mask=key_padding_mask)[0]
|
|
|
|
|
|
|
|
|
|
if self.batch_first:
|
|
|
|
|
out = out.transpose(0, 1)
|
|
|
|
|
|
|
|
|
|
return identity + self.dropout_layer(self.proj_drop(out))
|