mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmselfsup.models.utils import MultiPooling
|
|
|
|
|
|
def test_multi_pooling():
|
|
# adaptive
|
|
layer = MultiPooling(pool_type='adaptive', in_indices=(0, 1, 2))
|
|
fake_in = [
|
|
torch.rand((1, 32, 112, 112)),
|
|
torch.rand((1, 64, 56, 56)),
|
|
torch.rand((1, 128, 28, 28)),
|
|
]
|
|
res = layer.forward(fake_in)
|
|
assert res[0].shape == (1, 32, 12, 12)
|
|
assert res[1].shape == (1, 64, 6, 6)
|
|
assert res[2].shape == (1, 128, 4, 4)
|
|
|
|
# specified
|
|
layer = MultiPooling(pool_type='specified', in_indices=(0, 1, 2))
|
|
fake_in = [
|
|
torch.rand((1, 32, 112, 112)),
|
|
torch.rand((1, 64, 56, 56)),
|
|
torch.rand((1, 128, 28, 28)),
|
|
]
|
|
res = layer.forward(fake_in)
|
|
assert res[0].shape == (1, 32, 12, 12)
|
|
assert res[1].shape == (1, 64, 6, 6)
|
|
assert res[2].shape == (1, 128, 4, 4)
|
|
|
|
with pytest.raises(AssertionError):
|
|
layer = MultiPooling(pool_type='other')
|
|
|
|
with pytest.raises(AssertionError):
|
|
layer = MultiPooling(backbone='resnet101')
|