Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PerfSelect.cu
1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 
9 #include "../utils/DeviceDefs.cuh"
10 #include "../utils/DeviceUtils.h"
11 #include "../utils/BlockSelectKernel.cuh"
12 #include "../utils/WarpSelectKernel.cuh"
13 #include "../utils/HostTensor.cuh"
14 #include "../utils/DeviceTensor.cuh"
15 #include "../test/TestUtils.h"
16 #include <algorithm>
17 #include <gflags/gflags.h>
18 #include <gtest/gtest.h>
19 #include <sstream>
20 #include <unordered_map>
21 #include <vector>
22 
23 DEFINE_int32(rows, 10000, "rows in matrix");
24 DEFINE_int32(cols, 40000, "cols in matrix");
25 DEFINE_int32(k, 100, "k");
26 DEFINE_bool(dir, false, "direction of sort");
27 DEFINE_bool(warp, false, "warp select");
28 DEFINE_int32(iter, 5, "iterations to run");
29 DEFINE_bool(k_powers, false, "test k powers of 2 from 1 -> max k");
30 
31 int main(int argc, char** argv) {
32  gflags::ParseCommandLineFlags(&argc, &argv, true);
33 
34  std::vector<float> v = faiss::gpu::randVecs(FLAGS_rows, FLAGS_cols);
35  faiss::gpu::HostTensor<float, 2, true> hostVal({FLAGS_rows, FLAGS_cols});
36 
37  for (int r = 0; r < FLAGS_rows; ++r) {
38  for (int c = 0; c < FLAGS_cols; ++c) {
39  hostVal[r][c] = v[r * FLAGS_cols + c];
40  }
41  }
42 
43  // Select top-k on GPU
45 
46  int startK = FLAGS_k;
47  int limitK = FLAGS_k;
48 
49  if (FLAGS_k_powers) {
50  startK = 1;
51  limitK = GPU_MAX_SELECTION_K;
52  }
53 
54  for (int k = startK; k <= limitK; k *= 2) {
55  faiss::gpu::DeviceTensor<float, 2, true> gpuOutVal({FLAGS_rows, k});
56  faiss::gpu::DeviceTensor<int, 2, true> gpuOutInd({FLAGS_rows, k});
57 
58  for (int i = 0; i < FLAGS_iter; ++i) {
59  if (FLAGS_warp) {
60  faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd,
61  FLAGS_dir, k, 0);
62  } else {
63  faiss::gpu::runBlockSelect(gpuVal, gpuOutVal, gpuOutInd,
64  FLAGS_dir, k, 0);
65  }
66  }
67  }
68 
69  cudaDeviceSynchronize();
70 }