33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmselfsup.models.necks import DenseCLNeck
|
|
|
|
|
|
def test_densecl_neck():
|
|
neck = DenseCLNeck(16, 32, 16)
|
|
assert isinstance(neck.mlp, nn.Sequential)
|
|
assert isinstance(neck.mlp2, nn.Sequential)
|
|
assert neck.mlp[0].in_features == 16
|
|
assert neck.mlp[2].in_features == 32
|
|
assert neck.mlp[2].out_features == 16
|
|
assert neck.mlp2[0].in_channels == 16
|
|
assert neck.mlp2[2].in_channels == 32
|
|
assert neck.mlp2[2].out_channels == 16
|
|
|
|
# test neck when num_grid is None
|
|
fake_in = torch.rand((32, 16, 5, 5))
|
|
fake_out = neck.forward([fake_in])
|
|
assert fake_out[0].shape == torch.Size([32, 16])
|
|
assert fake_out[1].shape == torch.Size([32, 16, 25])
|
|
assert fake_out[2].shape == torch.Size([32, 16])
|
|
|
|
# test neck when num_grid is not None
|
|
neck = DenseCLNeck(16, 32, 16, num_grid=3)
|
|
fake_in = torch.rand((32, 16, 5, 5))
|
|
fake_out = neck.forward([fake_in])
|
|
assert fake_out[0].shape == torch.Size([32, 16])
|
|
assert fake_out[1].shape == torch.Size([32, 16, 9])
|
|
assert fake_out[2].shape == torch.Size([32, 16])
|