In [None]:
import matplotlib.pyplot as plt
import itertools
from faiss.contrib.evaluation import OperatingPoints
from enum import Enum
from bench_fw.benchmark_io import BenchmarkIO as BIO
from bench_fw.utils import filter_results, ParetoMode, ParetoMetric
from copy import copy
import numpy as np
import datetime
import glob
import io
import json
from zipfile import ZipFile
import tabulate

In [None]:
root = "/checkpoint/gsz/bench_fw/optimize/bigann"
results = BIO(root).read_json("result_std_d_bigann10M.json")
results.keys()

In [None]:
results['experiments']

In [None]:
def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):
 if plot is None:
 plot = plt.subplot()
 x = {}
 y = {}
 for accuracy, space, time, k, v in experiments:
 idx_name = v['index'] + ("snap" if 'search_params' in v and v['search_params']["snap"] == 1 else "")
 if idx_name not in x:
 x[idx_name] = []
 y[idx_name] = []
 x[idx_name].append(accuracy)
 if plot_space:
 y[idx_name].append(space)
 else:
 y[idx_name].append(time)

 #plt.figure(figsize=(10,6))
 #plt.title(accuracy_title)
 plot.set_xlabel(accuracy_title)
 plot.set_ylabel(cost_title)
 plot.set_yscale("log")
 marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D")) 
 for index in x.keys():
 plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)
 plot.legend(bbox_to_anchor=(1, 1), loc='upper left')

In [None]:
# index local optima
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 32 cores)", plot_space=False)

In [None]:
# global optima
accuracy_metric = "knn_intersection"
fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith("Flat"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 32 cores)", plot_space=False)

In [None]:
def pretty_params(p):
 p = copy(p)
 if 'snap' in p and p['snap'] == 0:
 del p['snap']
 return p
 
tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) 
 for accuracy, space, time, k, v in fr],
 tablefmt="html",
 headers=["accuracy","space", "time", "factory", "quantizer cfg", "search cfg"])

In [None]:
# index local optima @ precision 0.8
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")

In [None]:
# index local optima @ precision 0.2
precision = 0.2
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")

In [None]:
# global optima @ precision 0.8
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
plot_metric(fr, accuracy_title=f"range recall @ precision {precision}", cost_title="time (seconds, 16 cores)")

In [None]:
def plot_range_search_pr_curves(experiments):
 x = {}
 y = {}
 show = {
 'Flat': None,
 }
 for _, _, _, k, v in fr:
 if ".weighted" in k: # and v['index'] in show:
 x[k] = v['range_search_pr']['recall']
 y[k] = v['range_search_pr']['precision']
 
 plt.title("range search recall")
 plt.xlabel("recall")
 plt.ylabel("precision")
 for index in x.keys():
 plt.plot(x[index], y[index], '.', label=index)
 plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')

In [None]:
precision = 0.8
accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)
fr = filter_results(results, evaluation="weighted", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)
plot_range_search_pr_curves(fr)

In [None]:
root = "/checkpoint/gsz/bench_fw/ivf/bigann"
scales = [1, 2, 5, 10, 20, 50]
fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))
fig.tight_layout()
for plot, scale in zip(plots, scales, strict=True):
 results = BIO(root).read_json(f"result{scale}.json")
 accuracy_metric = "knn_intersection"
 fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
 plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 64 cores)", plot=plot)

In [None]:
x = {}
y = {}
accuracy=0.9
root = "/checkpoint/gsz/bench_fw/ivf/bigann"
scales = [1, 2, 5, 10, 20, 50]
#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))
#fig.tight_layout()
for scale in scales:
 results = BIO(root).read_json(f"result{scale}.json")
 scale *= 1_000_000
 accuracy_metric = "knn_intersection"
 fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
 seen = set()
 print(scale)
 for _, _, _, _, exp in fr:
 fact = exp["factory"]
 # "HNSW" in fact or 
 if fact in seen or fact in ["Flat", "IVF512,Flat", "IVF1024,Flat", "IVF2048,Flat"]:
 continue
 seen.add(fact)
 if fact not in x:
 x[fact] = []
 y[fact] = []
 x[fact].append(scale)
 y[fact].append(exp["time"] + exp["quantizer"]["time"])
 if (exp["knn_intersection"] > 0.92):
 print(fact)
 print(exp["search_params"])
 print(exp["knn_intersection"])

 #plot_metric(fr, accuracy_title="knn intersection", cost_title="time (seconds, 64 cores)", plot=plot)
 
plt.title(f"recall @ 1 = {accuracy*100}%")
plt.xlabel("database size")
plt.ylabel("time")
plt.xscale("log")
plt.yscale("log")

marker = itertools.cycle(("o", "v", "^", "<", ">", "s", "p", "P", "*", "h", "X", "D")) 
for index in x.keys():
 if "HNSW" in index:
 plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle="dashed")
 else:
 plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))
plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')

In [None]:
# global optima
accuracy_metric = "sym_recall"
fr = filter_results(results, evaluation="rec", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)
plot_metric(fr, accuracy_title="knn intersection", cost_title="space", plot_space=True)

In [None]:
def pretty_time(s):
 if s is None:
 return "None"
 s = int(s * 1000) / 1000
 m, s = divmod(s, 60)
 h, m = divmod(m, 60)
 d, h = divmod(h, 24)
 r = ""
 if d > 0:
 r += f"{int(d)}d "
 if h > 0:
 r += f"{int(h)}h "
 if m > 0:
 r += f"{int(m)}m "
 if s > 0 or len(r) == 0:
 r += f"{s:.3f}s"
 return r

