Use print() function in both Python 2 and Python 3 (#1443)

Summary:
Legacy __print__ statements are syntax errors in Python 3 but __print()__ function works as expected in both Python 2 and Python 3.

Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1443

Reviewed By: LowikC

Differential Revision: D24157415

Pulled By: mdouze

fbshipit-source-id: 4ec637aa26b61272e5337d47b7796a330ce25bad
pull/1449/head
cclauss 2020-10-08 00:26:02 -07:00 committed by Facebook GitHub Bot
parent 9b007c7418
commit efa1e3f64f
5 changed files with 74 additions and 71 deletions

View File

@ -5,6 +5,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import numpy as np
from matplotlib import pyplot
@ -137,7 +138,7 @@ for fname in fnames:
errorline = errorline[-1]
else:
errorline = 'NO STDERR'
print fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline
print(fname, stats['CHRONOS_JOB_INSTANCE_ID'], errorline)
else:
if indexkey in allres:
@ -186,7 +187,7 @@ def plot_tradeoffs(allres, code_size, recall_rank):
np.unique(bigtab[0, selection].astype(int))]
not_selected = list(set(names) - set(selected_methods))
print "methods without an optimal OP: ", not_selected
print("methods without an optimal OP: ", not_selected)
nq = 10000
pyplot.title('database ' + db + ' code_size=%d' % code_size)

View File

@ -5,6 +5,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import numpy as np
import time
import faiss
@ -19,24 +20,26 @@ ngpu = int(sys.argv[2])
# Load Leon's file format
def load_mnist(fname):
print "load", fname
print("load", fname)
f = open(fname)
header = np.fromfile(f, dtype='int8', count=4*4)
header = header.reshape(4, 4)[:, ::-1].copy().view('int32')
print header
print(header)
nim, xd, yd = [int(x) for x in header[1:]]
data = np.fromfile(f, count=nim * xd * yd,
dtype='uint8')
print data.shape, nim, xd, yd
print(data.shape, nim, xd, yd)
data = data.reshape(nim, xd, yd)
return data
basedir = "/path/to/mnist/data"
x = load_mnist(basedir + 'mnist8m/mnist8m-patterns-idx3-ubyte')
print "reshape"
print("reshape")
x = x.reshape(x.shape[0], -1).astype('float32')
@ -74,13 +77,13 @@ def train_kmeans(x, k, ngpu):
centroids = faiss.vector_float_to_array(clus.centroids)
obj = faiss.vector_float_to_array(clus.obj)
print "final objective: %.4g" % obj[-1]
print("final objective: %.4g" % obj[-1])
return centroids.reshape(k, d)
print "run"
print("run")
t0 = time.time()
train_kmeans(x, k, ngpu)
t1 = time.time()
print "total runtime: %.3f s" % (t1 - t0)
print("total runtime: %.3f s" % (t1 - t0))

View File

