304 lines
8.2 KiB
Python
304 lines
8.2 KiB
Python
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# fmt: off
|
|
# flake8: noqa
|
|
|
|
|
|
""":md
|
|
# Serializing codes separately, with IndexLSH and IndexPQ
|
|
|
|
Let's say, for example, you have a few vector embeddings per user
|
|
and want to shard a flat index by user so you can re-use the same LSH or PQ method
|
|
for all users but store each user's codes independently.
|
|
|
|
|
|
"""
|
|
|
|
""":py"""
|
|
import faiss
|
|
import numpy as np
|
|
|
|
""":py"""
|
|
d = 768
|
|
n = 1_000
|
|
ids = np.arange(n).astype('int64')
|
|
training_data = np.random.rand(n, d).astype('float32')
|
|
|
|
""":py"""
|
|
def read_ids_codes():
|
|
try:
|
|
return np.load("/tmp/ids.npy"), np.load("/tmp/codes.npy")
|
|
except FileNotFoundError:
|
|
return None, None
|
|
|
|
|
|
def write_ids_codes(ids, codes):
|
|
np.save("/tmp/ids.npy", ids)
|
|
np.save("/tmp/codes.npy", codes.reshape(len(ids), -1))
|
|
|
|
|
|
def write_template_index(template_index):
|
|
faiss.write_index(template_index, "/tmp/template.index")
|
|
|
|
|
|
def read_template_index_instance():
|
|
return faiss.read_index("/tmp/template.index")
|
|
|
|
""":md
|
|
## IndexLSH: separate codes
|
|
|
|
The first half of this notebook demonstrates how to store LSH codes. Unlike PQ, LSH does not require training. In fact, it's compression method, a random projections matrix, is deterministic on construction based on a random seed value that's [hardcoded](https://github.com/facebookresearch/faiss/blob/2c961cc308ade8a85b3aa10a550728ce3387f625/faiss/IndexLSH.cpp#L35).
|
|
"""
|
|
|
|
""":py"""
|
|
nbits = 1536
|
|
|
|
""":py"""
|
|
# demonstrating encoding is deterministic
|
|
|
|
codes = []
|
|
database_vector_float32 = np.random.rand(1, d).astype(np.float32)
|
|
for i in range(10):
|
|
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
|
|
code = index.index.sa_encode(database_vector_float32)
|
|
codes.append(code)
|
|
|
|
for i in range(1, 10):
|
|
assert np.array_equal(codes[0], codes[i])
|
|
|
|
""":py"""
|
|
# new database vector
|
|
|
|
ids, codes = read_ids_codes()
|
|
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
|
|
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
|
|
|
|
code = index.index.sa_encode(database_vector_float32)
|
|
|
|
if ids is not None and codes is not None:
|
|
ids = np.concatenate((ids, [database_vector_id]))
|
|
codes = np.vstack((codes, code))
|
|
else:
|
|
ids = np.array([database_vector_id])
|
|
codes = np.array([code])
|
|
|
|
write_ids_codes(ids, codes)
|
|
|
|
""":py '2840581589434841'"""
|
|
# then at query time
|
|
|
|
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
|
|
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
|
|
ids, codes = read_ids_codes()
|
|
|
|
index.add_sa_codes(codes, ids)
|
|
|
|
index.search(query_vector_float32, k=5)
|
|
|
|
""":py"""
|
|
!rm /tmp/ids.npy /tmp/codes.npy
|
|
|
|
""":md
|
|
## IndexPQ: separate codes from codebook
|
|
|
|
The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook
|
|
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
|
|
where you have a few vector embeddings per user and want to shard the flat index by user you
|
|
can re-use the same PQ method for all users but store each user's codes independently.
|
|
|
|
"""
|
|
|
|
""":py"""
|
|
M = d//8
|
|
nbits = 8
|
|
|
|
""":py"""
|
|
# at train time
|
|
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
|
|
template_index.train(training_data)
|
|
write_template_index(template_index)
|
|
|
|
""":py"""
|
|
# New database vector
|
|
|
|
index = read_template_index_instance()
|
|
ids, codes = read_ids_codes()
|
|
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
|
|
|
|
code = index.index.sa_encode(database_vector_float32)
|
|
|
|
if ids is not None and codes is not None:
|
|
ids = np.concatenate((ids, [database_vector_id]))
|
|
codes = np.vstack((codes, code))
|
|
else:
|
|
ids = np.array([database_vector_id])
|
|
codes = np.array([code])
|
|
|
|
write_ids_codes(ids, codes)
|
|
|
|
""":py '1858280061369209'"""
|
|
# then at query time
|
|
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
|
|
id_wrapper_index = read_template_index_instance()
|
|
ids, codes = read_ids_codes()
|
|
|
|
id_wrapper_index.add_sa_codes(codes, ids)
|
|
|
|
id_wrapper_index.search(query_vector_float32, k=5)
|
|
|
|
""":py"""
|
|
!rm /tmp/ids.npy /tmp/codes.npy /tmp/template.index
|
|
|
|
""":md
|
|
## Comparing these methods
|
|
|
|
- methods: Flat, LSH, PQ
|
|
- vary cost: nbits, M for 1x, 2x, 4x, 8x, 16x, 32x compression
|
|
- measure: recall@1
|
|
|
|
We don't measure latency as the number of vectors per user shard is insignificant.
|
|
|
|
"""
|
|
|
|
""":py '2898032417027201'"""
|
|
n, d
|
|
|
|
""":py"""
|
|
database_vector_ids, database_vector_float32s = np.arange(n), np.random.rand(n, d).astype(np.float32)
|
|
query_vector_float32s = np.random.rand(n, d).astype(np.float32)
|
|
|
|
""":py"""
|
|
index = faiss.index_factory(d, "IDMap2,Flat")
|
|
index.add_with_ids(database_vector_float32s, database_vector_ids)
|
|
_, ground_truth_result_ids= index.search(query_vector_float32s, k=1)
|
|
|
|
""":py '857475336204238'"""
|
|
from dataclasses import dataclass
|
|
|
|
pq_m_nbits = (
|
|
# 96 bytes
|
|
(96, 8),
|
|
(192, 4),
|
|
# 192 bytes
|
|
(192, 8),
|
|
(384, 4),
|
|
# 384 bytes
|
|
(384, 8),
|
|
(768, 4),
|
|
)
|
|
lsh_nbits = (768, 1536, 3072, 6144, 12288, 24576)
|
|
|
|
|
|
@dataclass
|
|
class Record:
|
|
type_: str
|
|
index: faiss.Index
|
|
args: tuple
|
|
recall: float
|
|
|
|
|
|
results = []
|
|
|
|
for m, nbits in pq_m_nbits:
|
|
print("pq", m, nbits)
|
|
index = faiss.index_factory(d, f"IDMap2,PQ{m}x{nbits}")
|
|
index.train(training_data)
|
|
index.add_with_ids(database_vector_float32s, database_vector_ids)
|
|
_, result_ids = index.search(query_vector_float32s, k=1)
|
|
recall = sum(result_ids == ground_truth_result_ids)
|
|
results.append(Record("pq", index, (m, nbits), recall))
|
|
|
|
for nbits in lsh_nbits:
|
|
print("lsh", nbits)
|
|
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
|
|
index.add_with_ids(database_vector_float32s, database_vector_ids)
|
|
_, result_ids = index.search(query_vector_float32s, k=1)
|
|
recall = sum(result_ids == ground_truth_result_ids)
|
|
results.append(Record("lsh", index, (nbits,), recall))
|
|
|
|
""":py '556918346720794'"""
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
def create_grouped_bar_chart(x_values, y_values_list, labels_list, xlabel, ylabel, title):
|
|
num_bars_per_group = len(x_values)
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
for x, y_values, labels in zip(x_values, y_values_list, labels_list):
|
|
num_bars = len(y_values)
|
|
bar_width = 0.08 * x
|
|
bar_positions = np.arange(num_bars) * bar_width - (num_bars - 1) * bar_width / 2 + x
|
|
|
|
bars = plt.bar(bar_positions, y_values, width=bar_width)
|
|
|
|
for bar, label in zip(bars, labels):
|
|
height = bar.get_height()
|
|
plt.annotate(
|
|
label,
|
|
xy=(bar.get_x() + bar.get_width() / 2, height),
|
|
xytext=(0, 3),
|
|
textcoords="offset points",
|
|
ha='center', va='bottom'
|
|
)
|
|
|
|
plt.xscale('log')
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
plt.xticks(x_values, labels=[str(x) for x in x_values])
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
# # Example usage:
|
|
# x_values = [1, 2, 4, 8, 16, 32]
|
|
# y_values_list = [
|
|
# [2.5, 3.6, 1.8],
|
|
# [3.0, 2.8],
|
|
# [2.5, 3.5, 4.0, 1.0],
|
|
# [4.2],
|
|
# [3.0, 5.5, 2.2],
|
|
# [6.0, 4.5]
|
|
# ]
|
|
# labels_list = [
|
|
# ['A1', 'B1', 'C1'],
|
|
# ['A2', 'B2'],
|
|
# ['A3', 'B3', 'C3', 'D3'],
|
|
# ['A4'],
|
|
# ['A5', 'B5', 'C5'],
|
|
# ['A6', 'B6']
|
|
# ]
|
|
|
|
# create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title")
|
|
|
|
""":py '1630106834206134'"""
|
|
# x-axis: compression ratio
|
|
# y-axis: recall@1
|
|
|
|
from collections import defaultdict
|
|
|
|
x = defaultdict(list)
|
|
x[1].append(("flat", 1.00))
|
|
for r in results:
|
|
y_value = r.recall[0] / n
|
|
x_value = int(d * 4 / r.index.sa_code_size())
|
|
label = None
|
|
if r.type_ == "pq":
|
|
label = f"PQ{r.args[0]}x{r.args[1]}"
|
|
if r.type_ == "lsh":
|
|
label = f"LSH{r.args[0]}"
|
|
x[x_value].append((label, y_value))
|
|
|
|
x_values = sorted(list(x.keys()))
|
|
create_grouped_bar_chart(
|
|
x_values,
|
|
[[e[1] for e in x[x_value]] for x_value in x_values],
|
|
[[e[0] for e in x[x_value]] for x_value in x_values],
|
|
"compression ratio",
|
|
"recall@1 q=1,000 queries",
|
|
"recall@1 for a database of n=1,000 d=768 vectors",
|
|
)
|