mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
180 lines
6.2 KiB
Python
180 lines
6.2 KiB
Python
from typing import List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import build_conv_layer
|
|
from torch import Tensor
|
|
|
|
from mmseg.registry import MODELS
|
|
|
|
|
|
class PatchTransformerEncoder(nn.Module):
|
|
"""the Patch Transformer Encoder.
|
|
|
|
Args:
|
|
in_channels (int): the channels of input
|
|
patch_size (int): the path size
|
|
embedding_dim (int): The feature dimension.
|
|
num_heads (int): the number of encoder head
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
patch_size=10,
|
|
embedding_dim=128,
|
|
num_heads=4,
|
|
conv_cfg=dict(type='Conv')):
|
|
super().__init__()
|
|
encoder_layers = nn.TransformerEncoderLayer(
|
|
embedding_dim, num_heads, dim_feedforward=1024)
|
|
self.transformer_encoder = nn.TransformerEncoder(
|
|
encoder_layers, num_layers=4) # takes shape S,N,E
|
|
|
|
self.embedding_convPxP = build_conv_layer(
|
|
conv_cfg,
|
|
in_channels,
|
|
embedding_dim,
|
|
kernel_size=patch_size,
|
|
stride=patch_size)
|
|
self.positional_encodings = nn.Parameter(
|
|
torch.rand(500, embedding_dim), requires_grad=True)
|
|
|
|
def forward(self, x):
|
|
embeddings = self.embedding_convPxP(x).flatten(
|
|
2) # .shape = n,c,s = n, embedding_dim, s
|
|
embeddings = embeddings + self.positional_encodings[:embeddings.shape[
|
|
2], :].T.unsqueeze(0)
|
|
|
|
# change to S,N,E format required by transformer
|
|
embeddings = embeddings.permute(2, 0, 1)
|
|
x = self.transformer_encoder(embeddings) # .shape = S, N, E
|
|
return x
|
|
|
|
|
|
class PixelWiseDotProduct(nn.Module):
|
|
"""the pixel wise dot product."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, K):
|
|
n, c, h, w = x.size()
|
|
_, cout, ck = K.size()
|
|
assert c == ck, 'Number of channels in x and Embedding dimension ' \
|
|
'(at dim 2) of K matrix must match'
|
|
y = torch.matmul(
|
|
x.view(n, c, h * w).permute(0, 2, 1),
|
|
K.permute(0, 2, 1)) # .shape = n, hw, cout
|
|
return y.permute(0, 2, 1).view(n, cout, h, w)
|
|
|
|
|
|
@MODELS.register_module()
|
|
class AdabinsHead(nn.Module):
|
|
"""the head of the adabins,include mViT.
|
|
|
|
Args:
|
|
in_channels (int):the channels of the input
|
|
n_query_channels (int):the channels of the query
|
|
patch_size (int): the patch size
|
|
embedding_dim (int):The feature dimension.
|
|
num_heads (int):the number of head
|
|
n_bins (int):the number of bins
|
|
min_val (float): the min width of bin
|
|
max_val (float): the max width of bin
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
norm (str): the activate method
|
|
align_corners (bool, optional): Geometrically, we consider the pixels
|
|
of the input and output as squares rather than points.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
n_query_channels=128,
|
|
patch_size=16,
|
|
embedding_dim=128,
|
|
num_heads=4,
|
|
n_bins=100,
|
|
min_val=0.1,
|
|
max_val=10,
|
|
conv_cfg=dict(type='Conv'),
|
|
norm='linear',
|
|
align_corners=False,
|
|
threshold=0):
|
|
super().__init__()
|
|
self.out_channels = n_bins
|
|
self.align_corners = align_corners
|
|
self.norm = norm
|
|
self.num_classes = n_bins
|
|
self.min_val = min_val
|
|
self.max_val = max_val
|
|
self.n_query_channels = n_query_channels
|
|
self.patch_transformer = PatchTransformerEncoder(
|
|
in_channels, patch_size, embedding_dim, num_heads)
|
|
self.dot_product_layer = PixelWiseDotProduct()
|
|
self.threshold = threshold
|
|
self.conv3x3 = build_conv_layer(
|
|
conv_cfg,
|
|
in_channels,
|
|
embedding_dim,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
self.regressor = nn.Sequential(
|
|
nn.Linear(embedding_dim, 256), nn.LeakyReLU(), nn.Linear(256, 256),
|
|
nn.LeakyReLU(), nn.Linear(256, n_bins))
|
|
self.conv_out = nn.Sequential(
|
|
build_conv_layer(conv_cfg, in_channels, n_bins, kernel_size=1),
|
|
nn.Softmax(dim=1))
|
|
|
|
def forward(self, x):
|
|
# n, c, h, w = x.size()
|
|
tgt = self.patch_transformer(x.clone()) # .shape = S, N, E
|
|
|
|
x = self.conv3x3(x)
|
|
|
|
regression_head, queries = tgt[0,
|
|
...], tgt[1:self.n_query_channels + 1,
|
|
...]
|
|
|
|
# Change from S, N, E to N, S, E
|
|
queries = queries.permute(1, 0, 2)
|
|
range_attention_maps = self.dot_product_layer(
|
|
x, queries) # .shape = n, n_query_channels, h, w
|
|
|
|
y = self.regressor(regression_head) # .shape = N, dim_out
|
|
if self.norm == 'linear':
|
|
y = torch.relu(y)
|
|
eps = 0.1
|
|
y = y + eps
|
|
elif self.norm == 'softmax':
|
|
return torch.softmax(y, dim=1), range_attention_maps
|
|
else:
|
|
y = torch.sigmoid(y)
|
|
bin_widths_normed = y / y.sum(dim=1, keepdim=True)
|
|
out = self.conv_out(range_attention_maps)
|
|
|
|
bin_widths = (self.max_val -
|
|
self.min_val) * bin_widths_normed # .shape = N, dim_out
|
|
bin_widths = F.pad(
|
|
bin_widths, (1, 0), mode='constant', value=self.min_val)
|
|
bin_edges = torch.cumsum(bin_widths, dim=1)
|
|
|
|
centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
|
|
n, dim_out = centers.size()
|
|
centers = centers.view(n, dim_out, 1, 1)
|
|
|
|
pred = torch.sum(out * centers, dim=1, keepdim=True)
|
|
return bin_edges, pred
|
|
|
|
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
|
test_cfg, **kwargs) -> Tensor:
|
|
"""Forward function for testing, only ``pam_cam`` is used."""
|
|
pred = self.forward(inputs)[-1]
|
|
final = torch.clamp(pred, self.min_val, self.max_val)
|
|
|
|
final[torch.isinf(final)] = self.max_val
|
|
final[torch.isnan(final)] = self.min_val
|
|
return final
|