102 lines
3.9 KiB
Python
102 lines
3.9 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, build_upsample_layer
|
|
|
|
|
|
class UpConvBlock(nn.Module):
|
|
"""Upsample convolution block in decoder for UNet.
|
|
|
|
This upsample convolution block consists of one upsample module
|
|
followed by one convolution block. The upsample module expands the
|
|
high-level low-resolution feature map and the convolution block fuses
|
|
the upsampled high-level low-resolution feature map and the low-level
|
|
high-resolution feature map from encoder.
|
|
|
|
Args:
|
|
conv_block (nn.Sequential): Sequential of convolutional layers.
|
|
in_channels (int): Number of input channels of the high-level
|
|
skip_channels (int): Number of input channels of the low-level
|
|
high-resolution feature map from encoder.
|
|
out_channels (int): Number of output channels.
|
|
num_convs (int): Number of convolutional layers in the conv_block.
|
|
Default: 2.
|
|
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
|
dilation (int): Dilation rate of convolutional layer in conv_block.
|
|
Default: 1.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
conv_cfg (dict | None): Config dict for convolution layer.
|
|
Default: None.
|
|
norm_cfg (dict | None): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU').
|
|
upsample_cfg (dict): The upsample config of the upsample module in
|
|
decoder. Default: dict(type='InterpConv'). If the size of
|
|
high-level feature map is the same as that of skip feature map
|
|
(low-level feature map from encoder), it does not need upsample the
|
|
high-level feature map and the upsample_cfg is None.
|
|
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
|
Default: None.
|
|
plugins (dict): plugins for convolutional layers. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
conv_block,
|
|
in_channels,
|
|
skip_channels,
|
|
out_channels,
|
|
num_convs=2,
|
|
stride=1,
|
|
dilation=1,
|
|
with_cp=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
upsample_cfg=dict(type='InterpConv'),
|
|
dcn=None,
|
|
plugins=None):
|
|
super(UpConvBlock, self).__init__()
|
|
assert dcn is None, 'Not implemented yet.'
|
|
assert plugins is None, 'Not implemented yet.'
|
|
|
|
self.conv_block = conv_block(
|
|
in_channels=2 * skip_channels,
|
|
out_channels=out_channels,
|
|
num_convs=num_convs,
|
|
stride=stride,
|
|
dilation=dilation,
|
|
with_cp=with_cp,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
dcn=None,
|
|
plugins=None)
|
|
if upsample_cfg is not None:
|
|
self.upsample = build_upsample_layer(
|
|
cfg=upsample_cfg,
|
|
in_channels=in_channels,
|
|
out_channels=skip_channels,
|
|
with_cp=with_cp,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
else:
|
|
self.upsample = ConvModule(
|
|
in_channels,
|
|
skip_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
def forward(self, skip, x):
|
|
"""Forward function."""
|
|
|
|
x = self.upsample(x)
|
|
out = torch.cat([skip, x], dim=1)
|
|
out = self.conv_block(out)
|
|
|
|
return out
|