[Improve] Use PyTorch official scaled_dot_product_attention to accelerate MultiheadAttention. (#1434)

* [Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`.

* Support `--local-rank` and `--amp` option for new version PyTorch.

* Fix imports and UT.
This commit is contained in:
Ma Zerun 2023-03-29 15:50:44 +08:00 committed by GitHub
parent 164f16e248
commit b017670e1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 63 additions and 19 deletions

View File

@ -207,8 +207,8 @@ workflows:
- lint - lint
- build_cpu_with_3rdparty: - build_cpu_with_3rdparty:
name: maximum_version_cpu name: maximum_version_cpu
torch: 1.13.0 torch: 2.0.0
torchvision: 0.14.0 torchvision: 0.15.0
python: 3.10.0 python: 3.10.0
requires: requires:
- minimum_version_cpu - minimum_version_cpu

View File

@ -95,11 +95,10 @@ class GlobalSubsampledAttention(MultiheadAttention):
self.head_dims).permute(2, 0, 3, 1, 4) self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1] k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale attn_drop = self.attn_drop if self.training else 0.
attn = attn.softmax(dim=-1) x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
attn = self.attn_drop(attn) x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.out_drop(self.proj_drop(x)) x = self.out_drop(self.proj_drop(x))

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import itertools import itertools
from functools import partial
from typing import List, Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
@ -19,12 +20,39 @@ from .layer_scale import LayerScale
# will raise extra warning. For more details, # will raise extra warning. For more details,
# refers to https://github.com/pytorch/pytorch/issues/50276 # refers to https://github.com/pytorch/pytorch/issues/50276
if digit_version(torch.__version__) >= digit_version('1.10.0'): if digit_version(torch.__version__) >= digit_version('1.10.0'):
from functools import partial
torch_meshgrid = partial(torch.meshgrid, indexing='ij') torch_meshgrid = partial(torch.meshgrid, indexing='ij')
else: else:
torch_meshgrid = torch.meshgrid torch_meshgrid = torch.meshgrid
def scaled_dot_product_attention_pyimpl(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
scale = scale or query.size(-1)**0.5
if is_causal and attn_mask is not None:
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
attn_weight = query @ key.transpose(-2, -1) / scale
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value
if digit_version(torch.__version__) >= digit_version('2.0.0'):
scaled_dot_product_attention = F.scaled_dot_product_attention
else:
scaled_dot_product_attention = scaled_dot_product_attention_pyimpl
class WindowMSA(BaseModule): class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative """Window based multi-head self-attention (W-MSA) module with relative
position bias. position bias.
@ -525,10 +553,15 @@ class MultiheadAttention(BaseModule):
self.v_shortcut = v_shortcut self.v_shortcut = v_shortcut
self.head_dims = embed_dims // num_heads self.head_dims = embed_dims // num_heads
self.scale = qk_scale or self.head_dims**-0.5 if qk_scale is not None:
self.scaled_dot_product_attention = partial(
scaled_dot_product_attention_pyimpl,
scale=self.head_dims**-0.5)
else:
self.scaled_dot_product_attention = scaled_dot_product_attention
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = attn_drop
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
@ -545,11 +578,10 @@ class MultiheadAttention(BaseModule):
self.head_dims).permute(2, 0, 3, 1, 4) self.head_dims).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale attn_drop = self.attn_drop if self.training else 0.
attn = attn.softmax(dim=-1) x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
attn = self.attn_drop(attn) x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x) x = self.proj(x)
x = self.out_drop(self.gamma1(self.proj_drop(x))) x = self.out_drop(self.gamma1(self.proj_drop(x)))
@ -1066,11 +1098,10 @@ class PromptMultiheadAttention(MultiheadAttention):
self.head_dims).permute(2, 0, 3, 1, 4) self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1] k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale attn_drop = self.attn_drop if self.training else 0.
attn = attn.softmax(dim=-1) attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
attn = self.attn_drop(attn) x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
x = (attn @ v).transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
x = self.proj(x) x = self.proj(x)
x = self.out_drop(self.gamma1(self.proj_drop(x))) x = self.out_drop(self.gamma1(self.proj_drop(x)))
return x return x

View File

@ -34,6 +34,10 @@ def parse_args():
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space ' 'Note that the quotation marks are necessary and that no white space '
'is allowed.') 'is allowed.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision test')
parser.add_argument( parser.add_argument(
'--show-dir', '--show-dir',
help='directory where the visualization images will be saved.') help='directory where the visualization images will be saved.')
@ -67,7 +71,10 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(args.local_rank)
@ -89,6 +96,10 @@ def merge_args(cfg, args):
cfg.load_from = args.checkpoint cfg.load_from = args.checkpoint
# enable automatic-mixed-precision test
if args.amp:
cfg.test_cfg.fp16 = True
# -------------------- visualization -------------------- # -------------------- visualization --------------------
if args.show or (args.show_dir is not None): if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \ assert 'visualization' in cfg.default_hooks, \

View File

@ -59,7 +59,10 @@ def parse_args():
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none', default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) # When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ: if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank) os.environ['LOCAL_RANK'] = str(args.local_rank)