{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "be081589-e1b2-4569-acb7-44203e273899", "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import itertools\n", "from faiss.contrib.evaluation import OperatingPoints\n", "from enum import Enum\n", "from bench_fw.benchmark_io import BenchmarkIO as BIO\n", "from bench_fw.utils import filter_results, ParetoMode, ParetoMetric\n", "from copy import copy\n", "import numpy as np\n", "import datetime\n", "import glob\n", "import io\n", "import json\n", "from zipfile import ZipFile\n", "import tabulate" ] }, { "cell_type": "code", "execution_count": null, "id": "a6492e95-24c7-4425-bf0a-27e10e879ca6", "metadata": { "tags": [] }, "outputs": [], "source": [ "root = \"/checkpoint/gsz/bench_fw/optimize/bigann\"\n", "results = BIO(root).read_json(\"result_std_d_bigann10M.json\")\n", "results.keys()" ] }, { "cell_type": "code", "execution_count": null, "id": "0875d269-aef4-426d-83dd-866970f43777", "metadata": { "tags": [] }, "outputs": [], "source": [ "results['experiments']" ] }, { "cell_type": "code", "execution_count": null, "id": "f080a6e2-1565-418b-8732-4adeff03a099", "metadata": { "tags": [] }, "outputs": [], "source": [ "def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):\n", " if plot is None:\n", " plot = plt.subplot()\n", " x = {}\n", " y = {}\n", " for accuracy, space, time, k, v in experiments:\n", " idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n", " if idx_name not in x:\n", " x[idx_name] = []\n", " y[idx_name] = []\n", " x[idx_name].append(accuracy)\n", " if plot_space:\n", " y[idx_name].append(space)\n", " else:\n", " y[idx_name].append(time)\n", "\n", " #plt.figure(figsize=(10,6))\n", " #plt.title(accuracy_title)\n", " plot.set_xlabel(accuracy_title)\n", " plot.set_ylabel(cost_title)\n", " plot.set_yscale(\"log\")\n", " marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", " for index in x.keys():\n", " plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)\n", " plot.legend(bbox_to_anchor=(1, 1), loc='upper left')" ] }, { "cell_type": "code", "execution_count": null, "id": "61007155-5edc-449e-835e-c141a01a2ae5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# index local optima\n", "accuracy_metric = \"knn_intersection\"\n", "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)\n", "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "f9f94dcc-5abe-4cad-9619-f5d1d24fb8c1", "metadata": { "tags": [] }, "outputs": [], "source": [ "# global optima\n", "accuracy_metric = \"knn_intersection\"\n", "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith(\"Flat\"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "0c10f587-26ef-49ec-83a9-88f6a2a433e8", "metadata": {}, "outputs": [], "source": [ "def pretty_params(p):\n", " p = copy(p)\n", " if 'snap' in p and p['snap'] == 0:\n", " del p['snap']\n", " return p\n", " \n", "tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) \n", " for accuracy, space, time, k, v in fr],\n", " tablefmt=\"html\",\n", " headers=[\"accuracy\",\"space\", \"time\", \"factory\", \"quantizer cfg\", \"search cfg\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "36e82084-18f6-4546-a717-163eb0224ee8", "metadata": {}, "outputs": [], "source": [ "# index local optima @ precision 0.8\n", "precision = 0.8\n", "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" ] }, { "cell_type": "code", "execution_count": null, "id": "aff79376-39f7-47c0-8b83-1efe5192bb7e", "metadata": {}, "outputs": [], "source": [ "# index local optima @ precision 0.2\n", "precision = 0.2\n", "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" ] }, { "cell_type": "code", "execution_count": null, "id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1", "metadata": {}, "outputs": [], "source": [ "# global optima @ precision 0.8\n", "precision = 0.8\n", "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9aead830-6209-4956-b7ea-4a5e0029d616", "metadata": {}, "outputs": [], "source": [ "def plot_range_search_pr_curves(experiments):\n", " x = {}\n", " y = {}\n", " show = {\n", " 'Flat': None,\n", " }\n", " for _, _, _, k, v in fr:\n", " if \".weighted\" in k: # and v['index'] in show:\n", " x[k] = v['range_search_pr']['recall']\n", " y[k] = v['range_search_pr']['precision']\n", " \n", " plt.title(\"range search recall\")\n", " plt.xlabel(\"recall\")\n", " plt.ylabel(\"precision\")\n", " for index in x.keys():\n", " plt.plot(x[index], y[index], '.', label=index)\n", " plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" ] }, { "cell_type": "code", "execution_count": null, "id": "92e45502-7a31-4a15-90df-fa3032d7d350", "metadata": {}, "outputs": [], "source": [ "precision = 0.8\n", "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n", "plot_range_search_pr_curves(fr)" ] }, { "cell_type": "code", "execution_count": null, "id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c", "metadata": { "tags": [] }, "outputs": [], "source": [ "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", "scales = [1, 2, 5, 10, 20, 50]\n", "fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", "fig.tight_layout()\n", "for plot, scale in zip(plots, scales, strict=True):\n", " results = BIO(root).read_json(f\"result{scale}.json\")\n", " accuracy_metric = \"knn_intersection\"\n", " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", " plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)" ] }, { "cell_type": "code", "execution_count": null, "id": "e503828c-ee61-45f7-814b-cce6461109bc", "metadata": {}, "outputs": [], "source": [ "x = {}\n", "y = {}\n", "accuracy=0.9\n", "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", "scales = [1, 2, 5, 10, 20, 50]\n", "#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", "#fig.tight_layout()\n", "for scale in scales:\n", " results = BIO(root).read_json(f\"result{scale}.json\")\n", " scale *= 1_000_000\n", " accuracy_metric = \"knn_intersection\"\n", " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", " seen = set()\n", " print(scale)\n", " for _, _, _, _, exp in fr:\n", " fact = exp[\"factory\"]\n", " # \"HNSW\" in fact or \n", " if fact in seen or fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", " continue\n", " seen.add(fact)\n", " if fact not in x:\n", " x[fact] = []\n", " y[fact] = []\n", " x[fact].append(scale)\n", " y[fact].append(exp[\"time\"] + exp[\"quantizer\"][\"time\"])\n", " if (exp[\"knn_intersection\"] > 0.92):\n", " print(fact)\n", " print(exp[\"search_params\"])\n", " print(exp[\"knn_intersection\"])\n", "\n", " #plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)\n", " \n", "plt.title(f\"recall @ 1 = {accuracy*100}%\")\n", "plt.xlabel(\"database size\")\n", "plt.ylabel(\"time\")\n", "plt.xscale(\"log\")\n", "plt.yscale(\"log\")\n", "\n", "marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", "for index in x.keys():\n", " if \"HNSW\" in index:\n", " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle=\"dashed\")\n", " else:\n", " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))\n", "plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" ] }, { "cell_type": "code", "execution_count": null, "id": "37a99bb2-f998-461b-a345-7cc6e702cb3a", "metadata": {}, "outputs": [], "source": [ "# global optima\n", "accuracy_metric = \"sym_recall\"\n", "fr = filter_results(results, evaluation=\"rec\", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)\n", "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"space\", plot_space=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "c973ce4e-3566-4f02-bd93-f113e3e0c791", "metadata": {}, "outputs": [], "source": [ "def pretty_time(s):\n", " if s is None:\n", " return \"None\"\n", " s = int(s * 1000) / 1000\n", " m, s = divmod(s, 60)\n", " h, m = divmod(m, 60)\n", " d, h = divmod(h, 24)\n", " r = \"\"\n", " if d > 0:\n", " r += f\"{int(d)}d \"\n", " if h > 0:\n", " r += f\"{int(h)}h \"\n", " if m > 0:\n", " r += f\"{int(m)}m \"\n", " if s > 0 or len(r) == 0:\n", " r += f\"{s:.3f}s\"\n", " return r\n", "\n", "def pretty_size(s):\n", " if s > 1024 * 1024:\n", " return f\"{s / 1024 / 1024:.1f}\".rstrip('0').rstrip('.') + \"MB\"\n", " if s > 1024:\n", " return f\"{s / 1024:.1f}\".rstrip('0').rstrip('.') + \"KB\"\n", " return f\"{s}\"\n", "\n", "def pretty_mse(m):\n", " if m is None:\n", " return \"None\"\n", " else:\n", " return f\"{m:.6f}\"" ] }, { "cell_type": "code", "execution_count": null, "id": "1ddcf226-fb97-4a59-9fc3-3ed8f7d5e703", "metadata": {}, "outputs": [], "source": [ "data = {}\n", "root = \"/checkpoint/gsz/bench_fw/bigann\"\n", "scales = [1, 2, 5, 10, 20, 50]\n", "for scale in scales:\n", " results = BIO(root).read_json(f\"result{scale}.json\")\n", " accuracy_metric = \"knn_intersection\"\n", " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", " d = {}\n", " data[f\"{scale}M\"] = d\n", " for _, _, _, _, exp in fr:\n", " fact = exp[\"factory\"]\n", " # \"HNSW\" in fact or \n", " if fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", " continue\n", " if fact not in d:\n", " d[fact] = []\n", " d[fact].append({\n", " \"nprobe\": exp[\"search_params\"][\"nprobe\"],\n", " \"recall\": exp[\"knn_intersection\"],\n", " \"time\": exp[\"time\"] + exp[\"quantizer\"][\"time\"],\n", " })\n", "data\n", "# with open(\"/checkpoint/gsz/bench_fw/codecs.json\", \"w\") as f:\n", "# json.dump(data, f)" ] }, { "cell_type": "code", "execution_count": null, "id": "e54eebb6-0a9f-4a72-84d2-f12c5bd44510", "metadata": {}, "outputs": [], "source": [ "ds = \"deep1b\"\n", "data = []\n", "jss = []\n", "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", "results = BIO(root).read_json(f\"result.json\")\n", "for k, e in results[\"experiments\"].items():\n", " if \"rec\" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", " code_size = results['indices'][e['codec']]['sa_code_size']\n", " codec_size = results['indices'][e['codec']]['codec_size']\n", " training_time = results['indices'][e['codec']]['training_time']\n", " # training_size = results['indices'][e['codec']]['training_size']\n", " cpu = e['cpu'] if 'cpu' in e else \"\"\n", " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", " jss.append({\n", " 'factory': e['factory'],\n", " 'parameters': e['construction_params'][0] if e['construction_params'] else \"\",\n", " 'evaluation_params': e['reconstruct_params'],\n", " 'code_size': code_size,\n", " 'codec_size': codec_size,\n", " 'training_time': training_time,\n", " 'training_size': training_size,\n", " 'mse': e['mse'],\n", " 'sym_recall': e['sym_recall'],\n", " 'asym_recall': e['asym_recall'],\n", " 'encode_time': e['encode_time'],\n", " 'decode_time': e['decode_time'],\n", " 'cpu': cpu,\n", " })\n", "\n", "print(\"|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", "print(\"|-|-|-|-|-|-|-|-|-|\")\n", "data.sort()\n", "for d in data:\n", " print(d[1])\n", "\n", "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json\", \"w\") as f:\n", " json.dump(jss, f)" ] }, { "cell_type": "code", "execution_count": null, "id": "d1216733-9670-407c-b3d2-5f87bce0321c", "metadata": {}, "outputs": [], "source": [ "def read_file(filename: str, keys):\n", " results = []\n", " with ZipFile(filename, \"r\") as zip_file:\n", " for key in keys:\n", " with zip_file.open(key, \"r\") as f:\n", " if key in [\"D\", \"I\", \"R\", \"lims\"]:\n", " results.append(np.load(f))\n", " elif key in [\"P\"]:\n", " t = io.TextIOWrapper(f)\n", " results.append(json.load(t))\n", " else:\n", " raise AssertionError()\n", " return results" ] }, { "cell_type": "code", "execution_count": null, "id": "56de051e-22db-4bef-b242-1ddabc9e0bb9", "metadata": {}, "outputs": [], "source": [ "ds = \"contriever\"\n", "data = []\n", "jss = []\n", "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", "for lf in glob.glob(root + '/*rec*.zip'):\n", " e, = read_file(lf, ['P'])\n", " if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", " code_size = e['codec_meta']['sa_code_size']\n", " codec_size = e['codec_meta']['codec_size']\n", " training_time = e['codec_meta']['training_time']\n", " training_size = None # e['codec_meta']['training_size']\n", " cpu = e['cpu'] if 'cpu' in e else \"\"\n", " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", " if eps in ps and eps != \"encode_ils_iters=16\" and eps != \"max_beam_size=32\":\n", " eps = \" \"\n", " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", " eps = e['reconstruct_params']\n", " del eps['snap']\n", " params = copy(e['construction_params'][0]) if e['construction_params'] else {}\n", " for k, v in e['reconstruct_params'].items():\n", " params[k] = v\n", " jss.append({\n", " 'factory': e['factory'],\n", " 'params': params,\n", " 'construction_params': e['construction_params'][0] if e['construction_params'] else {},\n", " 'evaluation_params': e['reconstruct_params'],\n", " 'code_size': code_size,\n", " 'codec_size': codec_size,\n", " 'training_time': training_time,\n", " # 'training_size': training_size,\n", " 'mse': e['mse'],\n", " 'sym_recall': e['sym_recall'],\n", " 'asym_recall': e['asym_recall'],\n", " 'encode_time': e['encode_time'],\n", " 'decode_time': e['decode_time'],\n", " 'cpu': cpu,\n", " })\n", "\n", "print(\"|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", "print(\"|-|-|-|-|-|-|-|-|-|\")\n", "data.sort()\n", "# for d in data:\n", "# print(d[1])\n", "\n", "print(len(data))\n", "\n", "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json\", \"w\") as f:\n", " json.dump(jss, f)" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:.conda-faiss_from_source] *", "language": "python", "name": "conda-env-.conda-faiss_from_source-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }