mmsegmentation/tests/test_models/test_heads/test_decode_head.py

194 lines
6.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch
import pytest
import torch
from mmengine.structures import PixelData
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.structures import SegDataSample
from .utils import to_cuda
@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
def test_decode_head():
with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead([32, 16], 16, num_classes=19)
with pytest.raises(AssertionError):
# default input_transform doesn't accept multiple inputs
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2])
with pytest.raises(AssertionError):
# supported mode is resize_concat only
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat')
with pytest.raises(AssertionError):
# in_channels should be list|tuple
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat')
with pytest.raises(AssertionError):
# in_index should be list|tuple
BaseDecodeHead([32],
16,
in_index=-1,
num_classes=19,
input_transform='resize_concat')
with pytest.raises(AssertionError):
# len(in_index) should equal len(in_channels)
BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[-1],
input_transform='resize_concat')
with pytest.raises(ValueError):
# out_channels should be equal to num_classes
BaseDecodeHead(32, 16, num_classes=19, out_channels=18)
# test out_channels
head = BaseDecodeHead(32, 16, num_classes=2)
assert head.out_channels == 2
# test out_channels == 1 and num_classes == 2
head = BaseDecodeHead(32, 16, num_classes=2, out_channels=1)
assert head.out_channels == 1 and head.num_classes == 2
# test default dropout
head = BaseDecodeHead(32, 16, num_classes=19)
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
# test set dropout
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2)
assert hasattr(head, 'dropout') and head.dropout.p == 0.2
# test no input_transform
inputs = [torch.randn(1, 32, 45, 45)]
head = BaseDecodeHead(32, 16, num_classes=19)
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 32
assert head.input_transform is None
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 32, 45, 45)
# test input_transform = resize_concat
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)]
head = BaseDecodeHead([32, 16],
16,
num_classes=19,
in_index=[0, 1],
input_transform='resize_concat')
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == 48
assert head.input_transform == 'resize_concat'
transformed_inputs = head._transform_inputs(inputs)
assert transformed_inputs.shape == (1, 48, 45, 45)
# test multi-loss, loss_decode is dict
with pytest.raises(TypeError):
# loss_decode must be a dict or sequence of dict.
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_ce' in loss
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=[
dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2')
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_1' in loss
assert 'loss_2' in loss
# 'loss_decode' must be a dict or sequence of dict
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
with pytest.raises(TypeError):
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
# test multi-loss, loss_decode is list of dict
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
dict(type='CrossEntropyLoss', loss_name='loss_2'),
dict(type='CrossEntropyLoss', loss_name='loss_3')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_1' in loss
assert 'loss_2' in loss
assert 'loss_3' in loss
# test multi-loss, loss_decode is list of dict, names of them are identical
inputs = torch.randn(2, 19, 8, 8).float()
data_samples = [
SegDataSample(gt_sem_seg=PixelData(data=torch.ones(64, 64).long()))
for _ in range(2)
]
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss_3 = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
head = BaseDecodeHead(
3,
16,
num_classes=19,
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
loss = head.loss_by_feat(
seg_logits=inputs, batch_data_samples=data_samples)
assert 'loss_ce' in loss
assert 'loss_ce' in loss_3
assert loss_3['loss_ce'] == 3 * loss['loss_ce']