def pretty_size(s):
 if s > 1024 * 1024:
 return f"{s / 1024 / 1024:.1f}".rstrip('0').rstrip('.') + "MB"
 if s > 1024:
 return f"{s / 1024:.1f}".rstrip('0').rstrip('.') + "KB"
 return f"{s}"

def pretty_mse(m):
 if m is None:
 return "None"
 else:
 return f"{m:.6f}"

In [None]:
data = {}
root = "/checkpoint/gsz/bench_fw/bigann"
scales = [1, 2, 5, 10, 20, 50]
for scale in scales:
 results = BIO(root).read_json(f"result{scale}.json")
 accuracy_metric = "knn_intersection"
 fr = filter_results(results, evaluation="knn", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)
 d = {}
 data[f"{scale}M"] = d
 for _, _, _, _, exp in fr:
 fact = exp["factory"]
 # "HNSW" in fact or 
 if fact in ["Flat", "IVF512,Flat", "IVF1024,Flat", "IVF2048,Flat"]:
 continue
 if fact not in d:
 d[fact] = []
 d[fact].append({
 "nprobe": exp["search_params"]["nprobe"],
 "recall": exp["knn_intersection"],
 "time": exp["time"] + exp["quantizer"]["time"],
 })
data
# with open("/checkpoint/gsz/bench_fw/codecs.json", "w") as f:
# json.dump(data, f)

In [None]:
ds = "deep1b"
data = []
jss = []
root = f"/checkpoint/gsz/bench_fw/codecs/{ds}"
results = BIO(root).read_json(f"result.json")
for k, e in results["experiments"].items():
 if "rec" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and "PRQ" in e['factory'] and e['sym_recall'] > 0.0:
 code_size = results['indices'][e['codec']]['sa_code_size']
 codec_size = results['indices'][e['codec']]['codec_size']
 training_time = results['indices'][e['codec']]['training_time']
 # training_size = results['indices'][e['codec']]['training_size']
 cpu = e['cpu'] if 'cpu' in e else ""
 ps = ', '.join([f"{k}={v}" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else " "
 eps = ', '.join([f"{k}={v}" for k,v in e['reconstruct_params'].items() if k != "snap"]) if e['reconstruct_params'] else " "
 data.append((code_size, f"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|"))
 jss.append({
 'factory': e['factory'],
 'parameters': e['construction_params'][0] if e['construction_params'] else "",
 'evaluation_params': e['reconstruct_params'],
 'code_size': code_size,
 'codec_size': codec_size,
 'training_time': training_time,
 'training_size': training_size,
 'mse': e['mse'],
 'sym_recall': e['sym_recall'],
 'asym_recall': e['asym_recall'],
 'encode_time': e['encode_time'],
 'decode_time': e['decode_time'],
 'cpu': cpu,
 })

print("|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|")
print("|-|-|-|-|-|-|-|-|-|")
data.sort()
for d in data:
 print(d[1])

with open(f"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json", "w") as f:
 json.dump(jss, f)

In [None]:
def read_file(filename: str, keys):
 results = []
 with ZipFile(filename, "r") as zip_file:
 for key in keys:
 with zip_file.open(key, "r") as f:
 if key in ["D", "I", "R", "lims"]:
 results.append(np.load(f))
 elif key in ["P"]:
 t = io.TextIOWrapper(f)
 results.append(json.load(t))
 else:
 raise AssertionError()
 return results

In [None]:
ds = "contriever"
data = []
jss = []
root = f"/checkpoint/gsz/bench_fw/codecs/{ds}"
for lf in glob.glob(root + '/*rec*.zip'):
 e, = read_file(lf, ['P'])
 if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and "PRQ" in e['factory'] and e['sym_recall'] > 0.0:
 code_size = e['codec_meta']['sa_code_size']
 codec_size = e['codec_meta']['codec_size']
 training_time = e['codec_meta']['training_time']
 training_size = None # e['codec_meta']['training_size']
 cpu = e['cpu'] if 'cpu' in e else ""
 ps = ', '.join([f"{k}={v}" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else " "
 eps = ', '.join([f"{k}={v}" for k,v in e['reconstruct_params'].items() if k != "snap"]) if e['reconstruct_params'] else " "
 if eps in ps and eps != "encode_ils_iters=16" and eps != "max_beam_size=32":
 eps = " "
 data.append((code_size, f"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|"))
 eps = e['reconstruct_params']
 del eps['snap']
 params = copy(e['construction_params'][0]) if e['construction_params'] else {}
 for k, v in e['reconstruct_params'].items():
 params[k] = v
 jss.append({
 'factory': e['factory'],
 'params': params,
 'construction_params': e['construction_params'][0] if e['construction_params'] else {},
 'evaluation_params': e['reconstruct_params'],
 'code_size': code_size,
 'codec_size': codec_size,
 'training_time': training_time,
 # 'training_size': training_size,
 'mse': e['mse'],
 'sym_recall': e['sym_recall'],
 'asym_recall': e['asym_recall'],
 'encode_time': e['encode_time'],
 'decode_time': e['decode_time'],
 'cpu': cpu,
 })

print("|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|")
print("|-|-|-|-|-|-|-|-|-|")
data.sort()
# for d in data:
# print(d[1])

print(len(data))

with open(f"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json", "w") as f:
 json.dump(jss, f)