101 lines
2.8 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
from unittest.mock import MagicMock
import pytest
import torch
import mmselfsup
from mmselfsup.models.algorithms.densecl import DenseCL
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import register_all_modules
register_all_modules()
queue_len = 32
feat_dim = 2
momentum = 0.999
loss_lambda = 0.5
backbone = dict(
type='ResNet',
depth=18,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(
type='DenseCLNeck',
in_channels=512,
hid_channels=2,
out_channels=2,
num_grid=None)
2022-07-07 12:37:16 +00:00
head = dict(
type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.2)
def mock_batch_shuffle_ddp(img):
return img, 0
def mock_batch_unshuffle_ddp(img, mock_input):
return img
def mock_concat_all_gather(img):
return img
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_densecl():
2022-07-07 12:37:16 +00:00
data_preprocessor = {
'mean': (123.675, 116.28, 103.53),
'std': (58.395, 57.12, 57.375),
'bgr_to_rgb': True
}
alg = DenseCL(
backbone=backbone,
neck=neck,
head=head,
queue_len=queue_len,
feat_dim=feat_dim,
momentum=momentum,
loss_lambda=loss_lambda,
2022-07-07 12:37:16 +00:00
data_preprocessor=copy.deepcopy(data_preprocessor))
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])
mmselfsup.models.algorithms.densecl.batch_shuffle_ddp = MagicMock(
side_effect=mock_batch_shuffle_ddp)
mmselfsup.models.algorithms.densecl.batch_unshuffle_ddp = MagicMock(
side_effect=mock_batch_unshuffle_ddp)
mmselfsup.models.algorithms.densecl.concat_all_gather = MagicMock(
side_effect=mock_concat_all_gather)
fake_data = {
'inputs':
[torch.randn((2, 3, 224, 224)),
torch.randn((2, 3, 224, 224))],
'data_sample': [SelfSupDataSample() for _ in range(2)]
}
2022-07-07 12:37:16 +00:00
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_loss['loss_single'].item(), float)
assert isinstance(fake_loss['loss_dense'].item(), float)
assert fake_loss['loss_single'].item() > 0
assert fake_loss['loss_dense'].item() > 0
assert alg.queue_ptr.item() == 2
assert alg.queue2_ptr.item() == 2
2022-07-07 12:37:16 +00:00
fake_feat = alg(fake_inputs, fake_data_samples, mode='tensor')
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
2022-07-07 12:37:16 +00:00
fake_outputs = alg(fake_inputs, fake_data_samples, mode='predict')
assert 'q_grid' in fake_outputs
assert 'value' in fake_outputs.q_grid
assert list(fake_outputs.q_grid.value.shape) == [2, 512, 49]