Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
4-GPU.cpp
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdio>
10 #include <cstdlib>
11 #include <cassert>
12 
13 #include <faiss/IndexFlat.h>
14 #include <faiss/gpu/GpuIndexFlat.h>
15 #include <faiss/gpu/GpuIndexIVFFlat.h>
16 #include <faiss/gpu/StandardGpuResources.h>
17 
18 
19 int main() {
20  int d = 64; // dimension
21  int nb = 100000; // database size
22  int nq = 10000; // nb of queries
23 
24  float *xb = new float[d * nb];
25  float *xq = new float[d * nq];
26 
27  for(int i = 0; i < nb; i++) {
28  for(int j = 0; j < d; j++)
29  xb[d * i + j] = drand48();
30  xb[d * i] += i / 1000.;
31  }
32 
33  for(int i = 0; i < nq; i++) {
34  for(int j = 0; j < d; j++)
35  xq[d * i + j] = drand48();
36  xq[d * i] += i / 1000.;
37  }
38 
40 
41  // Using a flat index
42 
43  faiss::gpu::GpuIndexFlatL2 index_flat(&res, d);
44 
45  printf("is_trained = %s\n", index_flat.is_trained ? "true" : "false");
46  index_flat.add(nb, xb); // add vectors to the index
47  printf("ntotal = %ld\n", index_flat.ntotal);
48 
49  int k = 4;
50 
51  { // search xq
52  long *I = new long[k * nq];
53  float *D = new float[k * nq];
54 
55  index_flat.search(nq, xq, k, D, I);
56 
57  // print results
58  printf("I (5 first results)=\n");
59  for(int i = 0; i < 5; i++) {
60  for(int j = 0; j < k; j++)
61  printf("%5ld ", I[i * k + j]);
62  printf("\n");
63  }
64 
65  printf("I (5 last results)=\n");
66  for(int i = nq - 5; i < nq; i++) {
67  for(int j = 0; j < k; j++)
68  printf("%5ld ", I[i * k + j]);
69  printf("\n");
70  }
71 
72  delete [] I;
73  delete [] D;
74  }
75 
76  // Using an IVF index
77 
78  int nlist = 100;
79  faiss::gpu::GpuIndexIVFFlat index_ivf(&res, d, nlist, faiss::METRIC_L2);
80  // here we specify METRIC_L2, by default it performs inner-product search
81 
82  assert(!index_ivf.is_trained);
83  index_ivf.train(nb, xb);
84  assert(index_ivf.is_trained);
85  index_ivf.add(nb, xb); // add vectors to the index
86 
87  printf("is_trained = %s\n", index_ivf.is_trained ? "true" : "false");
88  printf("ntotal = %ld\n", index_ivf.ntotal);
89 
90  { // search xq
91  long *I = new long[k * nq];
92  float *D = new float[k * nq];
93 
94  index_ivf.search(nq, xq, k, D, I);
95 
96  // print results
97  printf("I (5 first results)=\n");
98  for(int i = 0; i < 5; i++) {
99  for(int j = 0; j < k; j++)
100  printf("%5ld ", I[i * k + j]);
101  printf("\n");
102  }
103 
104  printf("I (5 last results)=\n");
105  for(int i = nq - 5; i < nq; i++) {
106  for(int j = 0; j < k; j++)
107  printf("%5ld ", I[i * k + j]);
108  printf("\n");
109  }
110 
111  delete [] I;
112  delete [] D;
113  }
114 
115 
116  delete [] xb;
117  delete [] xq;
118 
119  return 0;
120 }