42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
from mmcv.runner.base_module import BaseModule
|
|
|
|
|
|
class ConditionalPositionEncoding(BaseModule):
|
|
"""The Conditional Position Encoding (CPE) module.
|
|
|
|
The CPE is the implementation of 'Conditional Positional Encodings
|
|
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
embed_dims (int): The feature dimension. Default: 768.
|
|
stride (int): Stride of conv layer. Default: 1.
|
|
"""
|
|
|
|
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
|
|
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
|
|
self.proj = nn.Conv2d(
|
|
in_channels,
|
|
embed_dims,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=1,
|
|
bias=True,
|
|
groups=embed_dims)
|
|
self.stride = stride
|
|
|
|
def forward(self, x, hw_shape):
|
|
B, N, C = x.shape
|
|
H, W = hw_shape
|
|
feat_token = x
|
|
# convert (B, N, C) to (B, C, H, W)
|
|
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
|
|
if self.stride == 1:
|
|
x = self.proj(cnn_feat) + cnn_feat
|
|
else:
|
|
x = self.proj(cnn_feat)
|
|
x = x.flatten(2).transpose(1, 2)
|
|
return x
|