88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import tempfile
|
|
|
|
import torch
|
|
from mmengine.registry import init_default_scope
|
|
|
|
from mmseg.registry import MODELS
|
|
|
|
init_default_scope('mmseg')
|
|
|
|
|
|
def test_pidnet_backbone():
|
|
# Test PIDNet Standard Forward
|
|
norm_cfg = dict(type='BN', requires_grad=True)
|
|
backbone_cfg = dict(
|
|
type='PIDNet',
|
|
in_channels=3,
|
|
channels=32,
|
|
ppm_channels=96,
|
|
num_stem_blocks=2,
|
|
num_branch_blocks=3,
|
|
align_corners=False,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=dict(type='ReLU', inplace=True))
|
|
model = MODELS.build(backbone_cfg)
|
|
model.init_weights()
|
|
|
|
# Test init weights
|
|
temp_file = tempfile.NamedTemporaryFile()
|
|
temp_file.close()
|
|
torch.save(model.state_dict(), temp_file.name)
|
|
backbone_cfg.update(
|
|
init_cfg=dict(type='Pretrained', checkpoint=temp_file.name))
|
|
model = MODELS.build(backbone_cfg)
|
|
model.init_weights()
|
|
os.remove(temp_file.name)
|
|
|
|
# Test eval mode
|
|
model.eval()
|
|
batch_size = 1
|
|
imgs = torch.randn(batch_size, 3, 64, 128)
|
|
feats = model(imgs)
|
|
|
|
assert type(feats) == torch.Tensor
|
|
assert feats.shape == torch.Size([batch_size, 128, 8, 16])
|
|
|
|
# Test train mode
|
|
model.train()
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 64, 128)
|
|
feats = model(imgs)
|
|
|
|
assert len(feats) == 3
|
|
# test output for P branch
|
|
assert feats[0].shape == torch.Size([batch_size, 64, 8, 16])
|
|
# test output for I branch
|
|
assert feats[1].shape == torch.Size([batch_size, 128, 8, 16])
|
|
# test output for D branch
|
|
assert feats[2].shape == torch.Size([batch_size, 64, 8, 16])
|
|
|
|
# Test pidnet-m
|
|
backbone_cfg.update(channels=64)
|
|
model = MODELS.build(backbone_cfg)
|
|
feats = model(imgs)
|
|
|
|
assert len(feats) == 3
|
|
# test output for P branch
|
|
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
|
|
# test output for I branch
|
|
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
|
|
# test output for D branch
|
|
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])
|
|
|
|
# Test pidnet-l
|
|
backbone_cfg.update(
|
|
channels=64, ppm_channesl=112, num_stem_blocks=3, num_branch_blocks=4)
|
|
model = MODELS.build(backbone_cfg)
|
|
feats = model(imgs)
|
|
|
|
assert len(feats) == 3
|
|
# test output for P branch
|
|
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
|
|
# test output for I branch
|
|
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
|
|
# test output for D branch
|
|
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])
|