mmsegmentation/tests/test_models/test_heads/test_pidnet_head.py

90 lines
2.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.registry import init_default_scope
from mmseg.registry import MODELS
def test_pidnet_head():
init_default_scope('mmseg')
# Test PIDNet decode head Standard Forward
norm_cfg = dict(type='BN', requires_grad=True)
backbone_cfg = dict(
type='PIDNet',
in_channels=3,
channels=32,
ppm_channels=96,
num_stem_blocks=2,
num_branch_blocks=3,
align_corners=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True))
decode_head_cfg = dict(
type='PIDHead',
in_channels=128,
channels=128,
num_classes=19,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True),
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=[
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
],
loss_weight=0.4),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=[
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
],
loss_weight=1.0),
dict(type='BoundaryLoss', loss_weight=20.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=[
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
],
loss_weight=1.0)
])
backbone = MODELS.build(backbone_cfg)
head = MODELS.build(decode_head_cfg)
# Test train mode
backbone.train()
head.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feats = backbone(imgs)
seg_logit = head(feats)
assert isinstance(seg_logit, tuple)
assert len(seg_logit) == 3
p_logits, i_logits, d_logits = seg_logit
assert p_logits.shape == (batch_size, 19, 8, 16)
assert i_logits.shape == (batch_size, 19, 8, 16)
assert d_logits.shape == (batch_size, 1, 8, 16)
# Test eval mode
backbone.eval()
head.eval()
feats = backbone(imgs)
seg_logit = head(feats)
assert isinstance(seg_logit, torch.Tensor)
assert seg_logit.shape == (batch_size, 19, 8, 16)