196 lines
6.1 KiB
Python
196 lines
6.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead,
|
|
KernelUpdateHead)
|
|
from .utils import to_cuda
|
|
|
|
num_stages = 3
|
|
conv_kernel_size = 1
|
|
|
|
kernel_updator_cfg = dict(
|
|
type='KernelUpdator',
|
|
in_channels=16,
|
|
feat_channels=16,
|
|
out_channels=16,
|
|
gate_norm_act=True,
|
|
activate_out=True,
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
norm_cfg=dict(type='LN'))
|
|
|
|
|
|
def test_knet_head():
|
|
# test init function of kernel update head
|
|
kernel_update_head = KernelUpdateHead(
|
|
num_classes=150,
|
|
num_ffn_fcs=2,
|
|
num_heads=8,
|
|
num_mask_fcs=1,
|
|
feedforward_channels=128,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
dropout=0.0,
|
|
conv_kernel_size=conv_kernel_size,
|
|
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
|
with_ffn=True,
|
|
feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
|
kernel_init=True,
|
|
kernel_updator_cfg=kernel_updator_cfg)
|
|
kernel_update_head.init_weights()
|
|
|
|
head = IterativeDecodeHead(
|
|
num_stages=num_stages,
|
|
kernel_update_head=[
|
|
dict(
|
|
type='KernelUpdateHead',
|
|
num_classes=150,
|
|
num_ffn_fcs=2,
|
|
num_heads=8,
|
|
num_mask_fcs=1,
|
|
feedforward_channels=128,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
dropout=0.0,
|
|
conv_kernel_size=conv_kernel_size,
|
|
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
|
with_ffn=True,
|
|
feat_transform_cfg=dict(
|
|
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
|
kernel_init=False,
|
|
kernel_updator_cfg=kernel_updator_cfg)
|
|
for _ in range(num_stages)
|
|
],
|
|
kernel_generate_head=dict(
|
|
type='FCNHead',
|
|
in_channels=128,
|
|
in_index=3,
|
|
channels=32,
|
|
num_convs=2,
|
|
concat_input=True,
|
|
dropout_ratio=0.1,
|
|
num_classes=150,
|
|
align_corners=False))
|
|
head.init_weights()
|
|
inputs = [
|
|
torch.randn(1, 16, 27, 32),
|
|
torch.randn(1, 32, 27, 16),
|
|
torch.randn(1, 64, 27, 16),
|
|
torch.randn(1, 128, 27, 16)
|
|
]
|
|
|
|
if torch.cuda.is_available():
|
|
head, inputs = to_cuda(head, inputs)
|
|
outputs = head(inputs)
|
|
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
|
|
|
|
# test whether only return the prediction of
|
|
# the last stage during testing
|
|
with torch.no_grad():
|
|
head.eval()
|
|
outputs = head(inputs)
|
|
assert outputs.shape == (1, head.num_classes, 27, 16)
|
|
|
|
# test K-Net without `feat_transform_cfg`
|
|
head = IterativeDecodeHead(
|
|
num_stages=num_stages,
|
|
kernel_update_head=[
|
|
dict(
|
|
type='KernelUpdateHead',
|
|
num_classes=150,
|
|
num_ffn_fcs=2,
|
|
num_heads=8,
|
|
num_mask_fcs=1,
|
|
feedforward_channels=128,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
dropout=0.0,
|
|
conv_kernel_size=conv_kernel_size,
|
|
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
|
with_ffn=True,
|
|
feat_transform_cfg=None,
|
|
kernel_updator_cfg=kernel_updator_cfg)
|
|
for _ in range(num_stages)
|
|
],
|
|
kernel_generate_head=dict(
|
|
type='FCNHead',
|
|
in_channels=128,
|
|
in_index=3,
|
|
channels=32,
|
|
num_convs=2,
|
|
concat_input=True,
|
|
dropout_ratio=0.1,
|
|
num_classes=150,
|
|
align_corners=False))
|
|
head.init_weights()
|
|
|
|
inputs = [
|
|
torch.randn(1, 16, 27, 32),
|
|
torch.randn(1, 32, 27, 16),
|
|
torch.randn(1, 64, 27, 16),
|
|
torch.randn(1, 128, 27, 16)
|
|
]
|
|
|
|
if torch.cuda.is_available():
|
|
head, inputs = to_cuda(head, inputs)
|
|
outputs = head(inputs)
|
|
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
|
|
|
|
# test K-Net with
|
|
# self.mask_transform_stride == 2 and self.feat_gather_stride == 1
|
|
head = IterativeDecodeHead(
|
|
num_stages=num_stages,
|
|
kernel_update_head=[
|
|
dict(
|
|
type='KernelUpdateHead',
|
|
num_classes=150,
|
|
num_ffn_fcs=2,
|
|
num_heads=8,
|
|
num_mask_fcs=1,
|
|
feedforward_channels=128,
|
|
in_channels=32,
|
|
out_channels=32,
|
|
dropout=0.0,
|
|
conv_kernel_size=conv_kernel_size,
|
|
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
|
with_ffn=True,
|
|
feat_transform_cfg=dict(
|
|
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
|
kernel_init=False,
|
|
mask_transform_stride=2,
|
|
feat_gather_stride=1,
|
|
kernel_updator_cfg=kernel_updator_cfg)
|
|
for _ in range(num_stages)
|
|
],
|
|
kernel_generate_head=dict(
|
|
type='FCNHead',
|
|
in_channels=128,
|
|
in_index=3,
|
|
channels=32,
|
|
num_convs=2,
|
|
concat_input=True,
|
|
dropout_ratio=0.1,
|
|
num_classes=150,
|
|
align_corners=False))
|
|
head.init_weights()
|
|
|
|
inputs = [
|
|
torch.randn(1, 16, 27, 32),
|
|
torch.randn(1, 32, 27, 16),
|
|
torch.randn(1, 64, 27, 16),
|
|
torch.randn(1, 128, 27, 16)
|
|
]
|
|
|
|
if torch.cuda.is_available():
|
|
head, inputs = to_cuda(head, inputs)
|
|
outputs = head(inputs)
|
|
assert outputs[-1].shape == (1, head.num_classes, 26, 16)
|
|
|
|
# test loss function in K-Net
|
|
fake_label = torch.ones_like(
|
|
outputs[-1][:, 0:1, :, :], dtype=torch.int16).long()
|
|
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
|
|
assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0'])
|
|
assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1'])
|
|
assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2'])
|
|
assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])
|