mirror of https://github.com/alibaba/EasyCV.git
737 lines
28 KiB
Python
737 lines
28 KiB
Python
import copy
|
||
import warnings
|
||
from typing import Optional
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from mmcv.cnn.bricks import Linear
|
||
from mmcv.cnn.bricks.drop import build_dropout
|
||
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
||
from mmcv.utils import ConfigDict
|
||
from torch import Tensor
|
||
|
||
from easycv.framework.errors import RuntimeError
|
||
from easycv.models.builder import (build_attention, build_feedforward_network,
|
||
build_transformer_layer)
|
||
from easycv.models.registry import (ATTENTION, FEEDFORWARD_NETWORK,
|
||
TRANSFORMER_LAYER,
|
||
TRANSFORMER_LAYER_SEQUENCE)
|
||
from easycv.models.utils.activation import build_activation_layer
|
||
from easycv.models.utils.norm import build_norm_layer
|
||
|
||
|
||
class MLP(nn.Module):
|
||
""" Very simple multi-layer perceptron (also called FFN)"""
|
||
|
||
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
||
super().__init__()
|
||
self.num_layers = num_layers
|
||
h = [hidden_dim] * (num_layers - 1)
|
||
self.layers = nn.ModuleList(
|
||
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||
|
||
def forward(self, x):
|
||
for i, layer in enumerate(self.layers):
|
||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||
return x
|
||
|
||
|
||
class Mlp(nn.Module):
|
||
""" Multilayer perceptron.
|
||
Parameters:
|
||
act_layer: Specify the activate function, default use nn.GELU.
|
||
"""
|
||
|
||
def __init__(self,
|
||
in_features,
|
||
hidden_features=None,
|
||
out_features=None,
|
||
act_layer=nn.GELU,
|
||
drop=0.):
|
||
super().__init__()
|
||
out_features = out_features or in_features
|
||
hidden_features = hidden_features or in_features
|
||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||
self.act = act_layer()
|
||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||
self.drop = nn.Dropout(drop)
|
||
|
||
def forward(self, x):
|
||
x = self.fc1(x)
|
||
x = self.act(x)
|
||
x = self.drop(x)
|
||
x = self.fc2(x)
|
||
x = self.drop(x)
|
||
return x
|
||
|
||
|
||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||
if drop_prob == 0. or not training:
|
||
return x
|
||
keep_prob = 1 - drop_prob
|
||
shape = (x.shape[0], ) + (1, ) * (
|
||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||
random_tensor = keep_prob + torch.rand(
|
||
shape, dtype=x.dtype, device=x.device)
|
||
random_tensor.floor_() # binarize
|
||
output = x.div(keep_prob) * random_tensor
|
||
return output
|
||
|
||
|
||
class DropPath(nn.Module):
|
||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||
"""
|
||
|
||
def __init__(self, drop_prob=None):
|
||
super(DropPath, self).__init__()
|
||
self.drop_prob = drop_prob
|
||
|
||
def forward(self, x):
|
||
return drop_path(x, self.drop_prob, self.training)
|
||
|
||
def extra_repr(self):
|
||
return 'p={}'.format(self.drop_prob)
|
||
|
||
|
||
class TransformerEncoder(nn.Module):
|
||
|
||
def __init__(self,
|
||
encoder_layer,
|
||
num_layers,
|
||
norm=None,
|
||
d_model=256,
|
||
query_scale_type=None):
|
||
super().__init__()
|
||
self.layers = _get_clones(encoder_layer, num_layers)
|
||
self.num_layers = num_layers
|
||
self.query_scale_type = query_scale_type
|
||
if query_scale_type == 'cond_elewise':
|
||
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
||
self.norm = norm
|
||
|
||
def forward(self,
|
||
src,
|
||
mask: Optional[Tensor] = None,
|
||
src_key_padding_mask: Optional[Tensor] = None,
|
||
pos: Optional[Tensor] = None):
|
||
output = src
|
||
|
||
for layer_id, layer in enumerate(self.layers):
|
||
# rescale the content and pos sim
|
||
if self.query_scale_type == 'cond_elewise':
|
||
pos_scales = self.query_scale(output)
|
||
else:
|
||
pos_scales = 1
|
||
output = layer(
|
||
output,
|
||
src_mask=mask,
|
||
src_key_padding_mask=src_key_padding_mask,
|
||
pos=pos * pos_scales)
|
||
|
||
if self.norm is not None:
|
||
output = self.norm(output)
|
||
|
||
return output
|
||
|
||
|
||
class TransformerEncoderLayer(nn.Module):
|
||
|
||
def __init__(self,
|
||
d_model,
|
||
nhead,
|
||
dim_feedforward=2048,
|
||
dropout=0.1,
|
||
activation='relu',
|
||
normalize_before=False):
|
||
super().__init__()
|
||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||
# Implementation of Feedforward model
|
||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||
self.dropout = nn.Dropout(dropout)
|
||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||
|
||
self.norm1 = nn.LayerNorm(d_model)
|
||
self.norm2 = nn.LayerNorm(d_model)
|
||
self.dropout1 = nn.Dropout(dropout)
|
||
self.dropout2 = nn.Dropout(dropout)
|
||
|
||
self.activation = _get_activation_fn(activation)
|
||
self.normalize_before = normalize_before
|
||
|
||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||
return tensor if pos is None else tensor + pos
|
||
|
||
def forward(self,
|
||
src,
|
||
src_mask: Optional[Tensor] = None,
|
||
src_key_padding_mask: Optional[Tensor] = None,
|
||
pos: Optional[Tensor] = None):
|
||
q = k = self.with_pos_embed(src, pos)
|
||
src2 = self.self_attn(
|
||
q,
|
||
k,
|
||
value=src,
|
||
attn_mask=src_mask,
|
||
key_padding_mask=src_key_padding_mask)[0]
|
||
src = src + self.dropout1(src2)
|
||
src = self.norm1(src)
|
||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||
src = src + self.dropout2(src2)
|
||
src = self.norm2(src)
|
||
return src
|
||
|
||
|
||
def _get_clones(module, N, layer_share=False):
|
||
if layer_share:
|
||
return nn.ModuleList([module for i in range(N)])
|
||
else:
|
||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||
|
||
|
||
def _get_activation_fn(activation):
|
||
"""Return an activation function given a string"""
|
||
if activation == 'relu':
|
||
return F.relu
|
||
if activation == 'gelu':
|
||
return F.gelu
|
||
if activation == 'glu':
|
||
return F.glu
|
||
if activation == 'prelu':
|
||
return nn.PReLU()
|
||
if activation == 'selu':
|
||
return F.selu
|
||
raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
|
||
|
||
|
||
@FEEDFORWARD_NETWORK.register_module()
|
||
class FFN(BaseModule):
|
||
"""Implements feed-forward networks (FFNs) with identity connection.
|
||
|
||
Args:
|
||
embed_dims (int): The feature dimension. Same as
|
||
`MultiheadAttention`. Defaults: 256.
|
||
feedforward_channels (int): The hidden dimension of FFNs.
|
||
Defaults: 1024.
|
||
num_fcs (int, optional): The number of fully-connected layers in
|
||
FFNs. Default: 2.
|
||
act_cfg (dict, optional): The activation config for FFNs.
|
||
Default: dict(type='ReLU')
|
||
ffn_drop (float, optional): Probability of an element to be
|
||
zeroed in FFN. Default 0.0.
|
||
add_identity (bool, optional): Whether to add the
|
||
identity connection. Default: `True`.
|
||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||
when adding the shortcut.
|
||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||
Default: None.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims=256,
|
||
feedforward_channels=1024,
|
||
num_fcs=2,
|
||
act_cfg=dict(type='ReLU', inplace=True),
|
||
ffn_drop=0.,
|
||
dropout_layer=None,
|
||
add_identity=True,
|
||
init_cfg=None,
|
||
**kwargs):
|
||
super().__init__(init_cfg)
|
||
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
||
f'than 2. got {num_fcs}.'
|
||
self.embed_dims = embed_dims
|
||
self.feedforward_channels = feedforward_channels
|
||
self.num_fcs = num_fcs
|
||
self.act_cfg = act_cfg
|
||
self.activate = build_activation_layer(act_cfg)
|
||
|
||
layers = []
|
||
in_channels = embed_dims
|
||
for _ in range(num_fcs - 1):
|
||
layers.append(
|
||
Sequential(
|
||
Linear(in_channels, feedforward_channels), self.activate,
|
||
nn.Dropout(ffn_drop)))
|
||
in_channels = feedforward_channels
|
||
layers.append(Linear(feedforward_channels, embed_dims))
|
||
layers.append(nn.Dropout(ffn_drop))
|
||
self.layers = Sequential(*layers)
|
||
self.dropout_layer = build_dropout(
|
||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||
self.add_identity = add_identity
|
||
|
||
def forward(self, x, identity=None):
|
||
"""Forward function for `FFN`.
|
||
|
||
The function would add x to the output tensor if residue is None.
|
||
"""
|
||
out = self.layers(x)
|
||
if not self.add_identity:
|
||
return self.dropout_layer(out)
|
||
if identity is None:
|
||
identity = x
|
||
return identity + self.dropout_layer(out)
|
||
|
||
|
||
@TRANSFORMER_LAYER.register_module()
|
||
class BaseTransformerLayer(BaseModule):
|
||
"""Base `TransformerLayer` for vision transformer.
|
||
|
||
It can be built from `mmcv.ConfigDict` and support more flexible
|
||
customization, for example, using any number of `FFN or LN ` and
|
||
use different kinds of `attention` by specifying a list of `ConfigDict`
|
||
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
|
||
when you specifying `norm` as the first element of `operation_order`.
|
||
More details about the `prenorm`: `On Layer Normalization in the
|
||
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
|
||
|
||
Args:
|
||
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
||
Configs for `self_attention` or `cross_attention` modules,
|
||
The order of the configs in the list should be consistent with
|
||
corresponding attentions in operation_order.
|
||
If it is a dict, all of the attention modules in operation_order
|
||
will be built with this config. Default: None.
|
||
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
||
Configs for FFN, The order of the configs in the list should be
|
||
consistent with corresponding ffn in operation_order.
|
||
If it is a dict, all of the attention modules in operation_order
|
||
will be built with this config.
|
||
operation_order (tuple[str]): The execution order of operation
|
||
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
||
Support `prenorm` when you specifying first element as `norm`.
|
||
Default:None.
|
||
norm_cfg (dict): Config dict for normalization layer.
|
||
Default: dict(type='LN').
|
||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||
Default: None.
|
||
batch_first (bool): Key, Query and Value are shape
|
||
of (batch, n, embed_dim)
|
||
or (n, batch, embed_dim). Default to False.
|
||
"""
|
||
|
||
def __init__(self,
|
||
attn_cfgs=None,
|
||
ffn_cfgs=dict(
|
||
type='FFN',
|
||
embed_dims=256,
|
||
feedforward_channels=1024,
|
||
num_fcs=2,
|
||
ffn_drop=0.,
|
||
act_cfg=dict(type='ReLU', inplace=True),
|
||
),
|
||
operation_order=None,
|
||
norm_cfg=dict(type='LN'),
|
||
init_cfg=None,
|
||
batch_first=False,
|
||
**kwargs):
|
||
|
||
deprecated_args = dict(
|
||
feedforward_channels='feedforward_channels',
|
||
ffn_dropout='ffn_drop',
|
||
ffn_num_fcs='num_fcs')
|
||
for ori_name, new_name in deprecated_args.items():
|
||
if ori_name in kwargs:
|
||
warnings.warn(
|
||
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
||
f'has been deprecated, now you should set `{new_name}` '
|
||
f'and other FFN related arguments '
|
||
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
||
ffn_cfgs[new_name] = kwargs[ori_name]
|
||
|
||
super().__init__(init_cfg)
|
||
|
||
self.batch_first = batch_first
|
||
|
||
assert set(operation_order) & {
|
||
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
|
||
set(operation_order), f'The operation_order of' \
|
||
f' {self.__class__.__name__} should ' \
|
||
f'contains all four operation type ' \
|
||
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
|
||
|
||
num_attn = operation_order.count('self_attn') + operation_order.count(
|
||
'cross_attn')
|
||
if isinstance(attn_cfgs, dict):
|
||
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
|
||
else:
|
||
assert num_attn == len(attn_cfgs), f'The length ' \
|
||
f'of attn_cfg {num_attn} is ' \
|
||
f'not consistent with the number of attention' \
|
||
f'in operation_order {operation_order}.'
|
||
|
||
self.num_attn = num_attn
|
||
self.operation_order = operation_order
|
||
self.norm_cfg = norm_cfg
|
||
self.pre_norm = operation_order[0] == 'norm'
|
||
self.attentions = ModuleList()
|
||
|
||
index = 0
|
||
for operation_name in operation_order:
|
||
if operation_name in ['self_attn', 'cross_attn']:
|
||
if 'batch_first' in attn_cfgs[index]:
|
||
assert self.batch_first == attn_cfgs[index]['batch_first']
|
||
else:
|
||
attn_cfgs[index]['batch_first'] = self.batch_first
|
||
attention = build_attention(attn_cfgs[index])
|
||
# Some custom attentions used as `self_attn`
|
||
# or `cross_attn` can have different behavior.
|
||
attention.operation_name = operation_name
|
||
self.attentions.append(attention)
|
||
index += 1
|
||
|
||
self.embed_dims = self.attentions[0].embed_dims
|
||
|
||
self.ffns = ModuleList()
|
||
num_ffns = operation_order.count('ffn')
|
||
if isinstance(ffn_cfgs, dict):
|
||
ffn_cfgs = ConfigDict(ffn_cfgs)
|
||
if isinstance(ffn_cfgs, dict):
|
||
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
|
||
assert len(ffn_cfgs) == num_ffns
|
||
for ffn_index in range(num_ffns):
|
||
if 'embed_dims' not in ffn_cfgs[ffn_index]:
|
||
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
|
||
else:
|
||
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
|
||
self.ffns.append(
|
||
build_feedforward_network(ffn_cfgs[ffn_index],
|
||
dict(type='FFN')))
|
||
|
||
self.norms = ModuleList()
|
||
num_norms = operation_order.count('norm')
|
||
for _ in range(num_norms):
|
||
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
||
|
||
def forward(self,
|
||
query,
|
||
key=None,
|
||
value=None,
|
||
query_pos=None,
|
||
key_pos=None,
|
||
attn_masks=None,
|
||
query_key_padding_mask=None,
|
||
key_padding_mask=None,
|
||
**kwargs):
|
||
"""Forward function for `TransformerDecoderLayer`.
|
||
|
||
**kwargs contains some specific arguments of attentions.
|
||
|
||
Args:
|
||
query (Tensor): The input query with shape
|
||
[num_queries, bs, embed_dims] if
|
||
self.batch_first is False, else
|
||
[bs, num_queries embed_dims].
|
||
key (Tensor): The key tensor with shape [num_keys, bs,
|
||
embed_dims] if self.batch_first is False, else
|
||
[bs, num_keys, embed_dims] .
|
||
value (Tensor): The value tensor with same shape as `key`.
|
||
query_pos (Tensor): The positional encoding for `query`.
|
||
Default: None.
|
||
key_pos (Tensor): The positional encoding for `key`.
|
||
Default: None.
|
||
attn_masks (List[Tensor] | None): 2D Tensor used in
|
||
calculation of corresponding attention. The length of
|
||
it should equal to the number of `attention` in
|
||
`operation_order`. Default: None.
|
||
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
||
shape [bs, num_queries]. Only used in `self_attn` layer.
|
||
Defaults to None.
|
||
key_padding_mask (Tensor): ByteTensor for `query`, with
|
||
shape [bs, num_keys]. Default: None.
|
||
|
||
Returns:
|
||
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
||
"""
|
||
|
||
norm_index = 0
|
||
attn_index = 0
|
||
ffn_index = 0
|
||
identity = query
|
||
if attn_masks is None:
|
||
attn_masks = [None for _ in range(self.num_attn)]
|
||
elif isinstance(attn_masks, torch.Tensor):
|
||
attn_masks = [
|
||
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
||
]
|
||
warnings.warn(f'Use same attn_mask in all attentions in '
|
||
f'{self.__class__.__name__} ')
|
||
else:
|
||
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
||
f'attn_masks {len(attn_masks)} must be equal ' \
|
||
f'to the number of attention in ' \
|
||
f'operation_order {self.num_attn}'
|
||
|
||
for layer in self.operation_order:
|
||
if layer == 'self_attn':
|
||
temp_key = temp_value = query
|
||
query = self.attentions[attn_index](
|
||
query,
|
||
temp_key,
|
||
temp_value,
|
||
identity if self.pre_norm else None,
|
||
query_pos=query_pos,
|
||
key_pos=query_pos,
|
||
attn_mask=attn_masks[attn_index],
|
||
key_padding_mask=query_key_padding_mask,
|
||
**kwargs)
|
||
attn_index += 1
|
||
identity = query
|
||
|
||
elif layer == 'norm':
|
||
query = self.norms[norm_index](query)
|
||
norm_index += 1
|
||
|
||
elif layer == 'cross_attn':
|
||
query = self.attentions[attn_index](
|
||
query,
|
||
key,
|
||
value,
|
||
identity if self.pre_norm else None,
|
||
query_pos=query_pos,
|
||
key_pos=key_pos,
|
||
attn_mask=attn_masks[attn_index],
|
||
key_padding_mask=key_padding_mask,
|
||
**kwargs)
|
||
attn_index += 1
|
||
identity = query
|
||
|
||
elif layer == 'ffn':
|
||
query = self.ffns[ffn_index](
|
||
query, identity if self.pre_norm else None)
|
||
ffn_index += 1
|
||
|
||
return query
|
||
|
||
|
||
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
||
class TransformerLayerSequence(BaseModule):
|
||
"""Base class for TransformerEncoder and TransformerDecoder in vision
|
||
transformer.
|
||
|
||
As base-class of Encoder and Decoder in vision transformer.
|
||
Support customization such as specifying different kind
|
||
of `transformer_layer` in `transformer_coder`.
|
||
|
||
Args:
|
||
transformerlayer (list[obj:`mmcv.ConfigDict`] |
|
||
obj:`mmcv.ConfigDict`): Config of transformerlayer
|
||
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
|
||
it would be repeated `num_layer` times to a
|
||
list[`mmcv.ConfigDict`]. Default: None.
|
||
num_layers (int): The number of `TransformerLayer`. Default: None.
|
||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||
Default: None.
|
||
"""
|
||
|
||
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
||
super().__init__(init_cfg)
|
||
if isinstance(transformerlayers, dict):
|
||
transformerlayers = [
|
||
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
||
]
|
||
else:
|
||
assert isinstance(transformerlayers, list) and \
|
||
len(transformerlayers) == num_layers
|
||
self.num_layers = num_layers
|
||
self.layers = ModuleList()
|
||
for i in range(num_layers):
|
||
self.layers.append(build_transformer_layer(transformerlayers[i]))
|
||
self.embed_dims = self.layers[0].embed_dims
|
||
self.pre_norm = self.layers[0].pre_norm
|
||
|
||
def forward(self,
|
||
query,
|
||
key,
|
||
value,
|
||
query_pos=None,
|
||
key_pos=None,
|
||
attn_masks=None,
|
||
query_key_padding_mask=None,
|
||
key_padding_mask=None,
|
||
**kwargs):
|
||
"""Forward function for `TransformerCoder`.
|
||
|
||
Args:
|
||
query (Tensor): Input query with shape
|
||
`(num_queries, bs, embed_dims)`.
|
||
key (Tensor): The key tensor with shape
|
||
`(num_keys, bs, embed_dims)`.
|
||
value (Tensor): The value tensor with shape
|
||
`(num_keys, bs, embed_dims)`.
|
||
query_pos (Tensor): The positional encoding for `query`.
|
||
Default: None.
|
||
key_pos (Tensor): The positional encoding for `key`.
|
||
Default: None.
|
||
attn_masks (List[Tensor], optional): Each element is 2D Tensor
|
||
which is used in calculation of corresponding attention in
|
||
operation_order. Default: None.
|
||
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
||
shape [bs, num_queries]. Only used in self-attention
|
||
Default: None.
|
||
key_padding_mask (Tensor): ByteTensor for `query`, with
|
||
shape [bs, num_keys]. Default: None.
|
||
|
||
Returns:
|
||
Tensor: results with shape [num_queries, bs, embed_dims].
|
||
"""
|
||
for layer in self.layers:
|
||
query = layer(
|
||
query,
|
||
key,
|
||
value,
|
||
query_pos=query_pos,
|
||
key_pos=key_pos,
|
||
attn_masks=attn_masks,
|
||
query_key_padding_mask=query_key_padding_mask,
|
||
key_padding_mask=key_padding_mask,
|
||
**kwargs)
|
||
return query
|
||
|
||
|
||
@ATTENTION.register_module()
|
||
class MultiheadAttention(BaseModule):
|
||
"""A wrapper for ``torch.nn.MultiheadAttention``.
|
||
|
||
This module implements MultiheadAttention with identity connection,
|
||
and positional encoding is also passed as input.
|
||
|
||
Args:
|
||
embed_dims (int): The embedding dimension.
|
||
num_heads (int): Parallel attention heads.
|
||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||
Default: 0.0.
|
||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||
Default: 0.0.
|
||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||
when adding the shortcut.
|
||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||
Default: None.
|
||
batch_first (bool): When it is True, Key, Query and Value are shape of
|
||
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
|
||
Default to False.
|
||
"""
|
||
|
||
def __init__(self,
|
||
embed_dims,
|
||
num_heads,
|
||
attn_drop=0.,
|
||
proj_drop=0.,
|
||
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
||
init_cfg=None,
|
||
batch_first=False,
|
||
**kwargs):
|
||
super().__init__(init_cfg)
|
||
if 'dropout' in kwargs:
|
||
warnings.warn(
|
||
'The arguments `dropout` in MultiheadAttention '
|
||
'has been deprecated, now you can separately '
|
||
'set `attn_drop`(float), proj_drop(float), '
|
||
'and `dropout_layer`(dict) ', DeprecationWarning)
|
||
attn_drop = kwargs['dropout']
|
||
dropout_layer['drop_prob'] = kwargs.pop('dropout')
|
||
|
||
self.embed_dims = embed_dims
|
||
self.num_heads = num_heads
|
||
self.batch_first = batch_first
|
||
|
||
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
||
**kwargs)
|
||
|
||
self.proj_drop = nn.Dropout(proj_drop)
|
||
self.dropout_layer = build_dropout(
|
||
dropout_layer) if dropout_layer else nn.Identity()
|
||
|
||
def forward(self,
|
||
query,
|
||
key=None,
|
||
value=None,
|
||
identity=None,
|
||
query_pos=None,
|
||
key_pos=None,
|
||
attn_mask=None,
|
||
key_padding_mask=None,
|
||
**kwargs):
|
||
"""Forward function for `MultiheadAttention`.
|
||
|
||
**kwargs allow passing a more general data flow when combining
|
||
with other operations in `transformerlayer`.
|
||
|
||
Args:
|
||
query (Tensor): The input query with shape [num_queries, bs,
|
||
embed_dims] if self.batch_first is False, else
|
||
[bs, num_queries embed_dims].
|
||
key (Tensor): The key tensor with shape [num_keys, bs,
|
||
embed_dims] if self.batch_first is False, else
|
||
[bs, num_keys, embed_dims] .
|
||
If None, the ``query`` will be used. Defaults to None.
|
||
value (Tensor): The value tensor with same shape as `key`.
|
||
Same in `nn.MultiheadAttention.forward`. Defaults to None.
|
||
If None, the `key` will be used.
|
||
identity (Tensor): This tensor, with the same shape as x,
|
||
will be used for the identity link.
|
||
If None, `x` will be used. Defaults to None.
|
||
query_pos (Tensor): The positional encoding for query, with
|
||
the same shape as `x`. If not None, it will
|
||
be added to `x` before forward function. Defaults to None.
|
||
key_pos (Tensor): The positional encoding for `key`, with the
|
||
same shape as `key`. Defaults to None. If not None, it will
|
||
be added to `key` before forward function. If None, and
|
||
`query_pos` has the same shape as `key`, then `query_pos`
|
||
will be used for `key_pos`. Defaults to None.
|
||
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
|
||
num_keys]. Same in `nn.MultiheadAttention.forward`.
|
||
Defaults to None.
|
||
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
|
||
Defaults to None.
|
||
|
||
Returns:
|
||
Tensor: forwarded results with shape
|
||
[num_queries, bs, embed_dims]
|
||
if self.batch_first is False, else
|
||
[bs, num_queries embed_dims].
|
||
"""
|
||
|
||
if key is None:
|
||
key = query
|
||
if value is None:
|
||
value = key
|
||
if identity is None:
|
||
identity = query
|
||
if key_pos is None:
|
||
if query_pos is not None:
|
||
# use query_pos if key_pos is not available
|
||
if query_pos.shape == key.shape:
|
||
key_pos = query_pos
|
||
else:
|
||
warnings.warn(f'position encoding of key is'
|
||
f'missing in {self.__class__.__name__}.')
|
||
if query_pos is not None:
|
||
query = query + query_pos
|
||
if key_pos is not None:
|
||
key = key + key_pos
|
||
|
||
# Because the dataflow('key', 'query', 'value') of
|
||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||
# embed_dims), We should adjust the shape of dataflow from
|
||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||
# from num_query_first to batch_first.
|
||
if self.batch_first:
|
||
query = query.transpose(0, 1)
|
||
key = key.transpose(0, 1)
|
||
value = value.transpose(0, 1)
|
||
|
||
out = self.attn(
|
||
query=query,
|
||
key=key,
|
||
value=value,
|
||
attn_mask=attn_mask,
|
||
key_padding_mask=key_padding_mask)[0]
|
||
|
||
if self.batch_first:
|
||
out = out.transpose(0, 1)
|
||
|
||
return identity + self.dropout_layer(self.proj_drop(out))
|