mmsegmentation/tests/test_models/test_heads/test_isa_head.py

21 lines
525 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.decode_heads import ISAHead
from .utils import to_cuda
def test_isa_head():
inputs = [torch.randn(1, 32, 45, 45)]
isa_head = ISAHead(
in_channels=32,
channels=16,
num_classes=19,
isa_channels=16,
down_factor=(8, 8))
if torch.cuda.is_available():
isa_head, inputs = to_cuda(isa_head, inputs)
output = isa_head(inputs)
assert output.shape == (1, isa_head.num_classes, 45, 45)