mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
import timm
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule, build_conv_layer
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmseg.registry import MODELS
|
|
|
|
|
|
class UpSampleBN(nn.Module):
|
|
""" UpSample module
|
|
Args:
|
|
skip_input (int): the input feature
|
|
output_features (int): the output feature
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: dict(type='BN', requires_grad=True).
|
|
act_cfg (dict, optional): The activation layer of AAM:
|
|
Aggregate Attention Module.
|
|
"""
|
|
|
|
def __init__(self,
|
|
skip_input,
|
|
output_features,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='LeakyReLU')):
|
|
super().__init__()
|
|
|
|
self._net = nn.Sequential(
|
|
ConvModule(
|
|
in_channels=skip_input,
|
|
out_channels=output_features,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=True,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
),
|
|
ConvModule(
|
|
in_channels=output_features,
|
|
out_channels=output_features,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=True,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
))
|
|
|
|
def forward(self, x, concat_with):
|
|
up_x = F.interpolate(
|
|
x,
|
|
size=[concat_with.size(2),
|
|
concat_with.size(3)],
|
|
mode='bilinear',
|
|
align_corners=True)
|
|
f = torch.cat([up_x, concat_with], dim=1)
|
|
return self._net(f)
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
""" the efficientnet_b5 model
|
|
Args:
|
|
basemodel_name (str): the name of base model
|
|
"""
|
|
|
|
def __init__(self, basemodel_name):
|
|
super().__init__()
|
|
self.original_model = timm.create_model(
|
|
basemodel_name, pretrained=True)
|
|
# Remove last layer
|
|
self.original_model.global_pool = nn.Identity()
|
|
self.original_model.classifier = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
features = [x]
|
|
for k, v in self.original_model._modules.items():
|
|
if k == 'blocks':
|
|
for ki, vi in v._modules.items():
|
|
features.append(vi(features[-1]))
|
|
else:
|
|
features.append(v(features[-1]))
|
|
return features
|
|
|
|
|
|
@MODELS.register_module()
|
|
class AdabinsBackbone(BaseModule):
|
|
""" the backbone of the adabins
|
|
Args:
|
|
basemodel_name (str):the name of base model
|
|
num_features (int): the middle feature
|
|
num_classes (int): the classes number
|
|
bottleneck_features (int): the bottleneck features
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
"""
|
|
|
|
def __init__(self,
|
|
basemodel_name,
|
|
num_features=2048,
|
|
num_classes=128,
|
|
bottleneck_features=2048,
|
|
conv_cfg=dict(type='Conv')):
|
|
super().__init__()
|
|
self.encoder = Encoder(basemodel_name)
|
|
features = int(num_features)
|
|
self.conv2 = build_conv_layer(
|
|
conv_cfg,
|
|
bottleneck_features,
|
|
features,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=1)
|
|
self.up1 = UpSampleBN(
|
|
skip_input=features // 1 + 112 + 64, output_features=features // 2)
|
|
self.up2 = UpSampleBN(
|
|
skip_input=features // 2 + 40 + 24, output_features=features // 4)
|
|
self.up3 = UpSampleBN(
|
|
skip_input=features // 4 + 24 + 16, output_features=features // 8)
|
|
self.up4 = UpSampleBN(
|
|
skip_input=features // 8 + 16 + 8, output_features=features // 16)
|
|
|
|
self.conv3 = build_conv_layer(
|
|
conv_cfg,
|
|
features // 16,
|
|
num_classes,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
def forward(self, x):
|
|
features = self.encoder(x)
|
|
x_block0, x_block1, x_block2, x_block3, x_block4 = features[
|
|
3], features[4], features[5], features[7], features[10]
|
|
x_d0 = self.conv2(x_block4)
|
|
x_d1 = self.up1(x_d0, x_block3)
|
|
x_d2 = self.up2(x_d1, x_block2)
|
|
x_d3 = self.up3(x_d2, x_block1)
|
|
x_d4 = self.up4(x_d3, x_block0)
|
|
out = self.conv3(x_d4)
|
|
return out
|