mmsegmentation/tests/test_models/test_heads/test_dpt_head.py

49 lines
1.3 KiB
Python

import pytest
import torch
from mmseg.models.decode_heads import DPTHead
def test_dpt_head():
with pytest.raises(AssertionError):
# input_transform must be 'multiple_select'
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=4,
num_classes=19,
in_index=[0, 1, 2, 3])
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=4,
num_classes=19,
in_index=[0, 1, 2, 3],
input_transform='multiple_select')
inputs = [[torch.randn(4, 768, 2, 2),
torch.randn(4, 768)] for _ in range(4)]
output = head(inputs)
assert output.shape == torch.Size((4, 19, 16, 16))
# test readout operation
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=4,
num_classes=19,
in_index=[0, 1, 2, 3],
input_transform='multiple_select',
readout_type='add')
output = head(inputs)
assert output.shape == torch.Size((4, 19, 16, 16))
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=4,
num_classes=19,
in_index=[0, 1, 2, 3],
input_transform='multiple_select',
readout_type='project')
output = head(inputs)
assert output.shape == torch.Size((4, 19, 16, 16))