mmsegmentation/tests/test_models/test_backbones/test_pidnet.py

88 lines
2.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import torch
from mmengine.registry import init_default_scope
from mmseg.registry import MODELS
init_default_scope('mmseg')
def test_pidnet_backbone():
# Test PIDNet Standard Forward
norm_cfg = dict(type='BN', requires_grad=True)
backbone_cfg = dict(
type='PIDNet',
in_channels=3,
channels=32,
ppm_channels=96,
num_stem_blocks=2,
num_branch_blocks=3,
align_corners=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True))
model = MODELS.build(backbone_cfg)
model.init_weights()
# Test init weights
temp_file = tempfile.NamedTemporaryFile()
temp_file.close()
torch.save(model.state_dict(), temp_file.name)
backbone_cfg.update(
init_cfg=dict(type='Pretrained', checkpoint=temp_file.name))
model = MODELS.build(backbone_cfg)
model.init_weights()
os.remove(temp_file.name)
# Test eval mode
model.eval()
batch_size = 1
imgs = torch.randn(batch_size, 3, 64, 128)
feats = model(imgs)
assert type(feats) == torch.Tensor
assert feats.shape == torch.Size([batch_size, 128, 8, 16])
# Test train mode
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feats = model(imgs)
assert len(feats) == 3
# test output for P branch
assert feats[0].shape == torch.Size([batch_size, 64, 8, 16])
# test output for I branch
assert feats[1].shape == torch.Size([batch_size, 128, 8, 16])
# test output for D branch
assert feats[2].shape == torch.Size([batch_size, 64, 8, 16])
# Test pidnet-m
backbone_cfg.update(channels=64)
model = MODELS.build(backbone_cfg)
feats = model(imgs)
assert len(feats) == 3
# test output for P branch
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
# test output for I branch
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
# test output for D branch
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])
# Test pidnet-l
backbone_cfg.update(
channels=64, ppm_channesl=112, num_stem_blocks=3, num_branch_blocks=4)
model = MODELS.build(backbone_cfg)
feats = model(imgs)
assert len(feats) == 3
# test output for P branch
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
# test output for I branch
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
# test output for D branch
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])