2022-06-06 09:29:01 +00:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import os.path as osp
|
|
|
|
|
|
|
|
import pytest
|
2023-03-14 14:34:27 +08:00
|
|
|
from mmengine.registry import init_default_scope
|
2022-06-06 09:29:01 +00:00
|
|
|
|
|
|
|
from mmselfsup.datasets import DeepClusterImageNet
|
|
|
|
|
|
|
|
# dataset settings
|
|
|
|
train_pipeline = [
|
2023-03-14 14:09:22 +08:00
|
|
|
dict(type='LoadImageFromFile'),
|
2022-06-06 09:29:01 +00:00
|
|
|
dict(type='RandomResizedCrop', size=4)
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def test_deepcluster_dataset():
|
2023-03-14 14:34:27 +08:00
|
|
|
init_default_scope('mmselfsup')
|
2022-06-06 09:29:01 +00:00
|
|
|
|
|
|
|
data = dict(
|
|
|
|
ann_file=osp.join(
|
|
|
|
osp.dirname(__file__), '..', 'data', 'data_list.txt'),
|
|
|
|
metainfo=None,
|
|
|
|
data_root=osp.join(osp.dirname(__file__), '..', 'data'),
|
|
|
|
pipeline=train_pipeline)
|
|
|
|
dataset = DeepClusterImageNet(**data)
|
|
|
|
assert len(dataset.clustering_labels) == 2
|
|
|
|
|
|
|
|
x = dataset[0]
|
|
|
|
print(x)
|
|
|
|
assert x['img'].shape == (4, 4, 3)
|
|
|
|
assert x['clustering_label'] == -1
|
|
|
|
assert x['sample_idx'] == 0
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
dataset.assign_labels([1])
|
|
|
|
|
|
|
|
dataset.assign_labels([1, 0])
|
|
|
|
assert dataset.clustering_labels[0] == 1
|
|
|
|
assert dataset.clustering_labels[1] == 0
|
|
|
|
|
|
|
|
x = dataset[0]
|
|
|
|
assert x['clustering_label'] == 1
|