mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
fix EfficientMultiheadAttention in SegFormer (#1037)
This commit is contained in:
parent
7a1c9a5499
commit
d665f6b085
@ -146,8 +146,49 @@ class EfficientMultiheadAttention(MultiheadAttention):
|
||||
# The ret[0] of build_norm_layer is norm name.
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
|
||||
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
|
||||
from mmseg import digit_version, mmcv_version
|
||||
if mmcv_version < digit_version('1.3.17'):
|
||||
warnings.warn('The legacy version of forward function in'
|
||||
'EfficientMultiheadAttention is deprecated in'
|
||||
'mmcv>=1.3.17 and will no longer support in the'
|
||||
'future. Please upgrade your mmcv.')
|
||||
self.forward = self.legacy_forward
|
||||
|
||||
def forward(self, x, hw_shape, identity=None):
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
x_kv = self.sr(x_kv)
|
||||
x_kv = nchw_to_nlc(x_kv)
|
||||
x_kv = self.norm(x_kv)
|
||||
else:
|
||||
x_kv = x
|
||||
|
||||
if identity is None:
|
||||
identity = x_q
|
||||
|
||||
# Because the dataflow('key', 'query', 'value') of
|
||||
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
||||
# embed_dims), We should adjust the shape of dataflow from
|
||||
# batch_first (batch, num_query, embed_dims) to num_query_first
|
||||
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
||||
# from num_query_first to batch_first.
|
||||
if self.batch_first:
|
||||
x_q = x_q.transpose(0, 1)
|
||||
x_kv = x_kv.transpose(0, 1)
|
||||
|
||||
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
||||
|
||||
if self.batch_first:
|
||||
out = out.transpose(0, 1)
|
||||
|
||||
return identity + self.dropout_layer(self.proj_drop(out))
|
||||
|
||||
def legacy_forward(self, x, hw_shape, identity=None):
|
||||
"""multi head attention forward in mmcv version < 1.3.17."""
|
||||
|
||||
x_q = x
|
||||
if self.sr_ratio > 1:
|
||||
x_kv = nlc_to_nchw(x, hw_shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user