2024-11-05 02:21:44 +08:00
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
2024-10-24 20:42:41 +08:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# fmt: off
# flake8: noqa
""" :md
2024-11-05 02:21:44 +08:00
# Serializing codes separately, with IndexLSH and IndexPQ
Let ' s say, for example, you have a few vector embeddings per user
and want to shard a flat index by user so you can re - use the same LSH or PQ method
for all users but store each user ' s codes independently.
2024-10-24 20:42:41 +08:00
"""
""" :py """
import faiss
import numpy as np
""" :py """
d = 768
2024-11-05 02:21:44 +08:00
n = 1_000
2024-10-24 20:42:41 +08:00
ids = np . arange ( n ) . astype ( ' int64 ' )
training_data = np . random . rand ( n , d ) . astype ( ' float32 ' )
""" :py """
def read_ids_codes ( ) :
try :
return np . load ( " /tmp/ids.npy " ) , np . load ( " /tmp/codes.npy " )
except FileNotFoundError :
return None , None
def write_ids_codes ( ids , codes ) :
np . save ( " /tmp/ids.npy " , ids )
np . save ( " /tmp/codes.npy " , codes . reshape ( len ( ids ) , - 1 ) )
def write_template_index ( template_index ) :
faiss . write_index ( template_index , " /tmp/template.index " )
def read_template_index_instance ( ) :
2024-10-29 10:56:00 +08:00
return faiss . read_index ( " /tmp/template.index " )
2024-10-24 20:42:41 +08:00
2024-11-05 02:21:44 +08:00
""" :md
## IndexLSH: separate codes
The first half of this notebook demonstrates how to store LSH codes . Unlike PQ , LSH does not require training . In fact , it ' s compression method, a random projections matrix, is deterministic on construction based on a random seed value that ' s [ hardcoded ] ( https : / / github . com / facebookresearch / faiss / blob / 2 c961cc308ade8a85b3aa10a550728ce3387f625 / faiss / IndexLSH . cpp #L35).
"""
2024-10-24 20:42:41 +08:00
""" :py """
2024-11-05 02:21:44 +08:00
nbits = 1536
""" :py """
# demonstrating encoding is deterministic
codes = [ ]
database_vector_float32 = np . random . rand ( 1 , d ) . astype ( np . float32 )
for i in range ( 10 ) :
index = faiss . IndexIDMap2 ( faiss . IndexLSH ( d , nbits ) )
code = index . index . sa_encode ( database_vector_float32 )
codes . append ( code )
for i in range ( 1 , 10 ) :
assert np . array_equal ( codes [ 0 ] , codes [ i ] )
""" :py """
# new database vector
ids , codes = read_ids_codes ( )
database_vector_id , database_vector_float32 = max ( ids ) + 1 if ids is not None else 1 , np . random . rand ( 1 , d ) . astype ( np . float32 )
index = faiss . IndexIDMap2 ( faiss . IndexLSH ( d , nbits ) )
code = index . index . sa_encode ( database_vector_float32 )
2024-10-24 20:42:41 +08:00
2024-11-05 02:21:44 +08:00
if ids is not None and codes is not None :
ids = np . concatenate ( ( ids , [ database_vector_id ] ) )
codes = np . vstack ( ( codes , code ) )
else :
ids = np . array ( [ database_vector_id ] )
codes = np . array ( [ code ] )
write_ids_codes ( ids , codes )
""" :py ' 2840581589434841 ' """
# then at query time
query_vector_float32 = np . random . rand ( 1 , d ) . astype ( np . float32 )
index = faiss . IndexIDMap2 ( faiss . IndexLSH ( d , nbits ) )
ids , codes = read_ids_codes ( )
index . add_sa_codes ( codes , ids )
index . search ( query_vector_float32 , k = 5 )
""" :py """
! rm / tmp / ids . npy / tmp / codes . npy
""" :md
## IndexPQ: separate codes from codebook
The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook
( via faiss . write_index for IndexPQ ) independently of the vector codes . For example , in the case
where you have a few vector embeddings per user and want to shard the flat index by user you
can re - use the same PQ method for all users but store each user ' s codes independently.
"""
""" :py """
M = d / / 8
nbits = 8
""" :py """
# at train time
2024-10-29 10:56:00 +08:00
template_index = faiss . index_factory ( d , f " IDMap2,PQ { M } x { nbits } " )
2024-10-24 20:42:41 +08:00
template_index . train ( training_data )
write_template_index ( template_index )
""" :py """
# New database vector
2024-10-29 10:56:00 +08:00
index = read_template_index_instance ( )
2024-10-24 20:42:41 +08:00
ids , codes = read_ids_codes ( )
2024-11-05 02:21:44 +08:00
database_vector_id , database_vector_float32 = max ( ids ) + 1 if ids is not None else 1 , np . random . rand ( 1 , d ) . astype ( np . float32 )
2024-10-29 10:56:00 +08:00
code = index . index . sa_encode ( database_vector_float32 )
2024-10-24 20:42:41 +08:00
if ids is not None and codes is not None :
ids = np . concatenate ( ( ids , [ database_vector_id ] ) )
codes = np . vstack ( ( codes , code ) )
else :
ids = np . array ( [ database_vector_id ] )
codes = np . array ( [ code ] )
2024-10-29 10:56:00 +08:00
2024-10-24 20:42:41 +08:00
write_ids_codes ( ids , codes )
2024-11-05 02:21:44 +08:00
""" :py ' 1858280061369209 ' """
2024-10-24 20:42:41 +08:00
# then at query time
2024-10-29 10:56:00 +08:00
query_vector_float32 = np . random . rand ( 1 , d ) . astype ( np . float32 )
id_wrapper_index = read_template_index_instance ( )
2024-10-24 20:42:41 +08:00
ids , codes = read_ids_codes ( )
2024-10-29 10:56:00 +08:00
id_wrapper_index . add_sa_codes ( codes , ids )
2024-10-24 20:42:41 +08:00
id_wrapper_index . search ( query_vector_float32 , k = 5 )
""" :py """
! rm / tmp / ids . npy / tmp / codes . npy / tmp / template . index
2024-11-05 02:21:44 +08:00
""" :md
## Comparing these methods
- methods : Flat , LSH , PQ
- vary cost : nbits , M for 1 x , 2 x , 4 x , 8 x , 16 x , 32 x compression
- measure : recall @ 1
We don ' t measure latency as the number of vectors per user shard is insignificant.
"""
""" :py ' 2898032417027201 ' """
n , d
""" :py """
database_vector_ids , database_vector_float32s = np . arange ( n ) , np . random . rand ( n , d ) . astype ( np . float32 )
query_vector_float32s = np . random . rand ( n , d ) . astype ( np . float32 )
""" :py """
index = faiss . index_factory ( d , " IDMap2,Flat " )
index . add_with_ids ( database_vector_float32s , database_vector_ids )
_ , ground_truth_result_ids = index . search ( query_vector_float32s , k = 1 )
""" :py ' 857475336204238 ' """
from dataclasses import dataclass
pq_m_nbits = (
# 96 bytes
( 96 , 8 ) ,
( 192 , 4 ) ,
# 192 bytes
( 192 , 8 ) ,
( 384 , 4 ) ,
# 384 bytes
( 384 , 8 ) ,
( 768 , 4 ) ,
)
lsh_nbits = ( 768 , 1536 , 3072 , 6144 , 12288 , 24576 )
@dataclass
class Record :
type_ : str
index : faiss . Index
args : tuple
recall : float
results = [ ]
for m , nbits in pq_m_nbits :
print ( " pq " , m , nbits )
index = faiss . index_factory ( d , f " IDMap2,PQ { m } x { nbits } " )
index . train ( training_data )
index . add_with_ids ( database_vector_float32s , database_vector_ids )
_ , result_ids = index . search ( query_vector_float32s , k = 1 )
recall = sum ( result_ids == ground_truth_result_ids )
results . append ( Record ( " pq " , index , ( m , nbits ) , recall ) )
for nbits in lsh_nbits :
print ( " lsh " , nbits )
index = faiss . IndexIDMap2 ( faiss . IndexLSH ( d , nbits ) )
index . add_with_ids ( database_vector_float32s , database_vector_ids )
_ , result_ids = index . search ( query_vector_float32s , k = 1 )
recall = sum ( result_ids == ground_truth_result_ids )
results . append ( Record ( " lsh " , index , ( nbits , ) , recall ) )
""" :py ' 556918346720794 ' """
import matplotlib . pyplot as plt
import numpy as np
def create_grouped_bar_chart ( x_values , y_values_list , labels_list , xlabel , ylabel , title ) :
num_bars_per_group = len ( x_values )
plt . figure ( figsize = ( 12 , 6 ) )
for x , y_values , labels in zip ( x_values , y_values_list , labels_list ) :
num_bars = len ( y_values )
bar_width = 0.08 * x
bar_positions = np . arange ( num_bars ) * bar_width - ( num_bars - 1 ) * bar_width / 2 + x
bars = plt . bar ( bar_positions , y_values , width = bar_width )
for bar , label in zip ( bars , labels ) :
height = bar . get_height ( )
plt . annotate (
label ,
xy = ( bar . get_x ( ) + bar . get_width ( ) / 2 , height ) ,
xytext = ( 0 , 3 ) ,
textcoords = " offset points " ,
ha = ' center ' , va = ' bottom '
)
plt . xscale ( ' log ' )
plt . xlabel ( xlabel )
plt . ylabel ( ylabel )
plt . title ( title )
plt . xticks ( x_values , labels = [ str ( x ) for x in x_values ] )
plt . tight_layout ( )
plt . show ( )
# # Example usage:
# x_values = [1, 2, 4, 8, 16, 32]
# y_values_list = [
# [2.5, 3.6, 1.8],
# [3.0, 2.8],
# [2.5, 3.5, 4.0, 1.0],
# [4.2],
# [3.0, 5.5, 2.2],
# [6.0, 4.5]
# ]
# labels_list = [
# ['A1', 'B1', 'C1'],
# ['A2', 'B2'],
# ['A3', 'B3', 'C3', 'D3'],
# ['A4'],
# ['A5', 'B5', 'C5'],
# ['A6', 'B6']
# ]
# create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title")
""" :py ' 1630106834206134 ' """
# x-axis: compression ratio
# y-axis: recall@1
from collections import defaultdict
x = defaultdict ( list )
x [ 1 ] . append ( ( " flat " , 1.00 ) )
for r in results :
y_value = r . recall [ 0 ] / n
x_value = int ( d * 4 / r . index . sa_code_size ( ) )
label = None
if r . type_ == " pq " :
label = f " PQ { r . args [ 0 ] } x { r . args [ 1 ] } "
if r . type_ == " lsh " :
label = f " LSH { r . args [ 0 ] } "
x [ x_value ] . append ( ( label , y_value ) )
x_values = sorted ( list ( x . keys ( ) ) )
create_grouped_bar_chart (
x_values ,
[ [ e [ 1 ] for e in x [ x_value ] ] for x_value in x_values ] ,
[ [ e [ 0 ] for e in x [ x_value ] ] for x_value in x_values ] ,
" compression ratio " ,
" recall@1 q=1,000 queries " ,
" recall@1 for a database of n=1,000 d=768 vectors " ,
)