faiss/benchs/bench_fw/index.py
George Wang 6441b564da PQFS into Index trainer (#3941)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3941

Support PQFS with index trainer

Reviewed By: kuarora

Differential Revision: D64259953

fbshipit-source-id: fd7ed90aed2ff6b6351460dcf7b61058c59cd25b
2024-10-11 17:10:03 -07:00

1138 lines
41 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from collections import OrderedDict
from copy import copy
from dataclasses import dataclass
from typing import ClassVar, Dict, List, Optional
import faiss # @manual=//faiss/python:pyfaiss
import numpy as np
from faiss.benchs.bench_fw.descriptors import IndexBaseDescriptor
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib
knn_intersection_measure,
OperatingPointsWithRanges,
)
from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib
reverse_index_factory,
)
from faiss.contrib.ivf_tools import ( # @manual=//faiss/contrib:faiss_contrib
add_preassigned,
replace_ivf_quantizer,
)
from .descriptors import DatasetDescriptor
from .utils import (
distance_ratio_measure,
get_cpu_info,
refine_distances_knn,
refine_distances_range,
timer,
)
logger = logging.getLogger(__name__)
# The classes below are wrappers around Faiss indices, with different
# implementations for the case when we start with an already trained
# index (IndexFromCodec) vs factory strings (IndexFromFactory).
# In both cases the classes have operations for adding to an index
# and searching it, and outputs are cached on disk.
# IndexFromFactory also decomposes the index (pretransform and quantizer)
# and trains sub-indices independently.
class IndexBase:
def set_io(self, benchmark_io):
self.io = benchmark_io
@staticmethod
def set_index_param_dict_list(index, param_dict_list, assert_same=False):
if not param_dict_list:
return
index = faiss.downcast_index(index)
for param_dict in param_dict_list:
assert index is not None
IndexBase.set_index_param_dict(index, param_dict, assert_same)
index = faiss.try_extract_index_ivf(index)
if index is not None:
index = index.quantizer
@staticmethod
def set_index_param_dict(index, param_dict, assert_same=False):
if not param_dict:
return
for name, val in param_dict.items():
IndexBase.set_index_param(index, name, val, assert_same)
@staticmethod
def set_index_param(index, name, val, assert_same=False):
index = faiss.downcast_index(index)
val = int(val)
if (
isinstance(index, faiss.IndexPreTransform)
or isinstance(index, faiss.IndexIDMap)
):
Index.set_index_param(index.index, name, val)
return
elif name == "snap":
return
elif name == "lsq_gpu":
if val == 1:
ngpus = faiss.get_num_gpus()
icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus)
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
for i in range(index.plsq.nsplits):
lsq = faiss.downcast_Quantizer(
index.plsq.subquantizer(i)
)
if lsq.icm_encoder_factory is None:
lsq.icm_encoder_factory = icm_encoder_factory
else:
if index.lsq.icm_encoder_factory is None:
index.lsq.icm_encoder_factory = icm_encoder_factory
return
elif name in ["efSearch", "efConstruction"]:
obj = index.hnsw
elif name in ["nprobe", "parallel_mode"]:
obj = faiss.extract_index_ivf(index)
elif name in ["use_beam_LUT", "max_beam_size"]:
if isinstance(index, faiss.IndexProductResidualQuantizer):
obj = [
faiss.downcast_Quantizer(index.prq.subquantizer(i))
for i in range(index.prq.nsplits)
]
else:
obj = index.rq
elif name == "encode_ils_iters":
if isinstance(index, faiss.IndexProductLocalSearchQuantizer):
obj = [
faiss.downcast_Quantizer(index.plsq.subquantizer(i))
for i in range(index.plsq.nsplits)
]
else:
obj = index.lsq
else:
obj = index
if not isinstance(obj, list):
obj = [obj]
for o in obj:
test = getattr(o, name)
if assert_same and not name == "use_beam_LUT":
assert test == val
else:
setattr(o, name, val)
@staticmethod
def filter_index_param_dict_list(param_dict_list):
if (
param_dict_list is not None
and param_dict_list[0] is not None
and "k_factor" in param_dict_list[0]
):
filtered = copy(param_dict_list)
del filtered[0]["k_factor"]
return filtered
else:
return param_dict_list
def is_flat(self):
model = faiss.downcast_index(self.get_model())
return isinstance(model, faiss.IndexFlat)
def is_ivf(self):
return False
model = self.get_model()
return faiss.try_extract_index_ivf(model) is not None
def is_2layer(self):
def is_2layer_(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
return is_2layer_(index.index)
return isinstance(index, faiss.Index2Layer)
model = self.get_model()
return is_2layer_(model)
def is_decode_supported(self):
model = self.get_model()
if isinstance(model, faiss.IndexPreTransform):
for i in range(model.chain.size()):
vt = faiss.downcast_VectorTransform(model.chain.at(i))
if isinstance(vt, faiss.ITQTransform):
return False
return True
def is_pretransform(self):
codec = self.get_model()
if isinstance(codec, faiss.IndexRefine):
codec = faiss.downcast_index(codec.base_index)
return isinstance(codec, faiss.IndexPreTransform)
# index is a codec + database vectors
# in other words: a trained Faiss index
# that contains database vectors
def get_index_name(self):
raise NotImplementedError
def get_index(self):
raise NotImplementedError
# codec is a trained model
# in other words: a trained Faiss index
# without any database vectors
def get_codec_name(self):
raise NotImplementedError
def get_codec(self):
raise NotImplementedError
# model is an untrained Faiss index
# it can be used for training (see codec)
# or to inspect its structure
def get_model_name(self):
raise NotImplementedError
def get_model(self):
raise NotImplementedError
def get_construction_params(self):
raise NotImplementedError
def transform(self, vectors):
transformed_vectors = DatasetDescriptor(
tablename=f"{vectors.get_filename()}{self.get_codec_name()}transform.npy"
)
if not self.io.file_exist(transformed_vectors.tablename):
codec = self.get_codec()
assert isinstance(codec, faiss.IndexPreTransform)
transform = faiss.downcast_VectorTransform(codec.chain.at(0))
x = self.io.get_dataset(vectors)
xt = transform.apply(x)
self.io.write_nparray(xt, transformed_vectors.tablename)
return transformed_vectors
def snap(self, vectors):
transformed_vectors = DatasetDescriptor(
tablename=f"{vectors.get_filename()}{self.get_codec_name()}snap.npy"
)
if not self.io.file_exist(transformed_vectors.tablename):
codec = self.get_codec()
x = self.io.get_dataset(vectors)
xt = codec.sa_decode(codec.sa_encode(x))
self.io.write_nparray(xt, transformed_vectors.tablename)
return transformed_vectors
def knn_search_quantizer(self, query_vectors, k):
if self.is_pretransform():
pretransform = self.get_pretransform()
quantizer_query_vectors = pretransform.transform(query_vectors)
else:
pretransform = None
quantizer_query_vectors = query_vectors
quantizer, _, _ = self.get_quantizer(
dry_run=False, pretransform=pretransform
)
QD, QI, _, QP, _ = quantizer.knn_search(
dry_run=False,
search_parameters=None,
query_vectors=quantizer_query_vectors,
k=k,
)
xqt = self.io.get_dataset(quantizer_query_vectors)
return xqt, QD, QI, QP
def get_knn_search_name(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
reconstruct: bool = False,
):
name = self.get_index_name()
name += IndexBaseDescriptor.param_dict_to_name(search_parameters)
name += query_vectors.get_filename("q")
name += f"k_{k}."
name += f"t_{self.num_threads}."
if reconstruct:
name += "rec."
else:
name += "knn."
return name
def knn_search(
self,
dry_run,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
I_gt=None,
D_gt=None,
):
logger.info("knn_search: begin")
if (
search_parameters is not None and
search_parameters.get("snap", 0) == 1
):
query_vectors = self.snap(query_vectors)
filename = (
self.get_knn_search_name(search_parameters, query_vectors, k)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
D, I, R, P = self.io.read_file(filename, ["D", "I", "R", "P"])
else:
if dry_run:
return None, None, None, None, filename
index = self.get_index()
Index.set_index_param_dict(index, search_parameters)
if self.is_2layer():
# Index2Layer doesn't support search
xq = self.io.get_dataset(query_vectors)
xb = index.reconstruct_n(0, index.ntotal)
(D, I), t, _ = timer(
"knn_search 2layer", lambda: faiss.knn(xq, xb, k)
)
elif self.is_ivf() and not isinstance(index, faiss.IndexRefine):
index_ivf = faiss.extract_index_ivf(index)
nprobe = (
search_parameters["nprobe"]
if search_parameters is not None
and "nprobe" in search_parameters
else index_ivf.nprobe
)
xqt, QD, QI, QP = self.knn_search_quantizer(
query_vectors=query_vectors,
k=nprobe,
)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
(D, I), t, repeat = timer(
"knn_search_preassigned",
lambda: index_ivf.search_preassigned(xqt, k, QI, QD),
)
# Dref, Iref = index.search(xq, k)
# np.testing.assert_array_equal(I, Iref)
# np.testing.assert_allclose(D, Dref)
else:
xq = self.io.get_dataset(query_vectors)
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
if (
self.is_flat() or
not hasattr(self, "database_vectors") or
(self.database_vectors is None)
): # TODO
R = D
else:
xq = self.io.get_dataset(query_vectors)
xb = self.io.get_dataset(self.database_vectors)
R = refine_distances_knn(xq, xb, I, self.metric_type)
P = {
"time": t,
"k": k,
}
if self.is_ivf() and not isinstance(index, faiss.IndexRefine):
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": int(stats.nq // repeat),
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(filename, ["D", "I", "R", "P"], [D, I, R, P])
P |= {
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
"knn_intersection": (
knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None
),
"distance_ratio": (
distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None
),
}
logger.info("knn_search: end")
return D, I, R, P, None
def reconstruct(
self,
dry_run,
parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
k: int,
I_gt,
):
logger.info("reconstruct: begin")
filename = (
self.get_knn_search_name(
parameters, query_vectors, k, reconstruct=True
)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
(P,) = self.io.read_file(filename, ["P"])
P["index"] = self.get_index_name()
P["codec"] = self.get_codec_name()
P["factory"] = self.get_model_name()
P["reconstruct_params"] = parameters
P["construction_params"] = self.get_construction_params()
else:
if dry_run:
return None, filename
codec = self.get_codec()
codec_meta = self.fetch_meta()
Index.set_index_param_dict(codec, parameters)
xb = self.io.get_dataset(self.database_vectors)
xb_encoded, encode_t, _ = timer(
"sa_encode", lambda: codec.sa_encode(xb)
)
xq = self.io.get_dataset(query_vectors)
if self.is_decode_supported():
xb_decoded, decode_t, _ = timer(
"sa_decode", lambda: codec.sa_decode(xb_encoded)
)
mse = np.square(xb_decoded - xb).sum(axis=1).mean().item()
_, I = faiss.knn(xq, xb_decoded, k, metric=self.metric_type)
asym_recall = knn_intersection_measure(I, I_gt)
xq_decoded = codec.sa_decode(codec.sa_encode(xq))
_, I = faiss.knn(
xq_decoded, xb_decoded, k, metric=self.metric_type
)
else:
mse = None
asym_recall = None
decode_t = None
# assume hamming for sym
xq_encoded = codec.sa_encode(xq)
bin = faiss.IndexBinaryFlat(xq_encoded.shape[1] * 8)
bin.add(xb_encoded)
_, I = bin.search(xq_encoded, k)
sym_recall = knn_intersection_measure(I, I_gt)
P = {
"encode_time": encode_t,
"decode_time": decode_t,
"mse": mse,
"sym_recall": sym_recall,
"asym_recall": asym_recall,
"cpu": get_cpu_info(),
"num_threads": self.num_threads,
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.get_model_name(),
"reconstruct_params": parameters,
"construction_params": self.get_construction_params(),
"codec_meta": codec_meta,
}
self.io.write_file(filename, ["P"], [P])
logger.info("reconstruct: end")
return P, None
def get_range_search_name(
self,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
radius: Optional[float] = None,
):
name = self.get_index_name()
name += Index.param_dict_to_name(search_parameters)
name += query_vectors.get_filename("q")
if radius is not None:
name += f"r_{int(radius * 1000)}."
else:
name += "r_auto."
return name
def range_search(
self,
dry_run,
search_parameters: Optional[Dict[str, int]],
query_vectors: DatasetDescriptor,
radius: Optional[float] = None,
):
logger.info("range_search: begin")
if (
search_parameters is not None and
search_parameters.get("snap", 0) == 1
):
query_vectors = self.snap(query_vectors)
filename = (
self.get_range_search_name(
search_parameters, query_vectors, radius
)
+ "zip"
)
if self.io.file_exist(filename):
logger.info(f"Using cached results for {filename}")
lims, D, I, R, P = self.io.read_file(
filename, ["lims", "D", "I", "R", "P"]
)
else:
if dry_run:
return None, None, None, None, None, filename
xq = self.io.get_dataset(query_vectors)
index = self.get_index()
Index.set_index_param_dict(index, search_parameters)
if self.is_ivf():
xqt, QD, QI, QP = self.knn_search_quantizer(
query_vectors, search_parameters["nprobe"]
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
(lims, D, I), t, repeat = timer(
"range_search_preassigned",
lambda: index_ivf.range_search_preassigned(
xqt, radius, QI, QD
),
)
else:
(lims, D, I), t, _ = timer(
"range_search", lambda: index.range_search(xq, radius)
)
if self.is_flat():
R = D
else:
xb = self.io.get_dataset(self.database_vectors)
R = refine_distances_range(
lims, D, I, xq, xb, self.metric_type
)
P = {
"time": t,
"radius": radius,
"count": len(I),
}
if self.is_ivf():
stats = faiss.cvar.indexIVF_stats
P |= {
"quantizer": QP,
"nq": int(stats.nq // repeat),
"nlist": int(stats.nlist // repeat),
"ndis": int(stats.ndis // repeat),
"nheap_updates": int(stats.nheap_updates // repeat),
"quantization_time": int(
stats.quantization_time // repeat
),
"search_time": int(stats.search_time // repeat),
}
self.io.write_file(
filename, ["lims", "D", "I", "R", "P"], [lims, D, I, R, P]
)
P |= {
"index": self.get_index_name(),
"codec": self.get_codec_name(),
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
}
logger.info("range_seach: end")
return lims, D, I, R, P, None
# Common base for IndexFromCodec and IndexFromFactory,
# but not for the sub-indices of codec-based indices
# IndexFromQuantizer and IndexFromPreTransform, because
# they share the configuration of their parent IndexFromCodec
@dataclass
class Index(IndexBase):
num_threads: int
d: int
metric: str
codec_name: Optional[str] = None
index_name: Optional[str] = None
database_vectors: Optional[DatasetDescriptor] = None
construction_params: Optional[List[Dict[str, int]]] = None
search_params: Optional[Dict[str, int]] = None
serialize_full_index: bool = False
bucket: Optional[str] = None
index_path: Optional[str] = None
cached_codec: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict()
cached_index: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict()
def __post_init__(self):
logger.info(f"Initializing metric_type to {self.metric}")
if isinstance(self.metric, str):
if self.metric == "IP":
self.metric_type = faiss.METRIC_INNER_PRODUCT
elif self.metric == "L2":
self.metric_type = faiss.METRIC_L2
else:
raise ValueError
elif isinstance(self.metric, int):
self.metric_type = self.metric
if self.metric_type == faiss.METRIC_INNER_PRODUCT:
self.metric = "IP"
elif self.metric_type == faiss.METRIC_L2:
self.metric = "L2"
else:
raise ValueError
else:
raise ValueError
def supports_range_search(self):
codec = self.get_codec()
return not type(codec) in [
faiss.IndexHNSWFlat,
faiss.IndexIVFFastScan,
faiss.IndexRefine,
faiss.IndexPQ,
]
def fetch_codec(self):
raise NotImplementedError
def get_codec(self):
codec_name = self.get_codec_name()
if codec_name not in Index.cached_codec:
Index.cached_codec[codec_name], _, _ = self.fetch_codec()
if len(Index.cached_codec) > 1:
Index.cached_codec.popitem(last=False)
return Index.cached_codec[codec_name]
def get_model(self):
return self.get_index()
def get_model_name(self):
return self.get_index_name()
def get_codec_name(self) -> Optional[str]:
return self.codec_name
def get_index_name(self) -> Optional[str]:
return self.index_name
def fetch_index(self):
# read index from file if it is already available
index_filename = None
if self.index_path:
index_filename = os.path.basename(self.index_path)
elif self.index_name:
index_filename = self.index_name + "index"
if index_filename and self.io.file_exist(index_filename):
if self.index_path:
index = self.io.read_index(
index_filename,
self.bucket,
os.path.dirname(self.index_path),
)
else:
index = self.io.read_index(index_filename)
assert self.d == index.d
assert self.metric_type == index.metric_type
return index, 0
index = self.get_codec()
index.reset()
assert index.ntotal == 0
logger.info("Adding vectors to index")
xb = self.io.get_dataset(self.database_vectors)
if self.is_ivf() and not isinstance(index, faiss.IndexRefine):
xbt, QD, QI, QP = self.knn_search_quantizer(
query_vectors=self.database_vectors,
k=1,
)
index_ivf = faiss.extract_index_ivf(index)
if index_ivf.parallel_mode != 2:
logger.info("Setting IVF parallel mode")
index_ivf.parallel_mode = 2
_, t, _ = timer(
"add_preassigned",
lambda: add_preassigned(index_ivf, xbt, QI.ravel()),
once=True,
)
else:
_, t, _ = timer(
"add",
lambda: index.add(xb),
once=True,
)
assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0]
logger.info("Added vectors to index")
if self.serialize_full_index and index_filename:
codec_size = self.io.write_index(index, index_filename)
assert codec_size is not None
return index, t
def get_index(self):
index_name = self.index_name
# TODO(kuarora) : retrieve file from bucket and path.
if index_name not in Index.cached_index:
Index.cached_index[index_name], _ = self.fetch_index()
if len(Index.cached_index) > 3:
Index.cached_index.popitem(last=False)
return Index.cached_index[index_name]
def get_construction_params(self):
return self.construction_params
def get_code_size(self, codec=None):
def get_index_code_size(index):
index = faiss.downcast_index(index)
if isinstance(index, faiss.IndexPreTransform):
return get_index_code_size(index.index)
elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]:
return get_index_code_size(
index.base_index
) + get_index_code_size(index.refine_index)
else:
return index.code_size if hasattr(index, "code_size") else 0
if codec is None:
codec = self.get_codec()
return get_index_code_size(codec)
def get_sa_code_size(self, codec=None):
if codec is None:
codec = self.get_codec()
try:
return codec.sa_code_size()
except:
return None
def get_operating_points(self):
op = OperatingPointsWithRanges()
def add_range_or_val(name, range):
op.add_range(
name,
(
[self.search_params[name]]
if self.search_params and name in self.search_params
else range
),
)
add_range_or_val("snap", [0])
model = self.get_model()
model_ivf = faiss.try_extract_index_ivf(model)
if model_ivf is not None:
add_range_or_val(
"nprobe",
[2**i for i in range(12) if 2**i <= model_ivf.nlist * 0.5],
# [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [
# i
# for i in range(32, 64, 8)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(64, 128, 16)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(128, 256, 32)
# if i <= model_ivf.nlist * 0.1
# ] + [
# i
# for i in range(256, 512, 64)
# if i <= model_ivf.nlist * 0.1
# ] + [
# 2**i
# for i in range(9, 12)
# if 2**i <= model_ivf.nlist * 0.1
# ],
)
model = faiss.downcast_index(model)
if isinstance(model, faiss.IndexRefine):
add_range_or_val(
"k_factor",
[2**i for i in range(13)],
)
elif isinstance(model, faiss.IndexHNSWFlat):
add_range_or_val(
"efSearch",
[2**i for i in range(3, 11)],
)
elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(
model, faiss.IndexProductResidualQuantizer
):
add_range_or_val(
"max_beam_size",
[1, 2, 4, 8, 16, 32],
)
add_range_or_val(
"use_beam_LUT",
[1],
)
elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(
model, faiss.IndexProductLocalSearchQuantizer
):
add_range_or_val(
"encode_ils_iters",
[2, 4, 8, 16],
)
add_range_or_val(
"lsq_gpu",
[1],
)
return op
def is_flat_index(self):
return self.get_index_name().startswith("Flat")
# IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform
# are used to wrap pre-trained Faiss indices (codecs)
@dataclass
class IndexFromCodec(Index):
path: Optional[str] = None # remote or local path to the codec
def __post_init__(self):
super().__post_init__()
if self.path is None and self.codec_name is None:
raise ValueError("path or desc_name is not set")
def get_quantizer(self):
if not self.is_ivf():
raise ValueError("Not an IVF index")
quantizer = IndexFromQuantizer(self)
quantizer.set_io(self.io)
return quantizer
def get_pretransform(self):
if not self.is_ivf():
raise ValueError("Not an IVF index")
quantizer = IndexFromPreTransform(self)
quantizer.set_io(self.io)
return quantizer
def get_model_name(self):
if self.path is not None:
return os.path.basename(self.path)
else:
return self.get_codec_name()
def fetch_meta(self, dry_run=False):
return None, None
def fetch_codec(self):
if self.path is not None:
codec_filename = os.path.basename(self.path)
remote_path = os.path.dirname(self.path)
else:
codec_filename = self.get_codec_name() + "codec"
remote_path = None
codec = self.io.read_index(
codec_filename,
self.bucket,
remote_path,
)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
Index.set_index_param_dict_list(codec, self.construction_params)
return codec, None, None
def get_model(self):
return self.get_codec()
class IndexFromQuantizer(IndexBase):
ivf_index: Index
def __init__(self, ivf_index: Index):
self.ivf_index = ivf_index
super().__init__()
def get_model_name(self):
return self.get_index_name()
def get_codec_name(self):
return self.get_index_name()
def get_codec(self):
return self.get_index()
def get_index_name(self):
ivf_codec_name = self.ivf_index.get_codec_name()
return f"{ivf_codec_name}quantizer."
def get_index(self):
ivf_codec = faiss.extract_index_ivf(self.ivf_index.get_codec())
return ivf_codec.quantizer
class IndexFromPreTransform(IndexBase):
pre_transform_index: Index
def __init__(self, pre_transform_index: Index):
self.pre_transform_index = pre_transform_index
super().__init__()
def get_codec_name(self):
pre_transform_codec_name = self.pre_transform_index.get_codec_name()
return f"{pre_transform_codec_name}pretransform."
def get_codec(self):
return self.get_codec()
# IndexFromFactory is for creating and training indices from scratch
@dataclass
class IndexFromFactory(Index):
factory: Optional[str] = None
training_vectors: Optional[DatasetDescriptor] = None
assemble_opaque: bool = True
def __post_init__(self):
super().__post_init__()
if self.factory is None:
raise ValueError("factory is not set")
if self.factory != "Flat" and self.training_vectors is None:
raise ValueError(f"training_vectors is not set for {self.factory}")
def get_codec_name(self):
codec_name = super().get_codec_name()
if codec_name is None:
codec_name = f"{self.factory.replace(',', '_')}."
codec_name += f"d_{self.d}.{self.metric.upper()}."
if self.factory != "Flat":
assert self.training_vectors is not None
codec_name += self.training_vectors.get_filename("xt")
if self.construction_params is not None:
codec_name += IndexBaseDescriptor.param_dict_list_to_name(self.construction_params)
self.codec_name = codec_name
return self.codec_name
def fetch_meta(self, dry_run=False):
meta_filename = self.get_codec_name() + "json"
if self.io.file_exist(meta_filename):
meta = self.io.read_json(meta_filename)
report = None
else:
_, meta, report = self.fetch_codec(dry_run=dry_run)
return meta, report
def fetch_codec(self, dry_run=False):
codec_filename = self.get_codec_name() + "codec"
meta_filename = self.get_codec_name() + "json"
if self.io.file_exist(codec_filename) and self.io.file_exist(
meta_filename
):
codec = self.io.read_index(codec_filename)
assert self.d == codec.d
assert self.metric_type == codec.metric_type
meta = self.io.read_json(meta_filename)
else:
codec, training_time, requires = self.assemble(dry_run=dry_run)
if requires is not None:
assert dry_run
if requires == "":
return None, None, codec_filename
else:
return None, None, requires
codec_size = self.io.write_index(codec, codec_filename)
assert codec_size is not None
meta = {
"training_time": training_time,
"training_size": self.training_vectors.num_vectors if self.training_vectors else 0,
"codec_size": codec_size,
"sa_code_size": self.get_sa_code_size(codec),
"code_size": self.get_code_size(codec),
"cpu": get_cpu_info(),
}
self.io.write_json(meta, meta_filename, overwrite=True)
Index.set_index_param_dict_list(
codec, self.construction_params, assert_same=True
)
return codec, meta, None
def get_model_name(self):
return self.factory
def get_model(self):
model = faiss.index_factory(self.d, self.factory, self.metric_type)
Index.set_index_param_dict_list(model, self.construction_params)
return model
def get_pretransform(self):
model = self.get_model()
assert isinstance(model, faiss.IndexPreTransform)
sub_index = faiss.downcast_index(model.index)
if isinstance(sub_index, faiss.IndexFlat):
return self
# replace the sub-index with Flat
model.index = faiss.IndexFlat(model.index.d, model.index.metric_type)
pretransform = IndexFromFactory(
num_threads=self.num_threads,
d=model.d,
metric=model.metric_type,
database_vectors=self.database_vectors,
construction_params=self.construction_params,
search_params=None,
factory=reverse_index_factory(model),
training_vectors=self.training_vectors,
)
pretransform.set_io(self.io)
return pretransform
def get_quantizer(self, dry_run, pretransform=None):
model = self.get_model()
model_ivf = faiss.extract_index_ivf(model)
assert isinstance(model_ivf, faiss.IndexIVF)
assert ord(model_ivf.quantizer_trains_alone) in [0, 2]
if pretransform is None:
training_vectors = self.training_vectors
else:
training_vectors = pretransform.transform(self.training_vectors)
centroids, t, requires = training_vectors.k_means(
self.io, model_ivf.nlist, dry_run
)
if requires is not None:
return None, None, requires
quantizer = IndexFromFactory(
num_threads=self.num_threads,
d=model_ivf.quantizer.d,
metric=model_ivf.quantizer.metric_type,
database_vectors=centroids,
construction_params=(
self.construction_params[1:]
if self.construction_params is not None
else None
),
search_params=None,
factory=reverse_index_factory(model_ivf.quantizer),
training_vectors=centroids,
)
quantizer.set_io(self.io)
return quantizer, t, None
def assemble(self, dry_run):
logger.info(f"assemble {self.factory}")
model = self.get_model()
t_aggregate = 0
# try:
# reverse_index_factory(model)
# opaque = False
# except NotImplementedError:
# opaque = True
if self.assemble_opaque:
codec = model
else:
if isinstance(model, faiss.IndexPreTransform):
logger.info(f"assemble: pretransform {self.factory}")
sub_index = faiss.downcast_index(model.index)
if not isinstance(sub_index, faiss.IndexFlat):
# replace the sub-index with Flat and fetch pre-trained
pretransform = self.get_pretransform()
codec, meta, report = pretransform.fetch_codec(
dry_run=dry_run
)
if report is not None:
return None, None, report
t_aggregate += meta["training_time"]
assert codec.is_trained
transformed_training_vectors = pretransform.transform(
self.training_vectors
)
# replace the Flat index with the required sub-index
wrapper = IndexFromFactory(
num_threads=self.num_threads,
d=sub_index.d,
metric=sub_index.metric_type,
database_vectors=None,
construction_params=self.construction_params,
search_params=None,
factory=reverse_index_factory(sub_index),
training_vectors=transformed_training_vectors,
)
wrapper.set_io(self.io)
codec.index, meta, report = wrapper.fetch_codec(
dry_run=dry_run
)
if report is not None:
return None, None, report
t_aggregate += meta["training_time"]
assert codec.index.is_trained
else:
codec = model
elif isinstance(model, faiss.IndexIVF):
logger.info(f"assemble: ivf {self.factory}")
# replace the quantizer
quantizer, t, requires = self.get_quantizer(dry_run=dry_run)
if requires is not None:
return None, None, requires
t_aggregate += t
codec = faiss.clone_index(model)
quantizer_index, t = quantizer.fetch_index()
t_aggregate += t
replace_ivf_quantizer(codec, quantizer_index)
assert codec.quantizer.is_trained
assert codec.nlist == codec.quantizer.ntotal
elif isinstance(model, faiss.IndexRefine) or isinstance(
model, faiss.IndexRefineFlat
):
logger.info(f"assemble: refine {self.factory}")
# replace base_index
wrapper = IndexFromFactory(
num_threads=self.num_threads,
d=model.base_index.d,
metric=model.base_index.metric_type,
database_vectors=self.database_vectors,
construction_params=IndexBase.filter_index_param_dict_list(
self.construction_params
),
search_params=None,
factory=reverse_index_factory(model.base_index),
training_vectors=self.training_vectors,
)
wrapper.set_io(self.io)
codec = faiss.clone_index(model)
codec.base_index, meta, requires = wrapper.fetch_codec(
dry_run=dry_run
)
if requires is not None:
return None, None, requires
t_aggregate += meta["training_time"]
assert codec.base_index.is_trained
else:
codec = model
if self.factory != "Flat":
if dry_run:
return None, None, ""
logger.info(f"assemble, train {self.factory}")
xt = self.io.get_dataset(self.training_vectors)
_, t, _ = timer("train", lambda: codec.train(xt), once=True)
t_aggregate += t
return codec, t_aggregate, None