From 59564510148a5996beeab3731fa05f7a37588fe1 Mon Sep 17 00:00:00 2001 From: Junjun2016 Date: Thu, 22 Oct 2020 02:24:38 +0800 Subject: [PATCH] add unet (#161) * add unet * add unet * add unet * update test_unet * update test_unet * update test_unet * update test_unet * fix bugs * add init method for unet * add test of UNet init_weights method * add registry * merge upsample * fix test * Update mmseg/models/backbones/unet.py Co-authored-by: Jerry Jiarui XU * Update mmseg/models/backbones/unet.py Co-authored-by: Jerry Jiarui XU * split UpConvBlock from UNet * use reversed * rename upsample module * rename upsample module * rename upsample module * rename upsample module Co-authored-by: Jerry Jiarui XU --- mmseg/models/backbones/__init__.py | 3 +- mmseg/models/backbones/unet.py | 428 ++++++++++++++ mmseg/models/utils/__init__.py | 4 +- mmseg/models/utils/up_conv_block.py | 101 ++++ tests/test_models/test_unet.py | 833 ++++++++++++++++++++++++++++ 5 files changed, 1367 insertions(+), 2 deletions(-) create mode 100644 mmseg/models/backbones/unet.py create mode 100644 mmseg/models/utils/up_conv_block.py create mode 100644 tests/test_models/test_unet.py diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 6253bab42..db5eb1c5c 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -4,8 +4,9 @@ from .mobilenet_v2 import MobileNetV2 from .resnest import ResNeSt from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnext import ResNeXt +from .unet import UNet __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', - 'ResNeSt', 'MobileNetV2' + 'ResNeSt', 'MobileNetV2', 'UNet' ] diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py new file mode 100644 index 000000000..0e1b001c8 --- /dev/null +++ b/mmseg/models/backbones/unet.py @@ -0,0 +1,428 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, + build_norm_layer, constant_init, kaiming_init) +from mmcv.runner import load_checkpoint +from mmcv.utils.parrots_wrapper import _BatchNorm + +from mmseg.utils import get_root_logger +from ..builder import BACKBONES +from ..utils import UpConvBlock + + +class BasicConvBlock(nn.Module): + """Basic convolutional block for UNet. + + This module consists of several plain convolutional layers. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers. Default: 2. + stride (int): Whether use stride convolution to downsample + the input feature map. If stride=2, it only uses stride convolution + in the first convolutional layer to downsample the input feature + map. Options are 1 or 2. Default: 1. + dilation (int): Whether use dilated convolution to expand the + receptive field. Set dilation rate of each convolutional layer and + the dilation rate of the first convolutional layer is always 1. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + dcn (bool): Use deformable convoluton in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + dcn=None, + plugins=None): + super(BasicConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.with_cp = with_cp + convs = [] + for i in range(num_convs): + convs.append( + ConvModule( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride if i == 0 else 1, + dilation=1 if i == 0 else dilation, + padding=1 if i == 0 else dilation, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + self.convs = nn.Sequential(*convs) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.convs, x) + else: + out = self.convs(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class DeconvModule(nn.Module): + """Deconvolution upsample module in decoder for UNet (2X upsample). + + This module uses deconvolution to upsample feature map in the decoder + of UNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of the convolutional layer. Default: 4. + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + kernel_size=4, + scale_factor=2): + super(DeconvModule, self).__init__() + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + stride = scale_factor + padding = (kernel_size - scale_factor) // 2 + self.with_cp = with_cp + deconv = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + norm_name, norm = build_norm_layer(norm_cfg, out_channels) + activate = build_activation_layer(act_cfg) + self.deconv_upsamping = nn.Sequential(deconv, norm, activate) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.deconv_upsamping, x) + else: + out = self.deconv_upsamping(x) + return out + + +@UPSAMPLE_LAYERS.register_module() +class InterpConv(nn.Module): + """Interpolation upsample module in decoder for UNet. + + This module uses interpolation to upsample feature map in the decoder + of UNet. It consists of one interpolation upsample layer and one + convolutional layer. It can be one interpolation upsample layer followed + by one convolutional layer (conv_first=False) or one convolutional layer + followed by one interpolation upsample layer (conv_first=True). + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + conv_first (bool): Whether convolutional layer or interpolation + upsample layer first. Default: False. It means interpolation + upsample layer followed by one convolutional layer. + kernel_size (int): Kernel size of the convolutional layer. Default: 1. + stride (int): Stride of the convolutional layer. Default: 1. + padding (int): Padding of the convolutional layer. Default: 1. + upsampe_cfg (dict): Interpolation config of the upsample layer. + Default: dict( + scale_factor=2, mode='bilinear', align_corners=False). + """ + + def __init__(self, + in_channels, + out_channels, + with_cp=False, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + *, + conv_cfg=None, + conv_first=False, + kernel_size=1, + stride=1, + padding=0, + upsampe_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False)): + super(InterpConv, self).__init__() + + self.with_cp = with_cp + conv = ConvModule( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + upsample = nn.Upsample(**upsampe_cfg) + if conv_first: + self.interp_upsample = nn.Sequential(conv, upsample) + else: + self.interp_upsample = nn.Sequential(upsample, conv) + + def forward(self, x): + """Forward function.""" + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(self.interp_upsample, x) + else: + out = self.interp_upsample(x) + return out + + +@BACKBONES.register_module() +class UNet(nn.Module): + """UNet backbone. + U-Net: Convolutional Networks for Biomedical Image Segmentation. + https://arxiv.org/pdf/1505.04597.pdf + + Args: + in_channels (int): Number of input image channels. Default" 3. + base_channels (int): Number of base channels of each stage. + The output channels of the first stage. Default: 64. + num_stages (int): Number of stages in encoder, normally 5. Default: 5. + strides (Sequence[int 1 | 2]): Strides of each stage in encoder. + len(strides) is equal to num_stages. Normally the stride of the + first stage in encoder is 1. If strides[i]=2, it uses stride + convolution to downsample in the correspondance encoder stage. + Default: (1, 1, 1, 1, 1). + enc_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondance encoder stage. + Default: (2, 2, 2, 2, 2). + dec_num_convs (Sequence[int]): Number of convolutional layers in the + convolution block of the correspondance decoder stage. + Default: (2, 2, 2, 2). + downsamples (Sequence[int]): Whether use MaxPool to downsample the + feature map after the first stage of encoder + (stages: [1, num_stages)). If the correspondance encoder stage use + stride convolution (strides[i]=2), it will never use MaxPool to + downsample, even downsamples[i-1]=True. + Default: (True, True, True, True). + enc_dilations (Sequence[int]): Dilation rate of each stage in encoder. + Default: (1, 1, 1, 1, 1). + dec_dilations (Sequence[int]): Dilation rate of each stage in decoder. + Default: (1, 1, 1, 1). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + dcn (bool): Use deformable convoluton in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + + Notice: + The input image size should be devisible by the whole downsample rate + of the encoder. More detail of the whole downsample rate can be found + in UNet._check_input_devisible. + + """ + + def __init__(self, + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False, + dcn=None, + plugins=None): + super(UNet, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + assert len(strides) == num_stages, \ + 'The length of strides should be equal to num_stages, '\ + f'while the strides is {strides}, the length of '\ + f'strides is {len(strides)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_num_convs) == num_stages, \ + 'The length of enc_num_convs should be equal to num_stages, '\ + f'while the enc_num_convs is {enc_num_convs}, the length of '\ + f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_num_convs) == (num_stages-1), \ + 'The length of dec_num_convs should be equal to (num_stages-1), '\ + f'while the dec_num_convs is {dec_num_convs}, the length of '\ + f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\ + f'{num_stages}.' + assert len(downsamples) == (num_stages-1), \ + 'The length of downsamples should be equal to (num_stages-1), '\ + f'while the downsamples is {downsamples}, the length of '\ + f'downsamples is {len(downsamples)}, and the num_stages is '\ + f'{num_stages}.' + assert len(enc_dilations) == num_stages, \ + 'The length of enc_dilations should be equal to num_stages, '\ + f'while the enc_dilations is {enc_dilations}, the length of '\ + f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\ + f'{num_stages}.' + assert len(dec_dilations) == (num_stages-1), \ + 'The length of dec_dilations should be equal to (num_stages-1), '\ + f'while the dec_dilations is {dec_dilations}, the length of '\ + f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\ + f'{num_stages}.' + self.num_stages = num_stages + self.strides = strides + self.downsamples = downsamples + self.norm_eval = norm_eval + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + for i in range(num_stages): + enc_conv_block = [] + if i != 0: + if strides[i] == 1 and downsamples[i - 1]: + enc_conv_block.append(nn.MaxPool2d(kernel_size=2)) + upsample = (strides[i] != 1 or downsamples[i - 1]) + self.decoder.append( + UpConvBlock( + conv_block=BasicConvBlock, + in_channels=base_channels * 2**i, + skip_channels=base_channels * 2**(i - 1), + out_channels=base_channels * 2**(i - 1), + num_convs=dec_num_convs[i - 1], + stride=1, + dilation=dec_dilations[i - 1], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + upsample_cfg=upsample_cfg if upsample else None, + dcn=None, + plugins=None)) + + enc_conv_block.append( + BasicConvBlock( + in_channels=in_channels, + out_channels=base_channels * 2**i, + num_convs=enc_num_convs[i], + stride=strides[i], + dilation=enc_dilations[i], + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None)) + self.encoder.append((nn.Sequential(*enc_conv_block))) + in_channels = base_channels * 2**i + + def forward(self, x): + self._check_input_devisible(x) + enc_outs = [] + for enc in self.encoder: + x = enc(x) + enc_outs.append(x) + dec_outs = [x] + for i in reversed(range(len(self.decoder))): + x = self.decoder[i](enc_outs[i], x) + dec_outs.append(x) + + return dec_outs + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(UNet, self).train(mode) + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _check_input_devisible(self, x): + h, w = x.shape[-2:] + whole_downsample_rate = 1 + for i in range(1, self.num_stages): + if self.strides[i] == 2 or self.downsamples[i - 1]: + whole_downsample_rate *= 2 + assert (h % whole_downsample_rate == 0) \ + and (w % whole_downsample_rate == 0),\ + f'The input image size {(h, w)} should be devisible by the whole '\ + f'downsample rate {whole_downsample_rate}, when num_stages is '\ + f'{self.num_stages}, strides is {self.strides}, and downsamples '\ + f'is {self.downsamples}.' + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 969a0c7d9..5d233a423 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -2,7 +2,9 @@ from .inverted_residual import InvertedResidual from .make_divisible import make_divisible from .res_layer import ResLayer from .self_attention_block import SelfAttentionBlock +from .up_conv_block import UpConvBlock __all__ = [ - 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual' + 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', + 'UpConvBlock' ] diff --git a/mmseg/models/utils/up_conv_block.py b/mmseg/models/utils/up_conv_block.py new file mode 100644 index 000000000..df8a2aa7d --- /dev/null +++ b/mmseg/models/utils/up_conv_block.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, build_upsample_layer + + +class UpConvBlock(nn.Module): + """Upsample convolution block in decoder for UNet. + + This upsample convolution block consists of one upsample module + followed by one convolution block. The upsample module expands the + high-level low-resolution feature map and the convolution block fuses + the upsampled high-level low-resolution feature map and the low-level + high-resolution feature map from encoder. + + Args: + conv_block (nn.Sequential): Sequential of convolutional layers. + in_channels (int): Number of input channels of the high-level + skip_channels (int): Number of input channels of the low-level + high-resolution feature map from encoder. + out_channels (int): Number of output channels. + num_convs (int): Number of convolutional layers in the conv_block. + Default: 2. + stride (int): Stride of convolutional layer in conv_block. Default: 1. + dilation (int): Dilation rate of convolutional layer in conv_block. + Default: 1. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + conv_cfg (dict | None): Config dict for convolution layer. + Default: None. + norm_cfg (dict | None): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict | None): Config dict for activation layer in ConvModule. + Default: dict(type='ReLU'). + upsample_cfg (dict): The upsample config of the upsample module in + decoder. Default: dict(type='InterpConv'). If the size of + high-level feature map is the same as that of skip feature map + (low-level feature map from encoder), it does not need upsample the + high-level feature map and the upsample_cfg is None. + dcn (bool): Use deformable convoluton in convolutional layer or not. + Default: None. + plugins (dict): plugins for convolutional layers. Default: None. + """ + + def __init__(self, + conv_block, + in_channels, + skip_channels, + out_channels, + num_convs=2, + stride=1, + dilation=1, + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + dcn=None, + plugins=None): + super(UpConvBlock, self).__init__() + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.conv_block = conv_block( + in_channels=2 * skip_channels, + out_channels=out_channels, + num_convs=num_convs, + stride=stride, + dilation=dilation, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + dcn=None, + plugins=None) + if upsample_cfg is not None: + self.upsample = build_upsample_layer( + cfg=upsample_cfg, + in_channels=in_channels, + out_channels=skip_channels, + with_cp=with_cp, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + else: + self.upsample = ConvModule( + in_channels, + skip_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, skip, x): + """Forward function.""" + + x = self.upsample(x) + out = torch.cat([skip, x], dim=1) + out = self.conv_block(out) + + return out diff --git a/tests/test_models/test_unet.py b/tests/test_models/test_unet.py new file mode 100644 index 000000000..febe4f0c9 --- /dev/null +++ b/tests/test_models/test_unet.py @@ -0,0 +1,833 @@ +import pytest +import torch +from mmcv.cnn import ConvModule +from mmcv.utils.parrots_wrapper import _BatchNorm +from torch import nn + +from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, + InterpConv, UNet, UpConvBlock) + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_unet_basic_conv_block(): + with pytest.raises(AssertionError): + # Not implemented yet. + dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + BasicConvBlock(64, 64, dcn=dcn) + + with pytest.raises(AssertionError): + # Not implemented yet. + plugins = [ + dict( + cfg=dict(type='ContextBlock', ratio=1. / 16), + position='after_conv3') + ] + BasicConvBlock(64, 64, plugins=plugins) + + with pytest.raises(AssertionError): + # Not implemented yet + plugins = [ + dict( + cfg=dict( + type='GeneralizedAttention', + spatial_range=-1, + num_heads=8, + attention_type='0010', + kv_stride=2), + position='after_conv2') + ] + BasicConvBlock(64, 64, plugins=plugins) + + # test BasicConvBlock with checkpoint forward + block = BasicConvBlock(16, 16, with_cp=True) + assert block.with_cp + x = torch.randn(1, 16, 64, 64, requires_grad=True) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 64, 64]) + + block = BasicConvBlock(16, 16, with_cp=False) + assert not block.with_cp + x = torch.randn(1, 16, 64, 64) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 64, 64]) + + # test BasicConvBlock with stride convolution to downsample + block = BasicConvBlock(16, 16, stride=2) + x = torch.randn(1, 16, 64, 64) + x_out = block(x) + assert x_out.shape == torch.Size([1, 16, 32, 32]) + + # test BasicConvBlock structure and forward + block = BasicConvBlock(16, 64, num_convs=3, dilation=3) + assert block.convs[0].conv.in_channels == 16 + assert block.convs[0].conv.out_channels == 64 + assert block.convs[0].conv.kernel_size == (3, 3) + assert block.convs[0].conv.dilation == (1, 1) + assert block.convs[0].conv.padding == (1, 1) + + assert block.convs[1].conv.in_channels == 64 + assert block.convs[1].conv.out_channels == 64 + assert block.convs[1].conv.kernel_size == (3, 3) + assert block.convs[1].conv.dilation == (3, 3) + assert block.convs[1].conv.padding == (3, 3) + + assert block.convs[2].conv.in_channels == 64 + assert block.convs[2].conv.out_channels == 64 + assert block.convs[2].conv.kernel_size == (3, 3) + assert block.convs[2].conv.dilation == (3, 3) + assert block.convs[2].conv.padding == (3, 3) + + +def test_deconv_module(): + with pytest.raises(AssertionError): + # kernel_size should be greater than or equal to scale_factor and + # (kernel_size - scale_factor) should be even numbers + DeconvModule(64, 32, kernel_size=1, scale_factor=2) + + with pytest.raises(AssertionError): + # kernel_size should be greater than or equal to scale_factor and + # (kernel_size - scale_factor) should be even numbers + DeconvModule(64, 32, kernel_size=3, scale_factor=2) + + with pytest.raises(AssertionError): + # kernel_size should be greater than or equal to scale_factor and + # (kernel_size - scale_factor) should be even numbers + DeconvModule(64, 32, kernel_size=5, scale_factor=4) + + # test DeconvModule with checkpoint forward and upsample 2X. + block = DeconvModule(64, 32, with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 128, 128, requires_grad=True) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + block = DeconvModule(64, 32, with_cp=False) + assert not block.with_cp + x = torch.randn(1, 64, 128, 128) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test DeconvModule with different kernel size for upsample 2X. + x = torch.randn(1, 64, 64, 64) + block = DeconvModule(64, 32, kernel_size=2, scale_factor=2) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 128, 128]) + + block = DeconvModule(64, 32, kernel_size=6, scale_factor=2) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 128, 128]) + + # test DeconvModule with different kernel size for upsample 4X. + x = torch.randn(1, 64, 64, 64) + block = DeconvModule(64, 32, kernel_size=4, scale_factor=4) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + block = DeconvModule(64, 32, kernel_size=6, scale_factor=4) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + +def test_interp_conv(): + # test InterpConv with checkpoint forward and upsample 2X. + block = InterpConv(64, 32, with_cp=True) + assert block.with_cp + x = torch.randn(1, 64, 128, 128, requires_grad=True) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + block = InterpConv(64, 32, with_cp=False) + assert not block.with_cp + x = torch.randn(1, 64, 128, 128) + x_out = block(x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test InterpConv with conv_first=False for upsample 2X. + block = InterpConv(64, 32, conv_first=False) + x = torch.randn(1, 64, 128, 128) + x_out = block(x) + assert isinstance(block.interp_upsample[0], nn.Upsample) + assert isinstance(block.interp_upsample[1], ConvModule) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test InterpConv with conv_first=True for upsample 2X. + block = InterpConv(64, 32, conv_first=True) + x = torch.randn(1, 64, 128, 128) + x_out = block(x) + assert isinstance(block.interp_upsample[0], ConvModule) + assert isinstance(block.interp_upsample[1], nn.Upsample) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test InterpConv with bilinear upsample for upsample 2X. + block = InterpConv( + 64, + 32, + conv_first=False, + upsampe_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False)) + x = torch.randn(1, 64, 128, 128) + x_out = block(x) + assert isinstance(block.interp_upsample[0], nn.Upsample) + assert isinstance(block.interp_upsample[1], ConvModule) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + assert block.interp_upsample[0].mode == 'bilinear' + + # test InterpConv with nearest upsample for upsample 2X. + block = InterpConv( + 64, + 32, + conv_first=False, + upsampe_cfg=dict(scale_factor=2, mode='nearest')) + x = torch.randn(1, 64, 128, 128) + x_out = block(x) + assert isinstance(block.interp_upsample[0], nn.Upsample) + assert isinstance(block.interp_upsample[1], ConvModule) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + assert block.interp_upsample[0].mode == 'nearest' + + +def test_up_conv_block(): + with pytest.raises(AssertionError): + # Not implemented yet. + dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn) + + with pytest.raises(AssertionError): + # Not implemented yet. + plugins = [ + dict( + cfg=dict(type='ContextBlock', ratio=1. / 16), + position='after_conv3') + ] + UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) + + with pytest.raises(AssertionError): + # Not implemented yet + plugins = [ + dict( + cfg=dict( + type='GeneralizedAttention', + spatial_range=-1, + num_heads=8, + attention_type='0010', + kv_stride=2), + position='after_conv2') + ] + UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) + + # test UpConvBlock with checkpoint forward and upsample 2X. + block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True) + skip_x = torch.randn(1, 32, 256, 256, requires_grad=True) + x = torch.randn(1, 64, 128, 128, requires_grad=True) + x_out = block(skip_x, x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test UpConvBlock with upsample=True for upsample 2X. The spatial size of + # skip_x is 2X larger than x. + block = UpConvBlock( + BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv')) + skip_x = torch.randn(1, 32, 256, 256) + x = torch.randn(1, 64, 128, 128) + x_out = block(skip_x, x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test UpConvBlock with upsample=False for upsample 2X. The spatial size of + # skip_x is the same as that of x. + block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None) + skip_x = torch.randn(1, 32, 256, 256) + x = torch.randn(1, 64, 256, 256) + x_out = block(skip_x, x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test UpConvBlock with different upsample method for upsample 2X. + # The upsample method is interpolation upsample (bilinear or nearest). + block = UpConvBlock( + BasicConvBlock, + 64, + 32, + 32, + upsample_cfg=dict( + type='InterpConv', + upsampe_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False))) + skip_x = torch.randn(1, 32, 256, 256) + x = torch.randn(1, 64, 128, 128) + x_out = block(skip_x, x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test UpConvBlock with different upsample method for upsample 2X. + # The upsample method is deconvolution upsample. + block = UpConvBlock( + BasicConvBlock, + 64, + 32, + 32, + upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2)) + skip_x = torch.randn(1, 32, 256, 256) + x = torch.randn(1, 64, 128, 128) + x_out = block(skip_x, x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + # test BasicConvBlock structure and forward + block = UpConvBlock( + conv_block=BasicConvBlock, + in_channels=64, + skip_channels=32, + out_channels=32, + num_convs=3, + dilation=3, + upsample_cfg=dict( + type='InterpConv', + upsampe_cfg=dict( + scale_factor=2, mode='bilinear', align_corners=False))) + skip_x = torch.randn(1, 32, 256, 256) + x = torch.randn(1, 64, 128, 128) + x_out = block(skip_x, x) + assert x_out.shape == torch.Size([1, 32, 256, 256]) + + assert block.conv_block.convs[0].conv.in_channels == 64 + assert block.conv_block.convs[0].conv.out_channels == 32 + assert block.conv_block.convs[0].conv.kernel_size == (3, 3) + assert block.conv_block.convs[0].conv.dilation == (1, 1) + assert block.conv_block.convs[0].conv.padding == (1, 1) + + assert block.conv_block.convs[1].conv.in_channels == 32 + assert block.conv_block.convs[1].conv.out_channels == 32 + assert block.conv_block.convs[1].conv.kernel_size == (3, 3) + assert block.conv_block.convs[1].conv.dilation == (3, 3) + assert block.conv_block.convs[1].conv.padding == (3, 3) + + assert block.conv_block.convs[2].conv.in_channels == 32 + assert block.conv_block.convs[2].conv.out_channels == 32 + assert block.conv_block.convs[2].conv.kernel_size == (3, 3) + assert block.conv_block.convs[2].conv.dilation == (3, 3) + assert block.conv_block.convs[2].conv.padding == (3, 3) + + assert block.upsample.interp_upsample[1].conv.in_channels == 64 + assert block.upsample.interp_upsample[1].conv.out_channels == 32 + assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1) + assert block.upsample.interp_upsample[1].conv.dilation == (1, 1) + assert block.upsample.interp_upsample[1].conv.padding == (0, 0) + + +def test_unet(): + with pytest.raises(AssertionError): + # Not implemented yet. + dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) + UNet(3, 64, 5, dcn=dcn) + + with pytest.raises(AssertionError): + # Not implemented yet. + plugins = [ + dict( + cfg=dict(type='ContextBlock', ratio=1. / 16), + position='after_conv3') + ] + UNet(3, 64, 5, plugins=plugins) + + with pytest.raises(AssertionError): + # Not implemented yet + plugins = [ + dict( + cfg=dict( + type='GeneralizedAttention', + spatial_range=-1, + num_heads=8, + attention_type='0010', + kv_stride=2), + position='after_conv2') + ] + UNet(3, 64, 5, plugins=plugins) + + with pytest.raises(AssertionError): + # Check whether the input image size can be devisible by the whole + # downsample rate of the encoder. The whole downsample rate of this + # case is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=4, + strides=(1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2), + dec_num_convs=(2, 2, 2), + downsamples=(True, True, True), + enc_dilations=(1, 1, 1, 1), + dec_dilations=(1, 1, 1)) + x = torch.randn(2, 3, 65, 65) + unet(x) + + with pytest.raises(AssertionError): + # Check whether the input image size can be devisible by the whole + # downsample rate of the encoder. The whole downsample rate of this + # case is 16. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 65, 65) + unet(x) + + with pytest.raises(AssertionError): + # Check whether the input image size can be devisible by the whole + # downsample rate of the encoder. The whole downsample rate of this + # case is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 65, 65) + unet(x) + + with pytest.raises(AssertionError): + # Check whether the input image size can be devisible by the whole + # downsample rate of the encoder. The whole downsample rate of this + # case is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 2, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 65, 65) + unet(x) + + with pytest.raises(AssertionError): + # Check whether the input image size can be devisible by the whole + # downsample rate of the encoder. The whole downsample rate of this + # case is 32. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=6, + strides=(1, 1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2, 2), + downsamples=(True, True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1, 1)) + x = torch.randn(2, 3, 65, 65) + unet(x) + + with pytest.raises(AssertionError): + # Check if num_stages matchs strides, len(strides)=num_stages + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 64, 64) + unet(x) + + with pytest.raises(AssertionError): + # Check if num_stages matchs strides, len(enc_num_convs)=num_stages + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 64, 64) + unet(x) + + with pytest.raises(AssertionError): + # Check if num_stages matchs strides, len(dec_num_convs)=num_stages-1 + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 64, 64) + unet(x) + + with pytest.raises(AssertionError): + # Check if num_stages matchs strides, len(downsamples)=num_stages-1 + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 64, 64) + unet(x) + + with pytest.raises(AssertionError): + # Check if num_stages matchs strides, len(enc_dilations)=num_stages + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + x = torch.randn(2, 3, 64, 64) + unet(x) + + with pytest.raises(AssertionError): + # Check if num_stages matchs strides, len(dec_dilations)=num_stages-1 + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1, 1)) + x = torch.randn(2, 3, 64, 64) + unet(x) + + # test UNet norm_eval=True + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + norm_eval=True) + unet.train() + assert check_norm_state(unet.modules(), False) + + # test UNet norm_eval=False + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + norm_eval=False) + unet.train() + assert check_norm_state(unet.modules(), True) + + # test UNet forward and outputs. The whole downsample rate is 16. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 2, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 4. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) + assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 4. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) + assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 4. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) + assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 2. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, False, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 64, 64]) + assert x_outs[1].shape == torch.Size([2, 512, 64, 64]) + assert x_outs[2].shape == torch.Size([2, 256, 64, 64]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 1. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(False, False, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 128, 128]) + assert x_outs[1].shape == torch.Size([2, 512, 128, 128]) + assert x_outs[2].shape == torch.Size([2, 256, 128, 128]) + assert x_outs[3].shape == torch.Size([2, 128, 128, 128]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 16. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + print(unet) + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + print(unet) + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 8. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 2, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + print(unet) + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) + assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet forward and outputs. The whole downsample rate is 4. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + print(unet) + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) + assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) + + # test UNet init_weights method. + unet = UNet( + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 2, 2, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, False, False), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1)) + unet.init_weights(pretrained=None) + print(unet) + x = torch.randn(2, 3, 128, 128) + x_outs = unet(x) + assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) + assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) + assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) + assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) + assert x_outs[4].shape == torch.Size([2, 64, 128, 128])