mmsegmentation/tests/test_models/test_heads/test_decode_head.py

166 lines
6.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import pytest
import torch
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from .utils import to_cuda
@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
def test_decode_head():
with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead([32, 16], 16, num_classes=19)
with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2])
with pytest.raises(AssertionError):
# supported mode is resize_concat only
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat')
with pytest.raises(AssertionError):
# in_channels should be list|tuple
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat')
with pytest.raises(AssertionError):
# in_index should be list|tuple
BaseDecodeHead([32],
16,
in_index=-1,
num_classes=19,
input_transform='resize_concat')
with pytest.raises(AssertionError):
# len(in_index) should equal len(in_channels)
BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[-1],
input_transform='resize_concat')
# test default dropout
head = BaseDecodeHead(32, 16, num_classes=19)
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
# test set dropout
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2)
assert hasattr(head, 'dropout') and head.dropout.p == 0.2
# test no input_transform
inputs = [torch.randn(1, 32, 45, 45)]
head = BaseDecodeHead(32, 16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 32
assert head.input_transform is None
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 32, 45, 45)
# test input_transform = resize_concat
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)]
head = BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[0, 1],
input_transform='resize_concat')
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 48
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45)
# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss
# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
target = torch.ones(2, 1, 64, 64).long()
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss_3 = head.losses(seg_logit=inputs, seg_label=target)
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
head, target = to_cuda(head, target)
loss = head.losses(seg_logit=inputs, seg_label=target)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']