mmsegmentation/tests/test_models/test_heads/test_ham_head.py

45 lines
1.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import LightHamHead
from .utils import _conv_has_norm, to_cuda
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
def test_ham_head():
# test without sync_bn
head = LightHamHead(
in_channels=[16, 32, 64],
in_index=[1, 2, 3],
channels=64,
ham_channels=64,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=64,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True))
assert not _conv_has_norm(head, sync_bn=False)
inputs = [
torch.randn(1, 8, 32, 32),
torch.randn(1, 16, 16, 16),
torch.randn(1, 32, 8, 8),
torch.randn(1, 64, 4, 4)
]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
assert head.in_channels == [16, 32, 64]
assert head.hamburger.ham_in.in_channels == 64
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 16, 16)