68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_norm_layer
|
|
|
|
from ..builder import NECKS
|
|
|
|
|
|
@NECKS.register_module()
|
|
class Feature2Pyramid(nn.Module):
|
|
"""Feature2Pyramid.
|
|
|
|
A neck structure connect ViT backbone and decoder_heads.
|
|
|
|
Args:
|
|
embed_dims (int): Embedding dimension.
|
|
rescales (list[float]): Different sampling multiples were
|
|
used to obtain pyramid features. Default: [4, 2, 1, 0.5].
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='SyncBN', requires_grad=True).
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dim,
|
|
rescales=[4, 2, 1, 0.5],
|
|
norm_cfg=dict(type='SyncBN', requires_grad=True)):
|
|
super(Feature2Pyramid, self).__init__()
|
|
self.rescales = rescales
|
|
self.upsample_4x = None
|
|
for k in self.rescales:
|
|
if k == 4:
|
|
self.upsample_4x = nn.Sequential(
|
|
nn.ConvTranspose2d(
|
|
embed_dim, embed_dim, kernel_size=2, stride=2),
|
|
build_norm_layer(norm_cfg, embed_dim)[1],
|
|
nn.GELU(),
|
|
nn.ConvTranspose2d(
|
|
embed_dim, embed_dim, kernel_size=2, stride=2),
|
|
)
|
|
elif k == 2:
|
|
self.upsample_2x = nn.Sequential(
|
|
nn.ConvTranspose2d(
|
|
embed_dim, embed_dim, kernel_size=2, stride=2))
|
|
elif k == 1:
|
|
self.identity = nn.Identity()
|
|
elif k == 0.5:
|
|
self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
elif k == 0.25:
|
|
self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
|
|
else:
|
|
raise KeyError(f'invalid {k} for feature2pyramid')
|
|
|
|
def forward(self, inputs):
|
|
assert len(inputs) == len(self.rescales)
|
|
outputs = []
|
|
if self.upsample_4x is not None:
|
|
ops = [
|
|
self.upsample_4x, self.upsample_2x, self.identity,
|
|
self.downsample_2x
|
|
]
|
|
else:
|
|
ops = [
|
|
self.upsample_2x, self.identity, self.downsample_2x,
|
|
self.downsample_4x
|
|
]
|
|
for i in range(len(inputs)):
|
|
outputs.append(ops[i](inputs[i]))
|
|
return tuple(outputs)
|