823 lines
29 KiB
Python
823 lines
29 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
|
InterpConv, UNet, UpConvBlock)
|
|
from mmseg.ops import Upsample
|
|
from .utils import check_norm_state
|
|
|
|
|
|
def test_unet_basic_conv_block():
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet.
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
|
BasicConvBlock(64, 64, dcn=dcn)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet.
|
|
plugins = [
|
|
dict(
|
|
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
|
position='after_conv3')
|
|
]
|
|
BasicConvBlock(64, 64, plugins=plugins)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet
|
|
plugins = [
|
|
dict(
|
|
cfg=dict(
|
|
type='GeneralizedAttention',
|
|
spatial_range=-1,
|
|
num_heads=8,
|
|
attention_type='0010',
|
|
kv_stride=2),
|
|
position='after_conv2')
|
|
]
|
|
BasicConvBlock(64, 64, plugins=plugins)
|
|
|
|
# test BasicConvBlock with checkpoint forward
|
|
block = BasicConvBlock(16, 16, with_cp=True)
|
|
assert block.with_cp
|
|
x = torch.randn(1, 16, 64, 64, requires_grad=True)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 16, 64, 64])
|
|
|
|
block = BasicConvBlock(16, 16, with_cp=False)
|
|
assert not block.with_cp
|
|
x = torch.randn(1, 16, 64, 64)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 16, 64, 64])
|
|
|
|
# test BasicConvBlock with stride convolution to downsample
|
|
block = BasicConvBlock(16, 16, stride=2)
|
|
x = torch.randn(1, 16, 64, 64)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 16, 32, 32])
|
|
|
|
# test BasicConvBlock structure and forward
|
|
block = BasicConvBlock(16, 64, num_convs=3, dilation=3)
|
|
assert block.convs[0].conv.in_channels == 16
|
|
assert block.convs[0].conv.out_channels == 64
|
|
assert block.convs[0].conv.kernel_size == (3, 3)
|
|
assert block.convs[0].conv.dilation == (1, 1)
|
|
assert block.convs[0].conv.padding == (1, 1)
|
|
|
|
assert block.convs[1].conv.in_channels == 64
|
|
assert block.convs[1].conv.out_channels == 64
|
|
assert block.convs[1].conv.kernel_size == (3, 3)
|
|
assert block.convs[1].conv.dilation == (3, 3)
|
|
assert block.convs[1].conv.padding == (3, 3)
|
|
|
|
assert block.convs[2].conv.in_channels == 64
|
|
assert block.convs[2].conv.out_channels == 64
|
|
assert block.convs[2].conv.kernel_size == (3, 3)
|
|
assert block.convs[2].conv.dilation == (3, 3)
|
|
assert block.convs[2].conv.padding == (3, 3)
|
|
|
|
|
|
def test_deconv_module():
|
|
with pytest.raises(AssertionError):
|
|
# kernel_size should be greater than or equal to scale_factor and
|
|
# (kernel_size - scale_factor) should be even numbers
|
|
DeconvModule(64, 32, kernel_size=1, scale_factor=2)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# kernel_size should be greater than or equal to scale_factor and
|
|
# (kernel_size - scale_factor) should be even numbers
|
|
DeconvModule(64, 32, kernel_size=3, scale_factor=2)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# kernel_size should be greater than or equal to scale_factor and
|
|
# (kernel_size - scale_factor) should be even numbers
|
|
DeconvModule(64, 32, kernel_size=5, scale_factor=4)
|
|
|
|
# test DeconvModule with checkpoint forward and upsample 2X.
|
|
block = DeconvModule(64, 32, with_cp=True)
|
|
assert block.with_cp
|
|
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
block = DeconvModule(64, 32, with_cp=False)
|
|
assert not block.with_cp
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test DeconvModule with different kernel size for upsample 2X.
|
|
x = torch.randn(1, 64, 64, 64)
|
|
block = DeconvModule(64, 32, kernel_size=2, scale_factor=2)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 128, 128])
|
|
|
|
block = DeconvModule(64, 32, kernel_size=6, scale_factor=2)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 128, 128])
|
|
|
|
# test DeconvModule with different kernel size for upsample 4X.
|
|
x = torch.randn(1, 64, 64, 64)
|
|
block = DeconvModule(64, 32, kernel_size=4, scale_factor=4)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
block = DeconvModule(64, 32, kernel_size=6, scale_factor=4)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
|
|
def test_interp_conv():
|
|
# test InterpConv with checkpoint forward and upsample 2X.
|
|
block = InterpConv(64, 32, with_cp=True)
|
|
assert block.with_cp
|
|
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
block = InterpConv(64, 32, with_cp=False)
|
|
assert not block.with_cp
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test InterpConv with conv_first=False for upsample 2X.
|
|
block = InterpConv(64, 32, conv_first=False)
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(x)
|
|
assert isinstance(block.interp_upsample[0], Upsample)
|
|
assert isinstance(block.interp_upsample[1], ConvModule)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test InterpConv with conv_first=True for upsample 2X.
|
|
block = InterpConv(64, 32, conv_first=True)
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(x)
|
|
assert isinstance(block.interp_upsample[0], ConvModule)
|
|
assert isinstance(block.interp_upsample[1], Upsample)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test InterpConv with bilinear upsample for upsample 2X.
|
|
block = InterpConv(
|
|
64,
|
|
32,
|
|
conv_first=False,
|
|
upsample_cfg=dict(
|
|
scale_factor=2, mode='bilinear', align_corners=False))
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(x)
|
|
assert isinstance(block.interp_upsample[0], Upsample)
|
|
assert isinstance(block.interp_upsample[1], ConvModule)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
assert block.interp_upsample[0].mode == 'bilinear'
|
|
|
|
# test InterpConv with nearest upsample for upsample 2X.
|
|
block = InterpConv(
|
|
64,
|
|
32,
|
|
conv_first=False,
|
|
upsample_cfg=dict(scale_factor=2, mode='nearest'))
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(x)
|
|
assert isinstance(block.interp_upsample[0], Upsample)
|
|
assert isinstance(block.interp_upsample[1], ConvModule)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
assert block.interp_upsample[0].mode == 'nearest'
|
|
|
|
|
|
def test_up_conv_block():
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet.
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
|
UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet.
|
|
plugins = [
|
|
dict(
|
|
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
|
position='after_conv3')
|
|
]
|
|
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet
|
|
plugins = [
|
|
dict(
|
|
cfg=dict(
|
|
type='GeneralizedAttention',
|
|
spatial_range=-1,
|
|
num_heads=8,
|
|
attention_type='0010',
|
|
kv_stride=2),
|
|
position='after_conv2')
|
|
]
|
|
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
|
|
|
|
# test UpConvBlock with checkpoint forward and upsample 2X.
|
|
block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True)
|
|
skip_x = torch.randn(1, 32, 256, 256, requires_grad=True)
|
|
x = torch.randn(1, 64, 128, 128, requires_grad=True)
|
|
x_out = block(skip_x, x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test UpConvBlock with upsample=True for upsample 2X. The spatial size of
|
|
# skip_x is 2X larger than x.
|
|
block = UpConvBlock(
|
|
BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv'))
|
|
skip_x = torch.randn(1, 32, 256, 256)
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(skip_x, x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test UpConvBlock with upsample=False for upsample 2X. The spatial size of
|
|
# skip_x is the same as that of x.
|
|
block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None)
|
|
skip_x = torch.randn(1, 32, 256, 256)
|
|
x = torch.randn(1, 64, 256, 256)
|
|
x_out = block(skip_x, x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test UpConvBlock with different upsample method for upsample 2X.
|
|
# The upsample method is interpolation upsample (bilinear or nearest).
|
|
block = UpConvBlock(
|
|
BasicConvBlock,
|
|
64,
|
|
32,
|
|
32,
|
|
upsample_cfg=dict(
|
|
type='InterpConv',
|
|
upsample_cfg=dict(
|
|
scale_factor=2, mode='bilinear', align_corners=False)))
|
|
skip_x = torch.randn(1, 32, 256, 256)
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(skip_x, x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test UpConvBlock with different upsample method for upsample 2X.
|
|
# The upsample method is deconvolution upsample.
|
|
block = UpConvBlock(
|
|
BasicConvBlock,
|
|
64,
|
|
32,
|
|
32,
|
|
upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2))
|
|
skip_x = torch.randn(1, 32, 256, 256)
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(skip_x, x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
# test BasicConvBlock structure and forward
|
|
block = UpConvBlock(
|
|
conv_block=BasicConvBlock,
|
|
in_channels=64,
|
|
skip_channels=32,
|
|
out_channels=32,
|
|
num_convs=3,
|
|
dilation=3,
|
|
upsample_cfg=dict(
|
|
type='InterpConv',
|
|
upsample_cfg=dict(
|
|
scale_factor=2, mode='bilinear', align_corners=False)))
|
|
skip_x = torch.randn(1, 32, 256, 256)
|
|
x = torch.randn(1, 64, 128, 128)
|
|
x_out = block(skip_x, x)
|
|
assert x_out.shape == torch.Size([1, 32, 256, 256])
|
|
|
|
assert block.conv_block.convs[0].conv.in_channels == 64
|
|
assert block.conv_block.convs[0].conv.out_channels == 32
|
|
assert block.conv_block.convs[0].conv.kernel_size == (3, 3)
|
|
assert block.conv_block.convs[0].conv.dilation == (1, 1)
|
|
assert block.conv_block.convs[0].conv.padding == (1, 1)
|
|
|
|
assert block.conv_block.convs[1].conv.in_channels == 32
|
|
assert block.conv_block.convs[1].conv.out_channels == 32
|
|
assert block.conv_block.convs[1].conv.kernel_size == (3, 3)
|
|
assert block.conv_block.convs[1].conv.dilation == (3, 3)
|
|
assert block.conv_block.convs[1].conv.padding == (3, 3)
|
|
|
|
assert block.conv_block.convs[2].conv.in_channels == 32
|
|
assert block.conv_block.convs[2].conv.out_channels == 32
|
|
assert block.conv_block.convs[2].conv.kernel_size == (3, 3)
|
|
assert block.conv_block.convs[2].conv.dilation == (3, 3)
|
|
assert block.conv_block.convs[2].conv.padding == (3, 3)
|
|
|
|
assert block.upsample.interp_upsample[1].conv.in_channels == 64
|
|
assert block.upsample.interp_upsample[1].conv.out_channels == 32
|
|
assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1)
|
|
assert block.upsample.interp_upsample[1].conv.dilation == (1, 1)
|
|
assert block.upsample.interp_upsample[1].conv.padding == (0, 0)
|
|
|
|
|
|
def test_unet():
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet.
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
|
|
UNet(3, 64, 5, dcn=dcn)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet.
|
|
plugins = [
|
|
dict(
|
|
cfg=dict(type='ContextBlock', ratio=1. / 16),
|
|
position='after_conv3')
|
|
]
|
|
UNet(3, 64, 5, plugins=plugins)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Not implemented yet
|
|
plugins = [
|
|
dict(
|
|
cfg=dict(
|
|
type='GeneralizedAttention',
|
|
spatial_range=-1,
|
|
num_heads=8,
|
|
attention_type='0010',
|
|
kv_stride=2),
|
|
position='after_conv2')
|
|
]
|
|
UNet(3, 64, 5, plugins=plugins)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check whether the input image size can be divisible by the whole
|
|
# downsample rate of the encoder. The whole downsample rate of this
|
|
# case is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=4,
|
|
strides=(1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2),
|
|
downsamples=(True, True, True),
|
|
enc_dilations=(1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1))
|
|
x = torch.randn(2, 3, 65, 65)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check whether the input image size can be divisible by the whole
|
|
# downsample rate of the encoder. The whole downsample rate of this
|
|
# case is 16.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 65, 65)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check whether the input image size can be divisible by the whole
|
|
# downsample rate of the encoder. The whole downsample rate of this
|
|
# case is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 65, 65)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check whether the input image size can be divisible by the whole
|
|
# downsample rate of the encoder. The whole downsample rate of this
|
|
# case is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 2, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 65, 65)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check whether the input image size can be divisible by the whole
|
|
# downsample rate of the encoder. The whole downsample rate of this
|
|
# case is 32.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=6,
|
|
strides=(1, 1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2, 2),
|
|
downsamples=(True, True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 65, 65)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check if num_stages matches strides, len(strides)=num_stages
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 64, 64)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check if num_stages matches strides, len(enc_num_convs)=num_stages
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 64, 64)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check if num_stages matches strides, len(dec_num_convs)=num_stages-1
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 64, 64)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check if num_stages matches strides, len(downsamples)=num_stages-1
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 64, 64)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check if num_stages matches strides, len(enc_dilations)=num_stages
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 64, 64)
|
|
unet(x)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Check if num_stages matches strides, len(dec_dilations)=num_stages-1
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 64, 64)
|
|
unet(x)
|
|
|
|
# test UNet norm_eval=True
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1),
|
|
norm_eval=True)
|
|
unet.train()
|
|
assert check_norm_state(unet.modules(), False)
|
|
|
|
# test UNet norm_eval=False
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1),
|
|
norm_eval=False)
|
|
unet.train()
|
|
assert check_norm_state(unet.modules(), True)
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 16.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 8, 8])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 2, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 4.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 4.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 4.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 2.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, False, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 64, 64])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 64, 64])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 64, 64])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 1.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 1, 1, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(False, False, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 128, 128])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 128, 128])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 128, 128])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 128, 128])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 16.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, True),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 8, 8])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 8.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 2, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, True, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 16, 16])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 16, 16])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet forward and outputs. The whole downsample rate is 4.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1))
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|
|
|
|
# test UNet init_weights method.
|
|
unet = UNet(
|
|
in_channels=3,
|
|
base_channels=4,
|
|
num_stages=5,
|
|
strides=(1, 2, 2, 1, 1),
|
|
enc_num_convs=(2, 2, 2, 2, 2),
|
|
dec_num_convs=(2, 2, 2, 2),
|
|
downsamples=(True, True, False, False),
|
|
enc_dilations=(1, 1, 1, 1, 1),
|
|
dec_dilations=(1, 1, 1, 1),
|
|
pretrained=None)
|
|
unet.init_weights()
|
|
x = torch.randn(2, 3, 128, 128)
|
|
x_outs = unet(x)
|
|
assert x_outs[0].shape == torch.Size([2, 64, 32, 32])
|
|
assert x_outs[1].shape == torch.Size([2, 32, 32, 32])
|
|
assert x_outs[2].shape == torch.Size([2, 16, 32, 32])
|
|
assert x_outs[3].shape == torch.Size([2, 8, 64, 64])
|
|
assert x_outs[4].shape == torch.Size([2, 4, 128, 128])
|