mmclassification/mmcls/models/utils/position_encoding.py

42 lines
1.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner.base_module import BaseModule
class ConditionalPositionEncoding(BaseModule):
"""The Conditional Position Encoding (CPE) module.
The CPE is the implementation of 'Conditional Positional Encodings
for Vision Transformers <https://arxiv.org/abs/2102.10882>'_.
Args:
in_channels (int): Number of input channels.
embed_dims (int): The feature dimension. Default: 768.
stride (int): Stride of conv layer. Default: 1.
"""
def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
self.proj = nn.Conv2d(
in_channels,
embed_dims,
kernel_size=3,
stride=stride,
padding=1,
bias=True,
groups=embed_dims)
self.stride = stride
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
feat_token = x
# convert (B, N, C) to (B, C, H, W)
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
if self.stride == 1:
x = self.proj(cnn_feat) + cnn_feat
else:
x = self.proj(cnn_feat)
x = x.flatten(2).transpose(1, 2)
return x