MengzhangLI 5783bc1d99
[Feature] Support STDC Network (new) (#995)
* refactor stdc code

* update key

* fix backbone inference

* remove comments

* fixing errors

* fixing version conflict

* fux typo

* use STDCHead

* upload models&logs

* adding model converters script and fix unittest

* fix error

* fix error

* fix error

* delete redundant keys in config

* fix errors in configs and unittest

* fix errors in configs and unittest

* fix errors in configs and unittest

* change Memory name

* refactor stdc2mmseg

* change name to STDC

* refactor stdc

* refactor stdc

* stdc refactor

* stdc refactor

* stdc refactor

* stdc refactor

* stdc refactor

* stdc refactor

* refactor stdc

* stdc refactor

Co-authored-by: xiexinch <xinchen.xie@qq.com>
2021-12-10 23:09:32 +08:00

32 lines
1.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import STDCHead
from .utils import to_cuda
def test_stdc_head():
inputs = [torch.randn(1, 32, 21, 21)]
head = STDCHead(
in_channels=32,
channels=8,
num_convs=1,
num_classes=2,
in_index=-1,
loss_decode=[
dict(
type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
])
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert isinstance(outputs, torch.Tensor) and len(outputs) == 1
assert outputs.shape == torch.Size([1, head.num_classes, 21, 21])
fake_label = torch.ones_like(
outputs[:, 0:1, :, :], dtype=torch.int16).long()
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
assert loss['loss_ce'] != torch.zeros_like(loss['loss_ce'])
assert loss['loss_dice'] != torch.zeros_like(loss['loss_dice'])