@ -1,10 +1,9 @@
#!/usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import print_function
import os
import sys
import time
@ -73,7 +72,7 @@ aa('--efSearch', default='', type=str,
args = parser.parse_args()
print "args:", args
print("args:", args)
######################################################
@ -93,7 +92,7 @@ nb, d = xb.shape
if os.path.exists(args.indexfile):
print "reading", args.indexfile
print("reading", args.indexfile)
index = faiss.read_index(args.indexfile)
if isinstance(index, faiss.IndexPreTransform):
@ -108,7 +107,7 @@ if os.path.exists(args.indexfile):
else:
print "build index, key=", args.indexkey
print("build index, key=", args.indexkey)
index = faiss.index_factory(d, args.indexkey)
@ -127,29 +126,29 @@ else:
index_hnsw.storage.verbose = True
if args.M0 != -1:
print "set level 0 nb of neighbors to", args.M0
print("set level 0 nb of neighbors to", args.M0)
hnsw.set_nb_neighbors(0, args.M0)
xt2 = sanitize(xt[:args.maxtrain])
assert np.all(np.isfinite(xt2))
print "train, size", xt.shape
print("train, size", xt.shape)
t0 = time.time()
index.train(xt2)
print " train in %.3f s" % (time.time() - t0)
print(" train in %.3f s" % (time.time() - t0))
print "adding"
print("adding")
t0 = time.time()
if args.add_bs == -1:
index.add(sanitize(xb))
else:
for i0 in range(0, nb, args.add_bs):
i1 = min(nb, i0 + args.add_bs)
print " adding %d:%d / %d" % (i0, i1, nb)
print(" adding %d:%d / %d" % (i0, i1, nb))
index.add(sanitize(xb[i0:i1]))
print " add in %.3f s" % (time.time() - t0)
print "storing", args.indexfile
print(" add in %.3f s" % (time.time() - t0))
print("storing", args.indexfile)
faiss.write_index(index, args.indexfile)
@ -158,18 +157,18 @@ else:
######################################################
if args.beta_centroids:
print "reordering links"
print("reordering links")
index_hnsw.reorder_links()
if os.path.exists(args.beta_centroids):
print "load", args.beta_centroids
print("load", args.beta_centroids)
beta_centroids = np.load(args.beta_centroids)
nsq, k, M1 = beta_centroids.shape
assert M1 == hnsw.nb_neighbors(0) + 1
rfn = faiss.ReconstructFromNeighbors(index_hnsw, k, nsq)
else:
print "train beta centroids"
print("train beta centroids")
rfn = faiss.ReconstructFromNeighbors(
index_hnsw, args.beta_k, args.beta_nsq)
@ -178,7 +177,7 @@ if args.beta_centroids:
beta_centroids = neighbor_codec.train_beta_codebook(
rfn, xb_full, niter=args.beta_niter)
print " storing", args.beta_centroids
print(" storing", args.beta_centroids)
np.save(args.beta_centroids, beta_centroids)
@ -189,28 +188,28 @@ if args.beta_centroids:
if rfn.k == 1:
pass # no codes to take care of
elif os.path.exists(args.neigh_recons_codes):
print "loading neigh codes", args.neigh_recons_codes
print("loading neigh codes", args.neigh_recons_codes)
codes = np.load(args.neigh_recons_codes)
assert codes.size == rfn.code_size * index.ntotal
faiss.copy_array_to_vector(codes.astype('uint8'),
rfn.codes)
rfn.ntotal = index.ntotal
else:
print "encoding neigh codes"
print("encoding neigh codes")
t0 = time.time()
bs = 1000000 if args.add_bs == -1 else args.add_bs
for i0 in range(0, nb, bs):
i1 = min(i0 + bs, nb)
print " encode %d:%d / %d [%.3f s]\r" % (
i0, i1, nb, time.time() - t0),
print(" encode %d:%d / %d [%.3f s]\r" % (
i0, i1, nb, time.time() - t0), end=' ')
sys.stdout.flush()
xbatch = vec_transform(sanitize(xb[i0:i1]))
rfn.add_codes(i1 - i0, faiss.swig_ptr(xbatch))
print
print()
print "storing %s" % args.neigh_recons_codes
print("storing %s" % args.neigh_recons_codes)
codes = faiss.vector_to_array(rfn.codes)
np.save(args.neigh_recons_codes, codes)
@ -219,13 +218,13 @@ if args.beta_centroids:
######################################################
if args.exhaustive:
print "exhaustive evaluation"
print("exhaustive evaluation")
xq_tr = vec_transform(sanitize(xq))
index2 = faiss.IndexFlatL2(index_hnsw.d)
accu_recons_error = 0.0
if faiss.get_num_gpus() > 0:
print "do eval on GPU"
print("do eval on GPU")
co = faiss.GpuMultipleClonerOptions()
co.shard = False
index2 = faiss.index_cpu_to_all_gpus(index2, co)
@ -236,7 +235,7 @@ if args.exhaustive:
bs = 500000
for i0 in range(0, nb, bs):
i1 = min(nb, i0 + bs)
print ' handling batch %d:%d' % (i0, i1)
print(' handling batch %d:%d' % (i0, i1))
xb_recons = np.empty(
(i1 - i0, index_hnsw.d), dtype='float32')
@ -254,9 +253,9 @@ if args.exhaustive:
rh.finalize()
del index2
t1 = time.time()
print "done in %.3f s" % (t1 - t0)
print "total reconstruction error: ", accu_recons_error
print "eval retrieval:"
print("done in %.3f s" % (t1 - t0))
print("total reconstruction error: ", accu_recons_error)
print("eval retrieval:")
datasets.evaluate_DI(rh.D, rh.I, gt)
@ -276,12 +275,12 @@ def get_neighbors(hnsw, i, level):
xq = sanitize(xq)
if args.searchthreads != -1:
print "Setting nb of threads to", args.searchthreads
print("Setting nb of threads to", args.searchthreads)
faiss.omp_set_num_threads(args.searchthreads)
if gt is None:
print "no valid groundtruth -- exit"
print("no valid groundtruth -- exit")
sys.exit()
@ -292,13 +291,13 @@ efSearchs = [int(x) for x in args.efSearch.split(',')]
for k_reorder in k_reorders:
if index_hnsw.reconstruct_from_neighbors:
print "setting k_reorder=%d" % k_reorder
print("setting k_reorder=%d" % k_reorder)
index_hnsw.reconstruct_from_neighbors.k_reorder = k_reorder
for efSearch in efSearchs:
print "efSearch=%-4d" % efSearch,
print("efSearch=%-4d" % efSearch, end=' ')
hnsw.efSearch = efSearch
hnsw_stats.reset()
datasets.evaluate(xq, gt, index, k=args.k, endl=False)
print "ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder)
print("ndis %d nreorder %d" % (hnsw_stats.ndis, hnsw_stats.nreorder))

View File

@ -8,6 +8,7 @@
"""
Common functions to load datasets and compute their ground-truth
"""
from __future__ import print_function
import time
import numpy as np
@ -102,7 +103,7 @@ class ResultHeap:
def compute_GT_sliced(xb, xq, k):
print "compute GT"
print("compute GT")
t0 = time.time()
nb, d = xb.shape
nq, d = xq.shape
@ -121,24 +122,24 @@ def compute_GT_sliced(xb, xq, k):
D, I = db_gt.search(xqs, k)
rh.add_batch_result(D, I, i0)
db_gt.reset()
print "\r %d/%d, %.3f s" % (i0, nb, time.time() - t0),
print("\r %d/%d, %.3f s" % (i0, nb, time.time() - t0), end=' ')
sys.stdout.flush()
print
print()
rh.finalize()
gt_I = rh.I
print "GT time: %.3f s" % (time.time() - t0)
print("GT time: %.3f s" % (time.time() - t0))
return gt_I
def do_compute_gt(xb, xq, k):
print "computing GT"
print("computing GT")
nb, d = xb.shape
index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(d))
if nb < 100 * 1000:
print " add"
print(" add")
index.add(np.ascontiguousarray(xb, dtype='float32'))
print " search"
print(" search")
D, I = index.search(np.ascontiguousarray(xq, dtype='float32'), k)
else:
I = compute_GT_sliced(xb, xq, k)
@ -148,7 +149,7 @@ def do_compute_gt(xb, xq, k):
def load_data(dataset='deep1M', compute_gt=False):
print "load data", dataset
print("load data", dataset)
if dataset == 'sift1M':
basedir = simdir + 'sift1M/'
@ -190,7 +191,7 @@ def load_data(dataset='deep1M', compute_gt=False):
gt_fname = basedir + "%s_groundtruth.ivecs" % dataset
if compute_gt:
gt = do_compute_gt(xb, xq, 100)
print "store", gt_fname
print("store", gt_fname)
ivecs_write(gt_fname, gt)
gt = ivecs_read(gt_fname)
@ -198,8 +199,8 @@ def load_data(dataset='deep1M', compute_gt=False):
else:
assert False
print "dataset %s sizes: B %s Q %s T %s" % (
dataset, xb.shape, xq.shape, xt.shape)
print("dataset %s sizes: B %s Q %s T %s" % (
dataset, xb.shape, xq.shape, xt.shape))
return xt, xb, xq, gt
@ -214,7 +215,7 @@ def evaluate_DI(D, I, gt):
rank = 1
while rank <= k:
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
print "R@%d: %.4f" % (rank, recall),
print("R@%d: %.4f" % (rank, recall), end=' ')
rank *= 10
@ -223,13 +224,13 @@ def evaluate(xq, gt, index, k=100, endl=True):
D, I = index.search(xq, k)
t1 = time.time()
nq = xq.shape[0]
print "\t %8.4f ms per query, " % (
(t1 - t0) * 1000.0 / nq),
print("\t %8.4f ms per query, " % (
(t1 - t0) * 1000.0 / nq), end=' ')
rank = 1
while rank <= k:
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
print "R@%d: %.4f" % (rank, recall),
print("R@%d: %.4f" % (rank, recall), end=' ')
rank *= 10
if endl:
print
print()
return D, I

