谢昕辰 ef4b30038f [Feature] add DPT head (#605)
* add DPT head

* [fix] fix init error

* use mmcv function

* delete code

* remove transpose clas

* support NLC output shape

* Delete post_process_layer.py

* add unittest and docstring

* rename variables

* fix project error and add unittest

* match dpt weights

* add configs

* fix vit pos_embed bug and dpt feature fusion bug

* match vit output

* fix gelu

* minor change

* update unitest

* fix configs error

* inference test

* remove auxilary

* use local pretrain

* update training results

* update yml

* update fps and memory test

* update doc

* update readme

* add yml

* update doc

* remove with_cp

* update config

* update docstring

* remove dpt-l

* add init_cfg and modify readme.md

* Update dpt_vit-b16.py

* zh-n README

* use constructor instead of build function

* prevent tensor being modified by ConvModule

* fix unittest

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
2021-08-30 16:53:05 +08:00

49 lines
1.3 KiB
Python

import pytest
import torch
from mmseg.models.decode_heads import DPTHead
def test_dpt_head():
with pytest.raises(AssertionError):
# input_transform must be 'multiple_select'
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=256,
num_classes=19,
in_index=[0, 1, 2, 3])
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=256,
num_classes=19,
in_index=[0, 1, 2, 3],
input_transform='multiple_select')
inputs = [[torch.randn(4, 768, 2, 2),
torch.randn(4, 768)] for _ in range(4)]
output = head(inputs)
assert output.shape == torch.Size((4, 19, 16, 16))
# test readout operation
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=256,
num_classes=19,
in_index=[0, 1, 2, 3],
input_transform='multiple_select',
readout_type='add')
output = head(inputs)
assert output.shape == torch.Size((4, 19, 16, 16))
head = DPTHead(
in_channels=[768, 768, 768, 768],
channels=256,
num_classes=19,
in_index=[0, 1, 2, 3],
input_transform='multiple_select',
readout_type='project')
output = head(inputs)
assert output.shape == torch.Size((4, 19, 16, 16))