benchmark view results (#3144)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3144 Visualize results of running the benchmark with Pareto optima filtering: 1. per index or across indices 2. for space, time or space & time 3. knn or range search, the latter @ specific precision Reviewed By: mdouze Differential Revision: D51552775 fbshipit-source-id: d4f29e3d46ef044e71b54439b3972548c86af5a7pull/3164/head
parent
9519a19f42
commit
4c83965d2b
|
@ -0,0 +1,289 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be081589-e1b2-4569-acb7-44203e273899",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import itertools\n",
|
||||
"from faiss.contrib.evaluation import OperatingPoints\n",
|
||||
"from enum import Enum\n",
|
||||
"from bench_fw.benchmark_io import BenchmarkIO as BIO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a6492e95-24c7-4425-bf0a-27e10e879ca6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"root = \"/checkpoint\"\n",
|
||||
"results = BIO(root).read_json(\"result.json\")\n",
|
||||
"results.keys()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0875d269-aef4-426d-83dd-866970f43777",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"results['indices']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a7ff7078-29c7-407c-a079-201877b764ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Cost:\n",
|
||||
" def __init__(self, values):\n",
|
||||
" self.values = values\n",
|
||||
"\n",
|
||||
" def __le__(self, other):\n",
|
||||
" return all(v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True))\n",
|
||||
"\n",
|
||||
" def __lt__(self, other):\n",
|
||||
" return all(v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True))\n",
|
||||
"\n",
|
||||
"class ParetoMode(Enum):\n",
|
||||
" DISABLE = 1 # no Pareto filtering\n",
|
||||
" INDEX = 2 # index-local optima\n",
|
||||
" GLOBAL = 3 # global optima\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ParetoMetric(Enum):\n",
|
||||
" TIME = 0 # time vs accuracy\n",
|
||||
" SPACE = 1 # space vs accuracy\n",
|
||||
" TIME_SPACE = 2 # (time, space) vs accuracy\n",
|
||||
"\n",
|
||||
"def range_search_recall_at_precision(experiment, precision):\n",
|
||||
" return round(max(r for r, p in zip(experiment['range_search_pr']['recall'], experiment['range_search_pr']['precision']) if p > precision), 6)\n",
|
||||
"\n",
|
||||
"def filter_results(\n",
|
||||
" results,\n",
|
||||
" evaluation,\n",
|
||||
" accuracy_metric, # str or func\n",
|
||||
" time_metric=None, # func or None -> use default\n",
|
||||
" space_metric=None, # func or None -> use default\n",
|
||||
" min_accuracy=0,\n",
|
||||
" max_space=0,\n",
|
||||
" max_time=0,\n",
|
||||
" scaling_factor=1.0,\n",
|
||||
" \n",
|
||||
" pareto_mode=ParetoMode.DISABLE,\n",
|
||||
" pareto_metric=ParetoMetric.TIME,\n",
|
||||
"):\n",
|
||||
" if isinstance(accuracy_metric, str):\n",
|
||||
" accuracy_key = accuracy_metric\n",
|
||||
" accuracy_metric = lambda v: v[accuracy_key]\n",
|
||||
"\n",
|
||||
" if time_metric is None:\n",
|
||||
" time_metric = lambda v: v['time'] * scaling_factor + (v['quantizer']['time'] if 'quantizer' in v else 0)\n",
|
||||
"\n",
|
||||
" if space_metric is None:\n",
|
||||
" space_metric = lambda v: results['indices'][v['codec']]['code_size']\n",
|
||||
" \n",
|
||||
" fe = []\n",
|
||||
" ops = {}\n",
|
||||
" if pareto_mode == ParetoMode.GLOBAL:\n",
|
||||
" op = OperatingPoints()\n",
|
||||
" ops[\"global\"] = op\n",
|
||||
" for k, v in results['experiments'].items():\n",
|
||||
" if f\".{evaluation}\" in k:\n",
|
||||
" accuracy = accuracy_metric(v)\n",
|
||||
" if min_accuracy > 0 and accuracy < min_accuracy:\n",
|
||||
" continue\n",
|
||||
" space = space_metric(v)\n",
|
||||
" if max_space > 0 and space > max_space:\n",
|
||||
" continue\n",
|
||||
" time = time_metric(v)\n",
|
||||
" if max_time > 0 and time > max_time:\n",
|
||||
" continue\n",
|
||||
" idx_name = v['index']\n",
|
||||
" experiment = (accuracy, space, time, k, v)\n",
|
||||
" if pareto_mode == ParetoMode.DISABLE:\n",
|
||||
" fe.append(experiment)\n",
|
||||
" continue\n",
|
||||
" if pareto_mode == ParetoMode.INDEX:\n",
|
||||
" if idx_name not in ops:\n",
|
||||
" ops[idx_name] = OperatingPoints()\n",
|
||||
" op = ops[idx_name]\n",
|
||||
" if pareto_metric == ParetoMetric.TIME:\n",
|
||||
" op.add_operating_point(experiment, accuracy, time)\n",
|
||||
" elif pareto_metric == ParetoMetric.SPACE:\n",
|
||||
" op.add_operating_point(experiment, accuracy, space)\n",
|
||||
" else:\n",
|
||||
" op.add_operating_point(experiment, accuracy, Cost([time, space]))\n",
|
||||
"\n",
|
||||
" if ops:\n",
|
||||
" for op in ops.values():\n",
|
||||
" for v, _, _ in op.operating_points:\n",
|
||||
" fe.append(v)\n",
|
||||
"\n",
|
||||
" fe.sort()\n",
|
||||
" return fe"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f080a6e2-1565-418b-8732-4adeff03a099",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_metric(experiments, accuracy_title, cost_title, plot_space=False):\n",
|
||||
" x = {}\n",
|
||||
" y = {}\n",
|
||||
" for accuracy, space, time, k, v in experiments:\n",
|
||||
" idx_name = v['index']\n",
|
||||
" if idx_name not in x:\n",
|
||||
" x[idx_name] = []\n",
|
||||
" y[idx_name] = []\n",
|
||||
" x[idx_name].append(accuracy)\n",
|
||||
" if plot_space:\n",
|
||||
" y[idx_name].append(space)\n",
|
||||
" else:\n",
|
||||
" y[idx_name].append(time)\n",
|
||||
"\n",
|
||||
" #plt.figure(figsize=(10,6))\n",
|
||||
" plt.yscale(\"log\")\n",
|
||||
" plt.title(accuracy_title)\n",
|
||||
" plt.xlabel(accuracy_title)\n",
|
||||
" plt.ylabel(cost_title)\n",
|
||||
" marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n",
|
||||
" for index in x.keys():\n",
|
||||
" plt.plot(x[index], y[index], marker=next(marker), label=index)\n",
|
||||
" plt.legend(bbox_to_anchor=(1, 1), loc='upper left')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "61007155-5edc-449e-835e-c141a01a2ae5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"accuracy_metric = \"knn_intersection\"\n",
|
||||
"fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
|
||||
"plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 16 cores)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "36e82084-18f6-4546-a717-163eb0224ee8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"precision = 0.8\n",
|
||||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
|
||||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
|
||||
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aff79376-39f7-47c0-8b83-1efe5192bb7e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# index local optima\n",
|
||||
"precision = 0.2\n",
|
||||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
|
||||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
|
||||
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# global optima\n",
|
||||
"precision = 0.8\n",
|
||||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
|
||||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n",
|
||||
"plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9aead830-6209-4956-b7ea-4a5e0029d616",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_range_search_pr_curves(experiments):\n",
|
||||
" x = {}\n",
|
||||
" y = {}\n",
|
||||
" show = {\n",
|
||||
" 'Flat': None,\n",
|
||||
" }\n",
|
||||
" for _, _, _, k, v in fr:\n",
|
||||
" if \".weighted\" in k: # and v['index'] in show:\n",
|
||||
" x[k] = v['range_search_pr']['recall']\n",
|
||||
" y[k] = v['range_search_pr']['precision']\n",
|
||||
" \n",
|
||||
" plt.title(\"range search recall\")\n",
|
||||
" plt.xlabel(\"recall\")\n",
|
||||
" plt.ylabel(\"precision\")\n",
|
||||
" for index in x.keys():\n",
|
||||
" plt.plot(x[index], y[index], '.', label=index)\n",
|
||||
" plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "92e45502-7a31-4a15-90df-fa3032d7d350",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"precision = 0.8\n",
|
||||
"accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n",
|
||||
"fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n",
|
||||
"plot_range_search_pr_curves(fr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python [conda env:faiss_cpu_from_source] *",
|
||||
"language": "python",
|
||||
"name": "conda-env-faiss_cpu_from_source-py"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Loading…
Reference in New Issue