2022-03-04 13:43:49 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
2022-01-24 19:10:29 +08:00
|
|
|
from unittest.mock import patch
|
|
|
|
|
2021-12-15 19:07:01 +08:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from mmselfsup.utils.clustering import PIC, Kmeans
|
|
|
|
|
|
|
|
|
2022-01-24 19:10:29 +08:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_faiss_in_clutering():
|
|
|
|
with patch('mmselfsup.utils.clustering.faiss') as faiss:
|
|
|
|
yield faiss
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
def mock_faiss(mock_faiss_in_clutering):
|
|
|
|
mock_PCAmatrix = mock_faiss_in_clutering.PCAMatrix.return_value
|
|
|
|
mock_GpuIndexFlatL2 = mock_faiss_in_clutering.GpuIndexFlatL2.return_value
|
|
|
|
|
|
|
|
mock_PCAmatrix.apply_py.return_value = np.random.rand(10, 8)
|
|
|
|
mock_GpuIndexFlatL2.search.return_value = (
|
|
|
|
np.random.rand(1000, 6),
|
|
|
|
np.random.rand(1000, 6),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('verbose', [True, False])
|
|
|
|
def test_kmeans(mock_faiss, verbose):
|
2021-12-15 19:07:01 +08:00
|
|
|
fake_input = np.random.rand(10, 8).astype(np.float32)
|
|
|
|
pca_dim = 2
|
|
|
|
|
|
|
|
kmeans = Kmeans(2, pca_dim)
|
2022-01-24 19:10:29 +08:00
|
|
|
loss = kmeans.cluster(fake_input, verbose=verbose)
|
2021-12-15 19:07:01 +08:00
|
|
|
assert loss is not None
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
2022-01-24 19:10:29 +08:00
|
|
|
loss = kmeans.cluster(np.random.rand(10, 8), verbose=verbose)
|
2021-12-15 19:07:01 +08:00
|
|
|
|
|
|
|
|
2022-01-24 19:10:29 +08:00
|
|
|
def test_pic(mock_faiss):
|
2021-12-15 19:07:01 +08:00
|
|
|
fake_input = np.random.rand(1000, 16).astype(np.float32)
|
|
|
|
pic = PIC(pca_dim=8)
|
|
|
|
res = pic.cluster(fake_input)
|
|
|
|
assert res == 0
|