Refactor bench_fw to support train, build & search in parallel (#3527)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3527 **Context** Design Doc: [Faiss Benchmarking](https://docs.google.com/document/d/1c7zziITa4RD6jZsbG9_yOgyRjWdyueldSPH6QdZzL98/edit) **In this diff** 1. Be able to reference codec and index from blobstore (bucket & path) outside the experiment 2. To support #1, naming is moved to descriptors. 3. Build index can be written as well. 4. You can run benchmark with train and then refer it in index built and then refer index built in knn search. Index serialization is optional. Although not yet exposed through index descriptor. 5. Benchmark can support index with different datasets sizes 6. Working with varying dataset now support multiple ground truth. There may be small fixes before we could use this. 7. Added targets for bench_fw_range, ivf, codecs and optimize. **Analysis of ivf result**: D58823037 Reviewed By: algoriddle Differential Revision: D57236543 fbshipit-source-id: ad03b28bae937a35f8c20f12e0a5b0a27c34ff3bpull/3533/head
parent
3a7c718ace
commit
da75d03442
File diff suppressed because it is too large
Load Diff
|
@ -53,6 +53,7 @@ class BenchmarkIO:
|
|||
def __post_init__(self):
|
||||
self.cached_ds = {}
|
||||
|
||||
# TODO(kuarora): rename it as get_local_file
|
||||
def get_local_filename(self, filename):
|
||||
if len(filename) > 184:
|
||||
fn, ext = os.path.splitext(filename)
|
||||
|
@ -61,6 +62,9 @@ class BenchmarkIO:
|
|||
)
|
||||
return os.path.join(self.path, filename)
|
||||
|
||||
def get_remote_filepath(self, filename) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def download_file_from_blobstore(
|
||||
self,
|
||||
filename: str,
|
||||
|
@ -219,7 +223,7 @@ class BenchmarkIO:
|
|||
fn = self.download_file_from_blobstore(filename, bucket, path)
|
||||
logger.info(f"Loading index {fn}")
|
||||
ext = os.path.splitext(fn)[1]
|
||||
if ext in [".faiss", ".codec"]:
|
||||
if ext in [".faiss", ".codec", ".index"]:
|
||||
index = faiss.read_index(fn)
|
||||
elif ext == ".pkl":
|
||||
with open(fn, "rb") as model_file:
|
||||
|
|
|
@ -3,18 +3,21 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
|
||||
from .benchmark_io import BenchmarkIO
|
||||
from .utils import timer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDescriptor:
|
||||
class IndexDescriptorClassic:
|
||||
bucket: Optional[str] = None
|
||||
# either path or factory should be set,
|
||||
# but not both at the same time.
|
||||
|
@ -45,7 +48,6 @@ class IndexDescriptor:
|
|||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetDescriptor:
|
||||
# namespace possible values:
|
||||
|
@ -81,7 +83,7 @@ class DatasetDescriptor:
|
|||
|
||||
def get_filename(
|
||||
self,
|
||||
prefix: str = None,
|
||||
prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
filename = ""
|
||||
if prefix is not None:
|
||||
|
@ -116,3 +118,208 @@ class DatasetDescriptor:
|
|||
else:
|
||||
t = io.read_json(meta_filename)["k_means_time"]
|
||||
return kmeans_vectors, t, None
|
||||
|
||||
@dataclass
|
||||
class IndexBaseDescriptor:
|
||||
d: int
|
||||
metric: str
|
||||
desc_name: Optional[str] = None
|
||||
flat_desc_name: Optional[str] = None
|
||||
bucket: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
num_threads: int = 1
|
||||
|
||||
def get_name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_path(self, benchmark_io: BenchmarkIO) -> Optional[str]:
|
||||
if self.path is not None:
|
||||
return self.path
|
||||
self.path = benchmark_io.get_remote_filepath(self.desc_name)
|
||||
return self.path
|
||||
|
||||
@staticmethod
|
||||
def param_dict_list_to_name(param_dict_list):
|
||||
if not param_dict_list:
|
||||
return ""
|
||||
l = 0
|
||||
n = ""
|
||||
for param_dict in param_dict_list:
|
||||
n += IndexBaseDescriptor.param_dict_to_name(param_dict, f"cp{l}")
|
||||
l += 1
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def param_dict_to_name(param_dict, prefix="sp"):
|
||||
if not param_dict:
|
||||
return ""
|
||||
n = prefix
|
||||
for name, val in param_dict.items():
|
||||
if name == "snap":
|
||||
continue
|
||||
if name == "lsq_gpu" and val == 0:
|
||||
continue
|
||||
if name == "use_beam_LUT" and val == 0:
|
||||
continue
|
||||
n += f"_{name}_{val}"
|
||||
if n == prefix:
|
||||
return ""
|
||||
n += "."
|
||||
return n
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodecDescriptor(IndexBaseDescriptor):
|
||||
# either path or factory should be set,
|
||||
# but not both at the same time.
|
||||
factory: Optional[str] = None
|
||||
construction_params: Optional[List[Dict[str, int]]] = None
|
||||
training_vectors: Optional[DatasetDescriptor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.get_name()
|
||||
|
||||
def is_trained(self):
|
||||
return self.factory is None and self.path is not None
|
||||
|
||||
def is_valid(self):
|
||||
return self.factory is not None or self.path is not None
|
||||
|
||||
def get_name(self) -> str:
|
||||
if self.desc_name is not None:
|
||||
return self.desc_name
|
||||
if self.factory is not None:
|
||||
self.desc_name = self.name_from_factory()
|
||||
return self.desc_name
|
||||
if self.path is not None:
|
||||
self.desc_name = self.name_from_path()
|
||||
return self.desc_name
|
||||
raise ValueError("name, factory or path must be set")
|
||||
|
||||
def flat_name(self) -> str:
|
||||
if self.flat_desc_name is not None:
|
||||
return self.flat_desc_name
|
||||
self.flat_desc_name = f"Flat.d_{self.d}.{self.metric.upper()}."
|
||||
return self.flat_desc_name
|
||||
|
||||
def path(self, benchmark_io) -> str:
|
||||
if self.path is not None:
|
||||
return self.path
|
||||
return benchmark_io.get_remote_filepath(self.get_name())
|
||||
|
||||
def name_from_factory(self) -> str:
|
||||
assert self.factory is not None
|
||||
name = f"{self.factory.replace(',', '_')}."
|
||||
assert self.d is not None
|
||||
assert self.metric is not None
|
||||
name += f"d_{self.d}.{self.metric.upper()}."
|
||||
if self.factory != "Flat":
|
||||
assert self.training_vectors is not None
|
||||
name += self.training_vectors.get_filename("xt")
|
||||
name += IndexBaseDescriptor.param_dict_list_to_name(self.construction_params)
|
||||
return name
|
||||
|
||||
def name_from_path(self):
|
||||
assert self.path is not None
|
||||
filename = os.path.basename(self.path)
|
||||
ext = filename.split(".")[-1]
|
||||
if filename.endswith(ext):
|
||||
name = filename[:-len(ext)]
|
||||
else: # should never hit this rather raise value error
|
||||
name = filename
|
||||
return name
|
||||
|
||||
def alias(self, benchmark_io : BenchmarkIO):
|
||||
if hasattr(benchmark_io, "bucket"):
|
||||
return CodecDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
|
||||
return CodecDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexDescriptor(IndexBaseDescriptor):
|
||||
codec_desc: Optional[CodecDescriptor] = None
|
||||
database_desc: Optional[DatasetDescriptor] = None
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
def __post_init__(self):
|
||||
self.get_name()
|
||||
|
||||
def is_built(self):
|
||||
return self.codec_desc is None and self.database_desc is None
|
||||
|
||||
def get_name(self) -> str:
|
||||
if self.desc_name is None:
|
||||
self.desc_name = self.codec_desc.get_name() + self.database_desc.get_filename(prefix="xb")
|
||||
|
||||
return self.desc_name
|
||||
|
||||
def flat_name(self):
|
||||
if self.flat_desc_name is not None:
|
||||
return self.flat_desc_name
|
||||
self.flat_desc_name = self.codec_desc.flat_name() + self.database_desc.get_filename(prefix="xb")
|
||||
return self.flat_desc_name
|
||||
|
||||
# alias is used to refer when index is uploaded to blobstore and refered again
|
||||
def alias(self, benchmark_io: BenchmarkIO):
|
||||
if hasattr(benchmark_io, "bucket"):
|
||||
return IndexDescriptor(desc_name=self.get_name(), bucket=benchmark_io.bucket, path=self.get_path(benchmark_io), d=self.d, metric=self.metric)
|
||||
return IndexDescriptor(desc_name=self.get_name(), d=self.d, metric=self.metric)
|
||||
|
||||
@dataclass
|
||||
class KnnDescriptor(IndexBaseDescriptor):
|
||||
index_desc: Optional[IndexDescriptor] = None
|
||||
gt_index_desc: Optional[IndexDescriptor] = None
|
||||
query_dataset: Optional[DatasetDescriptor] = None
|
||||
search_params: Optional[Dict[str, int]] = None
|
||||
reconstruct: bool = False
|
||||
# range metric definitions
|
||||
# key: name
|
||||
# value: one of the following:
|
||||
#
|
||||
# radius
|
||||
# [0..radius) -> 1
|
||||
# [radius..inf) -> 0
|
||||
#
|
||||
# [[radius1, score1], ...]
|
||||
# [0..radius1) -> score1
|
||||
# [radius1..radius2) -> score2
|
||||
#
|
||||
# [[radius1_from, radius1_to, score1], ...]
|
||||
# [radius1_from, radius1_to) -> score1,
|
||||
# [radius2_from, radius2_to) -> score2
|
||||
range_metrics: Optional[Dict[str, Any]] = None
|
||||
radius: Optional[float] = None
|
||||
k: int = 1
|
||||
|
||||
range_ref_index_desc: Optional[str] = None
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
||||
def get_name(self):
|
||||
name = self.index_desc.get_name()
|
||||
name += IndexBaseDescriptor.param_dict_to_name(self.search_params)
|
||||
name += self.query_dataset.get_filename("q")
|
||||
name += f"k_{self.k}."
|
||||
name += f"t_{self.num_threads}."
|
||||
if self.reconstruct:
|
||||
name += "rec."
|
||||
else:
|
||||
name += "knn."
|
||||
return name
|
||||
|
||||
def flat_name(self):
|
||||
if self.flat_desc_name is not None:
|
||||
return self.flat_desc_name
|
||||
name = self.index_desc.flat_name()
|
||||
name += self.query_dataset.get_filename("q")
|
||||
name += f"k_{self.k}."
|
||||
name += f"t_{self.num_threads}."
|
||||
if self.reconstruct:
|
||||
name += "rec."
|
||||
else:
|
||||
name += "knn."
|
||||
self.flat_desc_name = name
|
||||
return name
|
||||
|
|
|
@ -13,6 +13,7 @@ from typing import ClassVar, Dict, List, Optional
|
|||
|
||||
import faiss # @manual=//faiss/python:pyfaiss_gpu
|
||||
import numpy as np
|
||||
from faiss.benchs.bench_fw.descriptors import IndexBaseDescriptor
|
||||
|
||||
from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
|
||||
knn_intersection_measure,
|
||||
|
@ -49,35 +50,6 @@ class IndexBase:
|
|||
def set_io(self, benchmark_io):
|
||||
self.io = benchmark_io
|
||||
|
||||
@staticmethod
|
||||
def param_dict_list_to_name(param_dict_list):
|
||||
if not param_dict_list:
|
||||
return ""
|
||||
l = 0
|
||||
n = ""
|
||||
for param_dict in param_dict_list:
|
||||
n += IndexBase.param_dict_to_name(param_dict, f"cp{l}")
|
||||
l += 1
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def param_dict_to_name(param_dict, prefix="sp"):
|
||||
if not param_dict:
|
||||
return ""
|
||||
n = prefix
|
||||
for name, val in param_dict.items():
|
||||
if name == "snap":
|
||||
continue
|
||||
if name == "lsq_gpu" and val == 0:
|
||||
continue
|
||||
if name == "use_beam_LUT" and val == 0:
|
||||
continue
|
||||
n += f"_{name}_{val}"
|
||||
if n == prefix:
|
||||
return ""
|
||||
n += "."
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def set_index_param_dict_list(index, param_dict_list, assert_same=False):
|
||||
if not param_dict_list:
|
||||
|
@ -282,7 +254,7 @@ class IndexBase:
|
|||
reconstruct: bool = False,
|
||||
):
|
||||
name = self.get_index_name()
|
||||
name += Index.param_dict_to_name(search_parameters)
|
||||
name += IndexBaseDescriptor.param_dict_to_name(search_parameters)
|
||||
name += query_vectors.get_filename("q")
|
||||
name += f"k_{k}."
|
||||
name += f"t_{self.num_threads}."
|
||||
|
@ -582,14 +554,21 @@ class Index(IndexBase):
|
|||
num_threads: int
|
||||
d: int
|
||||
metric: str
|
||||
database_vectors: DatasetDescriptor
|
||||
construction_params: List[Dict[str, int]]
|
||||
search_params: Dict[str, int]
|
||||
codec_name: Optional[str] = None
|
||||
index_name: Optional[str] = None
|
||||
database_vectors: Optional[DatasetDescriptor] = None
|
||||
construction_params: Optional[List[Dict[str, int]]] = None
|
||||
search_params: Optional[Dict[str, int]] = None
|
||||
serialize_full_index: bool = False
|
||||
|
||||
bucket: Optional[str] = None
|
||||
index_path: Optional[str] = None
|
||||
|
||||
cached_codec: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict()
|
||||
cached_index: ClassVar[OrderedDict[str, faiss.Index]] = OrderedDict()
|
||||
|
||||
def __post_init__(self):
|
||||
logger.info(f"Initializing metric_type to {self.metric}")
|
||||
if isinstance(self.metric, str):
|
||||
if self.metric == "IP":
|
||||
self.metric_type = faiss.METRIC_INNER_PRODUCT
|
||||
|
@ -628,13 +607,31 @@ class Index(IndexBase):
|
|||
Index.cached_codec.popitem(last=False)
|
||||
return Index.cached_codec[codec_name]
|
||||
|
||||
def get_index_name(self):
|
||||
name = self.get_codec_name()
|
||||
assert self.database_vectors is not None
|
||||
name += self.database_vectors.get_filename("xb")
|
||||
return name
|
||||
def get_codec_name(self) -> Optional[str]:
|
||||
return self.codec_name
|
||||
|
||||
def get_index_name(self) -> Optional[str]:
|
||||
return self.index_name
|
||||
|
||||
def fetch_index(self):
|
||||
# read index from file if it is already available
|
||||
if self.index_path:
|
||||
index_filename = os.path.basename(self.index_path)
|
||||
else:
|
||||
index_filename = self.index_name + "index"
|
||||
if self.io.file_exist(index_filename):
|
||||
if self.index_path:
|
||||
index = self.io.read_index(
|
||||
index_filename,
|
||||
self.bucket,
|
||||
os.path.dirname(self.index_path),
|
||||
)
|
||||
else:
|
||||
index = self.io.read_index(index_filename)
|
||||
assert self.d == index.d
|
||||
assert self.metric_type == index.metric_type
|
||||
return index, 0
|
||||
|
||||
index = self.get_codec()
|
||||
index.reset()
|
||||
assert index.ntotal == 0
|
||||
|
@ -664,10 +661,15 @@ class Index(IndexBase):
|
|||
)
|
||||
assert index.ntotal == xb.shape[0] or index_ivf.ntotal == xb.shape[0]
|
||||
logger.info("Added vectors to index")
|
||||
if self.serialize_full_index:
|
||||
codec_size = self.io.write_index(index, index_filename)
|
||||
assert codec_size is not None
|
||||
|
||||
return index, t
|
||||
|
||||
def get_index(self):
|
||||
index_name = self.get_index_name()
|
||||
index_name = self.index_name
|
||||
# TODO(kuarora) : retrieve file from bucket and path.
|
||||
if index_name not in Index.cached_index:
|
||||
Index.cached_index[index_name], _ = self.fetch_index()
|
||||
if len(Index.cached_index) > 3:
|
||||
|
@ -784,8 +786,12 @@ class Index(IndexBase):
|
|||
# are used to wrap pre-trained Faiss indices (codecs)
|
||||
@dataclass
|
||||
class IndexFromCodec(Index):
|
||||
path: str
|
||||
bucket: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.path is None:
|
||||
raise ValueError("path is not set")
|
||||
|
||||
def get_quantizer(self):
|
||||
if not self.is_ivf():
|
||||
|
@ -804,12 +810,6 @@ class IndexFromCodec(Index):
|
|||
def get_model_name(self):
|
||||
return os.path.basename(self.path)
|
||||
|
||||
def get_codec_name(self):
|
||||
assert self.path is not None
|
||||
name = os.path.basename(self.path)
|
||||
name += Index.param_dict_list_to_name(self.construction_params)
|
||||
return name
|
||||
|
||||
def fetch_meta(self, dry_run=False):
|
||||
return None, None
|
||||
|
||||
|
@ -871,20 +871,15 @@ class IndexFromPreTransform(IndexBase):
|
|||
# IndexFromFactory is for creating and training indices from scratch
|
||||
@dataclass
|
||||
class IndexFromFactory(Index):
|
||||
factory: str
|
||||
training_vectors: DatasetDescriptor
|
||||
factory: Optional[str] = None
|
||||
training_vectors: Optional[DatasetDescriptor] = None
|
||||
|
||||
def get_codec_name(self):
|
||||
assert self.factory is not None
|
||||
name = f"{self.factory.replace(',', '_')}."
|
||||
assert self.d is not None
|
||||
assert self.metric is not None
|
||||
name += f"d_{self.d}.{self.metric.upper()}."
|
||||
if self.factory != "Flat":
|
||||
assert self.training_vectors is not None
|
||||
name += self.training_vectors.get_filename("xt")
|
||||
name += Index.param_dict_list_to_name(self.construction_params)
|
||||
return name
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.factory is None:
|
||||
raise ValueError("factory is not set")
|
||||
if self.factory != "Flat" and self.training_vectors is None:
|
||||
raise ValueError(f"training_vectors is not set for {self.factory}")
|
||||
|
||||
def fetch_meta(self, dry_run=False):
|
||||
meta_filename = self.get_codec_name() + "json"
|
||||
|
|
|
@ -14,7 +14,7 @@ import faiss # @manual=//faiss/python:pyfaiss_gpu
|
|||
# )
|
||||
|
||||
from .benchmark import Benchmark
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptor
|
||||
from .descriptors import DatasetDescriptor, IndexDescriptorClassic
|
||||
from .utils import dict_merge, filter_results, ParetoMetric, ParetoMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -78,7 +78,7 @@ class Optimizer:
|
|||
)
|
||||
assert filtered
|
||||
index_descs = [
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=v["factory"],
|
||||
construction_params=v["construction_params"],
|
||||
search_params=v["search_params"],
|
||||
|
@ -103,8 +103,8 @@ class Optimizer:
|
|||
dry_run=False,
|
||||
)
|
||||
|
||||
descs = [IndexDescriptor(factory="Flat"),] + [
|
||||
IndexDescriptor(
|
||||
descs = [IndexDescriptorClassic(factory="Flat"),] + [
|
||||
IndexDescriptorClassic(
|
||||
factory="HNSW32",
|
||||
construction_params=[{"efConstruction": 2**i}],
|
||||
)
|
||||
|
@ -131,7 +131,7 @@ class Optimizer:
|
|||
training_vectors: DatasetDescriptor,
|
||||
database_vectors: DatasetDescriptor,
|
||||
query_vectors: DatasetDescriptor,
|
||||
quantizers: Dict[int, List[IndexDescriptor]],
|
||||
quantizers: Dict[int, List[IndexDescriptorClassic]],
|
||||
codecs: List[Tuple[str, str]],
|
||||
min_accuracy: float,
|
||||
):
|
||||
|
@ -159,7 +159,7 @@ class Optimizer:
|
|||
quantizer_desc.search_params,
|
||||
)
|
||||
ivf_descs.append(
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"{pretransform}IVF{nlist}({quantizer_desc.factory}),{fine_ivf}",
|
||||
construction_params=construction_params,
|
||||
)
|
||||
|
@ -188,7 +188,7 @@ class Optimizer:
|
|||
):
|
||||
_, results = self.benchmark_and_filter_candidates(
|
||||
index_descs=[
|
||||
IndexDescriptor(factory=f"IVF{nlist}(Flat),Flat"),
|
||||
IndexDescriptorClassic(factory=f"IVF{nlist}(Flat),Flat"),
|
||||
],
|
||||
training_vectors=training_vectors,
|
||||
database_vectors=database_vectors,
|
||||
|
@ -255,7 +255,7 @@ class Optimizer:
|
|||
|
||||
_, filtered = self.benchmark_and_filter_candidates(
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"IVF{nlist},{pq}"
|
||||
if opq is None
|
||||
else f"{opq},IVF{nlist},{pq}",
|
||||
|
|
|
@ -7,10 +7,10 @@ import logging
|
|||
import argparse
|
||||
import os
|
||||
|
||||
from bench_fw.benchmark import Benchmark
|
||||
from bench_fw.benchmark_io import BenchmarkIO
|
||||
from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor
|
||||
from bench_fw.index import IndexFromFactory
|
||||
from faiss.benchs.bench_fw.benchmark import Benchmark
|
||||
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO
|
||||
from faiss.benchs.bench_fw.descriptors import DatasetDescriptor, IndexDescriptorClassic
|
||||
from faiss.benchs.bench_fw.index import IndexFromFactory
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
@ -107,7 +107,7 @@ def run_local(rp):
|
|||
database_vectors=database_vectors,
|
||||
query_vectors=query_vectors,
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=factory,
|
||||
construction_params=construction_params,
|
||||
training_size=training_size,
|
||||
|
|
|
@ -11,7 +11,7 @@ from faiss.benchs.bench_fw.benchmark import Benchmark
|
|||
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO
|
||||
from faiss.benchs.bench_fw.descriptors import (
|
||||
DatasetDescriptor,
|
||||
IndexDescriptor,
|
||||
IndexDescriptorClassic,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
@ -30,7 +30,7 @@ def sift1M(bio):
|
|||
namespace="std_q", tablename="sift1M"
|
||||
),
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"IVF{2 ** nlist},Flat",
|
||||
)
|
||||
for nlist in range(8, 15)
|
||||
|
@ -38,8 +38,8 @@ def sift1M(bio):
|
|||
k=1,
|
||||
distance_metric="L2",
|
||||
)
|
||||
benchmark.set_io(bio)
|
||||
benchmark.benchmark(result_file="result.json", local=False, train=True, reconstruct=False, knn=True, range=False)
|
||||
benchmark.io = bio
|
||||
benchmark.benchmark(result_file="result.json", local=True, train=True, reconstruct=False, knn=True, range=False)
|
||||
|
||||
|
||||
def bigann(bio):
|
||||
|
@ -56,11 +56,11 @@ def bigann(bio):
|
|||
namespace="std_q", tablename="bigann1M"
|
||||
),
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"IVF{2 ** nlist},Flat",
|
||||
) for nlist in range(11, 19)
|
||||
] + [
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"IVF{2 ** nlist}_HNSW32,Flat",
|
||||
construction_params=[None, {"efConstruction": 200, "efSearch": 40}],
|
||||
) for nlist in range(11, 19)
|
||||
|
@ -84,18 +84,18 @@ def ssnpp(bio):
|
|||
tablename="ssnpp_queries_10K.npy"
|
||||
),
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"IVF{2 ** nlist},PQ256x4fs,Refine(SQfp16)",
|
||||
) for nlist in range(9, 16)
|
||||
] + [
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"IVF{2 ** nlist},Flat",
|
||||
) for nlist in range(9, 16)
|
||||
] + [
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"PQ256x4fs,Refine(SQfp16)",
|
||||
),
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory=f"HNSW32",
|
||||
),
|
||||
],
|
||||
|
|
|
@ -7,9 +7,9 @@ import argparse
|
|||
import logging
|
||||
import os
|
||||
|
||||
from bench_fw.benchmark_io import BenchmarkIO
|
||||
from bench_fw.descriptors import DatasetDescriptor
|
||||
from bench_fw.optimize import Optimizer
|
||||
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO
|
||||
from faiss.benchs.bench_fw.descriptors import DatasetDescriptor
|
||||
from faiss.benchs.bench_fw.optimize import Optimizer
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
|
|
@ -3,28 +3,29 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from bench_fw.benchmark import Benchmark
|
||||
from bench_fw.benchmark_io import BenchmarkIO
|
||||
from bench_fw.descriptors import DatasetDescriptor, IndexDescriptor
|
||||
from faiss.benchs.bench_fw.benchmark import Benchmark
|
||||
from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO
|
||||
from faiss.benchs.bench_fw.descriptors import DatasetDescriptor, IndexDescriptorClassic
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def ssnpp(bio):
|
||||
benchmark = Benchmark(
|
||||
num_threads=32,
|
||||
training_vectors=DatasetDescriptor(
|
||||
tablename="ssnpp_training_5M.npy",
|
||||
tablename="training.npy",
|
||||
),
|
||||
database_vectors=DatasetDescriptor(
|
||||
tablename="ssnpp_xb_range_filtered_119201.npy",
|
||||
tablename="database.npy",
|
||||
),
|
||||
query_vectors=DatasetDescriptor(tablename="ssnpp_xq_range_filtered_33615.npy"),
|
||||
query_vectors=DatasetDescriptor(tablename="query.npy"),
|
||||
index_descs=[
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory="Flat",
|
||||
range_metrics={
|
||||
"weighted": [
|
||||
|
@ -56,7 +57,7 @@ def ssnpp(bio):
|
|||
]
|
||||
},
|
||||
),
|
||||
IndexDescriptor(
|
||||
IndexDescriptorClassic(
|
||||
factory="IVF262144(PQ256x4fs),PQ32",
|
||||
),
|
||||
],
|
||||
|
@ -67,6 +68,7 @@ def ssnpp(bio):
|
|||
benchmark.set_io(bio)
|
||||
benchmark.benchmark("result.json", local=False, train=True, reconstruct=False, knn=False, range=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('experiment')
|
||||
|
|
Loading…
Reference in New Issue