mmselfsup/tests/test_models/test_necks/test_linear_neck.py

24 lines
702 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmselfsup.models.necks import LinearNeck
def test_linear_neck():
neck = LinearNeck(16, 32, with_avg_pool=True)
assert isinstance(neck.avgpool, nn.Module)
assert neck.fc.in_features == 16
assert neck.fc.out_features == 32
# test neck with avgpool
fake_in = torch.rand((32, 16, 5, 5))
fake_out = neck.forward([fake_in])
assert fake_out[0].shape == torch.Size([32, 32])
# test neck without avgpool
neck = LinearNeck(16, 32, with_avg_pool=False)
fake_in = torch.rand((32, 16))
fake_out = neck.forward([fake_in])
assert fake_out[0].shape == torch.Size([32, 32])