# Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import yaml import numpy as np from typing import Dict, List, Optional OIVF_TEST_ARGS: List[str] = [ "--config", "--xb", "--xq", "--command", "--cluster_run", "--no_residuals", ] def get_test_parser(args) -> argparse.ArgumentParser: parser = argparse.ArgumentParser() for arg in args: parser.add_argument(arg) return parser class TestDataCreator: def __init__( self, tempdir: str, dimension: int, data_type: np.dtype, index_factory: Optional[List] = ["OPQ4,IVF256,PQ4"], training_sample: Optional[int] = 9984, index_shard_size: Optional[int] = 1000, query_batch_size: Optional[int] = 1000, evaluation_sample: Optional[int] = 100, num_files: Optional[int] = None, file_size: Optional[int] = None, file_sizes: Optional[List] = None, nprobe: Optional[int] = 64, k: Optional[int] = 10, metric: Optional[str] = "METRIC_L2", normalise: Optional[bool] = False, with_queries_ds: Optional[bool] = False, evaluate_by_margin: Optional[bool] = False, ) -> None: self.tempdir = tempdir self.dimension = dimension self.data_type = np.dtype(data_type).name self.index_factory = {"prod": index_factory} if file_size and num_files: self.file_sizes = [file_size for _ in range(num_files)] elif file_sizes: self.file_sizes = file_sizes else: raise ValueError("no file sizes provided") self.num_files = len(self.file_sizes) self.training_sample = training_sample self.index_shard_size = index_shard_size self.query_batch_size = query_batch_size self.evaluation_sample = evaluation_sample self.nprobe = {"prod": [nprobe]} self.k = k self.metric = metric self.normalise = normalise self.config_file = self.tempdir + "/config_test.yaml" self.ds_name = "my_test_data" self.qs_name = "my_queries_data" self.evaluate_by_margin = evaluate_by_margin self.with_queries_ds = with_queries_ds def create_test_data(self) -> None: datafiles = self._create_data_files() files_info = [] for i, file in enumerate(datafiles): files_info.append( { "dtype": self.data_type, "format": "npy", "name": file, "size": self.file_sizes[i], } ) config_for_yaml = { "d": self.dimension, "output": self.tempdir, "index": self.index_factory, "nprobe": self.nprobe, "k": self.k, "normalise": self.normalise, "metric": self.metric, "training_sample": self.training_sample, "evaluation_sample": self.evaluation_sample, "index_shard_size": self.index_shard_size, "query_batch_size": self.query_batch_size, "datasets": { self.ds_name: { "root": self.tempdir, "size": sum(self.file_sizes), "files": files_info, } }, } if self.evaluate_by_margin: config_for_yaml["evaluate_by_margin"] = self.evaluate_by_margin q_datafiles = self._create_data_files("my_q_data") q_files_info = [] for i, file in enumerate(q_datafiles): q_files_info.append( { "dtype": self.data_type, "format": "npy", "name": file, "size": self.file_sizes[i], } ) if self.with_queries_ds: config_for_yaml["datasets"][self.qs_name] = { "root": self.tempdir, "size": sum(self.file_sizes), "files": q_files_info, } self._create_config_yaml(config_for_yaml) def setup_cli(self, command="consistency_check") -> argparse.Namespace: parser = get_test_parser(OIVF_TEST_ARGS) if self.with_queries_ds: return parser.parse_args( [ "--xb", self.ds_name, "--config", self.config_file, "--command", command, "--xq", self.qs_name, ] ) return parser.parse_args( [ "--xb", self.ds_name, "--config", self.config_file, "--command", command, ] ) def _create_data_files(self, name_of_file="my_data") -> List[str]: """ Creates a dataset "my_test_data" with number of files (num_files), using padding in the files name. If self.with_queries is True, it adds an extra dataset "my_queries_data" with the same number of files as the "my_test_data". The default name for embeddings files is "my_data" + .npy. """ filenames = [] for i, file_size in enumerate(self.file_sizes): # np.random.seed(i) db_vectors = np.random.random((file_size, self.dimension)).astype( self.data_type ) filename = name_of_file + f"{i:02}" + ".npy" filenames.append(filename) np.save(self.tempdir + "/" + filename, db_vectors) return filenames def _create_config_yaml(self, dict_file: Dict[str, str]) -> None: """ Creates a yaml file in dir (can be a temporary dir for tests). """ filename = self.tempdir + "/config_test.yaml" with open(filename, "w") as file: yaml.dump(dict_file, file, default_flow_style=False)