faiss/benchs/bench_all_ivf/parse_bench_all_ivf.py

503 lines
15 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
2018-12-20 21:43:36 +08:00
#
# This source code is licensed under the MIT license found in the
2018-12-20 21:43:36 +08:00
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from collections import defaultdict
2018-12-20 21:43:36 +08:00
from matplotlib import pyplot
import re
from argparse import Namespace
from faiss.contrib.factory_tools import get_code_size as unitsize
2018-12-20 21:43:36 +08:00
def dbsize_from_name(dbname):
sufs = {
'1B': 10**9,
'100M': 10**8,
'10M': 10**7,
'1M': 10**6,
}
for s in sufs:
if dbname.endswith(s):
return sufs[s]
else:
assert False
def keep_latest_stdout(fnames):
fnames = [fname for fname in fnames if fname.endswith('.stdout')]
fnames.sort()
n = len(fnames)
fnames2 = []
for i, fname in enumerate(fnames):
if i + 1 < n and fnames[i + 1][:-8] == fname[:-8]:
continue
fnames2.append(fname)
return fnames2
def parse_result_file(fname):
# print fname
st = 0
res = []
keys = []
stats = {}
stats['run_version'] = fname[-8]
indexkey = None
2018-12-20 21:43:36 +08:00
for l in open(fname):
if l.startswith("srun:"):
# looks like a crash...
if indexkey is None:
raise RuntimeError("instant crash")
break
elif st == 0:
if l.startswith("dataset in dimension"):
fi = l.split()
stats["d"] = int(fi[3][:-1])
stats["nq"] = int(fi[9])
stats["nb"] = int(fi[11])
stats["nt"] = int(fi[13])
2018-12-20 21:43:36 +08:00
if l.startswith('index size on disk:'):
stats['index_size'] = int(l.split()[-1])
if l.startswith('current RSS:'):
stats['RSS'] = int(l.split()[-1])
if l.startswith('precomputed tables size:'):
stats['tables_size'] = int(l.split()[-1])
if l.startswith('Setting nb of threads to'):
stats['n_threads'] = int(l.split()[-1])
if l.startswith(' add in'):
stats['add_time'] = float(l.split()[-2])
if l.startswith("vector code_size"):
stats['code_size'] = float(l.split()[-1])
2018-12-20 21:43:36 +08:00
if l.startswith('args:'):
args = eval(l[l.find(' '):])
indexkey = args.indexkey
elif "time(ms/q)" in l:
# result header
if 'R@1 R@10 R@100' in l:
stats["measure"] = "recall"
stats["ranks"] = [1, 10, 100]
elif 'I@1 I@10 I@100' in l:
stats["measure"] = "inter"
stats["ranks"] = [1, 10, 100]
elif 'inter@' in l:
stats["measure"] = "inter"
fi = l.split()
if fi[1] == "inter@":
rank = int(fi[2])
else:
rank = int(fi[1][len("inter@"):])
stats["ranks"] = [rank]
else:
assert False
2018-12-20 21:43:36 +08:00
st = 1
elif 'index size on disk:' in l:
stats["index_size"] = int(l.split()[-1])
2018-12-20 21:43:36 +08:00
elif st == 1:
st = 2
elif st == 2:
fi = l.split()
if l[0] == " ":
# means there are 0 parameters
fi = [""] + fi
2018-12-20 21:43:36 +08:00
keys.append(fi[0])
res.append([float(x) for x in fi[1:]])
return indexkey, np.array(res), keys, stats
# the directory used in run_on_cluster.bash
basedir = "/checkpoint/matthijs/bench_all_ivf/"
logdir = basedir + 'logs/'
2018-12-20 21:43:36 +08:00
def collect_results_for(db='deep1M', prefix="autotune."):
# run parsing
allres = {}
allstats = {}
missing = []
fnames = keep_latest_stdout(os.listdir(logdir))
# print fnames
# filenames are in the form <key>.x.stdout
# where x is a version number (from a to z)
# keep only latest version of each name
for fname in fnames:
if not (
'db' + db in fname and
fname.startswith(prefix) and
fname.endswith('.stdout')
):
continue
print("parse", fname, end=" ", flush=True)
try:
indexkey, res, _, stats = parse_result_file(logdir + fname)
except RuntimeError as e:
print("FAIL %s" % e)
res = np.zeros((2, 0))
except Exception as e:
print("PARSE ERROR " + e)
res = np.zeros((2, 0))
else:
print(len(res), "results")
if res.size == 0:
missing.append(fname)
else:
if indexkey in allres:
if allstats[indexkey]['run_version'] > stats['run_version']:
# don't use this run
continue
2018-12-20 21:43:36 +08:00
allres[indexkey] = res
allstats[indexkey] = stats
2018-12-20 21:43:36 +08:00
return allres, allstats
2018-12-20 21:43:36 +08:00
def extract_pareto_optimal(allres, keys, recall_idx=0, times_idx=3):
2018-12-20 21:43:36 +08:00
bigtab = []
for i, k in enumerate(keys):
v = allres[k]
2018-12-20 21:43:36 +08:00
perf = v[:, recall_idx]
times = v[:, times_idx]
2018-12-20 21:43:36 +08:00
bigtab.append(
np.vstack((
np.ones(times.size) * i,
2018-12-20 21:43:36 +08:00
perf, times
))
)
if bigtab == []:
return [], np.zeros((3, 0))
2018-12-20 21:43:36 +08:00
bigtab = np.hstack(bigtab)
# sort by perf
2018-12-20 21:43:36 +08:00
perm = np.argsort(bigtab[1, :])
bigtab_sorted = bigtab[:, perm]
best_times = np.minimum.accumulate(bigtab_sorted[2, ::-1])[::-1]
selection, = np.where(bigtab_sorted[2, :] == best_times)
selected_keys = [
keys[i] for i in
np.unique(bigtab_sorted[0, selection].astype(int))
]
ops = bigtab_sorted[:, selection]
2018-12-20 21:43:36 +08:00
return selected_keys, ops
2018-12-20 21:43:36 +08:00
def plot_subset(
allres, allstats, selected_methods, recall_idx, times_idx=3,
report=["overhead", "build time"]):
2018-12-20 21:43:36 +08:00
# important methods
for k in selected_methods:
v = allres[k]
stats = allstats[k]
d = stats["d"]
dbsize = stats["nb"]
if "index_size" in stats and "tables_size" in stats:
tot_size = stats['index_size'] + stats['tables_size']
else:
tot_size = -1
2018-12-20 21:43:36 +08:00
id_size = 8 # 64 bit
addt = ''
if 'add_time' in stats:
add_time = stats['add_time']
if add_time > 7200:
add_min = add_time / 60
addt = ', %dh%02d' % (add_min / 60, add_min % 60)
else:
add_sec = int(add_time)
addt = ', %dm%02d' % (add_sec / 60, add_sec % 60)
code_size = unitsize(d, k)
label = k
if "code_size" in report:
label += " %d bytes" % code_size
tight_size = (code_size + id_size) * dbsize
if tot_size < 0 or "overhead" not in report:
pass # don't know what the index size is
elif tot_size > 10 * tight_size:
label += " overhead x%.1f" % (tot_size / tight_size)
else:
label += " overhead+%.1f%%" % (
tot_size / tight_size * 100 - 100)
if "build time" in report:
label += " " + addt
linestyle = (':' if 'Refine' in k or 'RFlat' in k else
'-.' if 'SQ' in k else
'-' if '4fs' in k else
'-')
print(k, linestyle)
pyplot.semilogy(v[:, recall_idx], 1000 / v[:, times_idx], label=label,
linestyle=linestyle,
marker='o' if '4fs' in k else '+')
recall_rank = stats["ranks"][recall_idx]
if stats["measure"] == "recall":
pyplot.xlabel('1-recall at %d' % recall_rank)
elif stats["measure"] == "inter":
pyplot.xlabel('inter @ %d' % recall_rank)
else:
assert False
pyplot.ylabel('QPS (%d threads)' % stats["n_threads"])
def plot_tradeoffs(db, allres, allstats, code_size, recall_rank):
stat0 = next(iter(allstats.values()))
d = stat0["d"]
n_threads = stat0["n_threads"]
recall_idx = stat0["ranks"].index(recall_rank)
# times come after the perf measure
times_idx = len(stat0["ranks"])
if type(code_size) == int:
if code_size == 0:
code_size = [0, 1e50]
code_size_name = "any code size"
else:
code_size_name = "code_size=%d" % code_size
code_size = [code_size, code_size]
elif type(code_size) == tuple:
code_size_name = "code_size in [%d, %d]" % code_size
else:
assert False
names_maxperf = []
for k in sorted(allres):
v = allres[k]
if v.ndim != 2: continue
us = unitsize(d, k)
if not code_size[0] <= us <= code_size[1]: continue
names_maxperf.append((v[-1, recall_idx], k))
# sort from lowest to highest topline accuracy
names_maxperf.sort()
names = [name for mp, name in names_maxperf]
2018-12-20 21:43:36 +08:00
selected_methods, optimal_points = \
extract_pareto_optimal(allres, names, recall_idx, times_idx)
not_selected = list(set(names) - set(selected_methods))
print("methods without an optimal OP: ", not_selected)
pyplot.title('database ' + db + ' ' + code_size_name)
# grayed out lines
for k in not_selected:
v = allres[k]
if v.ndim != 2: continue
us = unitsize(d, k)
if not code_size[0] <= us <= code_size[1]: continue
2018-12-20 21:43:36 +08:00
linestyle = (':' if 'PQ' in k else
'-.' if 'SQ4' in k else
'--' if 'SQ8' in k else '-')
pyplot.semilogy(v[:, recall_idx], 1000 / v[:, times_idx], label=None,
2018-12-20 21:43:36 +08:00
linestyle=linestyle,
marker='o' if 'HNSW' in k else '+',
color='#cccccc', linewidth=0.2)
plot_subset(allres, allstats, selected_methods, recall_idx, times_idx)
2018-12-20 21:43:36 +08:00
if len(not_selected) == 0:
om = ''
else:
om = '\nomitted:'
nc = len(om)
for m in not_selected:
if nc > 80:
om += '\n'
nc = 0
om += ' ' + m
nc += len(m) + 1
# pyplot.semilogy(optimal_points[1, :], optimal_points[2, :], marker="s")
# print(optimal_points[0, :])
2018-12-20 21:43:36 +08:00
pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) )
pyplot.ylabel('QPS (%d threads)' % n_threads)
2018-12-20 21:43:36 +08:00
pyplot.legend()
pyplot.grid()
return selected_methods, not_selected
if __name__ == "__main__xx":
# tests on centroids indexing (v1)
for k in 1, 32, 128:
pyplot.gcf().set_size_inches(15, 10)
i = 1
for ncent in 65536, 262144, 1048576, 4194304:
db = f'deep_centroids_{ncent}.k{k}.'
allres, allstats = collect_results_for(
db=db, prefix="cent_index.")
pyplot.subplot(2, 2, i)
plot_subset(
allres, allstats, list(allres.keys()),
recall_idx=0,
times_idx=1,
report=["code_size"]
)
i += 1
pyplot.title(f"{ncent} centroids")
pyplot.legend()
pyplot.xlim([0.95, 1])
pyplot.grid()
pyplot.savefig('figs/deep1B_centroids_k%d.png' % k)
if __name__ == "__main__xx":
# centroids plot per k
pyplot.gcf().set_size_inches(15, 10)
i=1
for ncent in 65536, 262144, 1048576, 4194304:
xyd = defaultdict(list)
for k in 1, 4, 8, 16, 32, 64, 128, 256:
db = f'deep_centroids_{ncent}.k{k}.'
allres, allstats = collect_results_for(db=db, prefix="cent_index.")
for indexkey, res in allres.items():
idx, = np.where(res[:, 0] >= 0.99)
if idx.size > 0:
xyd[indexkey].append((k, 1000 / res[idx[0], 1]))
pyplot.subplot(2, 2, i)
i += 1
for indexkey, xy in xyd.items():
xy = np.array(xy)
pyplot.loglog(xy[:, 0], xy[:, 1], 'o-', label=indexkey)
pyplot.title(f"{ncent} centroids")
pyplot.xlabel("k")
xt = 2**np.arange(9)
pyplot.xticks(xt, ["%d" % x for x in xt])
pyplot.ylabel("QPS (32 threads)")
pyplot.legend()
pyplot.grid()
pyplot.savefig('../plots/deep1B_centroids_min99.png')
if __name__ == "__main__xx":
# main indexing plots
i = 0
for db in 'bigann10M', 'deep10M', 'bigann100M', 'deep100M', 'deep1B', 'bigann1B':
allres, allstats = collect_results_for(
db=db, prefix="autotune.")
for cs in 8, 16, 32, 64:
pyplot.figure(i)
i += 1
pyplot.gcf().set_size_inches(15, 10)
cs_range = (
(0, 8) if cs == 8 else (cs // 2 + 1, cs)
)
plot_tradeoffs(
db, allres, allstats, code_size=cs_range, recall_rank=1)
pyplot.savefig('../plots/tradeoffs_%s_cs%d_r1.png' % (
db, cs))
if __name__ == "__main__":
# 1M indexes
i = 0
for db in "glove", "music-100":
pyplot.figure(i)
pyplot.gcf().set_size_inches(15, 10)
i += 1
allres, allstats = collect_results_for(db=db, prefix="autotune.")
plot_tradeoffs(db, allres, allstats, code_size=0, recall_rank=1)
pyplot.savefig('../plots/1M_tradeoffs_' + db + ".png")
for db in "sift1M", "deep1M":
allres, allstats = collect_results_for(db=db, prefix="autotune.")
pyplot.figure(i)
pyplot.gcf().set_size_inches(15, 10)
i += 1
plot_tradeoffs(db, allres, allstats, code_size=(0, 64), recall_rank=1)
pyplot.savefig('../plots/1M_tradeoffs_' + db + "_small.png")
pyplot.figure(i)
pyplot.gcf().set_size_inches(15, 10)
i += 1
plot_tradeoffs(db, allres, allstats, code_size=(65, 10000), recall_rank=1)
pyplot.savefig('../plots/1M_tradeoffs_' + db + "_large.png")
if __name__ == "__main__xx":
db = 'sift1M'
allres, allstats = collect_results_for(db=db, prefix="autotune.")
pyplot.gcf().set_size_inches(15, 10)
keys = [
"IVF1024,PQ32x8",
"IVF1024,PQ64x4",
"IVF1024,PQ64x4fs",
"IVF1024,PQ64x4fsr",
"IVF1024,SQ4",
"IVF1024,SQ8"
]
plot_subset(allres, allstats, keys, recall_idx=0, report=["code_size"])
pyplot.legend()
pyplot.title(db)
pyplot.xlabel("1-recall@1")
pyplot.ylabel("QPS (32 threads)")
pyplot.grid()
pyplot.savefig('../plots/ivf1024_variants.png')
pyplot.figure(2)
pyplot.gcf().set_size_inches(15, 10)
keys = [
"HNSW32",
"IVF1024,PQ64x4fs",
"IVF1024,PQ64x4fsr",
"IVF1024,PQ64x4fs,RFlat",
"IVF1024,PQ64x4fs,Refine(SQfp16)",
"IVF1024,PQ64x4fs,Refine(SQ8)",
]
plot_subset(allres, allstats, keys, recall_idx=0, report=["code_size"])
pyplot.legend()
pyplot.title(db)
pyplot.xlabel("1-recall@1")
pyplot.ylabel("QPS (32 threads)")
pyplot.grid()
pyplot.savefig('../plots/ivf1024_rerank.png')