Yuan Liu 20488d01b4
[Refactor]: Refactor data flow (#429)
* [Refactor]: Refactor data flow

* [Fix]: Change data sample to data samples

* [Fix]: Change batch_inputs to inputs

* [Fix]: Fix lint and UT

* [Fix]: Fix UT

* [Fix]: Fix lint

* [Fix]: Fix docstring

* [Fix]: Fix UT

* [Refactor]: Add assert in data preprocessor

* [Fix]: Fix lint
2022-08-30 11:34:04 +08:00

101 lines
2.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
from unittest.mock import MagicMock
import pytest
import torch
import mmselfsup
from mmselfsup.models.algorithms.densecl import DenseCL
from mmselfsup.structures import SelfSupDataSample
from mmselfsup.utils import register_all_modules
register_all_modules()
queue_len = 32
feat_dim = 2
momentum = 0.999
loss_lambda = 0.5
backbone = dict(
type='ResNet',
depth=18,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(
type='DenseCLNeck',
in_channels=512,
hid_channels=2,
out_channels=2,
num_grid=None)
head = dict(
type='ContrastiveHead',
loss=dict(type='mmcls.CrossEntropyLoss'),
temperature=0.2)
def mock_batch_shuffle_ddp(img):
return img, 0
def mock_batch_unshuffle_ddp(img, mock_input):
return img
def mock_concat_all_gather(img):
return img
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_densecl():
data_preprocessor = {
'mean': (123.675, 116.28, 103.53),
'std': (58.395, 57.12, 57.375),
'bgr_to_rgb': True
}
alg = DenseCL(
backbone=backbone,
neck=neck,
head=head,
queue_len=queue_len,
feat_dim=feat_dim,
momentum=momentum,
loss_lambda=loss_lambda,
data_preprocessor=copy.deepcopy(data_preprocessor))
assert alg.queue.size() == torch.Size([feat_dim, queue_len])
assert alg.queue2.size() == torch.Size([feat_dim, queue_len])
mmselfsup.models.algorithms.densecl.batch_shuffle_ddp = MagicMock(
side_effect=mock_batch_shuffle_ddp)
mmselfsup.models.algorithms.densecl.batch_unshuffle_ddp = MagicMock(
side_effect=mock_batch_unshuffle_ddp)
mmselfsup.models.algorithms.densecl.concat_all_gather = MagicMock(
side_effect=mock_concat_all_gather)
fake_data = {
'inputs':
[torch.randn((2, 3, 224, 224)),
torch.randn((2, 3, 224, 224))],
'data_sample': [SelfSupDataSample() for _ in range(2)]
}
fake_inputs, fake_data_samples = alg.data_preprocessor(fake_data)
fake_loss = alg(fake_inputs, fake_data_samples, mode='loss')
assert isinstance(fake_loss['loss_single'].item(), float)
assert isinstance(fake_loss['loss_dense'].item(), float)
assert fake_loss['loss_single'].item() > 0
assert fake_loss['loss_dense'].item() > 0
assert alg.queue_ptr.item() == 2
assert alg.queue2_ptr.item() == 2
fake_feat = alg(fake_inputs, fake_data_samples, mode='tensor')
assert list(fake_feat[0].shape) == [2, 512, 7, 7]
fake_outputs = alg(fake_inputs, fake_data_samples, mode='predict')
assert 'q_grid' in fake_outputs
assert 'value' in fake_outputs.q_grid
assert list(fake_outputs.q_grid.value.shape) == [2, 512, 49]