[Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`. (#1434)
* [Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`. * Support `--local-rank` and `--amp` option for new version PyTorch. * Fix imports and UT.pull/1447/head
parent
164f16e248
commit
b017670e1b
|
@ -207,8 +207,8 @@ workflows:
|
|||
- lint
|
||||
- build_cpu_with_3rdparty:
|
||||
name: maximum_version_cpu
|
||||
torch: 1.13.0
|
||||
torchvision: 0.14.0
|
||||
torch: 2.0.0
|
||||
torchvision: 0.15.0
|
||||
python: 3.10.0
|
||||
requires:
|
||||
- minimum_version_cpu
|
||||
|
|
|
@ -95,11 +95,10 @@ class GlobalSubsampledAttention(MultiheadAttention):
|
|||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn_drop = self.attn_drop if self.training else 0.
|
||||
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.out_drop(self.proj_drop(x))
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -19,12 +20,39 @@ from .layer_scale import LayerScale
|
|||
# will raise extra warning. For more details,
|
||||
# refers to https://github.com/pytorch/pytorch/issues/50276
|
||||
if digit_version(torch.__version__) >= digit_version('1.10.0'):
|
||||
from functools import partial
|
||||
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
|
||||
else:
|
||||
torch_meshgrid = torch.meshgrid
|
||||
|
||||
|
||||
def scaled_dot_product_attention_pyimpl(query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.,
|
||||
scale=None,
|
||||
is_causal=False):
|
||||
scale = scale or query.size(-1)**0.5
|
||||
if is_causal and attn_mask is not None:
|
||||
attn_mask = torch.ones(
|
||||
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
|
||||
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
||||
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
|
||||
|
||||
attn_weight = query @ key.transpose(-2, -1) / scale
|
||||
if attn_mask is not None:
|
||||
attn_weight += attn_mask
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
attn_weight = torch.dropout(attn_weight, dropout_p, True)
|
||||
return attn_weight @ value
|
||||
|
||||
|
||||
if digit_version(torch.__version__) >= digit_version('2.0.0'):
|
||||
scaled_dot_product_attention = F.scaled_dot_product_attention
|
||||
else:
|
||||
scaled_dot_product_attention = scaled_dot_product_attention_pyimpl
|
||||
|
||||
|
||||
class WindowMSA(BaseModule):
|
||||
"""Window based multi-head self-attention (W-MSA) module with relative
|
||||
position bias.
|
||||
|
@ -525,10 +553,15 @@ class MultiheadAttention(BaseModule):
|
|||
self.v_shortcut = v_shortcut
|
||||
|
||||
self.head_dims = embed_dims // num_heads
|
||||
self.scale = qk_scale or self.head_dims**-0.5
|
||||
if qk_scale is not None:
|
||||
self.scaled_dot_product_attention = partial(
|
||||
scaled_dot_product_attention_pyimpl,
|
||||
scale=self.head_dims**-0.5)
|
||||
else:
|
||||
self.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
|
||||
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.attn_drop = attn_drop
|
||||
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
|
@ -545,11 +578,10 @@ class MultiheadAttention(BaseModule):
|
|||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn_drop = self.attn_drop if self.training else 0.
|
||||
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
|
||||
x = self.proj(x)
|
||||
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
||||
|
||||
|
@ -1066,11 +1098,10 @@ class PromptMultiheadAttention(MultiheadAttention):
|
|||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn_drop = self.attn_drop if self.training else 0.
|
||||
attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
|
||||
x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
|
||||
x = self.proj(x)
|
||||
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
||||
return x
|
||||
|
|
|
@ -34,6 +34,10 @@ def parse_args():
|
|||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
help='enable automatic-mixed-precision test')
|
||||
parser.add_argument(
|
||||
'--show-dir',
|
||||
help='directory where the visualization images will be saved.')
|
||||
|
@ -67,7 +71,10 @@ def parse_args():
|
|||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
|
||||
# will pass the `--local-rank` parameter to `tools/train.py` instead
|
||||
# of `--local_rank`.
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
@ -89,6 +96,10 @@ def merge_args(cfg, args):
|
|||
|
||||
cfg.load_from = args.checkpoint
|
||||
|
||||
# enable automatic-mixed-precision test
|
||||
if args.amp:
|
||||
cfg.test_cfg.fp16 = True
|
||||
|
||||
# -------------------- visualization --------------------
|
||||
if args.show or (args.show_dir is not None):
|
||||
assert 'visualization' in cfg.default_hooks, \
|
||||
|
|
|
@ -59,7 +59,10 @@ def parse_args():
|
|||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
|
||||
# will pass the `--local-rank` parameter to `tools/train.py` instead
|
||||
# of `--local_rank`.
|
||||
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
|
Loading…
Reference in New Issue