View File

@ -1,5 +1,3 @@
#! /usr/bin/env python2
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
@ -10,6 +8,7 @@ This is the training code for the link and code. Especially the
neighbors_kmeans function implements the EM-algorithm to find the
appropriate weightings and cluster them.
"""
from __future__ import print_function
import time
import numpy as np
@ -57,7 +56,7 @@ def train_kmeans(x, k, ngpu, max_points_per_centroid=256):
centroids = faiss.vector_float_to_array(clus.centroids)
obj = faiss.vector_float_to_array(clus.obj)
print "final objective: %.4g" % obj[-1]
print("final objective: %.4g" % obj[-1])
return centroids.reshape(k, d)
@ -91,7 +90,7 @@ def regress_from_neighbors (x, x_coded, Inn):
(N, knn) = get_Inn_shape(Inn)
betas = np.zeros((N,knn))
t0 = time.time()
for i in xrange (N):
for i in range (N):
xi = x[i,:]
NNi = get_neighbor_table(x_coded, Inn, i)
betas[i,:] = np.linalg.lstsq(NNi.transpose(), xi, rcond=0.01)[0]
@ -109,7 +108,7 @@ def regress_opt_beta (x, x_coded, Inn):
# construct the linear system to be solved
X = np.zeros ((d*N))
Y = np.zeros ((d*N, knn))
for i in xrange (N):
for i in range (N):
X[i*d:(i+1)*d] = x[i,:]
neighbor_table = get_neighbor_table(x_coded, Inn, i)
Y[i*d:(i+1)*d, :] = neighbor_table.transpose()
@ -125,7 +124,7 @@ def assign_beta (beta_centroids, x, x_coded, Inn, verbose=True):
(N, knn) = Inn.shape
x_ibeta = np.zeros ((N), dtype='int32')
t0= time.time()
for i in xrange (N):
for i in range (N):
NNi = x_coded[Inn[i,:]]
# Consider all possible betas for the encoding and compute the
# encoding error
@ -145,7 +144,7 @@ def recons_from_neighbors (beta_centroids, x_ibeta, x_coded, Inn):
(N, knn) = Inn.shape
x_rec = np.zeros(x_coded.shape)
t0= time.time()
for i in xrange (N):
for i in range (N):
NNi = x_coded[Inn[i,:]]
x_rec[i, :] = np.dot (beta_centroids[x_ibeta[i]], NNi)
if i % (N / 10) == 0:
@ -165,16 +164,16 @@ def neighbors_kmeans (x, x_coded, Inn, K, ngpus=1, niter=5):
rs = np.random.RandomState()
for iter in range(niter):
print 'iter', iter
print('iter', iter)
idx = assign_beta (beta_centroids, x, x_coded, Inn, verbose=False)
hist = np.bincount(idx)
for cl0 in np.where(hist == 0)[0]:
print " cluster %d empty, split" % cl0,
print(" cluster %d empty, split" % cl0, end=' ')
cl1 = idx[np.random.randint(idx.size)]
pos = np.nonzero (idx == cl1)[0]
pos = rs.choice(pos, pos.size / 2)
print " cl %d -> %d + %d" % (cl1, len(pos), hist[cl1] - len(pos))
print(" cl %d -> %d + %d" % (cl1, len(pos), hist[cl1] - len(pos)))
idx[pos] = cl0
hist = np.bincount(idx)
@ -194,7 +193,7 @@ def neighbors_kmeans (x, x_coded, Inn, K, ngpus=1, niter=5):
if residuals.size > 0:
tot_err += residuals.sum()
beta_centroids[k, :] = sol
print ' err=%g' % tot_err
print(' err=%g' % tot_err)
return beta_centroids
@ -226,8 +225,8 @@ def train_beta_codebook(rfn, xb_full, niter=10):
beta_centroids = []
for sq in range(rfn.nsq):
d0, d1 = sq * rfn.dsub, (sq + 1) * rfn.dsub
print "training subquantizer %d/%d on dimensions %d:%d" % (
sq, rfn.nsq, d0, d1)
print("training subquantizer %d/%d on dimensions %d:%d" % (
sq, rfn.nsq, d0, d1))
beta_centroids_i = neighbors_kmeans(
xb_full[:, d0:d1], rfn, (xb_full.shape[0], rfn.M + 1, sq),
rfn.k,