mirror of https://github.com/alibaba/EasyCV.git
455 lines
20 KiB
Python
455 lines
20 KiB
Python
# ------------------------------------------------------------------------
|
|
# DAB-DETR
|
|
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
# ------------------------------------------------------------------------
|
|
# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR)
|
|
# Copyright (c) 2021 Microsoft. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
# ------------------------------------------------------------------------
|
|
# Modified from DETR (https://github.com/facebookresearch/detr)
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
# ------------------------------------------------------------------------
|
|
# Modified from codes in torch.nn
|
|
# ------------------------------------------------------------------------
|
|
"""
|
|
MultiheadAttention that support query, key, and value to have different dimensions.
|
|
Query, key, and value projections are removed.
|
|
Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873
|
|
and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837
|
|
"""
|
|
|
|
import warnings
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn.functional import dropout, linear, pad, softmax
|
|
from torch.nn.init import constant_
|
|
from torch.nn.modules.linear import Linear
|
|
from torch.nn.modules.module import Module
|
|
|
|
try:
|
|
from torch.overrides import has_torch_function, handle_torch_function
|
|
except:
|
|
from torch._overrides import has_torch_function, handle_torch_function
|
|
|
|
|
|
class MultiheadAttention(Module):
|
|
r"""Allows the model to jointly attend to information
|
|
from different representation subspaces.
|
|
See reference: Attention Is All You Need
|
|
.. math::
|
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
|
Args:
|
|
embed_dim: total dimension of the model.
|
|
num_heads: parallel attention heads.
|
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
|
bias: add bias as module parameter. Default: True.
|
|
add_bias_kv: add bias to the key and value sequences at dim=0.
|
|
add_zero_attn: add a new batch of zeros to the key and
|
|
value sequences at dim=1.
|
|
kdim: total number of features in key. Default: None.
|
|
vdim: total number of features in value. Default: None.
|
|
Note: if kdim and vdim are None, they will be set to embed_dim such that
|
|
query, key, and value have the same number of features.
|
|
Examples::
|
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
|
"""
|
|
bias_k: Optional[torch.Tensor]
|
|
bias_v: Optional[torch.Tensor]
|
|
|
|
def __init__(self,
|
|
embed_dim,
|
|
num_heads,
|
|
dropout=0.,
|
|
bias=True,
|
|
add_bias_kv=False,
|
|
add_zero_attn=False,
|
|
kdim=None,
|
|
vdim=None):
|
|
super(MultiheadAttention, self).__init__()
|
|
self.embed_dim = embed_dim
|
|
self.kdim = kdim if kdim is not None else embed_dim
|
|
self.vdim = vdim if vdim is not None else embed_dim
|
|
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
|
|
|
self.num_heads = num_heads
|
|
self.dropout = dropout
|
|
self.head_dim = embed_dim // num_heads
|
|
assert self.head_dim * num_heads == self.embed_dim, 'embed_dim must be divisible by num_heads'
|
|
|
|
vdim = vdim if vdim is not None else embed_dim
|
|
self.out_proj = Linear(vdim, vdim)
|
|
|
|
self.in_proj_bias = None
|
|
self.in_proj_weight = None
|
|
self.bias_k = self.bias_v = None
|
|
self.q_proj_weight = None
|
|
self.k_proj_weight = None
|
|
self.v_proj_weight = None
|
|
|
|
self.add_zero_attn = add_zero_attn
|
|
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
constant_(self.out_proj.bias, 0.)
|
|
|
|
def __setstate__(self, state):
|
|
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
|
if '_qkv_same_embed_dim' not in state:
|
|
state['_qkv_same_embed_dim'] = True
|
|
|
|
super(MultiheadAttention, self).__setstate__(state)
|
|
|
|
def forward(self,
|
|
query,
|
|
key,
|
|
value,
|
|
key_padding_mask=None,
|
|
need_weights=True,
|
|
attn_mask=None):
|
|
# type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
|
|
r"""
|
|
Args:
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
See "Attention Is All You Need" for more details.
|
|
key_padding_mask: if provided, specified padding elements in the key will
|
|
be ignored by the attention. When given a binary mask and a value is True,
|
|
the corresponding value on the attention layer will be ignored. When given
|
|
a byte mask and a value is non-zero, the corresponding value on the attention
|
|
layer will be ignored
|
|
need_weights: output attn_output_weights.
|
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
Shape:
|
|
- Inputs:
|
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
|
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length,
|
|
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
- Outputs:
|
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
E is the embedding dimension.
|
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
L is the target sequence length, S is the source sequence length.
|
|
"""
|
|
if not self._qkv_same_embed_dim:
|
|
return multi_head_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.in_proj_weight,
|
|
self.in_proj_bias,
|
|
self.bias_k,
|
|
self.bias_v,
|
|
self.add_zero_attn,
|
|
self.dropout,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=need_weights,
|
|
attn_mask=attn_mask,
|
|
use_separate_proj_weight=True,
|
|
q_proj_weight=self.q_proj_weight,
|
|
k_proj_weight=self.k_proj_weight,
|
|
v_proj_weight=self.v_proj_weight,
|
|
out_dim=self.vdim)
|
|
else:
|
|
return multi_head_attention_forward(
|
|
query,
|
|
key,
|
|
value,
|
|
self.embed_dim,
|
|
self.num_heads,
|
|
self.in_proj_weight,
|
|
self.in_proj_bias,
|
|
self.bias_k,
|
|
self.bias_v,
|
|
self.add_zero_attn,
|
|
self.dropout,
|
|
self.out_proj.weight,
|
|
self.out_proj.bias,
|
|
training=self.training,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=need_weights,
|
|
attn_mask=attn_mask,
|
|
out_dim=self.vdim)
|
|
|
|
|
|
def multi_head_attention_forward(
|
|
query: Tensor,
|
|
key: Tensor,
|
|
value: Tensor,
|
|
embed_dim_to_check: int,
|
|
num_heads: int,
|
|
in_proj_weight: Tensor,
|
|
in_proj_bias: Tensor,
|
|
bias_k: Optional[Tensor],
|
|
bias_v: Optional[Tensor],
|
|
add_zero_attn: bool,
|
|
dropout_p: float,
|
|
out_proj_weight: Tensor,
|
|
out_proj_bias: Tensor,
|
|
training: bool = True,
|
|
key_padding_mask: Optional[Tensor] = None,
|
|
need_weights: bool = True,
|
|
attn_mask: Optional[Tensor] = None,
|
|
use_separate_proj_weight: bool = False,
|
|
q_proj_weight: Optional[Tensor] = None,
|
|
k_proj_weight: Optional[Tensor] = None,
|
|
v_proj_weight: Optional[Tensor] = None,
|
|
static_k: Optional[Tensor] = None,
|
|
static_v: Optional[Tensor] = None,
|
|
out_dim: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
|
r"""
|
|
Args:
|
|
query, key, value: map a query and a set of key-value pairs to an output.
|
|
See "Attention Is All You Need" for more details.
|
|
embed_dim_to_check: total dimension of the model.
|
|
num_heads: parallel attention heads.
|
|
in_proj_weight, in_proj_bias: input projection weight and bias.
|
|
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
|
add_zero_attn: add a new batch of zeros to the key and
|
|
value sequences at dim=1.
|
|
dropout_p: probability of an element to be zeroed.
|
|
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
|
training: apply dropout if is ``True``.
|
|
key_padding_mask: if provided, specified padding elements in the key will
|
|
be ignored by the attention. This is an binary mask. When the value is True,
|
|
the corresponding value on the attention layer will be filled with -inf.
|
|
need_weights: output attn_output_weights.
|
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
|
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
|
and value in different forms. If false, in_proj_weight will be used, which is
|
|
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
|
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
|
static_k, static_v: static key and value used for attention operators.
|
|
Shape:
|
|
Inputs:
|
|
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
|
the embedding dimension.
|
|
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
|
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
|
will be unchanged. If a BoolTensor is provided, the positions with the
|
|
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
|
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
|
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
|
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
|
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
|
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
|
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
|
is provided, it will be added to the attention weight.
|
|
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
|
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
|
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
|
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
|
Outputs:
|
|
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
|
E is the embedding dimension.
|
|
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
|
L is the target sequence length, S is the source sequence length.
|
|
"""
|
|
if not torch.jit.is_scripting():
|
|
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k,
|
|
bias_v, out_proj_weight, out_proj_bias)
|
|
if any([type(t) is not Tensor
|
|
for t in tens_ops]) and has_torch_function(tens_ops):
|
|
return handle_torch_function(
|
|
multi_head_attention_forward,
|
|
tens_ops,
|
|
query,
|
|
key,
|
|
value,
|
|
embed_dim_to_check,
|
|
num_heads,
|
|
in_proj_weight,
|
|
in_proj_bias,
|
|
bias_k,
|
|
bias_v,
|
|
add_zero_attn,
|
|
dropout_p,
|
|
out_proj_weight,
|
|
out_proj_bias,
|
|
training=training,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=need_weights,
|
|
attn_mask=attn_mask,
|
|
use_separate_proj_weight=use_separate_proj_weight,
|
|
q_proj_weight=q_proj_weight,
|
|
k_proj_weight=k_proj_weight,
|
|
v_proj_weight=v_proj_weight,
|
|
static_k=static_k,
|
|
static_v=static_v)
|
|
tgt_len, bsz, embed_dim = query.size()
|
|
assert embed_dim == embed_dim_to_check
|
|
# allow MHA to have different sizes for the feature dimension
|
|
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
|
|
|
head_dim = embed_dim // num_heads
|
|
v_head_dim = out_dim // num_heads
|
|
assert head_dim * num_heads == embed_dim, 'embed_dim must be divisible by num_heads'
|
|
scaling = float(head_dim)**-0.5
|
|
|
|
q = query * scaling
|
|
k = key
|
|
v = value
|
|
|
|
if attn_mask is not None:
|
|
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
|
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
|
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
|
if attn_mask.dtype == torch.uint8:
|
|
warnings.warn(
|
|
'Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.'
|
|
)
|
|
attn_mask = attn_mask.to(torch.bool)
|
|
|
|
if attn_mask.dim() == 2:
|
|
attn_mask = attn_mask.unsqueeze(0)
|
|
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
|
raise RuntimeError(
|
|
'The size of the 2D attn_mask is not correct.')
|
|
elif attn_mask.dim() == 3:
|
|
if list(attn_mask.size()) != [
|
|
bsz * num_heads,
|
|
query.size(0), key.size(0)
|
|
]:
|
|
raise RuntimeError(
|
|
'The size of the 3D attn_mask is not correct.')
|
|
else:
|
|
raise RuntimeError(
|
|
"attn_mask's dimension {} is not supported".format(
|
|
attn_mask.dim()))
|
|
# attn_mask's dim is 3 now.
|
|
|
|
# convert ByteTensor key_padding_mask to bool
|
|
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
|
warnings.warn(
|
|
'Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.'
|
|
)
|
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
|
|
|
if bias_k is not None and bias_v is not None:
|
|
if static_k is None and static_v is None:
|
|
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
|
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
|
if attn_mask is not None:
|
|
attn_mask = pad(attn_mask, (0, 1))
|
|
if key_padding_mask is not None:
|
|
key_padding_mask = pad(key_padding_mask, (0, 1))
|
|
else:
|
|
assert static_k is None, 'bias cannot be added to static key.'
|
|
assert static_v is None, 'bias cannot be added to static value.'
|
|
else:
|
|
assert bias_k is None
|
|
assert bias_v is None
|
|
|
|
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
|
if k is not None:
|
|
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
|
if v is not None:
|
|
v = v.contiguous().view(-1, bsz * num_heads,
|
|
v_head_dim).transpose(0, 1)
|
|
|
|
if static_k is not None:
|
|
assert static_k.size(0) == bsz * num_heads
|
|
assert static_k.size(2) == head_dim
|
|
k = static_k
|
|
|
|
if static_v is not None:
|
|
assert static_v.size(0) == bsz * num_heads
|
|
assert static_v.size(2) == v_head_dim
|
|
v = static_v
|
|
|
|
src_len = k.size(1)
|
|
|
|
if key_padding_mask is not None:
|
|
assert key_padding_mask.size(0) == bsz
|
|
assert key_padding_mask.size(1) == src_len
|
|
|
|
if add_zero_attn:
|
|
src_len += 1
|
|
k = torch.cat([
|
|
k,
|
|
torch.zeros(
|
|
(k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)
|
|
],
|
|
dim=1)
|
|
v = torch.cat([
|
|
v,
|
|
torch.zeros(
|
|
(v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)
|
|
],
|
|
dim=1)
|
|
if attn_mask is not None:
|
|
attn_mask = pad(attn_mask, (0, 1))
|
|
if key_padding_mask is not None:
|
|
key_padding_mask = pad(key_padding_mask, (0, 1))
|
|
|
|
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
|
assert list(
|
|
attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
|
|
|
if attn_mask is not None:
|
|
if attn_mask.dtype == torch.bool:
|
|
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
|
else:
|
|
attn_output_weights += attn_mask
|
|
|
|
if key_padding_mask is not None:
|
|
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
|
|
src_len)
|
|
attn_output_weights = attn_output_weights.masked_fill(
|
|
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
|
float('-inf'),
|
|
)
|
|
attn_output_weights = attn_output_weights.view(bsz * num_heads,
|
|
tgt_len, src_len)
|
|
|
|
# attn_output_weights = softmax(
|
|
# attn_output_weights, dim=-1)
|
|
attn_output_weights = softmax(
|
|
attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0],
|
|
dim=-1)
|
|
attn_output_weights = dropout(
|
|
attn_output_weights, p=dropout_p, training=training)
|
|
|
|
attn_output = torch.bmm(attn_output_weights, v)
|
|
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim]
|
|
attn_output = attn_output.transpose(0, 1).contiguous().view(
|
|
tgt_len, bsz, out_dim)
|
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
|
|
|
if need_weights:
|
|
# average attention weights over heads
|
|
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len,
|
|
src_len)
|
|
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
|
else:
|
|
return attn_output, None
|