mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
80 lines
1.9 KiB
Python
80 lines
1.9 KiB
Python
|
|
# Copyright (c) 2015-present, Facebook, Inc.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the CC-by-NC license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
#! /usr/bin/env python2
|
|
|
|
import libfb.py.mkl # noqa
|
|
|
|
import numpy as np
|
|
from libfb import testutil
|
|
|
|
import faiss
|
|
|
|
|
|
class TestClustering(testutil.BaseFacebookTestCase):
|
|
|
|
def test_clustering(self):
|
|
d = 64
|
|
n = 1000
|
|
np.random.seed(123)
|
|
x = np.random.random(size=(n, d)).astype('float32')
|
|
|
|
km = faiss.Kmeans(d, 32, niter=10)
|
|
err32 = km.train(x)
|
|
|
|
# check that objective is decreasing
|
|
prev = 1e50
|
|
for o in km.obj:
|
|
self.assertGreater(prev, o)
|
|
prev = o
|
|
|
|
km = faiss.Kmeans(d, 64, niter=10)
|
|
err64 = km.train(x)
|
|
|
|
# check that 64 centroids give a lower quantization error than 32
|
|
self.assertGreater(err32, err64)
|
|
|
|
|
|
class TestPCA(testutil.BaseFacebookTestCase):
|
|
|
|
def test_pca(self):
|
|
d = 64
|
|
n = 1000
|
|
np.random.seed(123)
|
|
x = np.random.random(size=(n, d)).astype('float32')
|
|
|
|
pca = faiss.PCAMatrix(d, 10)
|
|
pca.train(x)
|
|
y = pca.apply_py(x)
|
|
|
|
# check that energy per component is decreasing
|
|
column_norm2 = (y**2).sum(0)
|
|
|
|
prev = 1e50
|
|
for o in column_norm2:
|
|
self.assertGreater(prev, o)
|
|
prev = o
|
|
|
|
|
|
class TestProductQuantizer(testutil.BaseFacebookTestCase):
|
|
|
|
def test_pq(self):
|
|
d = 64
|
|
n = 1000
|
|
cs = 4
|
|
np.random.seed(123)
|
|
x = np.random.random(size=(n, d)).astype('float32')
|
|
pq = faiss.ProductQuantizer(d, cs, 8)
|
|
pq.train(x)
|
|
codes = pq.compute_codes(x)
|
|
x2 = pq.decode(codes)
|
|
diff = ((x - x2)**2).sum()
|
|
|
|
# print "diff=", diff
|
|
# diff= 1807.98
|
|
self.assertGreater(2500, diff)
|