43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import FastSCNN
|
|
|
|
|
|
def test_fastscnn_backbone():
|
|
with pytest.raises(AssertionError):
|
|
# Fast-SCNN channel constraints.
|
|
FastSCNN(
|
|
3, (32, 48),
|
|
64, (64, 96, 128), (2, 2, 1),
|
|
global_out_channels=127,
|
|
higher_in_channels=64,
|
|
lower_in_channels=128)
|
|
|
|
# Test FastSCNN Standard Forward
|
|
model = FastSCNN(
|
|
in_channels=3,
|
|
downsample_dw_channels=(4, 6),
|
|
global_in_channels=8,
|
|
global_block_channels=(8, 12, 16),
|
|
global_block_strides=(2, 2, 1),
|
|
global_out_channels=16,
|
|
higher_in_channels=8,
|
|
lower_in_channels=16,
|
|
fusion_out_channels=16,
|
|
)
|
|
model.init_weights()
|
|
model.train()
|
|
batch_size = 4
|
|
imgs = torch.randn(batch_size, 3, 64, 128)
|
|
feat = model(imgs)
|
|
|
|
assert len(feat) == 3
|
|
# higher-res
|
|
assert feat[0].shape == torch.Size([batch_size, 8, 8, 16])
|
|
# lower-res
|
|
assert feat[1].shape == torch.Size([batch_size, 16, 2, 4])
|
|
# FFM output
|
|
assert feat[2].shape == torch.Size([batch_size, 16, 8, 16])
|