faiss/benchs/bench_all_ivf/parse_bench_all_ivf.py

270 lines
7.5 KiB
Python

#! /usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import numpy as np
from matplotlib import pyplot
import re
from argparse import Namespace
# the directory used in run_on_cluster.bash
basedir = '/mnt/vol/gfsai-east/ai-group/users/matthijs/bench_all_ivf/'
logdir = basedir + 'logs/'
# which plot to output
db = 'bigann1B'
code_size = 8
def unitsize(indexkey):
""" size of one vector in the index """
mo = re.match('.*,PQ(\\d+)', indexkey)
if mo:
return int(mo.group(1))
if indexkey.endswith('SQ8'):
bits_per_d = 8
elif indexkey.endswith('SQ4'):
bits_per_d = 4
elif indexkey.endswith('SQfp16'):
bits_per_d = 16
else:
assert False
mo = re.match('PCAR(\\d+),.*', indexkey)
if mo:
return bits_per_d * int(mo.group(1)) / 8
mo = re.match('OPQ\\d+_(\\d+),.*', indexkey)
if mo:
return bits_per_d * int(mo.group(1)) / 8
mo = re.match('RR(\\d+),.*', indexkey)
if mo:
return bits_per_d * int(mo.group(1)) / 8
assert False
def dbsize_from_name(dbname):
sufs = {
'1B': 10**9,
'100M': 10**8,
'10M': 10**7,
'1M': 10**6,
}
for s in sufs:
if dbname.endswith(s):
return sufs[s]
else:
assert False
def keep_latest_stdout(fnames):
fnames = [fname for fname in fnames if fname.endswith('.stdout')]
fnames.sort()
n = len(fnames)
fnames2 = []
for i, fname in enumerate(fnames):
if i + 1 < n and fnames[i + 1][:-8] == fname[:-8]:
continue
fnames2.append(fname)
return fnames2
def parse_result_file(fname):
# print fname
st = 0
res = []
keys = []
stats = {}
stats['run_version'] = fname[-8]
for l in open(fname):
if st == 0:
if l.startswith('CHRONOS_JOB_INSTANCE_ID'):
stats['CHRONOS_JOB_INSTANCE_ID'] = l.split()[-1]
if l.startswith('index size on disk:'):
stats['index_size'] = int(l.split()[-1])
if l.startswith('current RSS:'):
stats['RSS'] = int(l.split()[-1])
if l.startswith('precomputed tables size:'):
stats['tables_size'] = int(l.split()[-1])
if l.startswith('Setting nb of threads to'):
stats['n_threads'] = int(l.split()[-1])
if l.startswith(' add in'):
stats['add_time'] = float(l.split()[-2])
if l.startswith('args:'):
args = eval(l[l.find(' '):])
indexkey = args.indexkey
elif 'R@1 R@10 R@100' in l:
st = 1
elif 'index size on disk:' in l:
index_size = int(l.split()[-1])
elif st == 1:
st = 2
elif st == 2:
fi = l.split()
keys.append(fi[0])
res.append([float(x) for x in fi[1:]])
return indexkey, np.array(res), keys, stats
# run parsing
allres = {}
allstats = {}
nts = []
missing = []
versions = {}
fnames = keep_latest_stdout(os.listdir(logdir))
# print fnames
# filenames are in the form <key>.x.stdout
# where x is a version number (from a to z)
# keep only latest version of each name
for fname in fnames:
if not ('db' + db in fname and fname.endswith('.stdout')):
continue
indexkey, res, _, stats = parse_result_file(logdir + fname)
if res.size == 0:
missing.append(fname)
errorline = open(
logdir + fname.replace('.stdout', '.stderr')).readlines()
if len(errorline) > 0:
errorline = errorline[-1]
else:
errorline = 'NO STDERR'
print(fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline)
else:
if indexkey in allres:
if allstats[indexkey]['run_version'] > stats['run_version']:
# don't use this run
continue
n_threads = stats.get('n_threads', 1)
nts.append(n_threads)
allres[indexkey] = res
allstats[indexkey] = stats
assert len(set(nts)) == 1
n_threads = nts[0]
def plot_tradeoffs(allres, code_size, recall_rank):
dbsize = dbsize_from_name(db)
recall_idx = int(np.log10(recall_rank))
bigtab = []
names = []
for k,v in sorted(allres.items()):
if v.ndim != 2: continue
us = unitsize(k)
if us != code_size: continue
perf = v[:, recall_idx]
times = v[:, 3]
bigtab.append(
np.vstack((
np.ones(times.size, dtype=int) * len(names),
perf, times
))
)
names.append(k)
bigtab = np.hstack(bigtab)
perm = np.argsort(bigtab[1, :])
bigtab = bigtab[:, perm]
times = np.minimum.accumulate(bigtab[2, ::-1])[::-1]
selection = np.where(bigtab[2, :] == times)
selected_methods = [names[i] for i in
np.unique(bigtab[0, selection].astype(int))]
not_selected = list(set(names) - set(selected_methods))
print("methods without an optimal OP: ", not_selected)
nq = 10000
pyplot.title('database ' + db + ' code_size=%d' % code_size)
# grayed out lines
for k in not_selected:
v = allres[k]
if v.ndim != 2: continue
us = unitsize(k)
if us != code_size: continue
linestyle = (':' if 'PQ' in k else
'-.' if 'SQ4' in k else
'--' if 'SQ8' in k else '-')
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=None,
linestyle=linestyle,
marker='o' if 'HNSW' in k else '+',
color='#cccccc', linewidth=0.2)
# important methods
for k in selected_methods:
v = allres[k]
if v.ndim != 2: continue
us = unitsize(k)
if us != code_size: continue
stats = allstats[k]
tot_size = stats['index_size'] + stats['tables_size']
id_size = 8 # 64 bit
addt = ''
if 'add_time' in stats:
add_time = stats['add_time']
if add_time > 7200:
add_min = add_time / 60
addt = ', %dh%02d' % (add_min / 60, add_min % 60)
else:
add_sec = int(add_time)
addt = ', %dm%02d' % (add_sec / 60, add_sec % 60)
label = k + ' (size+%.1f%%%s)' % (
tot_size / float((code_size + id_size) * dbsize) * 100 - 100,
addt)
linestyle = (':' if 'PQ' in k else
'-.' if 'SQ4' in k else
'--' if 'SQ8' in k else '-')
pyplot.semilogy(v[:, recall_idx], v[:, 3], label=label,
linestyle=linestyle,
marker='o' if 'HNSW' in k else '+')
if len(not_selected) == 0:
om = ''
else:
om = '\nomitted:'
nc = len(om)
for m in not_selected:
if nc > 80:
om += '\n'
nc = 0
om += ' ' + m
nc += len(m) + 1
pyplot.xlabel('1-recall at %d %s' % (recall_rank, om) )
pyplot.ylabel('search time per query (ms, %d threads)' % n_threads)
pyplot.legend()
pyplot.grid()
pyplot.savefig('figs/tradeoffs_%s_cs%d_r%d.png' % (
db, code_size, recall_rank))
return selected_methods, not_selected
pyplot.gcf().set_size_inches(15, 10)
plot_tradeoffs(allres, code_size=code_size, recall_rank=1)