add unet (#161)
* add unet * add unet * add unet * update test_unet * update test_unet * update test_unet * update test_unet * fix bugs * add init method for unet * add test of UNet init_weights method * add registry * merge upsample * fix test * Update mmseg/models/backbones/unet.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * Update mmseg/models/backbones/unet.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * split UpConvBlock from UNet * use reversed * rename upsample module * rename upsample module * rename upsample module * rename upsample module Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
eaefe54e8d
commit
5956451014
|
@ -4,8 +4,9 @@ from .mobilenet_v2 import MobileNetV2
|
|||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .unet import UNet
|
||||
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2'
|
||||
'ResNeSt', 'MobileNetV2', 'UNet'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,428 @@
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||
build_norm_layer, constant_init, kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import UpConvBlock
|
||||
|
||||
|
||||
class BasicConvBlock(nn.Module):
|
||||
"""Basic convolutional block for UNet.
|
||||
|
||||
This module consists of several plain convolutional layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers. Default: 2.
|
||||
stride (int): Whether use stride convolution to downsample
|
||||
the input feature map. If stride=2, it only uses stride convolution
|
||||
in the first convolutional layer to downsample the input feature
|
||||
map. Options are 1 or 2. Default: 1.
|
||||
dilation (int): Whether use dilated convolution to expand the
|
||||
receptive field. Set dilation rate of each convolutional layer and
|
||||
the dilation rate of the first convolutional layer is always 1.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super(BasicConvBlock, self).__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.with_cp = with_cp
|
||||
convs = []
|
||||
for i in range(num_convs):
|
||||
convs.append(
|
||||
ConvModule(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride if i == 0 else 1,
|
||||
dilation=1 if i == 0 else dilation,
|
||||
padding=1 if i == 0 else dilation,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
self.convs = nn.Sequential(*convs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.convs, x)
|
||||
else:
|
||||
out = self.convs(x)
|
||||
return out
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module()
|
||||
class DeconvModule(nn.Module):
|
||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||
|
||||
This module uses deconvolution to upsample feature map in the decoder
|
||||
of UNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
kernel_size=4,
|
||||
scale_factor=2):
|
||||
super(DeconvModule, self).__init__()
|
||||
|
||||
assert (kernel_size - scale_factor >= 0) and\
|
||||
(kernel_size - scale_factor) % 2 == 0,\
|
||||
f'kernel_size should be greater than or equal to scale_factor '\
|
||||
f'and (kernel_size - scale_factor) should be even numbers, '\
|
||||
f'while the kernel size is {kernel_size} and scale_factor is '\
|
||||
f'{scale_factor}.'
|
||||
|
||||
stride = scale_factor
|
||||
padding = (kernel_size - scale_factor) // 2
|
||||
self.with_cp = with_cp
|
||||
deconv = nn.ConvTranspose2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
|
||||
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
||||
activate = build_activation_layer(act_cfg)
|
||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.deconv_upsamping, x)
|
||||
else:
|
||||
out = self.deconv_upsamping(x)
|
||||
return out
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module()
|
||||
class InterpConv(nn.Module):
|
||||
"""Interpolation upsample module in decoder for UNet.
|
||||
|
||||
This module uses interpolation to upsample feature map in the decoder
|
||||
of UNet. It consists of one interpolation upsample layer and one
|
||||
convolutional layer. It can be one interpolation upsample layer followed
|
||||
by one convolutional layer (conv_first=False) or one convolutional layer
|
||||
followed by one interpolation upsample layer (conv_first=True).
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
conv_first (bool): Whether convolutional layer or interpolation
|
||||
upsample layer first. Default: False. It means interpolation
|
||||
upsample layer followed by one convolutional layer.
|
||||
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
||||
stride (int): Stride of the convolutional layer. Default: 1.
|
||||
padding (int): Padding of the convolutional layer. Default: 1.
|
||||
upsampe_cfg (dict): Interpolation config of the upsample layer.
|
||||
Default: dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
with_cp=False,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
*,
|
||||
conv_cfg=None,
|
||||
conv_first=False,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
upsampe_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)):
|
||||
super(InterpConv, self).__init__()
|
||||
|
||||
self.with_cp = with_cp
|
||||
conv = ConvModule(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample = nn.Upsample(**upsampe_cfg)
|
||||
if conv_first:
|
||||
self.interp_upsample = nn.Sequential(conv, upsample)
|
||||
else:
|
||||
self.interp_upsample = nn.Sequential(upsample, conv)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.interp_upsample, x)
|
||||
else:
|
||||
out = self.interp_upsample(x)
|
||||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class UNet(nn.Module):
|
||||
"""UNet backbone.
|
||||
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
||||
https://arxiv.org/pdf/1505.04597.pdf
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input image channels. Default" 3.
|
||||
base_channels (int): Number of base channels of each stage.
|
||||
The output channels of the first stage. Default: 64.
|
||||
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
||||
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
||||
len(strides) is equal to num_stages. Normally the stride of the
|
||||
first stage in encoder is 1. If strides[i]=2, it uses stride
|
||||
convolution to downsample in the correspondance encoder stage.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondance encoder stage.
|
||||
Default: (2, 2, 2, 2, 2).
|
||||
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
||||
convolution block of the correspondance decoder stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
||||
feature map after the first stage of encoder
|
||||
(stages: [1, num_stages)). If the correspondance encoder stage use
|
||||
stride convolution (strides[i]=2), it will never use MaxPool to
|
||||
downsample, even downsamples[i-1]=True.
|
||||
Default: (True, True, True, True).
|
||||
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
||||
Default: (1, 1, 1, 1, 1).
|
||||
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
||||
Default: (1, 1, 1, 1).
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
|
||||
Notice:
|
||||
The input image size should be devisible by the whole downsample rate
|
||||
of the encoder. More detail of the whole downsample rate can be found
|
||||
in UNet._check_input_devisible.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False,
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super(UNet, self).__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
assert len(strides) == num_stages, \
|
||||
'The length of strides should be equal to num_stages, '\
|
||||
f'while the strides is {strides}, the length of '\
|
||||
f'strides is {len(strides)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_num_convs) == num_stages, \
|
||||
'The length of enc_num_convs should be equal to num_stages, '\
|
||||
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
||||
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_num_convs) == (num_stages-1), \
|
||||
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
||||
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
||||
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(downsamples) == (num_stages-1), \
|
||||
'The length of downsamples should be equal to (num_stages-1), '\
|
||||
f'while the downsamples is {downsamples}, the length of '\
|
||||
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(enc_dilations) == num_stages, \
|
||||
'The length of enc_dilations should be equal to num_stages, '\
|
||||
f'while the enc_dilations is {enc_dilations}, the length of '\
|
||||
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
assert len(dec_dilations) == (num_stages-1), \
|
||||
'The length of dec_dilations should be equal to (num_stages-1), '\
|
||||
f'while the dec_dilations is {dec_dilations}, the length of '\
|
||||
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
||||
f'{num_stages}.'
|
||||
self.num_stages = num_stages
|
||||
self.strides = strides
|
||||
self.downsamples = downsamples
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
for i in range(num_stages):
|
||||
enc_conv_block = []
|
||||
if i != 0:
|
||||
if strides[i] == 1 and downsamples[i - 1]:
|
||||
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
||||
upsample = (strides[i] != 1 or downsamples[i - 1])
|
||||
self.decoder.append(
|
||||
UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=base_channels * 2**i,
|
||||
skip_channels=base_channels * 2**(i - 1),
|
||||
out_channels=base_channels * 2**(i - 1),
|
||||
num_convs=dec_num_convs[i - 1],
|
||||
stride=1,
|
||||
dilation=dec_dilations[i - 1],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
upsample_cfg=upsample_cfg if upsample else None,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
|
||||
enc_conv_block.append(
|
||||
BasicConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=base_channels * 2**i,
|
||||
num_convs=enc_num_convs[i],
|
||||
stride=strides[i],
|
||||
dilation=enc_dilations[i],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None))
|
||||
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
||||
in_channels = base_channels * 2**i
|
||||
|
||||
def forward(self, x):
|
||||
self._check_input_devisible(x)
|
||||
enc_outs = []
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
enc_outs.append(x)
|
||||
dec_outs = [x]
|
||||
for i in reversed(range(len(self.decoder))):
|
||||
x = self.decoder[i](enc_outs[i], x)
|
||||
dec_outs.append(x)
|
||||
|
||||
return dec_outs
|
||||
|
||||
def train(self, mode=True):
|
||||
"""Convert the model into training mode while keep normalization layer
|
||||
freezed."""
|
||||
super(UNet, self).train(mode)
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
# trick: eval have effect on BatchNorm only
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def _check_input_devisible(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
whole_downsample_rate = 1
|
||||
for i in range(1, self.num_stages):
|
||||
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
||||
whole_downsample_rate *= 2
|
||||
assert (h % whole_downsample_rate == 0) \
|
||||
and (w % whole_downsample_rate == 0),\
|
||||
f'The input image size {(h, w)} should be devisible by the whole '\
|
||||
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
||||
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
||||
f'is {self.downsamples}.'
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
"""Initialize the weights in backbone.
|
||||
|
||||
Args:
|
||||
pretrained (str, optional): Path to pre-trained weights.
|
||||
Defaults to None.
|
||||
"""
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
||||
elif pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
|
@ -2,7 +2,9 @@ from .inverted_residual import InvertedResidual
|
|||
from .make_divisible import make_divisible
|
||||
from .res_layer import ResLayer
|
||||
from .self_attention_block import SelfAttentionBlock
|
||||
from .up_conv_block import UpConvBlock
|
||||
|
||||
__all__ = [
|
||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual'
|
||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||
'UpConvBlock'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_upsample_layer
|
||||
|
||||
|
||||
class UpConvBlock(nn.Module):
|
||||
"""Upsample convolution block in decoder for UNet.
|
||||
|
||||
This upsample convolution block consists of one upsample module
|
||||
followed by one convolution block. The upsample module expands the
|
||||
high-level low-resolution feature map and the convolution block fuses
|
||||
the upsampled high-level low-resolution feature map and the low-level
|
||||
high-resolution feature map from encoder.
|
||||
|
||||
Args:
|
||||
conv_block (nn.Sequential): Sequential of convolutional layers.
|
||||
in_channels (int): Number of input channels of the high-level
|
||||
skip_channels (int): Number of input channels of the low-level
|
||||
high-resolution feature map from encoder.
|
||||
out_channels (int): Number of output channels.
|
||||
num_convs (int): Number of convolutional layers in the conv_block.
|
||||
Default: 2.
|
||||
stride (int): Stride of convolutional layer in conv_block. Default: 1.
|
||||
dilation (int): Dilation rate of convolutional layer in conv_block.
|
||||
Default: 1.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict | None): Config dict for convolution layer.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
||||
Default: dict(type='ReLU').
|
||||
upsample_cfg (dict): The upsample config of the upsample module in
|
||||
decoder. Default: dict(type='InterpConv'). If the size of
|
||||
high-level feature map is the same as that of skip feature map
|
||||
(low-level feature map from encoder), it does not need upsample the
|
||||
high-level feature map and the upsample_cfg is None.
|
||||
dcn (bool): Use deformable convoluton in convolutional layer or not.
|
||||
Default: None.
|
||||
plugins (dict): plugins for convolutional layers. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
conv_block,
|
||||
in_channels,
|
||||
skip_channels,
|
||||
out_channels,
|
||||
num_convs=2,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
dcn=None,
|
||||
plugins=None):
|
||||
super(UpConvBlock, self).__init__()
|
||||
assert dcn is None, 'Not implemented yet.'
|
||||
assert plugins is None, 'Not implemented yet.'
|
||||
|
||||
self.conv_block = conv_block(
|
||||
in_channels=2 * skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_convs=num_convs,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
dcn=None,
|
||||
plugins=None)
|
||||
if upsample_cfg is not None:
|
||||
self.upsample = build_upsample_layer(
|
||||
cfg=upsample_cfg,
|
||||
in_channels=in_channels,
|
||||
out_channels=skip_channels,
|
||||
with_cp=with_cp,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
else:
|
||||
self.upsample = ConvModule(
|
||||
in_channels,
|
||||
skip_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, skip, x):
|
||||
"""Forward function."""
|
||||
|
||||
x = self.upsample(x)
|
||||
out = torch.cat([skip, x], dim=1)
|
||||
out = self.conv_block(out)
|
||||
|
||||
return out
|
|
@ -0,0 +1,833 @@
|
|||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||
InterpConv, UNet, UpConvBlock)
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_unet_basic_conv_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
BasicConvBlock(64, 64, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
BasicConvBlock(64, 64, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
BasicConvBlock(64, 64, plugins=plugins)
|
||||
|
||||
# test BasicConvBlock with checkpoint forward
|
||||
block = BasicConvBlock(16, 16, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 16, 64, 64, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 64, 64])
|
||||
|
||||
block = BasicConvBlock(16, 16, with_cp=False)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(1, 16, 64, 64)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 64, 64])
|
||||
|
||||
# test BasicConvBlock with stride convolution to downsample
|
||||
block = BasicConvBlock(16, 16, stride=2)
|
||||
x = torch.randn(1, 16, 64, 64)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 16, 32, 32])
|
||||
|
||||
# test BasicConvBlock structure and forward
|
||||
block = BasicConvBlock(16, 64, num_convs=3, dilation=3)
|
||||
assert block.convs[0].conv.in_channels == 16
|
||||
assert block.convs[0].conv.out_channels == 64
|
||||
assert block.convs[0].conv.kernel_size == (3, 3)
|
||||
assert block.convs[0].conv.dilation == (1, 1)
|
||||
assert block.convs[0].conv.padding == (1, 1)
|
||||
|
||||
assert block.convs[1].conv.in_channels == 64
|
||||
assert block.convs[1].conv.out_channels == 64
|
||||
assert block.convs[1].conv.kernel_size == (3, 3)
|
||||
assert block.convs[1].conv.dilation == (3, 3)
|
||||
assert block.convs[1].conv.padding == (3, 3)
|
||||
|
||||
assert block.convs[2].conv.in_channels == 64
|
||||
assert block.convs[2].conv.out_channels == 64
|
||||
assert block.convs[2].conv.kernel_size == (3, 3)
|
||||
assert block.convs[2].conv.dilation == (3, 3)
|
||||
assert block.convs[2].conv.padding == (3, 3)
|
||||
|
||||
|
||||
def test_deconv_module():
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size should be greater than or equal to scale_factor and
|
||||
# (kernel_size - scale_factor) should be even numbers
|
||||
DeconvModule(64, 32, kernel_size=1, scale_factor=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size should be greater than or equal to scale_factor and
|
||||
# (kernel_size - scale_factor) should be even numbers
|
||||
DeconvModule(64, 32, kernel_size=3, scale_factor=2)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# kernel_size should be greater than or equal to scale_factor and
|
||||
# (kernel_size - scale_factor) should be even numbers
|
||||
DeconvModule(64, 32, kernel_size=5, scale_factor=4)
|
||||
|
||||
# test DeconvModule with checkpoint forward and upsample 2X.
|
||||
block = DeconvModule(64, 32, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
block = DeconvModule(64, 32, with_cp=False)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test DeconvModule with different kernel size for upsample 2X.
|
||||
x = torch.randn(1, 64, 64, 64)
|
||||
block = DeconvModule(64, 32, kernel_size=2, scale_factor=2)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 128, 128])
|
||||
|
||||
block = DeconvModule(64, 32, kernel_size=6, scale_factor=2)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 128, 128])
|
||||
|
||||
# test DeconvModule with different kernel size for upsample 4X.
|
||||
x = torch.randn(1, 64, 64, 64)
|
||||
block = DeconvModule(64, 32, kernel_size=4, scale_factor=4)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
block = DeconvModule(64, 32, kernel_size=6, scale_factor=4)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
|
||||
def test_interp_conv():
|
||||
# test InterpConv with checkpoint forward and upsample 2X.
|
||||
block = InterpConv(64, 32, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
block = InterpConv(64, 32, with_cp=False)
|
||||
assert not block.with_cp
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with conv_first=False for upsample 2X.
|
||||
block = InterpConv(64, 32, conv_first=False)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with conv_first=True for upsample 2X.
|
||||
block = InterpConv(64, 32, conv_first=True)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], ConvModule)
|
||||
assert isinstance(block.interp_upsample[1], nn.Upsample)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test InterpConv with bilinear upsample for upsample 2X.
|
||||
block = InterpConv(
|
||||
64,
|
||||
32,
|
||||
conv_first=False,
|
||||
upsampe_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False))
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
assert block.interp_upsample[0].mode == 'bilinear'
|
||||
|
||||
# test InterpConv with nearest upsample for upsample 2X.
|
||||
block = InterpConv(
|
||||
64,
|
||||
32,
|
||||
conv_first=False,
|
||||
upsampe_cfg=dict(scale_factor=2, mode='nearest'))
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(x)
|
||||
assert isinstance(block.interp_upsample[0], nn.Upsample)
|
||||
assert isinstance(block.interp_upsample[1], ConvModule)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
assert block.interp_upsample[0].mode == 'nearest'
|
||||
|
||||
|
||||
def test_up_conv_block():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
|
||||
|
||||
# test UpConvBlock with checkpoint forward and upsample 2X.
|
||||
block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True)
|
||||
skip_x = torch.randn(1, 32, 256, 256, requires_grad=True)
|
||||
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with upsample=True for upsample 2X. The spatial size of
|
||||
# skip_x is 2X larger than x.
|
||||
block = UpConvBlock(
|
||||
BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv'))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with upsample=False for upsample 2X. The spatial size of
|
||||
# skip_x is the same as that of x.
|
||||
block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None)
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 256, 256)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with different upsample method for upsample 2X.
|
||||
# The upsample method is interpolation upsample (bilinear or nearest).
|
||||
block = UpConvBlock(
|
||||
BasicConvBlock,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
upsample_cfg=dict(
|
||||
type='InterpConv',
|
||||
upsampe_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test UpConvBlock with different upsample method for upsample 2X.
|
||||
# The upsample method is deconvolution upsample.
|
||||
block = UpConvBlock(
|
||||
BasicConvBlock,
|
||||
64,
|
||||
32,
|
||||
32,
|
||||
upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
# test BasicConvBlock structure and forward
|
||||
block = UpConvBlock(
|
||||
conv_block=BasicConvBlock,
|
||||
in_channels=64,
|
||||
skip_channels=32,
|
||||
out_channels=32,
|
||||
num_convs=3,
|
||||
dilation=3,
|
||||
upsample_cfg=dict(
|
||||
type='InterpConv',
|
||||
upsampe_cfg=dict(
|
||||
scale_factor=2, mode='bilinear', align_corners=False)))
|
||||
skip_x = torch.randn(1, 32, 256, 256)
|
||||
x = torch.randn(1, 64, 128, 128)
|
||||
x_out = block(skip_x, x)
|
||||
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
||||
|
||||
assert block.conv_block.convs[0].conv.in_channels == 64
|
||||
assert block.conv_block.convs[0].conv.out_channels == 32
|
||||
assert block.conv_block.convs[0].conv.kernel_size == (3, 3)
|
||||
assert block.conv_block.convs[0].conv.dilation == (1, 1)
|
||||
assert block.conv_block.convs[0].conv.padding == (1, 1)
|
||||
|
||||
assert block.conv_block.convs[1].conv.in_channels == 32
|
||||
assert block.conv_block.convs[1].conv.out_channels == 32
|
||||
assert block.conv_block.convs[1].conv.kernel_size == (3, 3)
|
||||
assert block.conv_block.convs[1].conv.dilation == (3, 3)
|
||||
assert block.conv_block.convs[1].conv.padding == (3, 3)
|
||||
|
||||
assert block.conv_block.convs[2].conv.in_channels == 32
|
||||
assert block.conv_block.convs[2].conv.out_channels == 32
|
||||
assert block.conv_block.convs[2].conv.kernel_size == (3, 3)
|
||||
assert block.conv_block.convs[2].conv.dilation == (3, 3)
|
||||
assert block.conv_block.convs[2].conv.padding == (3, 3)
|
||||
|
||||
assert block.upsample.interp_upsample[1].conv.in_channels == 64
|
||||
assert block.upsample.interp_upsample[1].conv.out_channels == 32
|
||||
assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1)
|
||||
assert block.upsample.interp_upsample[1].conv.dilation == (1, 1)
|
||||
assert block.upsample.interp_upsample[1].conv.padding == (0, 0)
|
||||
|
||||
|
||||
def test_unet():
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
||||
UNet(3, 64, 5, dcn=dcn)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet.
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
||||
position='after_conv3')
|
||||
]
|
||||
UNet(3, 64, 5, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Not implemented yet
|
||||
plugins = [
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0010',
|
||||
kv_stride=2),
|
||||
position='after_conv2')
|
||||
]
|
||||
UNet(3, 64, 5, plugins=plugins)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be devisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=4,
|
||||
strides=(1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2),
|
||||
downsamples=(True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be devisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 16.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be devisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be devisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 2, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check whether the input image size can be devisible by the whole
|
||||
# downsample rate of the encoder. The whole downsample rate of this
|
||||
# case is 32.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=6,
|
||||
strides=(1, 1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 65, 65)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matchs strides, len(strides)=num_stages
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matchs strides, len(enc_num_convs)=num_stages
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matchs strides, len(dec_num_convs)=num_stages-1
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matchs strides, len(downsamples)=num_stages-1
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matchs strides, len(enc_dilations)=num_stages
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Check if num_stages matchs strides, len(dec_dilations)=num_stages-1
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1, 1))
|
||||
x = torch.randn(2, 3, 64, 64)
|
||||
unet(x)
|
||||
|
||||
# test UNet norm_eval=True
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
norm_eval=True)
|
||||
unet.train()
|
||||
assert check_norm_state(unet.modules(), False)
|
||||
|
||||
# test UNet norm_eval=False
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
norm_eval=False)
|
||||
unet.train()
|
||||
assert check_norm_state(unet.modules(), True)
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 16.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 2, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 2.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, False, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 64, 64])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 64, 64])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 64, 64])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 1.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(False, False, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 128, 128])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 128, 128])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 128, 128])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 128, 128])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 16.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 8.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 2, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet forward and outputs. The whole downsample rate is 4.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
||||
|
||||
# test UNet init_weights method.
|
||||
unet = UNet(
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 2, 2, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, False, False),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1))
|
||||
unet.init_weights(pretrained=None)
|
||||
print(unet)
|
||||
x = torch.randn(2, 3, 128, 128)
|
||||
x_outs = unet(x)
|
||||
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
|
||||
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
|
||||
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
|
||||
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
|
||||
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
|
Loading…
Reference in New Issue