mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from mmseg.models.decode_heads import DPTHead
|
||
|
|
||
|
|
||
|
def test_dpt_head():
|
||
|
|
||
|
with pytest.raises(AssertionError):
|
||
|
# input_transform must be 'multiple_select'
|
||
|
head = DPTHead(
|
||
|
in_channels=[768, 768, 768, 768],
|
||
|
channels=256,
|
||
|
num_classes=19,
|
||
|
in_index=[0, 1, 2, 3])
|
||
|
|
||
|
head = DPTHead(
|
||
|
in_channels=[768, 768, 768, 768],
|
||
|
channels=256,
|
||
|
num_classes=19,
|
||
|
in_index=[0, 1, 2, 3],
|
||
|
input_transform='multiple_select')
|
||
|
|
||
|
inputs = [[torch.randn(4, 768, 2, 2),
|
||
|
torch.randn(4, 768)] for _ in range(4)]
|
||
|
output = head(inputs)
|
||
|
assert output.shape == torch.Size((4, 19, 16, 16))
|
||
|
|
||
|
# test readout operation
|
||
|
head = DPTHead(
|
||
|
in_channels=[768, 768, 768, 768],
|
||
|
channels=256,
|
||
|
num_classes=19,
|
||
|
in_index=[0, 1, 2, 3],
|
||
|
input_transform='multiple_select',
|
||
|
readout_type='add')
|
||
|
output = head(inputs)
|
||
|
assert output.shape == torch.Size((4, 19, 16, 16))
|
||
|
|
||
|
head = DPTHead(
|
||
|
in_channels=[768, 768, 768, 768],
|
||
|
channels=256,
|
||
|
num_classes=19,
|
||
|
in_index=[0, 1, 2, 3],
|
||
|
input_transform='multiple_select',
|
||
|
readout_type='project')
|
||
|
output = head(inputs)
|
||
|
assert output.shape == torch.Size((4, 19, 16, 16))
|