mmsegmentation/tests/test_models/test_heads/test_knet_head.py

196 lines
6.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead,
KernelUpdateHead)
from .utils import to_cuda
num_stages = 3
conv_kernel_size = 1
kernel_updator_cfg = dict(
type='KernelUpdator',
in_channels=16,
feat_channels=16,
out_channels=16,
gate_norm_act=True,
activate_out=True,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))
def test_knet_head():
# test init function of kernel update head
kernel_update_head = KernelUpdateHead(
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_init=True,
kernel_updator_cfg=kernel_updator_cfg)
kernel_update_head.init_weights()
head = IterativeDecodeHead(
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_init=False,
kernel_updator_cfg=kernel_updator_cfg)
for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=32,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
align_corners=False))
head.init_weights()
inputs = [
torch.randn(1, 16, 27, 32),
torch.randn(1, 32, 27, 16),
torch.randn(1, 64, 27, 16),
torch.randn(1, 128, 27, 16)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
# test whether only return the prediction of
# the last stage during testing
with torch.no_grad():
head.eval()
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 27, 16)
# test K-Net without `feat_transform_cfg`
head = IterativeDecodeHead(
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=None,
kernel_updator_cfg=kernel_updator_cfg)
for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=32,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
align_corners=False))
head.init_weights()
inputs = [
torch.randn(1, 16, 27, 32),
torch.randn(1, 32, 27, 16),
torch.randn(1, 64, 27, 16),
torch.randn(1, 128, 27, 16)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
# test K-Net with
# self.mask_transform_stride == 2 and self.feat_gather_stride == 1
head = IterativeDecodeHead(
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=128,
in_channels=32,
out_channels=32,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_init=False,
mask_transform_stride=2,
feat_gather_stride=1,
kernel_updator_cfg=kernel_updator_cfg)
for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=128,
in_index=3,
channels=32,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
align_corners=False))
head.init_weights()
inputs = [
torch.randn(1, 16, 27, 32),
torch.randn(1, 32, 27, 16),
torch.randn(1, 64, 27, 16),
torch.randn(1, 128, 27, 16)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs[-1].shape == (1, head.num_classes, 26, 16)
# test loss function in K-Net
fake_label = torch.ones_like(
outputs[-1][:, 0:1, :, :], dtype=torch.int16).long()
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0'])
assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1'])
assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2'])
assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])