2023-08-09 23:57:01 +08:00

959 lines
30 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer
from mmcv.cnn.bricks.transformer import build_dropout
from mmengine.logging import print_log
from mmengine.model import BaseModule
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from mmseg.registry import MODELS
@MODELS.register_module()
class StrideFormer(BaseModule):
"""The StrideFormer implementation based on torch.
The original article refers to:https://arxiv.org/abs/2304.05152
Args:
mobileV3_cfg(list): Each sublist describe the config for a
MobileNetV3 block.
channels(list): The input channels for each MobileNetV3 block.
embed_dims(list): The channels of the features input to the sea
attention block.
key_dims(list, optional): The embeding dims for each head in
attention.
depths(list, optional): describes the depth of the attention block.
i,e: M,N.
num_heads(int, optional): The number of heads of the attention
blocks.
attn_ratios(int, optional): The expand ratio of V.
mlp_ratios(list, optional): The ratio of mlp blocks.
drop_path_rate(float, optional): The drop path rate in attention
block.
act_cfg(dict, optional): The activation layer of AAM:
Aggregate Attention Module.
inj_type(string, optional): The type of injection/AAM.
out_channels(int, optional): The output channels of the AAM.
dims(list, optional): The dimension of the fusion block.
out_feat_chs(list, optional): The input channels of the AAM.
stride_attention(bool, optional): whether to stride attention in
each attention layer.
pretrained(str, optional): the path of pretrained model.
"""
def __init__(
self,
mobileV3_cfg,
channels,
embed_dims,
key_dims=[16, 24],
depths=[2, 2],
num_heads=8,
attn_ratios=2,
mlp_ratios=[2, 4],
drop_path_rate=0.1,
act_cfg=dict(type='ReLU'),
inj_type='AAM',
out_channels=256,
dims=(128, 160),
out_feat_chs=None,
stride_attention=True,
pretrained=None,
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
assert not (init_cfg and pretrained
), 'init_cfg and pretrained cannot be set at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is not None:
raise TypeError('pretrained must be a str or None')
self.depths = depths
self.cfgs = mobileV3_cfg
self.dims = dims
for i in range(len(self.cfgs)):
smb = StackedMV3Block(
cfgs=self.cfgs[i],
stem=True if i == 0 else False,
in_channels=channels[i],
)
setattr(self, f'smb{i + 1}', smb)
for i in range(len(depths)):
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, depths[i])
]
trans = BasicLayer(
block_num=depths[i],
embedding_dim=embed_dims[i],
key_dim=key_dims[i],
num_heads=num_heads,
mlp_ratio=mlp_ratios[i],
attn_ratio=attn_ratios,
drop=0,
attn_drop=0.0,
drop_path=dpr,
act_cfg=act_cfg,
stride_attention=stride_attention,
)
setattr(self, f'trans{i + 1}', trans)
self.inj_type = inj_type
if self.inj_type == 'AAM':
self.inj_module = InjectionMultiSumallmultiallsum(
in_channels=out_feat_chs, out_channels=out_channels)
self.feat_channels = [
out_channels,
]
elif self.inj_type == 'AAMSx8':
self.inj_module = InjectionMultiSumallmultiallsumSimpx8(
in_channels=out_feat_chs, out_channels=out_channels)
self.feat_channels = [
out_channels,
]
elif self.inj_type == 'origin':
for i in range(len(dims)):
fuse = FusionBlock(
out_feat_chs[0] if i == 0 else dims[i - 1],
out_feat_chs[i + 1],
embed_dim=dims[i],
act_cfg=None,
)
setattr(self, f'fuse{i + 1}', fuse)
self.feat_channels = [
dims[i],
]
else:
raise NotImplementedError(self.inj_module + ' is not implemented')
self.pretrained = pretrained
# self.init_weights()
def init_weights(self):
if (isinstance(self.init_cfg, dict)
and self.init_cfg.get('type') == 'Pretrained'):
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
if 'pos_embed' in state_dict.keys():
if self.pos_embed.shape != state_dict['pos_embed'].shape:
print_log(msg=f'Resize the pos_embed shape from '
f'{state_dict["pos_embed"].shape} to '
f'{self.pos_embed.shape}')
h, w = self.img_size
pos_size = int(
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'],
(h // self.patch_size, w // self.patch_size),
(pos_size, pos_size),
self.interpolate_mode,
)
load_state_dict(self, state_dict, strict=False, logger=None)
def forward(self, x):
x_hw = x.shape[2:]
outputs = []
num_smb_stage = len(self.cfgs)
num_trans_stage = len(self.depths)
for i in range(num_smb_stage):
smb = getattr(self, f'smb{i + 1}')
x = smb(x)
# 1/8 shared feat
if i == 1:
outputs.append(x)
if num_trans_stage + i >= num_smb_stage:
trans = getattr(
self, f'trans{i + num_trans_stage - num_smb_stage + 1}')
x = trans(x)
outputs.append(x)
if self.inj_type == 'origin':
x_detail = outputs[0]
for i in range(len(self.dims)):
fuse = getattr(self, f'fuse{i + 1}')
x_detail = fuse(x_detail, outputs[i + 1])
output = x_detail
else:
output = self.inj_module(outputs)
return [output, x_hw]
class StackedMV3Block(nn.Module):
"""The MobileNetV3 block.
Args:
cfgs (list): The MobileNetV3 config list of a stage.
stem (bool): Whether is the first stage or not.
in_channels (int, optional): The channels of input image. Default: 3.
scale: float=1.0.
The coefficient that controls the size of network parameters.
Returns:
model: nn.Module.
A stage of specific MobileNetV3 model depends on args.
"""
def __init__(self,
cfgs,
stem,
in_channels,
scale=1.0,
norm_cfg=dict(type='BN')):
super().__init__()
self.scale = scale
self.stem = stem
if self.stem:
self.conv = ConvModule(
in_channels=3,
out_channels=_make_divisible(in_channels * self.scale),
kernel_size=3,
stride=2,
padding=1,
groups=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='HSwish'),
)
self.blocks = nn.ModuleList()
for i, (k, exp, c, se, act, s) in enumerate(cfgs):
self.blocks.append(
ResidualUnit(
in_channel=_make_divisible(in_channels * self.scale),
mid_channel=_make_divisible(self.scale * exp),
out_channel=_make_divisible(self.scale * c),
kernel_size=k,
stride=s,
use_se=se,
act=act,
dilation=1,
))
in_channels = _make_divisible(self.scale * c)
def forward(self, x):
if self.stem:
x = self.conv(x)
for i, block in enumerate(self.blocks):
x = block(x)
return x
class ResidualUnit(nn.Module):
"""The Residual module.
Args:
in_channel (int, optional): The channels of input feature.
mid_channel (int, optional): The channels of middle process.
out_channel (int, optional): The channels of output feature.
kernel_size (int, optional): The size of the convolving kernel.
stride (int, optional): The stride size.
use_se (bool, optional): if to use the SEModule.
act (string, optional): activation layer.
dilation (int, optional): The dilation size.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', requires_grad=True).
"""
def __init__(
self,
in_channel,
mid_channel,
out_channel,
kernel_size,
stride,
use_se,
act=None,
dilation=1,
norm_cfg=dict(type='BN'),
):
super().__init__()
self.if_shortcut = stride == 1 and in_channel == out_channel
self.if_se = use_se
self.expand_conv = ConvModule(
in_channels=in_channel,
out_channels=mid_channel,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=dict(type=act) if act is not None else None,
)
self.bottleneck_conv = ConvModule(
in_channels=mid_channel,
out_channels=mid_channel,
kernel_size=kernel_size,
stride=stride,
padding=int((kernel_size - 1) // 2) * dilation,
bias=False,
groups=mid_channel,
dilation=dilation,
norm_cfg=norm_cfg,
act_cfg=dict(type=act) if act is not None else None,
)
if self.if_se:
self.mid_se = SEModule(mid_channel)
self.linear_conv = ConvModule(
in_channels=mid_channel,
out_channels=out_channel,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
def forward(self, x):
identity = x
x = self.expand_conv(x)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = torch.add(identity, x)
return x
class SEModule(nn.Module):
"""SE Module.
Args:
channel (int, optional): The channels of input feature.
reduction (int, optional): The channel reduction rate.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
"""
def __init__(self, channel, reduction=4, act_cfg=dict(type='ReLU')):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_act1 = ConvModule(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
norm_cfg=None,
act_cfg=act_cfg,
)
self.conv_act2 = ConvModule(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
norm_cfg=None,
act_cfg=dict(type='Hardsigmoid', slope=0.2, offset=0.5),
)
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv_act1(x)
x = self.conv_act2(x)
return torch.mul(identity, x)
class BasicLayer(nn.Module):
"""The transformer basic layer.
Args:
block_num (int): the block nums of the transformer basic layer.
embedding_dim (int): The feature dimension.
key_dim (int): the key dim.
num_heads (int): Parallel attention heads.
mlp_ratio (float): the mlp ratio.
attn_ratio (float): the attention ratio.
drop (float): Probability of an element to be zeroed
after the feed forward layer.Default: 0.0.
attn_drop (float): The drop out rate for attention layer.
Default: 0.0.
drop_path (float): stochastic depth rate. Default 0.0.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
stride_attention (bool, optional): whether to stride attention in
each attention layer.
"""
def __init__(
self,
block_num,
embedding_dim,
key_dim,
num_heads,
mlp_ratio=4.0,
attn_ratio=2.0,
drop=0.0,
attn_drop=0.0,
drop_path=None,
act_cfg=None,
stride_attention=None,
):
super().__init__()
self.block_num = block_num
self.transformer_blocks = nn.ModuleList()
for i in range(self.block_num):
self.transformer_blocks.append(
Block(
embedding_dim,
key_dim=key_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
attn_ratio=attn_ratio,
drop=drop,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path,
act_cfg=act_cfg,
stride_attention=stride_attention,
))
def forward(self, x):
for i in range(self.block_num):
x = self.transformer_blocks[i](x)
return x
class Block(nn.Module):
"""the block of the transformer basic layer.
Args:
dim (int): The feature dimension.
key_dim (int): The key dimension.
num_heads (int): Parallel attention heads.
mlp_ratio (float): the mlp ratio.
attn_ratio (float): the attention ratio.
drop (float): Probability of an element to be zeroed
after the feed forward layer.Default: 0.0.
drop_path (float): stochastic depth rate. Default 0.0.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
stride_attention (bool, optional): whether to stride attention in
each attention layer.
"""
def __init__(
self,
dim,
key_dim,
num_heads,
mlp_ratio=4.0,
attn_ratio=2.0,
drop=0.0,
drop_path=0.0,
act_cfg=None,
stride_attention=None,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.attn = SeaAttention(
dim,
key_dim=key_dim,
num_heads=num_heads,
attn_ratio=attn_ratio,
act_cfg=act_cfg,
stride_attention=stride_attention,
)
self.drop_path = (
build_dropout(dict(type='DropPath', drop_prob=drop_path))
if drop_path > 0.0 else nn.Identity())
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = MLP(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop,
)
def forward(self, x1):
x1 = x1 + self.drop_path(self.attn(x1))
x1 = x1 + self.drop_path(self.mlp(x1))
return x1
class SqueezeAxialPositionalEmbedding(nn.Module):
"""the Squeeze Axial Positional Embedding.
Args:
dim (int): The feature dimension.
shape (int): The patch size.
"""
def __init__(self, dim, shape):
super().__init__()
self.pos_embed = nn.init.normal_(
nn.Parameter(torch.zeros(1, dim, shape)))
def forward(self, x):
B, C, N = x.shape
x = x + F.interpolate(
self.pos_embed, size=(N, ), mode='linear', align_corners=False)
return x
class SeaAttention(nn.Module):
"""The sea attention.
Args:
dim (int): The feature dimension.
key_dim (int): The key dimension.
num_heads (int): number of attention heads.
attn_ratio (float): the attention ratio.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
stride_attention (bool, optional): whether to stride attention in
each attention layer.
"""
def __init__(
self,
dim,
key_dim,
num_heads,
attn_ratio=4.0,
act_cfg=None,
norm_cfg=dict(type='BN'),
stride_attention=False,
):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.to_q = ConvModule(
dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.to_k = ConvModule(
dim, nh_kd, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.to_v = ConvModule(
dim, self.dh, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.stride_attention = stride_attention
if self.stride_attention:
self.stride_conv = ConvModule(
dim,
dim,
kernel_size=3,
stride=2,
padding=1,
bias=True,
groups=dim,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.proj = ConvModule(
self.dh,
dim,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('act', 'conv', 'norm'),
)
self.proj_encode_row = ConvModule(
self.dh,
self.dh,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('act', 'conv', 'norm'),
)
self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.proj_encode_column = ConvModule(
self.dh,
self.dh,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('act', 'conv', 'norm'),
)
self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16)
self.dwconv = ConvModule(
2 * self.dh,
2 * self.dh,
3,
padding=1,
groups=2 * self.dh,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
self.pwconv = ConvModule(
2 * self.dh, dim, 1, bias=False, norm_cfg=norm_cfg, act_cfg=None)
self.sigmoid = build_activation_layer(dict(type='HSigmoid'))
def forward(self, x):
B, C, H_ori, W_ori = x.shape
if self.stride_attention:
x = self.stride_conv(x)
B, C, H, W = x.shape
q = self.to_q(x) # [B, nhead*dim, H, W]
k = self.to_k(x)
v = self.to_v(x)
qkv = torch.cat([q, k, v], dim=1)
qkv = self.dwconv(qkv)
qkv = self.pwconv(qkv)
qrow = (self.pos_emb_rowq(q.mean(-1)).reshape(
[B, self.num_heads, -1, H]).permute(
(0, 1, 3, 2))) # [B, nhead, H, dim]
krow = self.pos_emb_rowk(k.mean(-1)).reshape(
[B, self.num_heads, -1, H]) # [B, nhead, dim, H]
vrow = (v.mean(-1).reshape([B, self.num_heads, -1,
H]).permute([0, 1, 3, 2])
) # [B, nhead, H, dim*attn_ratio]
attn_row = torch.matmul(qrow, krow) * self.scale # [B, nhead, H, H]
attn_row = nn.functional.softmax(attn_row, dim=-1)
xx_row = torch.matmul(attn_row, vrow) # [B, nhead, H, dim*attn_ratio]
xx_row = self.proj_encode_row(
xx_row.permute([0, 1, 3, 2]).reshape([B, self.dh, H, 1]))
# squeeze column
qcolumn = (
self.pos_emb_columnq(q.mean(-2)).reshape(
[B, self.num_heads, -1, W]).permute([0, 1, 3, 2]))
kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape(
[B, self.num_heads, -1, W])
vcolumn = (
torch.mean(v, -2).reshape([B, self.num_heads, -1,
W]).permute([0, 1, 3, 2]))
attn_column = torch.matmul(qcolumn, kcolumn) * self.scale
attn_column = nn.functional.softmax(attn_column, dim=-1)
xx_column = torch.matmul(attn_column, vcolumn) # B nH W C
xx_column = self.proj_encode_column(
xx_column.permute([0, 1, 3, 2]).reshape([B, self.dh, 1, W]))
xx = torch.add(xx_row, xx_column) # [B, self.dh, H, W]
xx = torch.add(v, xx)
xx = self.proj(xx)
xx = self.sigmoid(xx) * qkv
if self.stride_attention:
xx = F.interpolate(xx, size=(H_ori, W_ori), mode='bilinear')
return xx
class MLP(nn.Module):
"""the Multilayer Perceptron.
Args:
in_features (int): the input feature.
hidden_features (int): the hidden feature.
out_features (int): the output feature.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='PReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
drop (float): Probability of an element to be zeroed.
Default 0.0
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=None,
norm_cfg=dict(type='BN'),
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = ConvModule(
in_features,
hidden_features,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.dwconv = ConvModule(
hidden_features,
hidden_features,
kernel_size=3,
padding=1,
groups=hidden_features,
norm_cfg=None,
act_cfg=act_cfg,
)
self.fc2 = ConvModule(
hidden_features,
out_features,
1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.drop = build_dropout(dict(type='Dropout', drop_prob=drop))
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class FusionBlock(nn.Module):
"""The feature fusion block.
Args:
in_channel (int): the input channel.
out_channel (int): the output channel.
embed_dim (int): embedding dimension.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
"""
def __init__(
self,
in_channel,
out_channel,
embed_dim,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
) -> None:
super().__init__()
self.local_embedding = ConvModule(
in_channels=in_channel,
out_channels=embed_dim,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
)
self.global_act = ConvModule(
in_channels=out_channel,
out_channels=embed_dim,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg if act_cfg is not None else None,
)
def forward(self, x_l, x_g):
"""
x_g: global features
x_l: local features
"""
B, C, H, W = x_l.shape
local_feat = self.local_embedding(x_l)
global_act = self.global_act(x_g)
sig_act = F.interpolate(
global_act, size=(H, W), mode='bilinear', align_corners=False)
out = local_feat * sig_act
return out
class InjectionMultiSumallmultiallsum(nn.Module):
"""the Aggregate Attention Module.
Args:
in_channels (tuple): the input channel.
out_channels (int): the output channel.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
"""
def __init__(
self,
in_channels=(64, 128, 256, 384),
out_channels=256,
act_cfg=dict(type='Sigmoid'),
norm_cfg=dict(type='BN'),
):
super().__init__()
self.embedding_list = nn.ModuleList()
self.act_embedding_list = nn.ModuleList()
self.act_list = nn.ModuleList()
for i in range(len(in_channels)):
self.embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
))
self.act_embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
))
def forward(self, inputs): # x_x8, x_x16, x_x32, x_x64
low_feat1 = F.interpolate(inputs[0], scale_factor=0.5, mode='bilinear')
low_feat1_act = self.act_embedding_list[0](low_feat1)
low_feat1 = self.embedding_list[0](low_feat1)
low_feat2 = F.interpolate(
inputs[1], size=low_feat1.shape[-2:], mode='bilinear')
low_feat2_act = self.act_embedding_list[1](low_feat2) # x16
low_feat2 = self.embedding_list[1](low_feat2)
high_feat_act = F.interpolate(
self.act_embedding_list[2](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear',
)
high_feat = F.interpolate(
self.embedding_list[2](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear')
res = (
low_feat1_act * low_feat2_act * high_feat_act *
(low_feat1 + low_feat2) + high_feat)
return res
class InjectionMultiSumallmultiallsumSimpx8(nn.Module):
"""the Aggregate Attention Module.
Args:
in_channels (tuple): the input channel.
out_channels (int): the output channel.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN')
"""
def __init__(
self,
in_channels=(64, 128, 256, 384),
out_channels=256,
act_cfg=dict(type='Sigmoid'),
norm_cfg=dict(type='BN'),
):
super().__init__()
self.embedding_list = nn.ModuleList()
self.act_embedding_list = nn.ModuleList()
self.act_list = nn.ModuleList()
for i in range(len(in_channels)):
if i != 1:
self.embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None,
))
if i != 0:
self.act_embedding_list.append(
ConvModule(
in_channels=in_channels[i],
out_channels=out_channels,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
))
def forward(self, inputs):
# x_x8, x_x16, x_x32
low_feat1 = self.embedding_list[0](inputs[0])
low_feat2 = F.interpolate(
inputs[1], size=low_feat1.shape[-2:], mode='bilinear')
low_feat2_act = self.act_embedding_list[0](low_feat2)
high_feat_act = F.interpolate(
self.act_embedding_list[1](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear',
)
high_feat = F.interpolate(
self.embedding_list[1](inputs[2]),
size=low_feat2.shape[2:],
mode='bilinear')
res = low_feat2_act * high_feat_act * low_feat1 + high_feat
return res
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
@MODELS.register_module()
class Hardsigmoid(nn.Module):
"""the hardsigmoid activation.
Args:
slope (float, optional): The slope of hardsigmoid function.
Default is 0.1666667.
offset (float, optional): The offset of hardsigmoid function.
Default is 0.5.
inplace (bool): can optionally do the operation in-place.
Default: ``False``
"""
def __init__(self, slope=0.1666667, offset=0.5, inplace=False):
super().__init__()
self.slope = slope
self.offset = offset
def forward(self, x):
return (x * self.slope + self.offset).clamp(0, 1)