32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
from mmseg.models.decode_heads import STDCHead
|
|
from .utils import to_cuda
|
|
|
|
|
|
def test_stdc_head():
|
|
inputs = [torch.randn(1, 32, 21, 21)]
|
|
head = STDCHead(
|
|
in_channels=32,
|
|
channels=8,
|
|
num_convs=1,
|
|
num_classes=2,
|
|
in_index=-1,
|
|
loss_decode=[
|
|
dict(
|
|
type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
|
|
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
|
|
])
|
|
if torch.cuda.is_available():
|
|
head, inputs = to_cuda(head, inputs)
|
|
outputs = head(inputs)
|
|
assert isinstance(outputs, torch.Tensor) and len(outputs) == 1
|
|
assert outputs.shape == torch.Size([1, head.num_classes, 21, 21])
|
|
|
|
fake_label = torch.ones_like(
|
|
outputs[:, 0:1, :, :], dtype=torch.int16).long()
|
|
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
|
|
assert loss['loss_ce'] != torch.zeros_like(loss['loss_ce'])
|
|
assert loss['loss_dice'] != torch.zeros_like(loss['loss_dice'])
|