mirror of https://github.com/alibaba/EasyCV.git
fix import error
parent
873c749714
commit
ff3c2bd2c1
|
@ -3,11 +3,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import xavier_init
|
||||
from mmcv.cnn.bricks.registry import ATTENTION
|
||||
from mmcv.runner import force_fp32
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from easycv.models.builder import build_attention
|
||||
from easycv.models.registry import ATTENTION
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
|
|
|
@ -6,11 +6,12 @@ import warnings
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import constant_init, xavier_init
|
||||
from mmcv.cnn.bricks.registry import ATTENTION
|
||||
from mmcv.ops.multi_scale_deform_attn import \
|
||||
multi_scale_deformable_attn_pytorch
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from easycv.models.registry import ATTENTION
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
class TemporalSelfAttention(BaseModule):
|
||||
|
|
|
@ -8,8 +8,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv import ConfigDict
|
||||
from mmcv.cnn import build_norm_layer, xavier_init
|
||||
from mmcv.cnn.bricks.transformer import (BaseTransformerLayer,
|
||||
TransformerLayerSequence)
|
||||
from mmcv.runner import auto_fp16, force_fp32
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
from mmcv.utils import TORCH_VERSION, digit_version
|
||||
|
@ -22,6 +20,8 @@ from easycv.models.detection.utils.misc import inverse_sigmoid
|
|||
from easycv.models.registry import (POSITIONAL_ENCODING, TRANSFORMER,
|
||||
TRANSFORMER_LAYER,
|
||||
TRANSFORMER_LAYER_SEQUENCE)
|
||||
from easycv.models.utils.transformer import (BaseTransformerLayer,
|
||||
TransformerLayerSequence)
|
||||
from . import (CustomMSDeformableAttention, MSDeformableAttention3D,
|
||||
TemporalSelfAttention)
|
||||
|
||||
|
|
|
@ -59,7 +59,8 @@ norm_cfg = {
|
|||
# and potentially 'SN'
|
||||
'IBN': ('ibn', IBN),
|
||||
'SyncIBN': ('ibn', SyncIBN),
|
||||
'IN': ('in', nn.InstanceNorm2d)
|
||||
'IN': ('in', nn.InstanceNorm2d),
|
||||
'LN': ('ln', nn.LayerNorm)
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,24 @@
|
|||
import copy
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn.bricks import Linear
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
||||
from mmcv.utils import ConfigDict
|
||||
from torch import Tensor
|
||||
|
||||
from easycv.framework.errors import RuntimeError
|
||||
from easycv.models.builder import (build_attention, build_feedforward_network,
|
||||
build_transformer_layer)
|
||||
from easycv.models.registry import (ATTENTION, FEEDFORWARD_NETWORK,
|
||||
TRANSFORMER_LAYER,
|
||||
TRANSFORMER_LAYER_SEQUENCE)
|
||||
from easycv.models.utils.activation import build_activation_layer
|
||||
from easycv.models.utils.norm import build_norm_layer
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
@ -190,3 +202,535 @@ def _get_activation_fn(activation):
|
|||
if activation == 'selu':
|
||||
return F.selu
|
||||
raise RuntimeError(F'activation should be relu/gelu, not {activation}.')
|
||||
|
||||
|
||||
@FEEDFORWARD_NETWORK.register_module()
|
||||
class FFN(BaseModule):
|
||||
"""Implements feed-forward networks (FFNs) with identity connection.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
num_fcs (int, optional): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
add_identity (bool, optional): Whether to add the
|
||||
identity connection. Default: `True`.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims=256,
|
||||
feedforward_channels=1024,
|
||||
num_fcs=2,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
ffn_drop=0.,
|
||||
dropout_layer=None,
|
||||
add_identity=True,
|
||||
init_cfg=None,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg)
|
||||
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
||||
f'than 2. got {num_fcs}.'
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.num_fcs = num_fcs
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
|
||||
layers = []
|
||||
in_channels = embed_dims
|
||||
for _ in range(num_fcs - 1):
|
||||
layers.append(
|
||||
Sequential(
|
||||
Linear(in_channels, feedforward_channels), self.activate,
|
||||
nn.Dropout(ffn_drop)))
|
||||
in_channels = feedforward_channels
|
||||
layers.append(Linear(feedforward_channels, embed_dims))
|
||||
layers.append(nn.Dropout(ffn_drop))
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
self.add_identity = add_identity
|
||||
|
||||
def forward(self, x, identity=None):
|
||||
"""Forward function for `FFN`.
|
||||
|
||||
The function would add x to the output tensor if residue is None.
|
||||
"""
|
||||
out = self.layers(x)
|
||||
if not self.add_identity:
|
||||
return self.dropout_layer(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER.register_module()
|
||||
class BaseTransformerLayer(BaseModule):
|
||||
"""Base `TransformerLayer` for vision transformer.
|
||||
|
||||
It can be built from `mmcv.ConfigDict` and support more flexible
|
||||
customization, for example, using any number of `FFN or LN ` and
|
||||
use different kinds of `attention` by specifying a list of `ConfigDict`
|
||||
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
|
||||
when you specifying `norm` as the first element of `operation_order`.
|
||||
More details about the `prenorm`: `On Layer Normalization in the
|
||||
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
|
||||
|
||||
Args:
|
||||
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
||||
Configs for `self_attention` or `cross_attention` modules,
|
||||
The order of the configs in the list should be consistent with
|
||||
corresponding attentions in operation_order.
|
||||
If it is a dict, all of the attention modules in operation_order
|
||||
will be built with this config. Default: None.
|
||||
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
||||
Configs for FFN, The order of the configs in the list should be
|
||||
consistent with corresponding ffn in operation_order.
|
||||
If it is a dict, all of the attention modules in operation_order
|
||||
will be built with this config.
|
||||
operation_order (tuple[str]): The execution order of operation
|
||||
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
||||
Support `prenorm` when you specifying first element as `norm`.
|
||||
Default:None.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): Key, Query and Value are shape
|
||||
of (batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
attn_cfgs=None,
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=1024,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
operation_order=None,
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None,
|
||||
batch_first=False,
|
||||
**kwargs):
|
||||
|
||||
deprecated_args = dict(
|
||||
feedforward_channels='feedforward_channels',
|
||||
ffn_dropout='ffn_drop',
|
||||
ffn_num_fcs='num_fcs')
|
||||
for ori_name, new_name in deprecated_args.items():
|
||||
if ori_name in kwargs:
|
||||
warnings.warn(
|
||||
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
||||
f'has been deprecated, now you should set `{new_name}` '
|
||||
f'and other FFN related arguments '
|
||||
f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
|
||||
ffn_cfgs[new_name] = kwargs[ori_name]
|
||||
|
||||
super().__init__(init_cfg)
|
||||
|
||||
self.batch_first = batch_first
|
||||
|
||||
assert set(operation_order) & {
|
||||
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
|
||||
set(operation_order), f'The operation_order of' \
|
||||
f' {self.__class__.__name__} should ' \
|
||||
f'contains all four operation type ' \
|
||||
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
|
||||
|
||||
num_attn = operation_order.count('self_attn') + operation_order.count(
|
||||
'cross_attn')
|
||||
if isinstance(attn_cfgs, dict):
|
||||
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
|
||||
else:
|
||||
assert num_attn == len(attn_cfgs), f'The length ' \
|
||||
f'of attn_cfg {num_attn} is ' \
|
||||
f'not consistent with the number of attention' \
|
||||
f'in operation_order {operation_order}.'
|
||||
|
||||
self.num_attn = num_attn
|
||||
self.operation_order = operation_order
|
||||
self.norm_cfg = norm_cfg
|
||||
self.pre_norm = operation_order[0] == 'norm'
|
||||
self.attentions = ModuleList()
|
||||
|
||||
index = 0
|
||||
for operation_name in operation_order:
|
||||
if operation_name in ['self_attn', 'cross_attn']:
|
||||
if 'batch_first' in attn_cfgs[index]:
|
||||
assert self.batch_first == attn_cfgs[index]['batch_first']
|
||||
else:
|
||||
attn_cfgs[index]['batch_first'] = self.batch_first
|
||||
attention = build_attention(attn_cfgs[index])
|
||||
# Some custom attentions used as `self_attn`
|
||||
# or `cross_attn` can have different behavior.
|
||||
attention.operation_name = operation_name
|
||||
self.attentions.append(attention)
|
||||
index += 1
|
||||
|
||||
self.embed_dims = self.attentions[0].embed_dims
|
||||
|
||||
self.ffns = ModuleList()
|
||||
num_ffns = operation_order.count('ffn')
|
||||
if isinstance(ffn_cfgs, dict):
|
||||
ffn_cfgs = ConfigDict(ffn_cfgs)
|
||||
if isinstance(ffn_cfgs, dict):
|
||||
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
|
||||
assert len(ffn_cfgs) == num_ffns
|
||||
for ffn_index in range(num_ffns):
|
||||
if 'embed_dims' not in ffn_cfgs[ffn_index]:
|
||||
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
|
||||
else:
|
||||
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
|
||||
self.ffns.append(
|
||||
build_feedforward_network(ffn_cfgs[ffn_index],
|
||||
dict(type='FFN')))
|
||||
|
||||
self.norms = ModuleList()
|
||||
num_norms = operation_order.count('norm')
|
||||
for _ in range(num_norms):
|
||||
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key=None,
|
||||
value=None,
|
||||
query_pos=None,
|
||||
key_pos=None,
|
||||
attn_masks=None,
|
||||
query_key_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
**kwargs):
|
||||
"""Forward function for `TransformerDecoderLayer`.
|
||||
|
||||
**kwargs contains some specific arguments of attentions.
|
||||
|
||||
Args:
|
||||
query (Tensor): The input query with shape
|
||||
[num_queries, bs, embed_dims] if
|
||||
self.batch_first is False, else
|
||||
[bs, num_queries embed_dims].
|
||||
key (Tensor): The key tensor with shape [num_keys, bs,
|
||||
embed_dims] if self.batch_first is False, else
|
||||
[bs, num_keys, embed_dims] .
|
||||
value (Tensor): The value tensor with same shape as `key`.
|
||||
query_pos (Tensor): The positional encoding for `query`.
|
||||
Default: None.
|
||||
key_pos (Tensor): The positional encoding for `key`.
|
||||
Default: None.
|
||||
attn_masks (List[Tensor] | None): 2D Tensor used in
|
||||
calculation of corresponding attention. The length of
|
||||
it should equal to the number of `attention` in
|
||||
`operation_order`. Default: None.
|
||||
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
||||
shape [bs, num_queries]. Only used in `self_attn` layer.
|
||||
Defaults to None.
|
||||
key_padding_mask (Tensor): ByteTensor for `query`, with
|
||||
shape [bs, num_keys]. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
||||
"""
|
||||
|
||||
norm_index = 0
|
||||
attn_index = 0
|
||||
ffn_index = 0
|
||||
identity = query
|
||||
if attn_masks is None:
|
||||
attn_masks = [None for _ in range(self.num_attn)]
|
||||
elif isinstance(attn_masks, torch.Tensor):
|
||||
attn_masks = [
|
||||
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
||||
]
|
||||
warnings.warn(f'Use same attn_mask in all attentions in '
|
||||
f'{self.__class__.__name__} ')
|
||||
else:
|
||||
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
||||
f'attn_masks {len(attn_masks)} must be equal ' \
|
||||
f'to the number of attention in ' \
|
||||
f'operation_order {self.num_attn}'
|
||||
|
||||
for layer in self.operation_order:
|
||||
if layer == 'self_attn':
|
||||
temp_key = temp_value = query
|
||||
query = self.attentions[attn_index](
|
||||
query,
|
||||
temp_key,
|
||||
temp_value,
|
||||
identity if self.pre_norm else None,
|
||||
query_pos=query_pos,
|
||||
key_pos=query_pos,
|
||||
attn_mask=attn_masks[attn_index],
|
||||
key_padding_mask=query_key_padding_mask,
|
||||
**kwargs)
|
||||
attn_index += 1
|
||||
identity = query
|
||||
|
||||
elif layer == 'norm':
|
||||
query = self.norms[norm_index](query)
|
||||
norm_index += 1
|
||||
|
||||
elif layer == 'cross_attn':
|
||||
query = self.attentions[attn_index](
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
identity if self.pre_norm else None,
|
||||
query_pos=query_pos,
|
||||
key_pos=key_pos,
|
||||
attn_mask=attn_masks[attn_index],
|
||||
key_padding_mask=key_padding_mask,
|
||||
**kwargs)
|
||||
attn_index += 1
|
||||
identity = query
|
||||
|
||||
elif layer == 'ffn':
|
||||
query = self.ffns[ffn_index](
|
||||
query, identity if self.pre_norm else None)
|
||||
ffn_index += 1
|
||||
|
||||
return query
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
||||
class TransformerLayerSequence(BaseModule):
|
||||
"""Base class for TransformerEncoder and TransformerDecoder in vision
|
||||
transformer.
|
||||
|
||||
As base-class of Encoder and Decoder in vision transformer.
|
||||
Support customization such as specifying different kind
|
||||
of `transformer_layer` in `transformer_coder`.
|
||||
|
||||
Args:
|
||||
transformerlayer (list[obj:`mmcv.ConfigDict`] |
|
||||
obj:`mmcv.ConfigDict`): Config of transformerlayer
|
||||
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
|
||||
it would be repeated `num_layer` times to a
|
||||
list[`mmcv.ConfigDict`]. Default: None.
|
||||
num_layers (int): The number of `TransformerLayer`. Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
if isinstance(transformerlayers, dict):
|
||||
transformerlayers = [
|
||||
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
assert isinstance(transformerlayers, list) and \
|
||||
len(transformerlayers) == num_layers
|
||||
self.num_layers = num_layers
|
||||
self.layers = ModuleList()
|
||||
for i in range(num_layers):
|
||||
self.layers.append(build_transformer_layer(transformerlayers[i]))
|
||||
self.embed_dims = self.layers[0].embed_dims
|
||||
self.pre_norm = self.layers[0].pre_norm
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_pos=None,
|
||||
key_pos=None,
|
||||
attn_masks=None,
|
||||
query_key_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
**kwargs):
|
||||
"""Forward function for `TransformerCoder`.
|
||||
|
||||
Args:
|
||||
query (Tensor): Input query with shape
|
||||
`(num_queries, bs, embed_dims)`.
|
||||
key (Tensor): The key tensor with shape
|
||||
`(num_keys, bs, embed_dims)`.
|
||||
value (Tensor): The value tensor with shape
|
||||
`(num_keys, bs, embed_dims)`.
|
||||
query_pos (Tensor): The positional encoding for `query`.
|
||||
Default: None.
|
||||
key_pos (Tensor): The positional encoding for `key`.
|
||||
Default: None.
|
||||
attn_masks (List[Tensor], optional): Each element is 2D Tensor
|
||||
which is used in calculation of corresponding attention in
|
||||
operation_order. Default: None.
|
||||
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
||||
shape [bs, num_queries]. Only used in self-attention
|
||||
Default: None.
|
||||
key_padding_mask (Tensor): ByteTensor for `query`, with
|
||||
shape [bs, num_keys]. Default: None.
|
||||
|
||||
Returns:
|
||||
Tensor: results with shape [num_queries, bs, embed_dims].
|
||||
"""
|
||||
for layer in self.layers:
|
||||
query = layer(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query_pos=query_pos,
|
||||
key_pos=key_pos,
|
||||
attn_masks=attn_masks,
|
||||
query_key_padding_mask=query_key_padding_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
**kwargs)
|
||||
return query
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
class MultiheadAttention(BaseModule):
|
||||
"""A wrapper for ``torch.nn.MultiheadAttention``.
|
||||
|
||||
This module implements MultiheadAttention with identity connection,
|
||||
and positional encoding is also passed as input.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
attn_drop (float): A Dropout layer on attn_output_weights.
|
||||
Default: 0.0.
|
||||
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
||||
Default: 0.0.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
batch_first (bool): When it is True, Key, Query and Value are shape of
|
||||
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
|
||||
Default to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
||||
init_cfg=None,
|
||||
batch_first=False,
|
||||
**kwargs):
|
||||
super().__init__(init_cfg)
|
||||
if 'dropout' in kwargs:
|
||||
warnings.warn(
|
||||
'The arguments `dropout` in MultiheadAttention '
|
||||
'has been deprecated, now you can separately '
|
||||
'set `attn_drop`(float), proj_drop(float), '
|
||||
'and `dropout_layer`(dict) ', DeprecationWarning)
|
||||
attn_drop = kwargs['dropout']
|
||||
dropout_layer['drop_prob'] = kwargs.pop('dropout')
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
self.num_heads = num_heads
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
||||
**kwargs)
|
||||
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.dropout_layer = build_dropout(
|
||||
dropout_layer) if dropout_layer else nn.Identity()
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key=None,
|
||||
value=None,
|
||||
identity=None,
|
||||
query_pos=None,
|
||||
key_pos=None,
|
||||
attn_mask=None,
|
||||
key_padding_mask=None,
|
||||
**kwargs):
|
||||
"""Forward function for `MultiheadAttention`.
|
||||
|
||||
**kwargs allow passing a more general data flow when combining
|
||||
with other operations in `transformerlayer`.
|
||||
|
||||
Args:
|
||||
query (Tensor): The input query with shape [num_queries, bs,
|
||||
embed_dims] if self.batch_first is False, else
|
||||
[bs, num_queries embed_dims].
|
||||
key (Tensor): The key tensor with shape [num_keys, bs,
|
||||
embed_dims] if self.batch_first is False, else
|
||||
[bs, num_keys, embed_dims] .
|
||||
If None, the ``query`` will be used. Defaults to None.
|
||||
value (Tensor): The value tensor with same shape as `key`.
|
||||
Same in `nn.MultiheadAttention.forward`. Defaults to None.
|
||||
If None, the `key` will be used.
|
||||
identity (Tensor): This tensor, with the same shape as x,
|
||||
will be used for the identity link.
|
||||
If None, `x` will be used. Defaults to None.
|
||||
query_pos (Tensor): The positional encoding for query, with
|
||||
the same shape as `x`. If not None, it will
|
||||
be added to `x` before forward function. Defaults to None.
|
||||
key_pos (Tensor): The positional encoding for `key`, with the
|
||||
same shape as `key`. Defaults to None. If not None, it will
|
||||
be added to `key` before forward function. If None, and
|
||||
`query_pos` has the same shape as `key`, then `query_pos`
|
||||
will be used for `key_pos`. Defaults to None.
|
||||
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
|
||||
num_keys]. Same in `nn.MultiheadAttention.forward`.
|
||||
Defaults to None.
|
||||
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tensor: forwarded results with shape
|
||||
[num_queries, bs, embed_dims]
|
||||
if self.batch_first is False, else
|
||||
[bs, num_queries embed_dims].
|
||||
"""
|
||||
|
||||
if key is None:
|
||||
key = query
|
||||
if value is None:
|
||||
value = key
|
||||
if identity is None:
|
||||
identity = query
|
||||
if key_pos is None:
|
||||
if query_pos is not None:
|
||||
# use query_pos if key_pos is not available
|
||||
if query_pos.shape == key.shape:
|
||||
key_pos = query_pos
|
||||
else:
|
||||
warnings.warn(f'position encoding of key is'
|
||||
f'missing in {self.__class__.__name__}.')
|
||||
if query_pos is not None:
|
||||
query = query + query_pos
|
||||
if key_pos is not None:
|
||||
key = key + key_pos
|
||||
|
||||
# Because the dataflow('key', 'query', 'value') of
|
||||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||||
# embed_dims), We should adjust the shape of dataflow from
|
||||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||||
# from num_query_first to batch_first.
|
||||
if self.batch_first:
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
|
||||
out = self.attn(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask)[0]
|
||||
|
||||
if self.batch_first:
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
|
Loading…
Reference in New Issue