mmselfsup/tests/test_engine/test_hooks/test_deepcluster_hook.py
Yuan Liu 20488d01b4
[Refactor]: Refactor data flow (#429)
* [Refactor]: Refactor data flow

* [Fix]: Change data sample to data samples

* [Fix]: Change batch_inputs to inputs

* [Fix]: Fix lint and UT

* [Fix]: Fix UT

* [Fix]: Fix lint

* [Fix]: Fix docstring

* [Fix]: Fix UT

* [Refactor]: Add assert in data preprocessor

* [Fix]: Fix lint
2022-08-30 11:34:04 +08:00

76 lines
2.0 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import tempfile
from unittest import TestCase
import torch
from mmengine.structures import LabelData
from torch.utils.data import Dataset
from mmselfsup.engine.hooks import DeepClusterHook
from mmselfsup.structures import SelfSupDataSample
num_classes = 5
with_sobel = True,
backbone = dict(
type='ResNet',
depth=18,
in_channels=2,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='BN'))
neck = dict(type='AvgPool2dNeck')
head = dict(
type='ClsHead',
with_avg_pool=False, # already has avgpool in the neck
in_channels=512,
num_classes=num_classes)
loss = dict(type='mmcls.CrossEntropyLoss')
class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)
@property
def metainfo(self):
return self.METAINFO
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
data_sample = SelfSupDataSample()
gt_label = LabelData(value=self.label[index])
setattr(data_sample, 'gt_label', gt_label)
return dict(inputs=self.data[index], data_sample=data_sample)
class TestDeepClusterHook(TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
def tearDown(self):
self.temp_dir.cleanup()
def test_deepcluster_hook(self):
dummy_dataset = DummyDataset()
extract_dataloader = dict(
dataset=dummy_dataset,
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=1,
num_workers=0,
persistent_workers=False)
deepcluster_hook = DeepClusterHook(
extract_dataloader=extract_dataloader,
clustering=dict(type='Kmeans', k=num_classes, pca_dim=16),
unif_sampling=True,
reweight=False,
reweight_pow=0.5,
initial=True,
interval=1)
# test DeepClusterHook
assert deepcluster_hook.clustering_type == 'Kmeans'