43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
|
|
import pytest
|
|
from mmengine.registry import init_default_scope
|
|
|
|
from mmselfsup.datasets import DeepClusterImageNet
|
|
|
|
# dataset settings
|
|
train_pipeline = [
|
|
dict(type='LoadImageFromFile'),
|
|
dict(type='RandomResizedCrop', size=4)
|
|
]
|
|
|
|
|
|
def test_deepcluster_dataset():
|
|
init_default_scope('mmselfsup')
|
|
|
|
data = dict(
|
|
ann_file=osp.join(
|
|
osp.dirname(__file__), '..', 'data', 'data_list.txt'),
|
|
metainfo=None,
|
|
data_root=osp.join(osp.dirname(__file__), '..', 'data'),
|
|
pipeline=train_pipeline)
|
|
dataset = DeepClusterImageNet(**data)
|
|
assert len(dataset.clustering_labels) == 2
|
|
|
|
x = dataset[0]
|
|
print(x)
|
|
assert x['img'].shape == (4, 4, 3)
|
|
assert x['clustering_label'] == -1
|
|
assert x['sample_idx'] == 0
|
|
|
|
with pytest.raises(AssertionError):
|
|
dataset.assign_labels([1])
|
|
|
|
dataset.assign_labels([1, 0])
|
|
assert dataset.clustering_labels[0] == 1
|
|
assert dataset.clustering_labels[1] == 0
|
|
|
|
x = dataset[0]
|
|
assert x['clustering_label'] == 1
|