[Refactor] clean ut
parent
40b58076f7
commit
b414011530
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.backbones import MAEViT
|
||||
|
||||
backbone = dict(arch='b', patch_size=16, mask_ratio=0.75)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mae_pretrain_vit():
|
||||
mae_pretrain_backbone = MAEViT(**backbone)
|
||||
mae_pretrain_backbone.init_weights()
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
fake_outputs = mae_pretrain_backbone(fake_inputs)[0]
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 50, 768]
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.backbones import MIMVisionTransformer
|
||||
|
||||
finetune_backbone = dict(
|
||||
arch='b', patch_size=16, drop_path_rate=0.1, final_norm=False)
|
||||
|
||||
finetune_backbone_norm = dict(
|
||||
arch='b', patch_size=16, drop_path_rate=0.1, final_norm=True)
|
||||
|
||||
linprobe_backbone = dict(
|
||||
arch='b', patch_size=16, finetune=False, final_norm=False)
|
||||
|
||||
linprobe_backbone_use_window = dict(
|
||||
arch='b', patch_size=16, finetune=False, final_norm=False, use_window=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_mae_cls_vit():
|
||||
mae_finetune_backbone = MIMVisionTransformer(**finetune_backbone)
|
||||
mae_finetune_backbone_norm = MIMVisionTransformer(**finetune_backbone_norm)
|
||||
mae_linprobe_backbone = MIMVisionTransformer(**linprobe_backbone)
|
||||
mae_linprobe_backbone_use_window = MIMVisionTransformer(
|
||||
**linprobe_backbone_use_window)
|
||||
mae_linprobe_backbone.train()
|
||||
|
||||
fake_inputs = torch.randn((2, 3, 224, 224))
|
||||
fake_finetune_outputs = mae_finetune_backbone(fake_inputs)
|
||||
fake_finetune_outputs_norm = mae_finetune_backbone_norm(fake_inputs)
|
||||
fake_linprobe_outputs = mae_linprobe_backbone(fake_inputs)
|
||||
fake_linprobe_outputs_use_window = mae_linprobe_backbone_use_window(
|
||||
fake_inputs)
|
||||
assert list(fake_finetune_outputs.shape) == [2, 768]
|
||||
assert list(fake_linprobe_outputs.shape) == [2, 768]
|
||||
assert list(fake_finetune_outputs_norm.shape) == [2, 768]
|
||||
assert list(fake_linprobe_outputs_use_window.shape) == [2, 768]
|
|
@ -1,139 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmselfsup.models.backbones import ResNet
|
||||
from mmselfsup.models.backbones.resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNet building block."""
|
||||
if isinstance(modules, (BasicBlock, Bottleneck)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def all_zeros(modules):
|
||||
"""Check if the weight(and bias) is all zero."""
|
||||
weight_zero = torch.equal(modules.weight.data,
|
||||
torch.zeros_like(modules.weight.data))
|
||||
if hasattr(modules, 'bias'):
|
||||
bias_zero = torch.equal(modules.bias.data,
|
||||
torch.zeros_like(modules.bias.data))
|
||||
else:
|
||||
bias_zero = True
|
||||
|
||||
return weight_zero and bias_zero
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_resnet():
|
||||
"""Test resnet backbone."""
|
||||
# Test ResNet50 norm_eval=True
|
||||
model = ResNet(50, norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet50 with torchvision pretrained weight
|
||||
model = ResNet(depth=50, norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ResNet50 with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ResNet(50, frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert model.norm1.training is False
|
||||
for layer in [model.conv1, model.norm1]:
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(1, frozen_stages + 1):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ResNet18 forward
|
||||
model = ResNet(18, out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 64, 56, 56)
|
||||
assert feat[2].shape == (1, 128, 28, 28)
|
||||
assert feat[3].shape == (1, 256, 14, 14)
|
||||
assert feat[4].shape == (1, 512, 7, 7)
|
||||
|
||||
# Test ResNet50 with BatchNorm forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3, 4))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 256, 56, 56)
|
||||
assert feat[2].shape == (1, 512, 28, 28)
|
||||
assert feat[3].shape == (1, 1024, 14, 14)
|
||||
assert feat[4].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with layers 3 (top feature maps) out forward
|
||||
model = ResNet(50, out_indices=(4, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[0].shape == (1, 2048, 7, 7)
|
||||
|
||||
# Test ResNet50 with checkpoint forward
|
||||
model = ResNet(50, out_indices=(0, 1, 2, 3, 4), with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == (1, 64, 112, 112)
|
||||
assert feat[1].shape == (1, 256, 56, 56)
|
||||
assert feat[2].shape == (1, 512, 28, 28)
|
||||
assert feat[3].shape == (1, 1024, 14, 14)
|
||||
assert feat[4].shape == (1, 2048, 7, 7)
|
||||
|
||||
# zero initialization of residual blocks
|
||||
model = ResNet(50, zero_init_residual=True)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert all_zeros(m.norm2)
|
||||
|
||||
# non-zero initialization of residual blocks
|
||||
model = ResNet(50, zero_init_residual=False)
|
||||
model.init_weights()
|
||||
for m in model.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
assert not all_zeros(m.norm3)
|
||||
elif isinstance(m, BasicBlock):
|
||||
assert not all_zeros(m.norm2)
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.backbones import ResNeXt
|
||||
from mmselfsup.models.backbones.resnext import Bottleneck as BottleneckX
|
||||
|
||||
|
||||
def test_resnext():
|
||||
with pytest.raises(KeyError):
|
||||
# ResNeXt depth should be in [50, 101, 152]
|
||||
ResNeXt(depth=18)
|
||||
|
||||
# Test ResNeXt with group 32, width_per_group 4
|
||||
model = ResNeXt(
|
||||
depth=50, groups=32, width_per_group=4, out_indices=(0, 1, 2, 3, 4))
|
||||
for m in model.modules():
|
||||
if isinstance(m, BottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
assert feat[0].shape == torch.Size([1, 64, 112, 112])
|
||||
assert feat[1].shape == torch.Size([1, 256, 56, 56])
|
||||
assert feat[2].shape == torch.Size([1, 512, 28, 28])
|
||||
assert feat[3].shape == torch.Size([1, 1024, 14, 14])
|
||||
assert feat[4].shape == torch.Size([1, 2048, 7, 7])
|
||||
|
||||
# Test ResNeXt with group 32, width_per_group 4 and layers 3 out forward
|
||||
model = ResNeXt(depth=50, groups=32, width_per_group=4, out_indices=(4, ))
|
||||
for m in model.modules():
|
||||
if isinstance(m, BottleneckX):
|
||||
assert m.conv2.groups == 32
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 1
|
||||
assert feat[0].shape == torch.Size([1, 2048, 7, 7])
|
|
@ -1,11 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmselfsup.models.backbones import VisionTransformer
|
||||
|
||||
|
||||
def test_vision_transformer():
|
||||
vit = VisionTransformer(
|
||||
arch='mocov3-small', patch_size=16, frozen_stages=12, norm_eval=True)
|
||||
vit.train()
|
||||
|
||||
for p in vit.parameters():
|
||||
assert p.requires_grad is False
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import AvgPool2dNeck
|
||||
|
||||
|
||||
def test_avgpool2d_neck():
|
||||
fake_in = [torch.randn((2, 3, 8, 8))]
|
||||
|
||||
# test default
|
||||
neck = AvgPool2dNeck()
|
||||
fake_out = neck(fake_in)
|
||||
assert fake_out[0].shape == (2, 3, 1, 1)
|
||||
|
||||
# test custom
|
||||
neck = AvgPool2dNeck(2)
|
||||
fake_out = neck(fake_in)
|
||||
assert fake_out[0].shape == (2, 3, 2, 2)
|
||||
|
||||
# test custom
|
||||
neck = AvgPool2dNeck((1, 2))
|
||||
fake_out = neck(fake_in)
|
||||
assert fake_out[0].shape == (2, 3, 1, 2)
|
|
@ -1,32 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import DenseCLNeck
|
||||
|
||||
|
||||
def test_densecl_neck():
|
||||
neck = DenseCLNeck(16, 32, 16)
|
||||
assert isinstance(neck.mlp, nn.Sequential)
|
||||
assert isinstance(neck.mlp2, nn.Sequential)
|
||||
assert neck.mlp[0].in_features == 16
|
||||
assert neck.mlp[2].in_features == 32
|
||||
assert neck.mlp[2].out_features == 16
|
||||
assert neck.mlp2[0].in_channels == 16
|
||||
assert neck.mlp2[2].in_channels == 32
|
||||
assert neck.mlp2[2].out_channels == 16
|
||||
|
||||
# test neck when num_grid is None
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
assert fake_out[1].shape == torch.Size([32, 16, 25])
|
||||
assert fake_out[2].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck when num_grid is not None
|
||||
neck = DenseCLNeck(16, 32, 16, num_grid=3)
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
assert fake_out[1].shape == torch.Size([32, 16, 9])
|
||||
assert fake_out[2].shape == torch.Size([32, 16])
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import LinearNeck
|
||||
|
||||
|
||||
def test_linear_neck():
|
||||
neck = LinearNeck(16, 32, with_avg_pool=True)
|
||||
assert isinstance(neck.avgpool, nn.Module)
|
||||
assert neck.fc.in_features == 16
|
||||
assert neck.fc.out_features == 32
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = LinearNeck(16, 32, with_avg_pool=False)
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import MAEPretrainDecoder
|
||||
|
||||
|
||||
def test_linear_neck():
|
||||
decoder = MAEPretrainDecoder()
|
||||
decoder.init_weights()
|
||||
decoder.eval()
|
||||
inputs = torch.rand(1, 50, 1024)
|
||||
ids_restore = torch.arange(0, 196).unsqueeze(0)
|
||||
level_outputs = decoder.forward(inputs, ids_restore)
|
||||
assert tuple(level_outputs.shape) == (1, 196, 768)
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import MoCoV2Neck
|
||||
|
||||
|
||||
def test_mocov2_neck():
|
||||
neck = MoCoV2Neck(16, 32, 16)
|
||||
assert isinstance(neck.mlp, nn.Sequential)
|
||||
assert neck.mlp[0].in_features == 16
|
||||
assert neck.mlp[2].in_features == 32
|
||||
assert neck.mlp[2].out_features == 16
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = MoCoV2Neck(16, 32, 16, with_avg_pool=False)
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import NonLinearNeck
|
||||
|
||||
|
||||
def test_nonlinear_neck():
|
||||
# test neck arch
|
||||
neck = NonLinearNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
||||
assert neck.fc0.in_features == 16
|
||||
assert neck.fc0.out_features == 32
|
||||
assert neck.bn0.num_features == 32
|
||||
fc = getattr(neck, neck.fc_names[-1])
|
||||
assert fc.out_features == 16
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = NonLinearNeck(
|
||||
16, 32, 16, with_avg_pool=False, norm_cfg=dict(type='BN1d'))
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck with vit_backbone
|
||||
neck = NonLinearNeck(
|
||||
in_channels=16,
|
||||
hid_channels=32,
|
||||
out_channels=16,
|
||||
with_avg_pool=False,
|
||||
norm_cfg=dict(type='BN1d'),
|
||||
vit_backbone=True)
|
||||
fake_cls_token = torch.rand((32, 16))
|
||||
fake_patch_token = torch.rand((32, 16, 14, 14))
|
||||
fake_in = [fake_patch_token, fake_cls_token]
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
|
@ -1,24 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import ODCNeck
|
||||
|
||||
|
||||
def test_odc_neck():
|
||||
neck = ODCNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
||||
assert neck.fc0.in_features == 16
|
||||
assert neck.fc0.out_features == 32
|
||||
assert neck.bn0.num_features == 32
|
||||
assert neck.fc1.in_features == 32
|
||||
assert neck.fc1.out_features == 16
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 16, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = ODCNeck(16, 32, 16, with_avg_pool=False, norm_cfg=dict(type='BN1d'))
|
||||
fake_in = torch.rand((32, 16))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 16])
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.necks import RelativeLocNeck
|
||||
|
||||
|
||||
def test_relative_loc_neck():
|
||||
neck = RelativeLocNeck(16, 32)
|
||||
assert neck.fc.in_features == 32
|
||||
assert neck.fc.out_features == 32
|
||||
assert neck.bn.num_features == 32
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = torch.rand((32, 32, 5, 5))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
||||
|
||||
# test neck without avgpool
|
||||
neck = RelativeLocNeck(16, 32, with_avg_pool=False)
|
||||
fake_in = torch.rand((32, 32))
|
||||
fake_out = neck.forward([fake_in])
|
||||
assert fake_out[0].shape == torch.Size([32, 32])
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.necks import SwAVNeck
|
||||
|
||||
|
||||
def test_swav_neck():
|
||||
neck = SwAVNeck(16, 32, 16, norm_cfg=dict(type='BN1d'))
|
||||
assert isinstance(neck.projection_neck, (nn.Module, nn.Sequential))
|
||||
|
||||
# test neck with avgpool
|
||||
fake_in = [[torch.rand((32, 16, 5, 5))], [torch.rand((32, 16, 5, 5))],
|
||||
[torch.rand((32, 16, 3, 3))]]
|
||||
fake_out = neck.forward(fake_in)
|
||||
assert fake_out[0].shape == torch.Size([32 * len(fake_in), 16])
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import Encoder
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_dalle():
|
||||
model = Encoder()
|
||||
fake_inputs = torch.rand((2, 3, 112, 112))
|
||||
fake_outputs = model(fake_inputs)
|
||||
|
||||
assert list(fake_outputs.shape) == [2, 8192, 14, 14]
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import knn_classifier
|
||||
|
||||
|
||||
def test_knn_classifier():
|
||||
train_feats = torch.ones(200, 3)
|
||||
train_labels = torch.ones(200).long()
|
||||
test_feats = torch.ones(200, 3)
|
||||
test_labels = torch.ones(200).long()
|
||||
num_knn = [10, 20, 100, 200]
|
||||
for k in num_knn:
|
||||
top1, top5 = knn_classifier(train_feats, train_labels, test_feats,
|
||||
test_labels, k, 0.07)
|
||||
assert top1 == 100.
|
||||
assert top5 == 100.
|
|
@ -1,37 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import MultiPooling
|
||||
|
||||
|
||||
def test_multi_pooling():
|
||||
# adaptive
|
||||
layer = MultiPooling(pool_type='adaptive', in_indices=(0, 1, 2))
|
||||
fake_in = [
|
||||
torch.rand((1, 32, 112, 112)),
|
||||
torch.rand((1, 64, 56, 56)),
|
||||
torch.rand((1, 128, 28, 28)),
|
||||
]
|
||||
res = layer.forward(fake_in)
|
||||
assert res[0].shape == (1, 32, 12, 12)
|
||||
assert res[1].shape == (1, 64, 6, 6)
|
||||
assert res[2].shape == (1, 128, 4, 4)
|
||||
|
||||
# specified
|
||||
layer = MultiPooling(pool_type='specified', in_indices=(0, 1, 2))
|
||||
fake_in = [
|
||||
torch.rand((1, 32, 112, 112)),
|
||||
torch.rand((1, 64, 56, 56)),
|
||||
torch.rand((1, 128, 28, 28)),
|
||||
]
|
||||
res = layer.forward(fake_in)
|
||||
assert res[0].shape == (1, 32, 12, 12)
|
||||
assert res[1].shape == (1, 64, 6, 6)
|
||||
assert res[2].shape == (1, 128, 4, 4)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer = MultiPooling(pool_type='other')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
layer = MultiPooling(backbone='resnet101')
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmselfsup.models.utils import MultiPrototypes
|
||||
|
||||
|
||||
def test_multi_prototypes():
|
||||
with pytest.raises(AssertionError):
|
||||
layer = MultiPrototypes(output_dim=16, num_prototypes=2)
|
||||
|
||||
layer = MultiPrototypes(output_dim=16, num_prototypes=[3, 4, 5])
|
||||
assert isinstance(getattr(layer, 'prototypes0'), nn.Module)
|
||||
assert isinstance(getattr(layer, 'prototypes1'), nn.Module)
|
||||
assert isinstance(getattr(layer, 'prototypes2'), nn.Module)
|
||||
|
||||
fake_in = torch.rand((32, 16))
|
||||
res = layer.forward(fake_in)
|
||||
assert len(res) == 3
|
||||
assert res[0].shape == (32, 3)
|
||||
assert res[1].shape == (32, 4)
|
||||
assert res[2].shape == (32, 5)
|
|
@ -1,14 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmselfsup.models.utils import Sobel
|
||||
|
||||
|
||||
def test_sobel():
|
||||
sobel_layer = Sobel()
|
||||
fake_input = torch.rand((1, 3, 224, 224))
|
||||
fake_res = sobel_layer(fake_input)
|
||||
assert fake_res.shape == (1, 2, 224, 224)
|
||||
|
||||
for p in sobel_layer.sobel.parameters():
|
||||
assert p.requires_grad is False
|
Loading…
Reference in New Issue