faiss/demos/offline_ivf/run.py

220 lines
5.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from utils import (
load_config,
add_group_args,
)
from offline_ivf import OfflineIVF
import faiss
from typing import List, Callable, Dict
import submitit
def join_lists_in_dict(poss: List[str]) -> List[str]:
"""
Joins two lists of prod and non-prod values, checking if the prod value is already included.
If there is no non-prod list, it returns the prod list.
"""
if "non-prod" in poss.keys():
all_poss = poss["non-prod"]
if poss["prod"][-1] not in poss["non-prod"]:
all_poss += poss["prod"]
return all_poss
else:
return poss["prod"]
def main(
args: argparse.Namespace,
cfg: Dict[str, str],
nprobe: int,
index_factory_str: str,
) -> None:
oivf = OfflineIVF(cfg, args, nprobe, index_factory_str)
eval(f"oivf.{args.command}()")
def process_options_and_run_jobs(args: argparse.Namespace) -> None:
"""
If "--cluster_run", it launches an array of jobs to the cluster using the submitit library for all the index strings. In
the case of evaluate, it launches a job for each index string and nprobe pair. Otherwise, it launches a single job
that is ran locally with the prod values for index string and nprobe.
"""
cfg = load_config(args.config)
index_strings = cfg["index"]
nprobes = cfg["nprobe"]
if args.command == "evaluate":
if args.cluster_run:
all_nprobes = join_lists_in_dict(nprobes)
all_index_strings = join_lists_in_dict(index_strings)
for index_factory_str in all_index_strings:
for nprobe in all_nprobes:
launch_job(main, args, cfg, nprobe, index_factory_str)
else:
launch_job(
main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1]
)
else:
if args.cluster_run:
all_index_strings = join_lists_in_dict(index_strings)
for index_factory_str in all_index_strings:
launch_job(
main, args, cfg, nprobes["prod"][-1], index_factory_str
)
else:
launch_job(
main, args, cfg, nprobes["prod"][-1], index_strings["prod"][-1]
)
def launch_job(
func: Callable,
args: argparse.Namespace,
cfg: Dict[str, str],
n_probe: int,
index_str: str,
) -> None:
"""
Launches an array of slurm jobs to the cluster using the submitit library.
"""
if args.cluster_run:
assert args.num_nodes >= 1
executor = submitit.AutoExecutor(folder=args.logs_dir)
executor.update_parameters(
nodes=args.num_nodes,
gpus_per_node=args.gpus_per_node,
cpus_per_task=args.cpus_per_task,
tasks_per_node=args.tasks_per_node,
name=args.job_name,
slurm_partition=args.partition,
slurm_time=70 * 60,
)
if args.slurm_constraint:
executor.update_parameters(slurm_constraint=args.slurm_constrain)
job = executor.submit(func, args, cfg, n_probe, index_str)
print(f"Job id: {job.job_id}")
else:
func(args, cfg, n_probe, index_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
group = parser.add_argument_group("general")
add_group_args(group, "--command", required=True, help="command to run")
add_group_args(
group,
"--config",
required=True,
help="config yaml with the dataset specs",
)
add_group_args(
group, "--nt", type=int, default=96, help="nb search threads"
)
add_group_args(
group,
"--no_residuals",
action="store_false",
help="set index.by_residual to False during train index.",
)
group = parser.add_argument_group("slurm_job")
add_group_args(
group,
"--cluster_run",
action="store_true",
help=" if True, runs in cluster",
)
add_group_args(
group,
"--job_name",
type=str,
default="oivf",
help="cluster job name",
)
add_group_args(
group,
"--num_nodes",
type=str,
default=1,
help="num of nodes per job",
)
add_group_args(
group,
"--tasks_per_node",
type=int,
default=1,
help="tasks per job",
)
add_group_args(
group,
"--gpus_per_node",
type=int,
default=8,
help="cluster job name",
)
add_group_args(
group,
"--cpus_per_task",
type=int,
default=80,
help="cluster job name",
)
add_group_args(
group,
"--logs_dir",
type=str,
default="/checkpoint/marialomeli/offline_faiss/logs",
help="cluster job name",
)
add_group_args(
group,
"--slurm_constraint",
type=str,
default=None,
help="can be volta32gb for the fair cluster",
)
add_group_args(
group,
"--partition",
type=str,
default="learnlab",
help="specify which partition to use if ran on cluster with job arrays",
choices=[
"learnfair",
"devlab",
"scavenge",
"learnlab",
"nllb",
"seamless",
"seamless_medium",
"learnaccel",
"onellm_low",
"learn",
"scavenge",
],
)
group = parser.add_argument_group("dataset")
add_group_args(group, "--xb", required=True, help="database vectors")
add_group_args(group, "--xq", help="query vectors")
args = parser.parse_args()
print("args:", args)
faiss.omp_set_num_threads(args.nt)
process_options_and_run_jobs(args=args)