162 lines
6.1 KiB
Python
162 lines
6.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmcv.cnn import ConvModule
|
|
from mmengine.model.weight_init import constant_init
|
|
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class SelfAttentionBlock(nn.Module):
|
|
"""General self-attention block/non-local block.
|
|
|
|
Please refer to https://arxiv.org/abs/1706.03762 for details about key,
|
|
query and value.
|
|
|
|
Args:
|
|
key_in_channels (int): Input channels of key feature.
|
|
query_in_channels (int): Input channels of query feature.
|
|
channels (int): Output channels of key/query transform.
|
|
out_channels (int): Output channels.
|
|
share_key_query (bool): Whether share projection weight between key
|
|
and query projection.
|
|
query_downsample (nn.Module): Query downsample module.
|
|
key_downsample (nn.Module): Key downsample module.
|
|
key_query_num_convs (int): Number of convs for key/query projection.
|
|
value_num_convs (int): Number of convs for value projection.
|
|
matmul_norm (bool): Whether normalize attention map with sqrt of
|
|
channels
|
|
with_out (bool): Whether use out projection.
|
|
conv_cfg (dict|None): Config of conv layers.
|
|
norm_cfg (dict|None): Config of norm layers.
|
|
act_cfg (dict|None): Config of activation layers.
|
|
"""
|
|
|
|
def __init__(self, key_in_channels, query_in_channels, channels,
|
|
out_channels, share_key_query, query_downsample,
|
|
key_downsample, key_query_num_convs, value_out_num_convs,
|
|
key_query_norm, value_out_norm, matmul_norm, with_out,
|
|
conv_cfg, norm_cfg, act_cfg):
|
|
super(SelfAttentionBlock, self).__init__()
|
|
if share_key_query:
|
|
assert key_in_channels == query_in_channels
|
|
self.key_in_channels = key_in_channels
|
|
self.query_in_channels = query_in_channels
|
|
self.out_channels = out_channels
|
|
self.channels = channels
|
|
self.share_key_query = share_key_query
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.key_project = self.build_project(
|
|
key_in_channels,
|
|
channels,
|
|
num_convs=key_query_num_convs,
|
|
use_conv_module=key_query_norm,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
if share_key_query:
|
|
self.query_project = self.key_project
|
|
else:
|
|
self.query_project = self.build_project(
|
|
query_in_channels,
|
|
channels,
|
|
num_convs=key_query_num_convs,
|
|
use_conv_module=key_query_norm,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
self.value_project = self.build_project(
|
|
key_in_channels,
|
|
channels if with_out else out_channels,
|
|
num_convs=value_out_num_convs,
|
|
use_conv_module=value_out_norm,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
if with_out:
|
|
self.out_project = self.build_project(
|
|
channels,
|
|
out_channels,
|
|
num_convs=value_out_num_convs,
|
|
use_conv_module=value_out_norm,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
else:
|
|
self.out_project = None
|
|
|
|
self.query_downsample = query_downsample
|
|
self.key_downsample = key_downsample
|
|
self.matmul_norm = matmul_norm
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
"""Initialize weight of later layer."""
|
|
if self.out_project is not None:
|
|
if not isinstance(self.out_project, ConvModule):
|
|
constant_init(self.out_project, 0)
|
|
|
|
def build_project(self, in_channels, channels, num_convs, use_conv_module,
|
|
conv_cfg, norm_cfg, act_cfg):
|
|
"""Build projection layer for key/query/value/out."""
|
|
if use_conv_module:
|
|
convs = [
|
|
ConvModule(
|
|
in_channels,
|
|
channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
]
|
|
for _ in range(num_convs - 1):
|
|
convs.append(
|
|
ConvModule(
|
|
channels,
|
|
channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
else:
|
|
convs = [nn.Conv2d(in_channels, channels, 1)]
|
|
for _ in range(num_convs - 1):
|
|
convs.append(nn.Conv2d(channels, channels, 1))
|
|
if len(convs) > 1:
|
|
convs = nn.Sequential(*convs)
|
|
else:
|
|
convs = convs[0]
|
|
return convs
|
|
|
|
def forward(self, query_feats, key_feats):
|
|
"""Forward function."""
|
|
batch_size = query_feats.size(0)
|
|
query = self.query_project(query_feats)
|
|
if self.query_downsample is not None:
|
|
query = self.query_downsample(query)
|
|
query = query.reshape(*query.shape[:2], -1)
|
|
query = query.permute(0, 2, 1).contiguous()
|
|
|
|
key = self.key_project(key_feats)
|
|
value = self.value_project(key_feats)
|
|
if self.key_downsample is not None:
|
|
key = self.key_downsample(key)
|
|
value = self.key_downsample(value)
|
|
key = key.reshape(*key.shape[:2], -1)
|
|
value = value.reshape(*value.shape[:2], -1)
|
|
value = value.permute(0, 2, 1).contiguous()
|
|
|
|
sim_map = torch.matmul(query, key)
|
|
if self.matmul_norm:
|
|
sim_map = (self.channels**-.5) * sim_map
|
|
sim_map = F.softmax(sim_map, dim=-1)
|
|
|
|
context = torch.matmul(sim_map, value)
|
|
context = context.permute(0, 2, 1).contiguous()
|
|
context = context.reshape(batch_size, -1, *query_feats.shape[2:])
|
|
if self.out_project is not None:
|
|
context = self.out_project(context)
|
|
return context
|