faiss/benchs/bench_fw_notebook.ipynb
Gergely Szilvasy 1d0e8d489f index optimizer (#3154)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3154

Using the benchmark to find Pareto optimal indices, in this case on BigANN as an example.

Separately optimize the coarse quantizer and the vector codec and use Pareto optimal configurations to construct IVF indices, which are then retested at various scales. See `optimize()` in `optimize.py` as the main function driving the process.

The results can be interpreted with `bench_fw_notebook.ipynb`, which allows:
* filtering by maximum code size
* maximum time
* minimum accuracy
* space or time Pareto optimal options
* and visualize the results and output them as a table.

This version is intentionally limited to IVF(Flat|HNSW),PQ|SQ indices...

Reviewed By: mdouze

Differential Revision: D51781670

fbshipit-source-id: 2c0f800d374ea845255934f519cc28095c00a51f
2024-01-30 10:58:13 -08:00

530 lines
22 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "be081589-e1b2-4569-acb7-44203e273899",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import itertools\n",
"from faiss.contrib.evaluation import OperatingPoints\n",
"from enum import Enum\n",
"from bench_fw.benchmark_io import BenchmarkIO as BIO\n",
"from bench_fw.utils import filter_results, ParetoMode, ParetoMetric\n",
"from copy import copy\n",
"import numpy as np\n",
"import datetime\n",
"import glob\n",
"import io\n",
"import json\n",
"from zipfile import ZipFile\n",
"import tabulate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6492e95-24c7-4425-bf0a-27e10e879ca6",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"root = \"/checkpoint/gsz/bench_fw/optimize/bigann\"\n",
"results = BIO(root).read_json(\"result_std_d_bigann10M.json\")\n",
"results.keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0875d269-aef4-426d-83dd-866970f43777",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"results['experiments']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f080a6e2-1565-418b-8732-4adeff03a099",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):\n",
" if plot is None:\n",
" plot = plt.subplot()\n",
" x = {}\n",
" y = {}\n",
" for accuracy, space, time, k, v in experiments:\n",
" idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n",
" if idx_name not in x:\n",
" x[idx_name] = []\n",
" y[idx_name] = []\n",
" x[idx_name].append(accuracy)\n",
" if plot_space:\n",
" y[idx_name].append(space)\n",
" else:\n",
" y[idx_name].append(time)\n",
"\n",
" #plt.figure(figsize=(10,6))\n",
" #plt.title(accuracy_title)\n",
" plot.set_xlabel(accuracy_title)\n",
" plot.set_ylabel(cost_title)\n",
" plot.set_yscale(\"log\")\n",
" marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n",
" for index in x.keys():\n",
" plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)\n",
" plot.legend(bbox_to_anchor=(1, 1), loc='upper left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61007155-5edc-449e-835e-c141a01a2ae5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# index local optima\n",
"accuracy_metric = \"knn_intersection\"\n",
"fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)\n",
"plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9f94dcc-5abe-4cad-9619-f5d1d24fb8c1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# global optima\n",
"accuracy_metric = \"knn_intersection\"\n",
"fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith(\"Flat\"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
"plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c10f587-26ef-49ec-83a9-88f6a2a433e8",
"metadata": {},
"outputs": [],
"source": [
"def pretty_params(p):\n",
" p = copy(p)\n",
" if 'snap' in p and p['snap'] == 0:\n",
" del p['snap']\n",
" return p\n",
" \n",
"tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) \n",
" for accuracy, space, time, k, v in fr],\n",
" tablefmt=\"html\",\n",
" headers=[\"accuracy\",\"space\", \"time\", \"factory\", \"quantizer cfg\", \"search cfg\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36e82084-18f6-4546-a717-163eb0224ee8",
"metadata": {},
"outputs": [],
"source": [
"# index local optima @ precision 0.8\n",
"precision = 0.8\n",
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aff79376-39f7-47c0-8b83-1efe5192bb7e",
"metadata": {},
"outputs": [],
"source": [
"# index local optima @ precision 0.2\n",
"precision = 0.2\n",
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1",
"metadata": {},
"outputs": [],
"source": [
"# global optima @ precision 0.8\n",
"precision = 0.8\n",
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9aead830-6209-4956-b7ea-4a5e0029d616",
"metadata": {},
"outputs": [],
"source": [
"def plot_range_search_pr_curves(experiments):\n",
" x = {}\n",
" y = {}\n",
" show = {\n",
" 'Flat': None,\n",
" }\n",
" for _, _, _, k, v in fr:\n",
" if \".weighted\" in k: # and v['index'] in show:\n",
" x[k] = v['range_search_pr']['recall']\n",
" y[k] = v['range_search_pr']['precision']\n",
" \n",
" plt.title(\"range search recall\")\n",
" plt.xlabel(\"recall\")\n",
" plt.ylabel(\"precision\")\n",
" for index in x.keys():\n",
" plt.plot(x[index], y[index], '.', label=index)\n",
" plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92e45502-7a31-4a15-90df-fa3032d7d350",
"metadata": {},
"outputs": [],
"source": [
"precision = 0.8\n",
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n",
"plot_range_search_pr_curves(fr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n",
"scales = [1, 2, 5, 10, 20, 50]\n",
"fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n",
"fig.tight_layout()\n",
"for plot, scale in zip(plots, scales, strict=True):\n",
" results = BIO(root).read_json(f\"result{scale}.json\")\n",
" accuracy_metric = \"knn_intersection\"\n",
" fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
" plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e503828c-ee61-45f7-814b-cce6461109bc",
"metadata": {},
"outputs": [],
"source": [
"x = {}\n",
"y = {}\n",
"accuracy=0.9\n",
"root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n",
"scales = [1, 2, 5, 10, 20, 50]\n",
"#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n",
"#fig.tight_layout()\n",
"for scale in scales:\n",
" results = BIO(root).read_json(f\"result{scale}.json\")\n",
" scale *= 1_000_000\n",
" accuracy_metric = \"knn_intersection\"\n",
" fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
" seen = set()\n",
" print(scale)\n",
" for _, _, _, _, exp in fr:\n",
" fact = exp[\"factory\"]\n",
" # \"HNSW\" in fact or \n",
" if fact in seen or fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n",
" continue\n",
" seen.add(fact)\n",
" if fact not in x:\n",
" x[fact] = []\n",
" y[fact] = []\n",
" x[fact].append(scale)\n",
" y[fact].append(exp[\"time\"] + exp[\"quantizer\"][\"time\"])\n",
" if (exp[\"knn_intersection\"] > 0.92):\n",
" print(fact)\n",
" print(exp[\"search_params\"])\n",
" print(exp[\"knn_intersection\"])\n",
"\n",
" #plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)\n",
" \n",
"plt.title(f\"recall @ 1 = {accuracy*100}%\")\n",
"plt.xlabel(\"database size\")\n",
"plt.ylabel(\"time\")\n",
"plt.xscale(\"log\")\n",
"plt.yscale(\"log\")\n",
"\n",
"marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n",
"for index in x.keys():\n",
" if \"HNSW\" in index:\n",
" plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle=\"dashed\")\n",
" else:\n",
" plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))\n",
"plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37a99bb2-f998-461b-a345-7cc6e702cb3a",
"metadata": {},
"outputs": [],
"source": [
"# global optima\n",
"accuracy_metric = \"sym_recall\"\n",
"fr = filter_results(results, evaluation=\"rec\", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)\n",
"plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"space\", plot_space=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c973ce4e-3566-4f02-bd93-f113e3e0c791",
"metadata": {},
"outputs": [],
"source": [
"def pretty_time(s):\n",
" if s is None:\n",
" return \"None\"\n",
" s = int(s * 1000) / 1000\n",
" m, s = divmod(s, 60)\n",
" h, m = divmod(m, 60)\n",
" d, h = divmod(h, 24)\n",
" r = \"\"\n",
" if d > 0:\n",
" r += f\"{int(d)}d \"\n",
" if h > 0:\n",
" r += f\"{int(h)}h \"\n",
" if m > 0:\n",
" r += f\"{int(m)}m \"\n",
" if s > 0 or len(r) == 0:\n",
" r += f\"{s:.3f}s\"\n",
" return r\n",
"\n",
"def pretty_size(s):\n",
" if s > 1024 * 1024:\n",
" return f\"{s / 1024 / 1024:.1f}\".rstrip('0').rstrip('.') + \"MB\"\n",
" if s > 1024:\n",
" return f\"{s / 1024:.1f}\".rstrip('0').rstrip('.') + \"KB\"\n",
" return f\"{s}\"\n",
"\n",
"def pretty_mse(m):\n",
" if m is None:\n",
" return \"None\"\n",
" else:\n",
" return f\"{m:.6f}\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ddcf226-fb97-4a59-9fc3-3ed8f7d5e703",
"metadata": {},
"outputs": [],
"source": [
"data = {}\n",
"root = \"/checkpoint/gsz/bench_fw/bigann\"\n",
"scales = [1, 2, 5, 10, 20, 50]\n",
"for scale in scales:\n",
" results = BIO(root).read_json(f\"result{scale}.json\")\n",
" accuracy_metric = \"knn_intersection\"\n",
" fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
" d = {}\n",
" data[f\"{scale}M\"] = d\n",
" for _, _, _, _, exp in fr:\n",
" fact = exp[\"factory\"]\n",
" # \"HNSW\" in fact or \n",
" if fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n",
" continue\n",
" if fact not in d:\n",
" d[fact] = []\n",
" d[fact].append({\n",
" \"nprobe\": exp[\"search_params\"][\"nprobe\"],\n",
" \"recall\": exp[\"knn_intersection\"],\n",
" \"time\": exp[\"time\"] + exp[\"quantizer\"][\"time\"],\n",
" })\n",
"data\n",
"# with open(\"/checkpoint/gsz/bench_fw/codecs.json\", \"w\") as f:\n",
"# json.dump(data, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e54eebb6-0a9f-4a72-84d2-f12c5bd44510",
"metadata": {},
"outputs": [],
"source": [
"ds = \"deep1b\"\n",
"data = []\n",
"jss = []\n",
"root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n",
"results = BIO(root).read_json(f\"result.json\")\n",
"for k, e in results[\"experiments\"].items():\n",
" if \"rec\" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n",
" code_size = results['indices'][e['codec']]['sa_code_size']\n",
" codec_size = results['indices'][e['codec']]['codec_size']\n",
" training_time = results['indices'][e['codec']]['training_time']\n",
" # training_size = results['indices'][e['codec']]['training_size']\n",
" cpu = e['cpu'] if 'cpu' in e else \"\"\n",
" ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n",
" eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n",
" data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n",
" jss.append({\n",
" 'factory': e['factory'],\n",
" 'parameters': e['construction_params'][0] if e['construction_params'] else \"\",\n",
" 'evaluation_params': e['reconstruct_params'],\n",
" 'code_size': code_size,\n",
" 'codec_size': codec_size,\n",
" 'training_time': training_time,\n",
" 'training_size': training_size,\n",
" 'mse': e['mse'],\n",
" 'sym_recall': e['sym_recall'],\n",
" 'asym_recall': e['asym_recall'],\n",
" 'encode_time': e['encode_time'],\n",
" 'decode_time': e['decode_time'],\n",
" 'cpu': cpu,\n",
" })\n",
"\n",
"print(\"|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n",
"print(\"|-|-|-|-|-|-|-|-|-|\")\n",
"data.sort()\n",
"for d in data:\n",
" print(d[1])\n",
"\n",
"with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json\", \"w\") as f:\n",
" json.dump(jss, f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1216733-9670-407c-b3d2-5f87bce0321c",
"metadata": {},
"outputs": [],
"source": [
"def read_file(filename: str, keys):\n",
" results = []\n",
" with ZipFile(filename, \"r\") as zip_file:\n",
" for key in keys:\n",
" with zip_file.open(key, \"r\") as f:\n",
" if key in [\"D\", \"I\", \"R\", \"lims\"]:\n",
" results.append(np.load(f))\n",
" elif key in [\"P\"]:\n",
" t = io.TextIOWrapper(f)\n",
" results.append(json.load(t))\n",
" else:\n",
" raise AssertionError()\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56de051e-22db-4bef-b242-1ddabc9e0bb9",
"metadata": {},
"outputs": [],
"source": [
"ds = \"contriever\"\n",
"data = []\n",
"jss = []\n",
"root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n",
"for lf in glob.glob(root + '/*rec*.zip'):\n",
" e, = read_file(lf, ['P'])\n",
" if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n",
" code_size = e['codec_meta']['sa_code_size']\n",
" codec_size = e['codec_meta']['codec_size']\n",
" training_time = e['codec_meta']['training_time']\n",
" training_size = None # e['codec_meta']['training_size']\n",
" cpu = e['cpu'] if 'cpu' in e else \"\"\n",
" ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n",
" eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n",
" if eps in ps and eps != \"encode_ils_iters=16\" and eps != \"max_beam_size=32\":\n",
" eps = \" \"\n",
" data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n",
" eps = e['reconstruct_params']\n",
" del eps['snap']\n",
" params = copy(e['construction_params'][0]) if e['construction_params'] else {}\n",
" for k, v in e['reconstruct_params'].items():\n",
" params[k] = v\n",
" jss.append({\n",
" 'factory': e['factory'],\n",
" 'params': params,\n",
" 'construction_params': e['construction_params'][0] if e['construction_params'] else {},\n",
" 'evaluation_params': e['reconstruct_params'],\n",
" 'code_size': code_size,\n",
" 'codec_size': codec_size,\n",
" 'training_time': training_time,\n",
" # 'training_size': training_size,\n",
" 'mse': e['mse'],\n",
" 'sym_recall': e['sym_recall'],\n",
" 'asym_recall': e['asym_recall'],\n",
" 'encode_time': e['encode_time'],\n",
" 'decode_time': e['decode_time'],\n",
" 'cpu': cpu,\n",
" })\n",
"\n",
"print(\"|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n",
"print(\"|-|-|-|-|-|-|-|-|-|\")\n",
"data.sort()\n",
"# for d in data:\n",
"# print(d[1])\n",
"\n",
"print(len(data))\n",
"\n",
"with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json\", \"w\") as f:\n",
" json.dump(jss, f)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-faiss_from_source] *",
"language": "python",
"name": "conda-env-.conda-faiss_from_source-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}