17 lines
539 B
Python
17 lines
539 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmselfsup.models.necks import SwAVNeck
|
|
|
|
|
|
def test_swav_neck():
|
|
neck = SwAVNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
|
assert isinstance(neck.projection_neck, (nn.Module, nn.Sequential))
|
|
|
|
# test neck with avgpool
|
|
fake_in = [[torch.rand((32, 16, 5, 5))], [torch.rand((32, 16, 5, 5))],
|
|
[torch.rand((32, 16, 3, 3))]]
|
|
fake_out = neck.forward(fake_in)
|
|
assert fake_out[0].shape == torch.Size([32 * len(fake_in), 16])
|