# Copyright (c) OpenMMLab. All rights reserved. import torch.nn as nn from mmcv.runner.base_module import BaseModule class ConditionalPositionEncoding(BaseModule): """The Conditional Position Encoding (CPE) module. The CPE is the implementation of 'Conditional Positional Encodings for Vision Transformers '_. Args: in_channels (int): Number of input channels. embed_dims (int): The feature dimension. Default: 768. stride (int): Stride of conv layer. Default: 1. """ def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) self.proj = nn.Conv2d( in_channels, embed_dims, kernel_size=3, stride=stride, padding=1, bias=True, groups=embed_dims) self.stride = stride def forward(self, x, hw_shape): B, N, C = x.shape H, W = hw_shape feat_token = x # convert (B, N, C) to (B, C, H, W) cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) if self.stride == 1: x = self.proj(cnn_feat) + cnn_feat else: x = self.proj(cnn_feat) x = x.flatten(2).transpose(1, 2) return x