From b017670e1bb5bfccfa3550cbdbf68c78889ce79a Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Wed, 29 Mar 2023 15:50:44 +0800 Subject: [PATCH] [Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`. (#1434) * [Improve] Use PyTorch official `scaled_dot_product_attention` to accelerate `MultiheadAttention`. * Support `--local-rank` and `--amp` option for new version PyTorch. * Fix imports and UT. --- .circleci/test.yml | 4 +-- mmpretrain/models/backbones/twins.py | 7 ++-- mmpretrain/models/utils/attention.py | 53 ++++++++++++++++++++++------ tools/test.py | 13 ++++++- tools/train.py | 5 ++- 5 files changed, 63 insertions(+), 19 deletions(-) diff --git a/.circleci/test.yml b/.circleci/test.yml index e78a1e1ff..c2fa5cf4b 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -207,8 +207,8 @@ workflows: - lint - build_cpu_with_3rdparty: name: maximum_version_cpu - torch: 1.13.0 - torchvision: 0.14.0 + torch: 2.0.0 + torchvision: 0.15.0 python: 3.10.0 requires: - minimum_version_cpu diff --git a/mmpretrain/models/backbones/twins.py b/mmpretrain/models/backbones/twins.py index 429167066..be55c02db 100644 --- a/mmpretrain/models/backbones/twins.py +++ b/mmpretrain/models/backbones/twins.py @@ -95,11 +95,10 @@ class GlobalSubsampledAttention(MultiheadAttention): self.head_dims).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) - x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.out_drop(self.proj_drop(x)) diff --git a/mmpretrain/models/utils/attention.py b/mmpretrain/models/utils/attention.py index 5b1e4b1af..35e49a3a2 100644 --- a/mmpretrain/models/utils/attention.py +++ b/mmpretrain/models/utils/attention.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import itertools +from functools import partial from typing import List, Optional, Union import numpy as np @@ -19,12 +20,39 @@ from .layer_scale import LayerScale # will raise extra warning. For more details, # refers to https://github.com/pytorch/pytorch/issues/50276 if digit_version(torch.__version__) >= digit_version('1.10.0'): - from functools import partial torch_meshgrid = partial(torch.meshgrid, indexing='ij') else: torch_meshgrid = torch.meshgrid +def scaled_dot_product_attention_pyimpl(query, + key, + value, + attn_mask=None, + dropout_p=0., + scale=None, + is_causal=False): + scale = scale or query.size(-1)**0.5 + if is_causal and attn_mask is not None: + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) + + attn_weight = query @ key.transpose(-2, -1) / scale + if attn_mask is not None: + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +if digit_version(torch.__version__) >= digit_version('2.0.0'): + scaled_dot_product_attention = F.scaled_dot_product_attention +else: + scaled_dot_product_attention = scaled_dot_product_attention_pyimpl + + class WindowMSA(BaseModule): """Window based multi-head self-attention (W-MSA) module with relative position bias. @@ -525,10 +553,15 @@ class MultiheadAttention(BaseModule): self.v_shortcut = v_shortcut self.head_dims = embed_dims // num_heads - self.scale = qk_scale or self.head_dims**-0.5 + if qk_scale is not None: + self.scaled_dot_product_attention = partial( + scaled_dot_product_attention_pyimpl, + scale=self.head_dims**-0.5) + else: + self.scaled_dot_product_attention = scaled_dot_product_attention self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) + self.attn_drop = attn_drop self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) @@ -545,11 +578,10 @@ class MultiheadAttention(BaseModule): self.head_dims).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + attn_drop = self.attn_drop if self.training else 0. + x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = x.transpose(1, 2).reshape(B, N, self.embed_dims) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims) x = self.proj(x) x = self.out_drop(self.gamma1(self.proj_drop(x))) @@ -1066,11 +1098,10 @@ class PromptMultiheadAttention(MultiheadAttention): self.head_dims).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + attn_drop = self.attn_drop if self.training else 0. + attn = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop) + x = attn.transpose(1, 2).reshape(B, x.shape[1], self.embed_dims) - x = (attn @ v).transpose(1, 2).reshape(B, x.shape[1], self.embed_dims) x = self.proj(x) x = self.out_drop(self.gamma1(self.proj_drop(x))) return x diff --git a/tools/test.py b/tools/test.py index 3d62a95b1..426644959 100644 --- a/tools/test.py +++ b/tools/test.py @@ -34,6 +34,10 @@ def parse_args(): 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') + parser.add_argument( + '--amp', + action='store_true', + help='enable automatic-mixed-precision test') parser.add_argument( '--show-dir', help='directory where the visualization images will be saved.') @@ -67,7 +71,10 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -89,6 +96,10 @@ def merge_args(cfg, args): cfg.load_from = args.checkpoint + # enable automatic-mixed-precision test + if args.amp: + cfg.test_cfg.fp16 = True + # -------------------- visualization -------------------- if args.show or (args.show_dir is not None): assert 'visualization' in cfg.default_hooks, \ diff --git a/tools/train.py b/tools/train.py index 3ba0b1974..35413f966 100644 --- a/tools/train.py +++ b/tools/train.py @@ -59,7 +59,10 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') - parser.add_argument('--local_rank', type=int, default=0) + # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` + # will pass the `--local-rank` parameter to `tools/train.py` instead + # of `--local_rank`. + parser.add_argument('--local_rank', '--local-rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank)