Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PerfSelect.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 
10 #include "../utils/DeviceUtils.h"
11 #include "../utils/BlockSelectKernel.cuh"
12 #include "../utils/WarpSelectKernel.cuh"
13 #include "../utils/HostTensor.cuh"
14 #include "../utils/DeviceTensor.cuh"
15 #include "../test/TestUtils.h"
16 #include <algorithm>
17 #include <gflags/gflags.h>
18 #include <gtest/gtest.h>
19 #include <sstream>
20 #include <unordered_map>
21 #include <vector>
22 
23 DEFINE_int32(rows, 10000, "rows in matrix");
24 DEFINE_int32(cols, 40000, "cols in matrix");
25 DEFINE_int32(k, 100, "k");
26 DEFINE_bool(dir, false, "direction of sort");
27 DEFINE_bool(warp, false, "warp select");
28 DEFINE_int32(iter, 5, "iterations to run");
29 DEFINE_bool(k_powers, false, "test k powers of 2 from 1 -> 1024");
30 
31 int main(int argc, char** argv) {
32  gflags::ParseCommandLineFlags(&argc, &argv, true);
33 
34  std::vector<float> v = faiss::gpu::randVecs(FLAGS_rows, FLAGS_cols);
35  faiss::gpu::HostTensor<float, 2, true> hostVal({FLAGS_rows, FLAGS_cols});
36 
37  for (int r = 0; r < FLAGS_rows; ++r) {
38  for (int c = 0; c < FLAGS_cols; ++c) {
39  hostVal[r][c] = v[r * FLAGS_cols + c];
40  }
41  }
42 
43  // Select top-k on GPU
45 
46  // enough space for any k
47  faiss::gpu::DeviceTensor<float, 2, true> gpuOutVal({FLAGS_rows, 1024});
48  faiss::gpu::DeviceTensor<int, 2, true> gpuOutInd({FLAGS_rows, 1024});
49 
50  int startK = FLAGS_k;
51  int limitK = FLAGS_k;
52 
53  if (FLAGS_k_powers) {
54  startK = 1;
55  limitK = 1024;
56  }
57 
58  for (int k = startK; k <= limitK; k *= 2) {
59  for (int i = 0; i < FLAGS_iter; ++i) {
60  if (FLAGS_warp) {
61  faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd,
62  FLAGS_dir, k, 0);
63  } else {
64  faiss::gpu::runBlockSelect(gpuVal, gpuOutVal, gpuOutInd,
65  FLAGS_dir, k, 0);
66  }
67  }
68  }
69 
70  cudaDeviceSynchronize();
71 }