[Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`. (#1434)

* [Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`.

* Support `--local-rank` and `--amp` option for new version PyTorch.

* Fix imports and UT.
pull/1447/head
Ma Zerun 2023-03-29 15:50:44 +08:00 committed by GitHub
parent 164f16e248
commit b017670e1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 19 deletions

View File

@ -207,8 +207,8 @@ workflows:
- lint
- build_cpu_with_3rdparty:
name: maximum_version_cpu
torch: 1.13.0
torchvision: 0.14.0
torch: 2.0.0
torchvision: 0.15.0
python: 3.10.0
requires:
- minimum_version_cpu

View File

@ -95,11 +95,10 @@ class GlobalSubsampledAttention(MultiheadAttention):
self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn_drop = self.attn_drop if self.training else 0.
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.out_drop(self.proj_drop(x))

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from functools import partial
from typing import List, Optional, Union
import numpy as np
@ -19,12 +20,39 @@ from .layer_scale import LayerScale
# will raise extra warning. For more details,
# refers to https://github.com/pytorch/pytorch/issues/50276
if digit_version(torch.__version__) >= digit_version('1.10.0'):
from functools import partial
torch_meshgrid = partial(torch.meshgrid, indexing='ij')
else:
torch_meshgrid = torch.meshgrid
def scaled_dot_product_attention_pyimpl(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
scale = scale or query.size(-1)**0.5
if is_causal and attn_mask is not None:
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
attn_weight = query @ key.transpose(-2, -1) / scale
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value
if digit_version(torch.__version__) >= digit_version('2.0.0'):
scaled_dot_product_attention = F.scaled_dot_product_attention
else:
scaled_dot_product_attention = scaled_dot_product_attention_pyimpl
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
@ -525,10 +553,15 @@ class MultiheadAttention(BaseModule):
self.v_shortcut = v_shortcut
self.head_dims = embed_dims // num_heads
self.scale = qk_scale or self.head_dims**-0.5
if qk_scale is not None:
self.scaled_dot_product_attention = partial(
scaled_dot_product_attention_pyimpl,
scale=self.head_dims**-0.5)
else:
self.scaled_dot_product_attention = scaled_dot_product_attention
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.attn_drop = attn_drop
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
@ -545,11 +578,10 @@ class MultiheadAttention(BaseModule):
self.head_dims).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn_drop = self.attn_drop if self.training else 0.
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
x = self.out_drop(self.gamma1(self.proj_drop(x)))
@ -1066,11 +1098,10 @@ class PromptMultiheadAttention(MultiheadAttention):
self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn_drop = self.attn_drop if self.training else 0.
attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
x = (attn @ v).transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
x = self.proj(x)
x = self.out_drop(self.gamma1(self.proj_drop(x)))
return x

View File

@ -34,6 +34,10 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision test')
parser.add_argument(
'--show-dir',
help='directory where the visualization images will be saved.')
@ -67,7 +71,10 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
@ -89,6 +96,10 @@ def merge_args(cfg, args):
cfg.load_from = args.checkpoint
# enable automatic-mixed-precision test
if args.amp:
cfg.test_cfg.fp16 = True
# -------------------- visualization --------------------
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \

View File

@ -59,7 +59,10 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)