Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
PerfSelect.cu
1 /**
2  * Copyright (c) 2015-present, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD+Patents license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Copyright 2004-present Facebook. All Rights Reserved.
10 
11 #include "../utils/DeviceUtils.h"
12 #include "../utils/BlockSelectKernel.cuh"
13 #include "../utils/WarpSelectKernel.cuh"
14 #include "../utils/HostTensor.cuh"
15 #include "../utils/DeviceTensor.cuh"
16 #include "../test/TestUtils.h"
17 #include <algorithm>
18 #include <gflags/gflags.h>
19 #include <gtest/gtest.h>
20 #include <sstream>
21 #include <unordered_map>
22 #include <vector>
23 
24 DEFINE_int32(rows, 10000, "rows in matrix");
25 DEFINE_int32(cols, 40000, "cols in matrix");
26 DEFINE_int32(k, 100, "k");
27 DEFINE_bool(dir, false, "direction of sort");
28 DEFINE_bool(warp, false, "warp select");
29 DEFINE_int32(iter, 5, "iterations to run");
30 DEFINE_bool(k_powers, false, "test k powers of 2 from 1 -> 1024");
31 
32 int main(int argc, char** argv) {
33  gflags::ParseCommandLineFlags(&argc, &argv, true);
34 
35  std::vector<float> v = faiss::gpu::randVecs(FLAGS_rows, FLAGS_cols);
36  faiss::gpu::HostTensor<float, 2, true> hostVal({FLAGS_rows, FLAGS_cols});
37 
38  for (int r = 0; r < FLAGS_rows; ++r) {
39  for (int c = 0; c < FLAGS_cols; ++c) {
40  hostVal[r][c] = v[r * FLAGS_cols + c];
41  }
42  }
43 
44  // Select top-k on GPU
46 
47  // enough space for any k
48  faiss::gpu::DeviceTensor<float, 2, true> gpuOutVal({FLAGS_rows, 1024});
49  faiss::gpu::DeviceTensor<int, 2, true> gpuOutInd({FLAGS_rows, 1024});
50 
51  int startK = FLAGS_k;
52  int limitK = FLAGS_k;
53 
54  if (FLAGS_k_powers) {
55  startK = 1;
56  limitK = 1024;
57  }
58 
59  for (int k = startK; k <= limitK; k *= 2) {
60  for (int i = 0; i < FLAGS_iter; ++i) {
61  if (FLAGS_warp) {
62  faiss::gpu::runWarpSelect(gpuVal, gpuOutVal, gpuOutInd,
63  FLAGS_dir, k, 0);
64  } else {
65  faiss::gpu::runBlockSelect(gpuVal, gpuOutVal, gpuOutInd,
66  FLAGS_dir, k, 0);
67  }
68  }
69  }
70 
71  cudaDeviceSynchronize();
72 }